Skip to main content

sqlmodel_session/
identity_map.rs

1//! Identity Map pattern for tracking unique object instances per primary key.
2//!
3//! The Identity Map ensures that each database row corresponds to exactly one
4//! object instance within a session. This provides:
5//!
6//! - **Uniqueness**: Same PK always returns the same object reference
7//! - **Cache**: Avoids redundant queries for the same object
8//! - **Consistency**: Changes to an object are visible everywhere it's used
9//!
10//! # Design
11//!
12//! Unlike the simple clone-based approach, this implementation uses `Arc<RwLock<T>>`
13//! to provide true shared references. When you get an object twice with the same PK,
14//! you get references to the same underlying object.
15//!
16//! # Example
17//!
18//! ```ignore
19//! let mut map = IdentityMap::new();
20//!
21//! // Insert a new object
22//! let user_ref = map.insert(user);
23//!
24//! // Get the same object by PK
25//! let user_ref2 = map.get::<User>(&pk_values);
26//!
27//! // Both references point to the same object
28//! assert!(Arc::ptr_eq(&user_ref.unwrap(), &user_ref2.unwrap()));
29//!
30//! // Modifications are visible through both references
31//! user_ref.write().unwrap().name = "Changed".to_string();
32//! assert_eq!(user_ref2.read().unwrap().name, "Changed");
33//! ```
34
35use sqlmodel_core::{Model, Value};
36use std::any::{Any, TypeId};
37use std::collections::HashMap;
38use std::sync::{Arc, RwLock, Weak};
39
40/// Hash a slice of values for use as a primary key identifier.
41fn hash_pk_values(values: &[Value]) -> u64 {
42    use std::collections::hash_map::DefaultHasher;
43    use std::hash::Hasher;
44
45    let mut hasher = DefaultHasher::new();
46    for v in values {
47        hash_single_value(v, &mut hasher);
48    }
49    hasher.finish()
50}
51
52/// Hash a single Value into the hasher.
53fn hash_single_value(v: &Value, hasher: &mut impl std::hash::Hasher) {
54    use std::hash::Hash;
55
56    match v {
57        Value::Null => 0u8.hash(hasher),
58        Value::Bool(b) => {
59            1u8.hash(hasher);
60            b.hash(hasher);
61        }
62        Value::TinyInt(i) => {
63            2u8.hash(hasher);
64            i.hash(hasher);
65        }
66        Value::SmallInt(i) => {
67            3u8.hash(hasher);
68            i.hash(hasher);
69        }
70        Value::Int(i) => {
71            4u8.hash(hasher);
72            i.hash(hasher);
73        }
74        Value::BigInt(i) => {
75            5u8.hash(hasher);
76            i.hash(hasher);
77        }
78        Value::Float(f) => {
79            6u8.hash(hasher);
80            f.to_bits().hash(hasher);
81        }
82        Value::Double(f) => {
83            7u8.hash(hasher);
84            f.to_bits().hash(hasher);
85        }
86        Value::Decimal(s) => {
87            8u8.hash(hasher);
88            s.hash(hasher);
89        }
90        Value::Text(s) => {
91            9u8.hash(hasher);
92            s.hash(hasher);
93        }
94        Value::Bytes(b) => {
95            10u8.hash(hasher);
96            b.hash(hasher);
97        }
98        Value::Date(d) => {
99            11u8.hash(hasher);
100            d.hash(hasher);
101        }
102        Value::Time(t) => {
103            12u8.hash(hasher);
104            t.hash(hasher);
105        }
106        Value::Timestamp(ts) => {
107            13u8.hash(hasher);
108            ts.hash(hasher);
109        }
110        Value::TimestampTz(ts) => {
111            14u8.hash(hasher);
112            ts.hash(hasher);
113        }
114        Value::Uuid(u) => {
115            15u8.hash(hasher);
116            u.hash(hasher);
117        }
118        Value::Json(j) => {
119            16u8.hash(hasher);
120            j.to_string().hash(hasher);
121        }
122        Value::Array(arr) => {
123            17u8.hash(hasher);
124            arr.len().hash(hasher);
125            for item in arr {
126                hash_single_value(item, hasher);
127            }
128        }
129        Value::Default => {
130            18u8.hash(hasher);
131        }
132    }
133}
134
135/// A type-erased entry in the identity map.
136///
137/// This wrapper holds a type-erased `Arc<RwLock<M>>` which can be downcast
138/// to recover the concrete model type.
139struct IdentityEntry {
140    /// Type-erased Arc. Actually stores `Arc<RwLock<M>>` for some M.
141    /// We type-erase the Arc itself so we can return clones of the same Arc.
142    arc: Box<dyn Any + Send + Sync>,
143    /// The primary key values for this entry (stored for debugging/introspection).
144    #[allow(dead_code)]
145    pk_values: Vec<Value>,
146}
147
148/// Identity Map for tracking unique object instances.
149///
150/// The map is keyed by (TypeId, pk_hash) to ensure each model type has its own
151/// namespace, and objects with the same PK return the same reference.
152#[derive(Default)]
153pub struct IdentityMap {
154    /// Map from (TypeId, pk_hash) to the entry.
155    entries: HashMap<(TypeId, u64), IdentityEntry>,
156}
157
158impl IdentityMap {
159    /// Create a new empty identity map.
160    #[must_use]
161    pub fn new() -> Self {
162        Self {
163            entries: HashMap::new(),
164        }
165    }
166
167    /// Insert a model into the identity map.
168    ///
169    /// If an object with the same PK already exists, returns the existing reference
170    /// (the new object is ignored). Otherwise, inserts the new object and returns
171    /// a reference to it.
172    ///
173    /// # Returns
174    ///
175    /// An `Arc<RwLock<M>>` pointing to the object in the map.
176    pub fn insert<M: Model + Send + Sync + 'static>(&mut self, model: M) -> Arc<RwLock<M>> {
177        let pk_values = model.primary_key_value();
178        let pk_hash = hash_pk_values(&pk_values);
179        let type_id = TypeId::of::<M>();
180        let key = (type_id, pk_hash);
181
182        // Check if already exists - return clone of the existing Arc
183        if let Some(entry) = self.entries.get(&key) {
184            if let Some(existing_arc) = entry.arc.downcast_ref::<Arc<RwLock<M>>>() {
185                // Return clone of the same Arc (not a new Arc with cloned value)
186                return Arc::clone(existing_arc);
187            }
188        }
189
190        // Insert new entry - store the Arc itself in type-erased form
191        let arc: Arc<RwLock<M>> = Arc::new(RwLock::new(model));
192        let type_erased: Box<dyn Any + Send + Sync> = Box::new(Arc::clone(&arc));
193
194        self.entries.insert(
195            key,
196            IdentityEntry {
197                arc: type_erased,
198                pk_values,
199            },
200        );
201
202        arc
203    }
204
205    /// Get an object from the identity map by primary key.
206    ///
207    /// # Returns
208    ///
209    /// `Some(Arc<RwLock<M>>)` if found, `None` otherwise.
210    /// The returned Arc is a clone of the stored Arc, so modifications are shared.
211    pub fn get<M: Model + Send + Sync + 'static>(
212        &self,
213        pk_values: &[Value],
214    ) -> Option<Arc<RwLock<M>>> {
215        let pk_hash = hash_pk_values(pk_values);
216        let type_id = TypeId::of::<M>();
217        let key = (type_id, pk_hash);
218
219        let entry = self.entries.get(&key)?;
220
221        // Downcast the type-erased Arc to the concrete type and clone it
222        let arc = entry.arc.downcast_ref::<Arc<RwLock<M>>>()?;
223        Some(Arc::clone(arc))
224    }
225
226    /// Check if an object with the given PK exists in the map.
227    pub fn contains<M: Model + 'static>(&self, pk_values: &[Value]) -> bool {
228        let pk_hash = hash_pk_values(pk_values);
229        let type_id = TypeId::of::<M>();
230        self.entries.contains_key(&(type_id, pk_hash))
231    }
232
233    /// Check if a model instance exists in the map.
234    pub fn contains_model<M: Model + 'static>(&self, model: &M) -> bool {
235        let pk_values = model.primary_key_value();
236        self.contains::<M>(&pk_values)
237    }
238
239    /// Remove an object from the identity map.
240    ///
241    /// # Returns
242    ///
243    /// `true` if the object was removed, `false` if it wasn't in the map.
244    pub fn remove<M: Model + 'static>(&mut self, pk_values: &[Value]) -> bool {
245        let pk_hash = hash_pk_values(pk_values);
246        let type_id = TypeId::of::<M>();
247        self.entries.remove(&(type_id, pk_hash)).is_some()
248    }
249
250    /// Remove a model instance from the identity map.
251    pub fn remove_model<M: Model + 'static>(&mut self, model: &M) -> bool {
252        let pk_values = model.primary_key_value();
253        self.remove::<M>(&pk_values)
254    }
255
256    /// Clear all entries from the identity map.
257    pub fn clear(&mut self) {
258        self.entries.clear();
259    }
260
261    /// Get the number of entries in the map.
262    #[must_use]
263    pub fn len(&self) -> usize {
264        self.entries.len()
265    }
266
267    /// Check if the map is empty.
268    #[must_use]
269    pub fn is_empty(&self) -> bool {
270        self.entries.is_empty()
271    }
272
273    /// Get or insert a model into the identity map.
274    ///
275    /// If an object with the same PK already exists, returns a reference to it.
276    /// Otherwise, inserts the new object and returns a reference.
277    ///
278    /// This is useful when you want to either get an existing object or insert
279    /// a new one atomically.
280    pub fn get_or_insert<M: Model + Clone + Send + Sync + 'static>(
281        &mut self,
282        model: M,
283    ) -> Arc<RwLock<M>> {
284        let pk_values = model.primary_key_value();
285
286        // Check if exists first
287        if let Some(existing) = self.get::<M>(&pk_values) {
288            return existing;
289        }
290
291        // Insert and return
292        self.insert(model)
293    }
294
295    /// Update an object in the identity map.
296    ///
297    /// If the object exists, updates it with the new values and returns true.
298    /// If it doesn't exist, returns false.
299    pub fn update<M: Model + Clone + Send + Sync + 'static>(&mut self, model: &M) -> bool {
300        let pk_values = model.primary_key_value();
301        let pk_hash = hash_pk_values(&pk_values);
302        let type_id = TypeId::of::<M>();
303        let key = (type_id, pk_hash);
304
305        if let Some(entry) = self.entries.get(&key) {
306            // Downcast the Box to get the Arc<RwLock<M>>, then write to the RwLock
307            if let Some(arc) = entry.arc.downcast_ref::<Arc<RwLock<M>>>() {
308                let mut guard = arc.write().expect("lock poisoned");
309                *guard = model.clone();
310                return true;
311            }
312        }
313
314        false
315    }
316}
317
318/// Type alias for the boxed type-erased value used in weak identity maps.
319type WeakEntryValue = Weak<RwLock<Box<dyn Any + Send + Sync>>>;
320
321/// A weak-reference based identity map that allows objects to be garbage collected.
322///
323/// This variant uses `Weak<RwLock<>>` instead of `Arc<RwLock<>>`, allowing
324/// objects to be dropped when no external references remain. The map
325/// automatically cleans up stale entries on access.
326#[derive(Default)]
327pub struct WeakIdentityMap {
328    /// Map from (TypeId, pk_hash) to weak reference.
329    entries: HashMap<(TypeId, u64), WeakEntryValue>,
330}
331
332impl WeakIdentityMap {
333    /// Create a new empty weak identity map.
334    #[must_use]
335    pub fn new() -> Self {
336        Self {
337            entries: HashMap::new(),
338        }
339    }
340
341    /// Register an object in the weak identity map.
342    ///
343    /// The map holds a weak reference; when all strong references are dropped,
344    /// the object can be garbage collected.
345    pub fn register<M: Model + 'static>(
346        &mut self,
347        arc: &Arc<RwLock<Box<dyn Any + Send + Sync>>>,
348        pk_values: &[Value],
349    ) {
350        let pk_hash = hash_pk_values(pk_values);
351        let type_id = TypeId::of::<M>();
352        let key = (type_id, pk_hash);
353        self.entries.insert(key, Arc::downgrade(arc));
354    }
355
356    /// Try to get an object from the weak map.
357    ///
358    /// Returns `None` if the object was never registered or has been dropped.
359    pub fn get<M: Model + Clone + Send + Sync + 'static>(
360        &self,
361        pk_values: &[Value],
362    ) -> Option<Arc<RwLock<Box<dyn Any + Send + Sync>>>> {
363        let pk_hash = hash_pk_values(pk_values);
364        let type_id = TypeId::of::<M>();
365        let key = (type_id, pk_hash);
366
367        self.entries.get(&key)?.upgrade()
368    }
369
370    /// Remove stale (dropped) entries from the map.
371    ///
372    /// Call this periodically to clean up memory.
373    pub fn prune(&mut self) {
374        self.entries.retain(|_, weak| weak.strong_count() > 0);
375    }
376
377    /// Clear all entries.
378    pub fn clear(&mut self) {
379        self.entries.clear();
380    }
381
382    /// Get the number of entries (including potentially stale ones).
383    #[must_use]
384    pub fn len(&self) -> usize {
385        self.entries.len()
386    }
387
388    /// Check if the map is empty.
389    #[must_use]
390    pub fn is_empty(&self) -> bool {
391        self.entries.is_empty()
392    }
393}
394
395// ============================================================================
396// Convenience type aliases
397// ============================================================================
398
399/// A reference to an object in the identity map.
400pub type ModelRef<M> = Arc<RwLock<M>>;
401
402/// A guard for reading an object from the identity map.
403pub type ModelReadGuard<'a, M> = std::sync::RwLockReadGuard<'a, M>;
404
405/// A guard for writing to an object in the identity map.
406pub type ModelWriteGuard<'a, M> = std::sync::RwLockWriteGuard<'a, M>;
407
408// ============================================================================
409// Unit Tests
410// ============================================================================
411
412#[cfg(test)]
413#[allow(unsafe_code)]
414mod tests {
415    use super::*;
416    use sqlmodel_core::{FieldInfo, Row, SqlType};
417
418    #[derive(Debug, Clone, PartialEq)]
419    struct TestUser {
420        id: Option<i64>,
421        name: String,
422    }
423
424    impl Model for TestUser {
425        const TABLE_NAME: &'static str = "users";
426        const PRIMARY_KEY: &'static [&'static str] = &["id"];
427
428        fn fields() -> &'static [FieldInfo] {
429            static FIELDS: &[FieldInfo] = &[
430                FieldInfo::new("id", "id", SqlType::BigInt).primary_key(true),
431                FieldInfo::new("name", "name", SqlType::Text),
432            ];
433            FIELDS
434        }
435
436        fn to_row(&self) -> Vec<(&'static str, Value)> {
437            vec![
438                ("id", self.id.map_or(Value::Null, Value::BigInt)),
439                ("name", Value::Text(self.name.clone())),
440            ]
441        }
442
443        fn from_row(row: &Row) -> sqlmodel_core::Result<Self> {
444            Ok(Self {
445                id: row.get_named("id").ok(),
446                name: row.get_named("name")?,
447            })
448        }
449
450        fn primary_key_value(&self) -> Vec<Value> {
451            vec![self.id.map_or(Value::Null, Value::BigInt)]
452        }
453
454        fn is_new(&self) -> bool {
455            self.id.is_none()
456        }
457    }
458
459    // Mark TestUser as Send + Sync for testing
460    unsafe impl Send for TestUser {}
461    unsafe impl Sync for TestUser {}
462
463    #[test]
464    fn test_identity_map_insert_and_get() {
465        let mut map = IdentityMap::new();
466
467        let user = TestUser {
468            id: Some(1),
469            name: "Alice".to_string(),
470        };
471
472        let ref1 = map.insert(user.clone());
473        assert_eq!(ref1.read().unwrap().name, "Alice");
474
475        // Getting by PK should return the same data
476        let ref2 = map.get::<TestUser>(&[Value::BigInt(1)]);
477        assert!(ref2.is_some());
478        assert_eq!(ref2.unwrap().read().unwrap().name, "Alice");
479    }
480
481    #[test]
482    fn test_identity_map_modifications_visible() {
483        let mut map = IdentityMap::new();
484
485        let user = TestUser {
486            id: Some(1),
487            name: "Alice".to_string(),
488        };
489
490        let ref1 = map.insert(user);
491
492        // Modify through ref1
493        ref1.write().unwrap().name = "Bob".to_string();
494
495        // The modification should be visible via update
496        assert!(map.update(&TestUser {
497            id: Some(1),
498            name: "Charlie".to_string(),
499        }));
500
501        // Get again and verify
502        let ref2 = map.get::<TestUser>(&[Value::BigInt(1)]).unwrap();
503        assert_eq!(ref2.read().unwrap().name, "Charlie");
504    }
505
506    #[test]
507    fn test_identity_map_contains() {
508        let mut map = IdentityMap::new();
509
510        let user = TestUser {
511            id: Some(1),
512            name: "Alice".to_string(),
513        };
514
515        assert!(!map.contains::<TestUser>(&[Value::BigInt(1)]));
516
517        map.insert(user.clone());
518
519        assert!(map.contains::<TestUser>(&[Value::BigInt(1)]));
520        assert!(map.contains_model(&user));
521        assert!(!map.contains::<TestUser>(&[Value::BigInt(2)]));
522    }
523
524    #[test]
525    fn test_identity_map_remove() {
526        let mut map = IdentityMap::new();
527
528        let user = TestUser {
529            id: Some(1),
530            name: "Alice".to_string(),
531        };
532
533        map.insert(user.clone());
534        assert!(map.contains::<TestUser>(&[Value::BigInt(1)]));
535
536        assert!(map.remove::<TestUser>(&[Value::BigInt(1)]));
537        assert!(!map.contains::<TestUser>(&[Value::BigInt(1)]));
538
539        // Removing again returns false
540        assert!(!map.remove::<TestUser>(&[Value::BigInt(1)]));
541    }
542
543    #[test]
544    fn test_identity_map_clear() {
545        let mut map = IdentityMap::new();
546
547        map.insert(TestUser {
548            id: Some(1),
549            name: "Alice".to_string(),
550        });
551        map.insert(TestUser {
552            id: Some(2),
553            name: "Bob".to_string(),
554        });
555
556        assert_eq!(map.len(), 2);
557
558        map.clear();
559
560        assert!(map.is_empty());
561        assert_eq!(map.len(), 0);
562    }
563
564    #[test]
565    fn test_identity_map_get_or_insert() {
566        let mut map = IdentityMap::new();
567
568        let user1 = TestUser {
569            id: Some(1),
570            name: "Alice".to_string(),
571        };
572
573        // First call inserts
574        let ref1 = map.get_or_insert(user1.clone());
575        assert_eq!(ref1.read().unwrap().name, "Alice");
576
577        // Second call with same PK returns existing (doesn't update)
578        let user2 = TestUser {
579            id: Some(1),
580            name: "Bob".to_string(),
581        };
582        let ref2 = map.get_or_insert(user2);
583        // Should still be Alice since it was already inserted
584        assert_eq!(ref2.read().unwrap().name, "Alice");
585    }
586
587    #[test]
588    fn test_composite_pk_hashing() {
589        // Test that composite PKs hash correctly
590        let pk1 = vec![Value::BigInt(1), Value::Text("a".to_string())];
591        let pk2 = vec![Value::BigInt(1), Value::Text("a".to_string())];
592        let pk3 = vec![Value::BigInt(1), Value::Text("b".to_string())];
593
594        assert_eq!(hash_pk_values(&pk1), hash_pk_values(&pk2));
595        assert_ne!(hash_pk_values(&pk1), hash_pk_values(&pk3));
596    }
597
598    #[test]
599    fn test_null_pk_handling() {
600        let mut map = IdentityMap::new();
601
602        // Objects with null PKs should still be insertable
603        let user = TestUser {
604            id: None,
605            name: "Anonymous".to_string(),
606        };
607
608        let _ = map.insert(user.clone());
609        assert!(map.contains::<TestUser>(&[Value::Null]));
610    }
611
612    #[test]
613    fn test_different_types_same_pk() {
614        // Define a second model type with same PK structure
615        #[derive(Debug, Clone)]
616        struct TestTeam {
617            id: Option<i64>,
618            name: String,
619        }
620
621        impl Model for TestTeam {
622            const TABLE_NAME: &'static str = "teams";
623            const PRIMARY_KEY: &'static [&'static str] = &["id"];
624
625            fn fields() -> &'static [FieldInfo] {
626                &[]
627            }
628
629            fn to_row(&self) -> Vec<(&'static str, Value)> {
630                vec![]
631            }
632
633            fn from_row(_row: &Row) -> sqlmodel_core::Result<Self> {
634                Ok(Self {
635                    id: None,
636                    name: String::new(),
637                })
638            }
639
640            fn primary_key_value(&self) -> Vec<Value> {
641                vec![self.id.map_or(Value::Null, Value::BigInt)]
642            }
643
644            fn is_new(&self) -> bool {
645                self.id.is_none()
646            }
647        }
648
649        unsafe impl Send for TestTeam {}
650        unsafe impl Sync for TestTeam {}
651
652        let mut map = IdentityMap::new();
653
654        // Insert user with id=1
655        map.insert(TestUser {
656            id: Some(1),
657            name: "Alice".to_string(),
658        });
659
660        // Insert team with id=1 (same PK value, different type)
661        map.insert(TestTeam {
662            id: Some(1),
663            name: "Engineering".to_string(),
664        });
665
666        // Both should exist independently
667        assert!(map.contains::<TestUser>(&[Value::BigInt(1)]));
668        assert!(map.contains::<TestTeam>(&[Value::BigInt(1)]));
669
670        // Getting each returns the correct type
671        let user = map.get::<TestUser>(&[Value::BigInt(1)]).unwrap();
672        assert_eq!(user.read().unwrap().name, "Alice");
673
674        let team = map.get::<TestTeam>(&[Value::BigInt(1)]).unwrap();
675        assert_eq!(team.read().unwrap().name, "Engineering");
676    }
677}