1use sqlmodel_core::{Model, Value};
36use std::any::{Any, TypeId};
37use std::collections::HashMap;
38use std::sync::{Arc, RwLock, Weak};
39
40fn hash_pk_values(values: &[Value]) -> u64 {
42 use std::collections::hash_map::DefaultHasher;
43 use std::hash::Hasher;
44
45 let mut hasher = DefaultHasher::new();
46 for v in values {
47 hash_single_value(v, &mut hasher);
48 }
49 hasher.finish()
50}
51
52fn hash_single_value(v: &Value, hasher: &mut impl std::hash::Hasher) {
54 use std::hash::Hash;
55
56 match v {
57 Value::Null => 0u8.hash(hasher),
58 Value::Bool(b) => {
59 1u8.hash(hasher);
60 b.hash(hasher);
61 }
62 Value::TinyInt(i) => {
63 2u8.hash(hasher);
64 i.hash(hasher);
65 }
66 Value::SmallInt(i) => {
67 3u8.hash(hasher);
68 i.hash(hasher);
69 }
70 Value::Int(i) => {
71 4u8.hash(hasher);
72 i.hash(hasher);
73 }
74 Value::BigInt(i) => {
75 5u8.hash(hasher);
76 i.hash(hasher);
77 }
78 Value::Float(f) => {
79 6u8.hash(hasher);
80 f.to_bits().hash(hasher);
81 }
82 Value::Double(f) => {
83 7u8.hash(hasher);
84 f.to_bits().hash(hasher);
85 }
86 Value::Decimal(s) => {
87 8u8.hash(hasher);
88 s.hash(hasher);
89 }
90 Value::Text(s) => {
91 9u8.hash(hasher);
92 s.hash(hasher);
93 }
94 Value::Bytes(b) => {
95 10u8.hash(hasher);
96 b.hash(hasher);
97 }
98 Value::Date(d) => {
99 11u8.hash(hasher);
100 d.hash(hasher);
101 }
102 Value::Time(t) => {
103 12u8.hash(hasher);
104 t.hash(hasher);
105 }
106 Value::Timestamp(ts) => {
107 13u8.hash(hasher);
108 ts.hash(hasher);
109 }
110 Value::TimestampTz(ts) => {
111 14u8.hash(hasher);
112 ts.hash(hasher);
113 }
114 Value::Uuid(u) => {
115 15u8.hash(hasher);
116 u.hash(hasher);
117 }
118 Value::Json(j) => {
119 16u8.hash(hasher);
120 j.to_string().hash(hasher);
121 }
122 Value::Array(arr) => {
123 17u8.hash(hasher);
124 arr.len().hash(hasher);
125 for item in arr {
126 hash_single_value(item, hasher);
127 }
128 }
129 Value::Default => {
130 18u8.hash(hasher);
131 }
132 }
133}
134
135struct IdentityEntry {
140 arc: Box<dyn Any + Send + Sync>,
143 #[allow(dead_code)]
145 pk_values: Vec<Value>,
146}
147
148#[derive(Default)]
153pub struct IdentityMap {
154 entries: HashMap<(TypeId, u64), IdentityEntry>,
156}
157
158impl IdentityMap {
159 #[must_use]
161 pub fn new() -> Self {
162 Self {
163 entries: HashMap::new(),
164 }
165 }
166
167 pub fn insert<M: Model + Send + Sync + 'static>(&mut self, model: M) -> Arc<RwLock<M>> {
177 let pk_values = model.primary_key_value();
178 let pk_hash = hash_pk_values(&pk_values);
179 let type_id = TypeId::of::<M>();
180 let key = (type_id, pk_hash);
181
182 if let Some(entry) = self.entries.get(&key) {
184 if let Some(existing_arc) = entry.arc.downcast_ref::<Arc<RwLock<M>>>() {
185 return Arc::clone(existing_arc);
187 }
188 }
189
190 let arc: Arc<RwLock<M>> = Arc::new(RwLock::new(model));
192 let type_erased: Box<dyn Any + Send + Sync> = Box::new(Arc::clone(&arc));
193
194 self.entries.insert(
195 key,
196 IdentityEntry {
197 arc: type_erased,
198 pk_values,
199 },
200 );
201
202 arc
203 }
204
205 pub fn get<M: Model + Send + Sync + 'static>(
212 &self,
213 pk_values: &[Value],
214 ) -> Option<Arc<RwLock<M>>> {
215 let pk_hash = hash_pk_values(pk_values);
216 let type_id = TypeId::of::<M>();
217 let key = (type_id, pk_hash);
218
219 let entry = self.entries.get(&key)?;
220
221 let arc = entry.arc.downcast_ref::<Arc<RwLock<M>>>()?;
223 Some(Arc::clone(arc))
224 }
225
226 pub fn contains<M: Model + 'static>(&self, pk_values: &[Value]) -> bool {
228 let pk_hash = hash_pk_values(pk_values);
229 let type_id = TypeId::of::<M>();
230 self.entries.contains_key(&(type_id, pk_hash))
231 }
232
233 pub fn contains_model<M: Model + 'static>(&self, model: &M) -> bool {
235 let pk_values = model.primary_key_value();
236 self.contains::<M>(&pk_values)
237 }
238
239 pub fn remove<M: Model + 'static>(&mut self, pk_values: &[Value]) -> bool {
245 let pk_hash = hash_pk_values(pk_values);
246 let type_id = TypeId::of::<M>();
247 self.entries.remove(&(type_id, pk_hash)).is_some()
248 }
249
250 pub fn remove_model<M: Model + 'static>(&mut self, model: &M) -> bool {
252 let pk_values = model.primary_key_value();
253 self.remove::<M>(&pk_values)
254 }
255
256 pub fn clear(&mut self) {
258 self.entries.clear();
259 }
260
261 #[must_use]
263 pub fn len(&self) -> usize {
264 self.entries.len()
265 }
266
267 #[must_use]
269 pub fn is_empty(&self) -> bool {
270 self.entries.is_empty()
271 }
272
273 pub fn get_or_insert<M: Model + Clone + Send + Sync + 'static>(
281 &mut self,
282 model: M,
283 ) -> Arc<RwLock<M>> {
284 let pk_values = model.primary_key_value();
285
286 if let Some(existing) = self.get::<M>(&pk_values) {
288 return existing;
289 }
290
291 self.insert(model)
293 }
294
295 pub fn update<M: Model + Clone + Send + Sync + 'static>(&mut self, model: &M) -> bool {
300 let pk_values = model.primary_key_value();
301 let pk_hash = hash_pk_values(&pk_values);
302 let type_id = TypeId::of::<M>();
303 let key = (type_id, pk_hash);
304
305 if let Some(entry) = self.entries.get(&key) {
306 if let Some(arc) = entry.arc.downcast_ref::<Arc<RwLock<M>>>() {
308 let mut guard = arc.write().expect("lock poisoned");
309 *guard = model.clone();
310 return true;
311 }
312 }
313
314 false
315 }
316}
317
318type WeakEntryValue = Weak<RwLock<Box<dyn Any + Send + Sync>>>;
320
321#[derive(Default)]
327pub struct WeakIdentityMap {
328 entries: HashMap<(TypeId, u64), WeakEntryValue>,
330}
331
332impl WeakIdentityMap {
333 #[must_use]
335 pub fn new() -> Self {
336 Self {
337 entries: HashMap::new(),
338 }
339 }
340
341 pub fn register<M: Model + 'static>(
346 &mut self,
347 arc: &Arc<RwLock<Box<dyn Any + Send + Sync>>>,
348 pk_values: &[Value],
349 ) {
350 let pk_hash = hash_pk_values(pk_values);
351 let type_id = TypeId::of::<M>();
352 let key = (type_id, pk_hash);
353 self.entries.insert(key, Arc::downgrade(arc));
354 }
355
356 pub fn get<M: Model + Clone + Send + Sync + 'static>(
360 &self,
361 pk_values: &[Value],
362 ) -> Option<Arc<RwLock<Box<dyn Any + Send + Sync>>>> {
363 let pk_hash = hash_pk_values(pk_values);
364 let type_id = TypeId::of::<M>();
365 let key = (type_id, pk_hash);
366
367 self.entries.get(&key)?.upgrade()
368 }
369
370 pub fn prune(&mut self) {
374 self.entries.retain(|_, weak| weak.strong_count() > 0);
375 }
376
377 pub fn clear(&mut self) {
379 self.entries.clear();
380 }
381
382 #[must_use]
384 pub fn len(&self) -> usize {
385 self.entries.len()
386 }
387
388 #[must_use]
390 pub fn is_empty(&self) -> bool {
391 self.entries.is_empty()
392 }
393}
394
395pub type ModelRef<M> = Arc<RwLock<M>>;
401
402pub type ModelReadGuard<'a, M> = std::sync::RwLockReadGuard<'a, M>;
404
405pub type ModelWriteGuard<'a, M> = std::sync::RwLockWriteGuard<'a, M>;
407
408#[cfg(test)]
413#[allow(unsafe_code)]
414mod tests {
415 use super::*;
416 use sqlmodel_core::{FieldInfo, Row, SqlType};
417
418 #[derive(Debug, Clone, PartialEq)]
419 struct TestUser {
420 id: Option<i64>,
421 name: String,
422 }
423
424 impl Model for TestUser {
425 const TABLE_NAME: &'static str = "users";
426 const PRIMARY_KEY: &'static [&'static str] = &["id"];
427
428 fn fields() -> &'static [FieldInfo] {
429 static FIELDS: &[FieldInfo] = &[
430 FieldInfo::new("id", "id", SqlType::BigInt).primary_key(true),
431 FieldInfo::new("name", "name", SqlType::Text),
432 ];
433 FIELDS
434 }
435
436 fn to_row(&self) -> Vec<(&'static str, Value)> {
437 vec![
438 ("id", self.id.map_or(Value::Null, Value::BigInt)),
439 ("name", Value::Text(self.name.clone())),
440 ]
441 }
442
443 fn from_row(row: &Row) -> sqlmodel_core::Result<Self> {
444 Ok(Self {
445 id: row.get_named("id").ok(),
446 name: row.get_named("name")?,
447 })
448 }
449
450 fn primary_key_value(&self) -> Vec<Value> {
451 vec![self.id.map_or(Value::Null, Value::BigInt)]
452 }
453
454 fn is_new(&self) -> bool {
455 self.id.is_none()
456 }
457 }
458
459 unsafe impl Send for TestUser {}
461 unsafe impl Sync for TestUser {}
462
463 #[test]
464 fn test_identity_map_insert_and_get() {
465 let mut map = IdentityMap::new();
466
467 let user = TestUser {
468 id: Some(1),
469 name: "Alice".to_string(),
470 };
471
472 let ref1 = map.insert(user.clone());
473 assert_eq!(ref1.read().unwrap().name, "Alice");
474
475 let ref2 = map.get::<TestUser>(&[Value::BigInt(1)]);
477 assert!(ref2.is_some());
478 assert_eq!(ref2.unwrap().read().unwrap().name, "Alice");
479 }
480
481 #[test]
482 fn test_identity_map_modifications_visible() {
483 let mut map = IdentityMap::new();
484
485 let user = TestUser {
486 id: Some(1),
487 name: "Alice".to_string(),
488 };
489
490 let ref1 = map.insert(user);
491
492 ref1.write().unwrap().name = "Bob".to_string();
494
495 assert!(map.update(&TestUser {
497 id: Some(1),
498 name: "Charlie".to_string(),
499 }));
500
501 let ref2 = map.get::<TestUser>(&[Value::BigInt(1)]).unwrap();
503 assert_eq!(ref2.read().unwrap().name, "Charlie");
504 }
505
506 #[test]
507 fn test_identity_map_contains() {
508 let mut map = IdentityMap::new();
509
510 let user = TestUser {
511 id: Some(1),
512 name: "Alice".to_string(),
513 };
514
515 assert!(!map.contains::<TestUser>(&[Value::BigInt(1)]));
516
517 map.insert(user.clone());
518
519 assert!(map.contains::<TestUser>(&[Value::BigInt(1)]));
520 assert!(map.contains_model(&user));
521 assert!(!map.contains::<TestUser>(&[Value::BigInt(2)]));
522 }
523
524 #[test]
525 fn test_identity_map_remove() {
526 let mut map = IdentityMap::new();
527
528 let user = TestUser {
529 id: Some(1),
530 name: "Alice".to_string(),
531 };
532
533 map.insert(user.clone());
534 assert!(map.contains::<TestUser>(&[Value::BigInt(1)]));
535
536 assert!(map.remove::<TestUser>(&[Value::BigInt(1)]));
537 assert!(!map.contains::<TestUser>(&[Value::BigInt(1)]));
538
539 assert!(!map.remove::<TestUser>(&[Value::BigInt(1)]));
541 }
542
543 #[test]
544 fn test_identity_map_clear() {
545 let mut map = IdentityMap::new();
546
547 map.insert(TestUser {
548 id: Some(1),
549 name: "Alice".to_string(),
550 });
551 map.insert(TestUser {
552 id: Some(2),
553 name: "Bob".to_string(),
554 });
555
556 assert_eq!(map.len(), 2);
557
558 map.clear();
559
560 assert!(map.is_empty());
561 assert_eq!(map.len(), 0);
562 }
563
564 #[test]
565 fn test_identity_map_get_or_insert() {
566 let mut map = IdentityMap::new();
567
568 let user1 = TestUser {
569 id: Some(1),
570 name: "Alice".to_string(),
571 };
572
573 let ref1 = map.get_or_insert(user1.clone());
575 assert_eq!(ref1.read().unwrap().name, "Alice");
576
577 let user2 = TestUser {
579 id: Some(1),
580 name: "Bob".to_string(),
581 };
582 let ref2 = map.get_or_insert(user2);
583 assert_eq!(ref2.read().unwrap().name, "Alice");
585 }
586
587 #[test]
588 fn test_composite_pk_hashing() {
589 let pk1 = vec![Value::BigInt(1), Value::Text("a".to_string())];
591 let pk2 = vec![Value::BigInt(1), Value::Text("a".to_string())];
592 let pk3 = vec![Value::BigInt(1), Value::Text("b".to_string())];
593
594 assert_eq!(hash_pk_values(&pk1), hash_pk_values(&pk2));
595 assert_ne!(hash_pk_values(&pk1), hash_pk_values(&pk3));
596 }
597
598 #[test]
599 fn test_null_pk_handling() {
600 let mut map = IdentityMap::new();
601
602 let user = TestUser {
604 id: None,
605 name: "Anonymous".to_string(),
606 };
607
608 let _ = map.insert(user.clone());
609 assert!(map.contains::<TestUser>(&[Value::Null]));
610 }
611
612 #[test]
613 fn test_different_types_same_pk() {
614 #[derive(Debug, Clone)]
616 struct TestTeam {
617 id: Option<i64>,
618 name: String,
619 }
620
621 impl Model for TestTeam {
622 const TABLE_NAME: &'static str = "teams";
623 const PRIMARY_KEY: &'static [&'static str] = &["id"];
624
625 fn fields() -> &'static [FieldInfo] {
626 &[]
627 }
628
629 fn to_row(&self) -> Vec<(&'static str, Value)> {
630 vec![]
631 }
632
633 fn from_row(_row: &Row) -> sqlmodel_core::Result<Self> {
634 Ok(Self {
635 id: None,
636 name: String::new(),
637 })
638 }
639
640 fn primary_key_value(&self) -> Vec<Value> {
641 vec![self.id.map_or(Value::Null, Value::BigInt)]
642 }
643
644 fn is_new(&self) -> bool {
645 self.id.is_none()
646 }
647 }
648
649 unsafe impl Send for TestTeam {}
650 unsafe impl Sync for TestTeam {}
651
652 let mut map = IdentityMap::new();
653
654 map.insert(TestUser {
656 id: Some(1),
657 name: "Alice".to_string(),
658 });
659
660 map.insert(TestTeam {
662 id: Some(1),
663 name: "Engineering".to_string(),
664 });
665
666 assert!(map.contains::<TestUser>(&[Value::BigInt(1)]));
668 assert!(map.contains::<TestTeam>(&[Value::BigInt(1)]));
669
670 let user = map.get::<TestUser>(&[Value::BigInt(1)]).unwrap();
672 assert_eq!(user.read().unwrap().name, "Alice");
673
674 let team = map.get::<TestTeam>(&[Value::BigInt(1)]).unwrap();
675 assert_eq!(team.read().unwrap().name, "Engineering");
676 }
677}