use sqlmodel_core::{Model, Value};
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::{Arc, RwLock, Weak};
fn hash_pk_values(values: &[Value]) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::Hasher;
let mut hasher = DefaultHasher::new();
for v in values {
hash_single_value(v, &mut hasher);
}
hasher.finish()
}
fn hash_single_value(v: &Value, hasher: &mut impl std::hash::Hasher) {
use std::hash::Hash;
match v {
Value::Null => 0u8.hash(hasher),
Value::Bool(b) => {
1u8.hash(hasher);
b.hash(hasher);
}
Value::TinyInt(i) => {
2u8.hash(hasher);
i.hash(hasher);
}
Value::SmallInt(i) => {
3u8.hash(hasher);
i.hash(hasher);
}
Value::Int(i) => {
4u8.hash(hasher);
i.hash(hasher);
}
Value::BigInt(i) => {
5u8.hash(hasher);
i.hash(hasher);
}
Value::Float(f) => {
6u8.hash(hasher);
f.to_bits().hash(hasher);
}
Value::Double(f) => {
7u8.hash(hasher);
f.to_bits().hash(hasher);
}
Value::Decimal(s) => {
8u8.hash(hasher);
s.hash(hasher);
}
Value::Text(s) => {
9u8.hash(hasher);
s.hash(hasher);
}
Value::Bytes(b) => {
10u8.hash(hasher);
b.hash(hasher);
}
Value::Date(d) => {
11u8.hash(hasher);
d.hash(hasher);
}
Value::Time(t) => {
12u8.hash(hasher);
t.hash(hasher);
}
Value::Timestamp(ts) => {
13u8.hash(hasher);
ts.hash(hasher);
}
Value::TimestampTz(ts) => {
14u8.hash(hasher);
ts.hash(hasher);
}
Value::Uuid(u) => {
15u8.hash(hasher);
u.hash(hasher);
}
Value::Json(j) => {
16u8.hash(hasher);
j.to_string().hash(hasher);
}
Value::Array(arr) => {
17u8.hash(hasher);
arr.len().hash(hasher);
for item in arr {
hash_single_value(item, hasher);
}
}
Value::Default => {
18u8.hash(hasher);
}
}
}
struct IdentityEntry {
arc: Box<dyn Any + Send + Sync>,
#[allow(dead_code)]
pk_values: Vec<Value>,
}
#[derive(Default)]
pub struct IdentityMap {
entries: HashMap<(TypeId, u64), IdentityEntry>,
}
impl IdentityMap {
#[must_use]
pub fn new() -> Self {
Self {
entries: HashMap::new(),
}
}
pub fn insert<M: Model + Send + Sync + 'static>(&mut self, model: M) -> Arc<RwLock<M>> {
let pk_values = model.primary_key_value();
let pk_hash = hash_pk_values(&pk_values);
let type_id = TypeId::of::<M>();
let key = (type_id, pk_hash);
if let Some(entry) = self.entries.get(&key) {
if let Some(existing_arc) = entry.arc.downcast_ref::<Arc<RwLock<M>>>() {
return Arc::clone(existing_arc);
}
}
let arc: Arc<RwLock<M>> = Arc::new(RwLock::new(model));
let type_erased: Box<dyn Any + Send + Sync> = Box::new(Arc::clone(&arc));
self.entries.insert(
key,
IdentityEntry {
arc: type_erased,
pk_values,
},
);
arc
}
pub fn get<M: Model + Send + Sync + 'static>(
&self,
pk_values: &[Value],
) -> Option<Arc<RwLock<M>>> {
let pk_hash = hash_pk_values(pk_values);
let type_id = TypeId::of::<M>();
let key = (type_id, pk_hash);
let entry = self.entries.get(&key)?;
let arc = entry.arc.downcast_ref::<Arc<RwLock<M>>>()?;
Some(Arc::clone(arc))
}
pub fn contains<M: Model + 'static>(&self, pk_values: &[Value]) -> bool {
let pk_hash = hash_pk_values(pk_values);
let type_id = TypeId::of::<M>();
self.entries.contains_key(&(type_id, pk_hash))
}
pub fn contains_model<M: Model + 'static>(&self, model: &M) -> bool {
let pk_values = model.primary_key_value();
self.contains::<M>(&pk_values)
}
pub fn remove<M: Model + 'static>(&mut self, pk_values: &[Value]) -> bool {
let pk_hash = hash_pk_values(pk_values);
let type_id = TypeId::of::<M>();
self.entries.remove(&(type_id, pk_hash)).is_some()
}
pub fn remove_model<M: Model + 'static>(&mut self, model: &M) -> bool {
let pk_values = model.primary_key_value();
self.remove::<M>(&pk_values)
}
pub fn clear(&mut self) {
self.entries.clear();
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn get_or_insert<M: Model + Clone + Send + Sync + 'static>(
&mut self,
model: M,
) -> Arc<RwLock<M>> {
let pk_values = model.primary_key_value();
if let Some(existing) = self.get::<M>(&pk_values) {
return existing;
}
self.insert(model)
}
pub fn update<M: Model + Clone + Send + Sync + 'static>(&mut self, model: &M) -> bool {
let pk_values = model.primary_key_value();
let pk_hash = hash_pk_values(&pk_values);
let type_id = TypeId::of::<M>();
let key = (type_id, pk_hash);
if let Some(entry) = self.entries.get(&key) {
if let Some(arc) = entry.arc.downcast_ref::<Arc<RwLock<M>>>() {
let mut guard = arc.write().expect("lock poisoned");
*guard = model.clone();
return true;
}
}
false
}
}
type WeakEntryValue = Weak<RwLock<Box<dyn Any + Send + Sync>>>;
#[derive(Default)]
pub struct WeakIdentityMap {
entries: HashMap<(TypeId, u64), WeakEntryValue>,
}
impl WeakIdentityMap {
#[must_use]
pub fn new() -> Self {
Self {
entries: HashMap::new(),
}
}
pub fn register<M: Model + 'static>(
&mut self,
arc: &Arc<RwLock<Box<dyn Any + Send + Sync>>>,
pk_values: &[Value],
) {
let pk_hash = hash_pk_values(pk_values);
let type_id = TypeId::of::<M>();
let key = (type_id, pk_hash);
self.entries.insert(key, Arc::downgrade(arc));
}
pub fn get<M: Model + Clone + Send + Sync + 'static>(
&self,
pk_values: &[Value],
) -> Option<Arc<RwLock<Box<dyn Any + Send + Sync>>>> {
let pk_hash = hash_pk_values(pk_values);
let type_id = TypeId::of::<M>();
let key = (type_id, pk_hash);
self.entries.get(&key)?.upgrade()
}
pub fn prune(&mut self) {
self.entries.retain(|_, weak| weak.strong_count() > 0);
}
pub fn clear(&mut self) {
self.entries.clear();
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
pub type ModelRef<M> = Arc<RwLock<M>>;
pub type ModelReadGuard<'a, M> = std::sync::RwLockReadGuard<'a, M>;
pub type ModelWriteGuard<'a, M> = std::sync::RwLockWriteGuard<'a, M>;
#[cfg(test)]
#[allow(unsafe_code)]
mod tests {
use super::*;
use sqlmodel_core::{FieldInfo, Row, SqlType};
#[derive(Debug, Clone, PartialEq)]
struct TestUser {
id: Option<i64>,
name: String,
}
impl Model for TestUser {
const TABLE_NAME: &'static str = "users";
const PRIMARY_KEY: &'static [&'static str] = &["id"];
fn fields() -> &'static [FieldInfo] {
static FIELDS: &[FieldInfo] = &[
FieldInfo::new("id", "id", SqlType::BigInt).primary_key(true),
FieldInfo::new("name", "name", SqlType::Text),
];
FIELDS
}
fn to_row(&self) -> Vec<(&'static str, Value)> {
vec![
("id", self.id.map_or(Value::Null, Value::BigInt)),
("name", Value::Text(self.name.clone())),
]
}
fn from_row(row: &Row) -> sqlmodel_core::Result<Self> {
Ok(Self {
id: row.get_named("id").ok(),
name: row.get_named("name")?,
})
}
fn primary_key_value(&self) -> Vec<Value> {
vec![self.id.map_or(Value::Null, Value::BigInt)]
}
fn is_new(&self) -> bool {
self.id.is_none()
}
}
unsafe impl Send for TestUser {}
unsafe impl Sync for TestUser {}
#[test]
fn test_identity_map_insert_and_get() {
let mut map = IdentityMap::new();
let user = TestUser {
id: Some(1),
name: "Alice".to_string(),
};
let ref1 = map.insert(user.clone());
assert_eq!(ref1.read().unwrap().name, "Alice");
let ref2 = map.get::<TestUser>(&[Value::BigInt(1)]);
assert!(ref2.is_some());
assert_eq!(ref2.unwrap().read().unwrap().name, "Alice");
}
#[test]
fn test_identity_map_modifications_visible() {
let mut map = IdentityMap::new();
let user = TestUser {
id: Some(1),
name: "Alice".to_string(),
};
let ref1 = map.insert(user);
ref1.write().unwrap().name = "Bob".to_string();
assert!(map.update(&TestUser {
id: Some(1),
name: "Charlie".to_string(),
}));
let ref2 = map.get::<TestUser>(&[Value::BigInt(1)]).unwrap();
assert_eq!(ref2.read().unwrap().name, "Charlie");
}
#[test]
fn test_identity_map_contains() {
let mut map = IdentityMap::new();
let user = TestUser {
id: Some(1),
name: "Alice".to_string(),
};
assert!(!map.contains::<TestUser>(&[Value::BigInt(1)]));
map.insert(user.clone());
assert!(map.contains::<TestUser>(&[Value::BigInt(1)]));
assert!(map.contains_model(&user));
assert!(!map.contains::<TestUser>(&[Value::BigInt(2)]));
}
#[test]
fn test_identity_map_remove() {
let mut map = IdentityMap::new();
let user = TestUser {
id: Some(1),
name: "Alice".to_string(),
};
map.insert(user.clone());
assert!(map.contains::<TestUser>(&[Value::BigInt(1)]));
assert!(map.remove::<TestUser>(&[Value::BigInt(1)]));
assert!(!map.contains::<TestUser>(&[Value::BigInt(1)]));
assert!(!map.remove::<TestUser>(&[Value::BigInt(1)]));
}
#[test]
fn test_identity_map_clear() {
let mut map = IdentityMap::new();
map.insert(TestUser {
id: Some(1),
name: "Alice".to_string(),
});
map.insert(TestUser {
id: Some(2),
name: "Bob".to_string(),
});
assert_eq!(map.len(), 2);
map.clear();
assert!(map.is_empty());
assert_eq!(map.len(), 0);
}
#[test]
fn test_identity_map_get_or_insert() {
let mut map = IdentityMap::new();
let user1 = TestUser {
id: Some(1),
name: "Alice".to_string(),
};
let ref1 = map.get_or_insert(user1.clone());
assert_eq!(ref1.read().unwrap().name, "Alice");
let user2 = TestUser {
id: Some(1),
name: "Bob".to_string(),
};
let ref2 = map.get_or_insert(user2);
assert_eq!(ref2.read().unwrap().name, "Alice");
}
#[test]
fn test_composite_pk_hashing() {
let pk1 = vec![Value::BigInt(1), Value::Text("a".to_string())];
let pk2 = vec![Value::BigInt(1), Value::Text("a".to_string())];
let pk3 = vec![Value::BigInt(1), Value::Text("b".to_string())];
assert_eq!(hash_pk_values(&pk1), hash_pk_values(&pk2));
assert_ne!(hash_pk_values(&pk1), hash_pk_values(&pk3));
}
#[test]
fn test_null_pk_handling() {
let mut map = IdentityMap::new();
let user = TestUser {
id: None,
name: "Anonymous".to_string(),
};
let _ = map.insert(user.clone());
assert!(map.contains::<TestUser>(&[Value::Null]));
}
#[test]
fn test_different_types_same_pk() {
#[derive(Debug, Clone)]
struct TestTeam {
id: Option<i64>,
name: String,
}
impl Model for TestTeam {
const TABLE_NAME: &'static str = "teams";
const PRIMARY_KEY: &'static [&'static str] = &["id"];
fn fields() -> &'static [FieldInfo] {
&[]
}
fn to_row(&self) -> Vec<(&'static str, Value)> {
vec![]
}
fn from_row(_row: &Row) -> sqlmodel_core::Result<Self> {
Ok(Self {
id: None,
name: String::new(),
})
}
fn primary_key_value(&self) -> Vec<Value> {
vec![self.id.map_or(Value::Null, Value::BigInt)]
}
fn is_new(&self) -> bool {
self.id.is_none()
}
}
unsafe impl Send for TestTeam {}
unsafe impl Sync for TestTeam {}
let mut map = IdentityMap::new();
map.insert(TestUser {
id: Some(1),
name: "Alice".to_string(),
});
map.insert(TestTeam {
id: Some(1),
name: "Engineering".to_string(),
});
assert!(map.contains::<TestUser>(&[Value::BigInt(1)]));
assert!(map.contains::<TestTeam>(&[Value::BigInt(1)]));
let user = map.get::<TestUser>(&[Value::BigInt(1)]).unwrap();
assert_eq!(user.read().unwrap().name, "Alice");
let team = map.get::<TestTeam>(&[Value::BigInt(1)]).unwrap();
assert_eq!(team.read().unwrap().name, "Engineering");
}
}