use std::{
any::TypeId,
collections::HashMap,
sync::{Arc, Mutex},
};
use serde::{Serialize, de::DeserializeOwned};
use crate::{
repository::{RepositoryItem, RepositoryMigrations},
sdk_managed::{Database, DatabaseConfiguration, DatabaseError},
};
#[derive(Clone)]
pub struct MemoryDatabase(Arc<Mutex<HashMap<TypeId, HashMap<String, String>>>>);
impl MemoryDatabase {
pub fn new() -> Self {
MemoryDatabase(Arc::new(Mutex::new(HashMap::new())))
}
}
impl Database for MemoryDatabase {
async fn initialize(
configuration: DatabaseConfiguration,
_migrations: RepositoryMigrations,
) -> Result<Self, DatabaseError> {
let DatabaseConfiguration::Memory = configuration else {
return Err(DatabaseError::UnsupportedConfiguration(configuration));
};
Ok(MemoryDatabase::new())
}
async fn get<T: Serialize + DeserializeOwned + RepositoryItem>(
&self,
key: &str,
) -> Result<Option<T>, DatabaseError> {
let store = self.0.lock().expect("Mutex is not poisoned");
let type_map = store.get(&TypeId::of::<T>());
match type_map.and_then(|m| m.get(key)) {
Some(json) => Ok(Some(serde_json::from_str(json)?)),
None => Ok(None),
}
}
async fn list<T: Serialize + DeserializeOwned + RepositoryItem>(
&self,
) -> Result<Vec<T>, DatabaseError> {
let store = self.0.lock().expect("Mutex is not poisoned");
match store.get(&TypeId::of::<T>()) {
None => Ok(vec![]),
Some(type_map) => {
let mut results = Vec::with_capacity(type_map.len());
for json in type_map.values() {
results.push(serde_json::from_str(json)?);
}
Ok(results)
}
}
}
async fn set<T: Serialize + DeserializeOwned + RepositoryItem>(
&self,
key: &str,
value: T,
) -> Result<(), DatabaseError> {
let json = serde_json::to_string(&value)?;
let mut store = self.0.lock().expect("Mutex is not poisoned");
store
.entry(TypeId::of::<T>())
.or_default()
.insert(key.to_string(), json);
Ok(())
}
async fn set_bulk<T: Serialize + DeserializeOwned + RepositoryItem>(
&self,
values: Vec<(String, T)>,
) -> Result<(), DatabaseError> {
let mut store = self.0.lock().expect("Mutex is not poisoned");
let type_map = store.entry(TypeId::of::<T>()).or_default();
for (key, value) in values {
let json = serde_json::to_string(&value)?;
type_map.insert(key, json);
}
Ok(())
}
async fn remove<T: Serialize + DeserializeOwned + RepositoryItem>(
&self,
key: &str,
) -> Result<(), DatabaseError> {
let mut store = self.0.lock().expect("Mutex is not poisoned");
if let Some(type_map) = store.get_mut(&TypeId::of::<T>()) {
type_map.remove(key);
}
Ok(())
}
async fn remove_bulk<T: Serialize + DeserializeOwned + RepositoryItem>(
&self,
keys: Vec<String>,
) -> Result<(), DatabaseError> {
let mut store = self.0.lock().expect("Mutex is not poisoned");
if let Some(type_map) = store.get_mut(&TypeId::of::<T>()) {
for key in keys {
type_map.remove(&key);
}
}
Ok(())
}
async fn remove_all<T: Serialize + DeserializeOwned + RepositoryItem>(
&self,
) -> Result<(), DatabaseError> {
let mut store = self.0.lock().expect("Mutex is not poisoned");
store.remove(&TypeId::of::<T>());
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::register_repository_item;
#[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
struct TypeA(String);
register_repository_item!(String => TypeA, "MemTypeA");
#[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
struct TypeB(u64);
register_repository_item!(String => TypeB, "MemTypeB");
#[tokio::test]
async fn test_memory_database_get_set_remove() {
let db = MemoryDatabase::new();
assert_eq!(db.get::<TypeA>("key1").await.unwrap(), None);
db.set("key1", TypeA("hello".to_string())).await.unwrap();
assert_eq!(
db.get::<TypeA>("key1").await.unwrap(),
Some(TypeA("hello".to_string()))
);
db.remove::<TypeA>("key1").await.unwrap();
assert_eq!(db.get::<TypeA>("key1").await.unwrap(), None);
}
#[tokio::test]
async fn test_memory_database_type_isolation() {
let db = MemoryDatabase::new();
db.set("key1", TypeA("value_a".to_string())).await.unwrap();
db.set("key1", TypeB(42)).await.unwrap();
assert_eq!(
db.get::<TypeA>("key1").await.unwrap(),
Some(TypeA("value_a".to_string()))
);
assert_eq!(db.get::<TypeB>("key1").await.unwrap(), Some(TypeB(42)));
}
#[tokio::test]
async fn test_memory_database_clone_shares_store() {
let db1 = MemoryDatabase::new();
let db2 = db1.clone();
db1.set("key1", TypeA("shared".to_string())).await.unwrap();
assert_eq!(
db2.get::<TypeA>("key1").await.unwrap(),
Some(TypeA("shared".to_string()))
);
}
#[tokio::test]
async fn test_memory_database_list() {
let db = MemoryDatabase::new();
db.set("a", TypeA("1".to_string())).await.unwrap();
db.set("b", TypeA("2".to_string())).await.unwrap();
db.set("c", TypeB(99)).await.unwrap();
let mut list_a = db.list::<TypeA>().await.unwrap();
list_a.sort_by_key(|x| x.0.clone());
assert_eq!(list_a, vec![TypeA("1".to_string()), TypeA("2".to_string())]);
assert_eq!(db.list::<TypeB>().await.unwrap(), vec![TypeB(99)]);
}
#[tokio::test]
async fn test_memory_database_set_bulk() {
let db = MemoryDatabase::new();
db.set_bulk(vec![
("x".to_string(), TypeA("v1".to_string())),
("y".to_string(), TypeA("v2".to_string())),
])
.await
.unwrap();
assert_eq!(
db.get::<TypeA>("x").await.unwrap(),
Some(TypeA("v1".to_string()))
);
assert_eq!(
db.get::<TypeA>("y").await.unwrap(),
Some(TypeA("v2".to_string()))
);
}
#[tokio::test]
async fn test_memory_database_remove_bulk() {
let db = MemoryDatabase::new();
db.set("a", TypeA("1".to_string())).await.unwrap();
db.set("b", TypeA("2".to_string())).await.unwrap();
db.set("c", TypeA("3".to_string())).await.unwrap();
db.remove_bulk::<TypeA>(vec!["a".to_string(), "b".to_string()])
.await
.unwrap();
assert_eq!(db.get::<TypeA>("a").await.unwrap(), None);
assert_eq!(db.get::<TypeA>("b").await.unwrap(), None);
assert_eq!(
db.get::<TypeA>("c").await.unwrap(),
Some(TypeA("3".to_string()))
);
}
#[tokio::test]
async fn test_memory_database_remove_all() {
let db = MemoryDatabase::new();
db.set("a", TypeA("1".to_string())).await.unwrap();
db.set("b", TypeA("2".to_string())).await.unwrap();
db.set("z", TypeB(5)).await.unwrap();
db.remove_all::<TypeA>().await.unwrap();
assert_eq!(db.list::<TypeA>().await.unwrap(), vec![]);
assert_eq!(db.list::<TypeB>().await.unwrap(), vec![TypeB(5)]);
}
#[tokio::test]
async fn test_memory_database_initialize_memory_config() {
let db = MemoryDatabase::initialize(
DatabaseConfiguration::Memory,
RepositoryMigrations::new(vec![]),
)
.await
.unwrap();
db.set("k", TypeA("v".to_string())).await.unwrap();
assert_eq!(
db.get::<TypeA>("k").await.unwrap(),
Some(TypeA("v".to_string()))
);
}
#[tokio::test]
async fn test_memory_database_initialize_rejects_non_memory_config() {
let result = MemoryDatabase::initialize(
DatabaseConfiguration::Sqlite {
db_name: "ignored".to_string(),
folder_path: std::path::PathBuf::from("/tmp"),
},
RepositoryMigrations::new(vec![]),
)
.await;
assert!(matches!(
result,
Err(DatabaseError::UnsupportedConfiguration(_))
));
}
}