use super::{KeyValueStore, StoreResult, error::StoreError};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[derive(Default, Clone)]
pub struct MemoryStore {
data: Arc<RwLock<HashMap<String, Arc<dyn std::any::Any + Send + Sync>>>>,
}
impl MemoryStore {
pub fn new() -> Self {
Self {
data: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn get_shared<TState: 'static + Send + Sync>(&self, key: &str) -> StoreResult<Arc<TState>> {
let data = self
.data
.read()
.map_err(|e| StoreError::lock_error(format!("Read lock poisoned: {e}")))?;
match data.get(key) {
Some(value) => value.clone().downcast::<TState>().map_err(|_| {
StoreError::type_mismatch(format!(
"Cannot downcast value for key '{key}' to requested type"
))
}),
None => Err(StoreError::key_not_found(key)),
}
}
}
impl KeyValueStore for MemoryStore {
fn get<TState: 'static + Clone>(&self, key: &str) -> StoreResult<TState> {
let data = self
.data
.read()
.map_err(|e| StoreError::lock_error(format!("Read lock poisoned: {e}")))?;
match data.get(key) {
Some(value) => value.downcast_ref::<TState>().cloned().ok_or_else(|| {
StoreError::type_mismatch(format!(
"Cannot downcast value for key '{key}' to requested type"
))
}),
None => Err(StoreError::key_not_found(key)),
}
}
fn get_shared<TState: 'static + Send + Sync>(&self, key: &str) -> StoreResult<Arc<TState>> {
MemoryStore::get_shared(self, key)
}
fn put<TState: 'static + Send + Sync + Clone>(
&self,
key: &str,
value: TState,
) -> StoreResult<()> {
let mut data = self
.data
.write()
.map_err(|e| StoreError::lock_error(format!("Write lock poisoned: {e}")))?;
data.insert(key.to_string(), Arc::new(value));
Ok(())
}
fn remove(&self, key: &str) -> StoreResult<()> {
let mut data = self
.data
.write()
.map_err(|e| StoreError::lock_error(format!("Write lock poisoned: {e}")))?;
data.remove(key);
Ok(())
}
fn append<TState: 'static + Send + Sync + Clone>(
&self,
key: &str,
item: TState,
) -> StoreResult<()> {
let mut data = self
.data
.write()
.map_err(|e| StoreError::lock_error(format!("Write lock poisoned: {e}")))?;
if let Some(existing) = data.get(key) {
let existing_clone = Arc::clone(existing);
match existing_clone.downcast::<Vec<TState>>() {
Ok(arc_vec) => {
data.remove(key);
let mut vec = match Arc::try_unwrap(arc_vec) {
Ok(v) => v,
Err(shared) => (*shared).clone(),
};
vec.push(item);
data.insert(key.to_string(), Arc::new(vec));
Ok(())
}
Err(_) => {
Err(StoreError::append_type_mismatch(key))
}
}
} else {
data.insert(key.to_string(), Arc::new(vec![item]));
Ok(())
}
}
fn contains_key(&self, key: &str) -> StoreResult<bool> {
let data = self
.data
.read()
.map_err(|e| StoreError::lock_error(format!("Read lock poisoned: {e}")))?;
Ok(data.contains_key(key))
}
fn keys(&self) -> StoreResult<Vec<String>> {
let data = self
.data
.read()
.map_err(|e| StoreError::lock_error(format!("Read lock poisoned: {e}")))?;
Ok(data.keys().cloned().collect())
}
fn len(&self) -> StoreResult<usize> {
let data = self
.data
.read()
.map_err(|e| StoreError::lock_error(format!("Read lock poisoned: {e}")))?;
Ok(data.len())
}
fn clear(&self) -> StoreResult<()> {
let mut data = self
.data
.write()
.map_err(|e| StoreError::lock_error(format!("Write lock poisoned: {e}")))?;
data.clear();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
#[test]
fn test_new_store() {
let store = MemoryStore::new();
assert_eq!(store.len().unwrap(), 0);
assert!(store.is_empty().unwrap());
}
#[test]
fn test_default_store() {
let store = MemoryStore::default();
assert_eq!(store.len().unwrap(), 0);
assert!(store.is_empty().unwrap());
}
#[test]
fn test_put_and_get_string() {
let store = MemoryStore::new();
let key = "test_string";
let value = "Hello, World!".to_string();
store.put(key, value.clone()).unwrap();
let retrieved: String = store.get(key).unwrap();
assert_eq!(retrieved, value);
}
#[test]
fn test_put_and_get_integer() {
let store = MemoryStore::new();
let key = "test_int";
let value = 42i32;
store.put(key, value).unwrap();
let retrieved: i32 = store.get(key).unwrap();
assert_eq!(retrieved, value);
}
#[test]
fn test_put_and_get_vector() {
let store = MemoryStore::new();
let key = "test_vec";
let value = vec![1, 2, 3, 4, 5];
store.put(key, value.clone()).unwrap();
let retrieved: Vec<i32> = store.get(key).unwrap();
assert_eq!(retrieved, value);
}
#[test]
fn test_put_and_get_custom_struct() {
#[derive(Debug, Clone, PartialEq)]
struct TestData {
id: u32,
name: String,
}
let store = MemoryStore::new();
let key = "test_struct";
let value = TestData {
id: 123,
name: "Test".to_string(),
};
store.put(key, value.clone()).unwrap();
let retrieved: TestData = store.get(key).unwrap();
assert_eq!(retrieved, value);
}
#[test]
fn test_get_nonexistent_key() {
let store = MemoryStore::new();
let result: Result<String, StoreError> = store.get("nonexistent");
assert!(result.is_err());
match result.unwrap_err() {
StoreError::KeyNotFound(msg) => {
assert_eq!(msg, "Key 'nonexistent' not found in store")
}
_ => panic!("Expected KeyNotFound error"),
}
}
#[test]
fn test_type_mismatch() {
let store = MemoryStore::new();
let key = "test_type";
store.put(key, "hello".to_string()).unwrap();
let result: Result<i32, StoreError> = store.get(key);
assert!(result.is_err());
match result.unwrap_err() {
StoreError::TypeMismatch(_) => (),
_ => panic!("Expected TypeMismatch error"),
}
}
#[test]
fn test_remove_existing_key() {
let store = MemoryStore::new();
let key = "test_remove";
let value = "to_be_removed".to_string();
store.put(key, value).unwrap();
assert!(store.get::<String>(key).is_ok());
store.remove(key).unwrap();
assert!(store.get::<String>(key).is_err());
}
#[test]
fn test_remove_nonexistent_key() {
let store = MemoryStore::new();
let result = store.remove("nonexistent");
assert!(result.is_ok());
}
#[test]
fn test_append_to_new_key() {
let store = MemoryStore::new();
let key = "test_append_new";
let item = "first_item".to_string();
store.append(key, item.clone()).unwrap();
let retrieved: Vec<String> = store.get(key).unwrap();
assert_eq!(retrieved, vec![item]);
}
#[test]
fn test_append_to_existing_vector() {
let store = MemoryStore::new();
let key = "test_append_existing";
let initial_vec = vec!["first".to_string(), "second".to_string()];
store.put(key, initial_vec.clone()).unwrap();
let new_item = "third".to_string();
store.append(key, new_item.clone()).unwrap();
let retrieved: Vec<String> = store.get(key).unwrap();
let expected = vec![
"first".to_string(),
"second".to_string(),
"third".to_string(),
];
assert_eq!(retrieved, expected);
}
#[test]
fn test_append_to_non_vector() {
let store = MemoryStore::new();
let key = "test_append_error";
store.put(key, "not_a_vector".to_string()).unwrap();
let result = store.append(key, "item".to_string());
assert!(result.is_err());
match result.unwrap_err() {
StoreError::AppendTypeMismatch(msg) => assert_eq!(
msg,
"Cannot append to key 'test_append_error': existing value is not a Vec<TState>"
),
_ => panic!("Expected AppendTypeMismatch error"),
}
}
#[test]
fn test_append_non_destructive_on_type_mismatch() {
let store = MemoryStore::new();
let key = "test_non_destructive";
let original_value = "original".to_string();
store.put(key, original_value.clone()).unwrap();
let result = store.append::<String>(key, "item".to_string());
assert!(result.is_err());
let retrieved: String = store
.get(key)
.expect("Value should still exist after append error");
assert_eq!(
retrieved, original_value,
"Original value should be preserved after type mismatch"
);
}
#[test]
fn test_keys_empty_store() {
let store = MemoryStore::new();
let keys = store.keys().unwrap();
assert!(keys.is_empty());
}
#[test]
fn test_keys_with_data() {
let store = MemoryStore::new();
store.put("key1", "value1".to_string()).unwrap();
store.put("key2", 42i32).unwrap();
store.put("key3", vec![1, 2, 3]).unwrap();
let keys = store.keys().unwrap();
assert_eq!(keys.len(), 3);
assert!(keys.contains(&"key1".to_string()));
assert!(keys.contains(&"key2".to_string()));
assert!(keys.contains(&"key3".to_string()));
}
#[test]
fn test_len_and_is_empty() {
let store = MemoryStore::new();
assert_eq!(store.len().unwrap(), 0);
assert!(store.is_empty().unwrap());
store.put("key1", "value1".to_string()).unwrap();
assert_eq!(store.len().unwrap(), 1);
assert!(!store.is_empty().unwrap());
store.put("key2", 42i32).unwrap();
store.put("key3", vec![1, 2, 3]).unwrap();
assert_eq!(store.len().unwrap(), 3);
assert!(!store.is_empty().unwrap());
store.remove("key2").unwrap();
assert_eq!(store.len().unwrap(), 2);
assert!(!store.is_empty().unwrap());
}
#[test]
fn test_clear() {
let store = MemoryStore::new();
store.put("key1", "value1".to_string()).unwrap();
store.put("key2", 42i32).unwrap();
store.put("key3", vec![1, 2, 3]).unwrap();
assert_eq!(store.len().unwrap(), 3);
store.clear().unwrap();
assert_eq!(store.len().unwrap(), 0);
assert!(store.is_empty().unwrap());
assert!(store.get::<String>("key1").is_err());
assert!(store.get::<i32>("key2").is_err());
assert!(store.get::<Vec<i32>>("key3").is_err());
}
#[test]
fn test_overwrite_existing_key() {
let store = MemoryStore::new();
let key = "test_overwrite";
store.put(key, "initial".to_string()).unwrap();
assert_eq!(store.get::<String>(key).unwrap(), "initial");
store.put(key, "overwritten".to_string()).unwrap();
assert_eq!(store.get::<String>(key).unwrap(), "overwritten");
assert_eq!(store.len().unwrap(), 1);
}
#[test]
fn test_overwrite_with_different_type() {
let store = MemoryStore::new();
let key = "test_type_overwrite";
store.put(key, "string_value".to_string()).unwrap();
assert_eq!(store.get::<String>(key).unwrap(), "string_value");
store.put(key, 42i32).unwrap();
assert_eq!(store.get::<i32>(key).unwrap(), 42);
assert!(store.get::<String>(key).is_err());
assert_eq!(store.len().unwrap(), 1);
}
#[test]
fn test_clone_store() {
let store1 = MemoryStore::new();
store1.put("key1", "value1".to_string()).unwrap();
let store2 = store1.clone();
assert_eq!(store2.get::<String>("key1").unwrap(), "value1");
store2.put("key2", "value2".to_string()).unwrap();
assert_eq!(store1.get::<String>("key2").unwrap(), "value2");
assert_eq!(store1.len().unwrap(), 2);
assert_eq!(store2.len().unwrap(), 2);
}
#[test]
fn test_thread_safety_concurrent_reads() {
let store = Arc::new(MemoryStore::new());
store.put("shared_key", "shared_value".to_string()).unwrap();
let mut handles = vec![];
for i in 0..10 {
let store_clone = Arc::clone(&store);
let handle = thread::spawn(move || {
for _ in 0..100 {
let value: String = store_clone.get("shared_key").unwrap();
assert_eq!(value, "shared_value");
let len = store_clone.len().unwrap();
assert!(len > 0);
}
i
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_thread_safety_concurrent_writes() {
let store = Arc::new(MemoryStore::new());
let mut handles = vec![];
for i in 0..10 {
let store_clone = Arc::clone(&store);
let handle = thread::spawn(move || {
for j in 0..10 {
let key = format!("key_{i}_{j}");
let value = format!("value_{i}_{j}");
store_clone.put(&key, value).unwrap();
}
i
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(store.len().unwrap(), 100);
assert_eq!(store.get::<String>("key_0_0").unwrap(), "value_0_0");
assert_eq!(store.get::<String>("key_5_7").unwrap(), "value_5_7");
assert_eq!(store.get::<String>("key_9_9").unwrap(), "value_9_9");
}
#[test]
fn test_thread_safety_mixed_operations() {
let store = Arc::new(MemoryStore::new());
for i in 0..50 {
store.put(&format!("initial_{i}"), i).unwrap();
}
let mut handles = vec![];
for _ in 0..5 {
let store_clone = Arc::clone(&store);
let handle = thread::spawn(move || {
for i in 0..50 {
if let Ok(value) = store_clone.get::<i32>(&format!("initial_{i}")) {
assert_eq!(value, i);
}
thread::sleep(Duration::from_millis(1));
}
});
handles.push(handle);
}
for thread_id in 0..3 {
let store_clone = Arc::clone(&store);
let handle = thread::spawn(move || {
for i in 0..20 {
let key = format!("new_{thread_id}_{i}");
let value = format!("new_value_{thread_id}_{i}");
store_clone.put(&key, value).unwrap();
thread::sleep(Duration::from_millis(1));
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let final_len = store.len().unwrap();
assert_eq!(final_len, 110); }
#[test]
fn test_append_thread_safety() {
let store = Arc::new(MemoryStore::new());
let mut handles = vec![];
for thread_id in 0..5 {
let store_clone = Arc::clone(&store);
let handle = thread::spawn(move || {
for i in 0..10 {
let value = format!("thread_{thread_id}_item_{i}");
store_clone.append("shared_vector", value).unwrap();
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let final_vec: Vec<String> = store.get("shared_vector").unwrap();
assert_eq!(final_vec.len(), 50);
for thread_id in 0..5 {
for i in 0..10 {
let expected_value = format!("thread_{thread_id}_item_{i}");
assert!(final_vec.contains(&expected_value));
}
}
}
#[test]
fn test_complex_data_types() {
#[derive(Debug, Clone, PartialEq)]
struct ComplexData {
numbers: Vec<i32>,
text: String,
optional: Option<String>,
nested: std::collections::HashMap<String, i32>,
}
let store = MemoryStore::new();
let mut nested = std::collections::HashMap::new();
nested.insert("nested_key".to_string(), 42);
let complex_data = ComplexData {
numbers: vec![1, 2, 3, 4, 5],
text: "Complex data structure".to_string(),
optional: Some("Some value".to_string()),
nested,
};
store.put("complex", complex_data.clone()).unwrap();
let retrieved: ComplexData = store.get("complex").unwrap();
assert_eq!(retrieved, complex_data);
}
#[test]
fn test_large_data_handling() {
let store = MemoryStore::new();
let large_vec: Vec<i32> = (0..10000).collect();
store.put("large_data", large_vec.clone()).unwrap();
let retrieved: Vec<i32> = store.get("large_data").unwrap();
assert_eq!(retrieved, large_vec);
assert_eq!(retrieved.len(), 10000);
}
#[test]
fn test_contains_key_present() {
let store = MemoryStore::new();
store.put("exists", 42i32).unwrap();
assert!(store.contains_key("exists").unwrap());
}
#[test]
fn test_contains_key_absent() {
let store = MemoryStore::new();
assert!(!store.contains_key("missing").unwrap());
}
#[test]
fn test_contains_key_after_remove() {
let store = MemoryStore::new();
store.put("key", "val".to_string()).unwrap();
assert!(store.contains_key("key").unwrap());
store.remove("key").unwrap();
assert!(!store.contains_key("key").unwrap());
}
#[test]
fn test_get_shared_returns_arc() {
let store = MemoryStore::new();
store.put("msg", "hello".to_string()).unwrap();
let arc1: Arc<String> = store.get_shared("msg").unwrap();
let arc2: Arc<String> = store.get_shared("msg").unwrap();
assert_eq!(*arc1, "hello");
assert!(
Arc::ptr_eq(&arc1, &arc2),
"get_shared must return clones of the same Arc, not fresh allocations"
);
}
#[test]
fn test_get_shared_missing_key() {
let store = MemoryStore::new();
let result: StoreResult<Arc<String>> = store.get_shared("nope");
assert!(result.is_err());
match result.unwrap_err() {
StoreError::KeyNotFound(_) => (),
_ => panic!("Expected KeyNotFound error"),
}
}
#[test]
fn test_get_shared_trait_bound_no_clone_required() {
fn shared_via_trait<T: 'static + Send + Sync, S: KeyValueStore>(
store: &S,
key: &str,
) -> StoreResult<Arc<T>> {
store.get_shared::<T>(key)
}
let store = MemoryStore::new();
store.put("n", 42u32).unwrap();
let got: Arc<u32> = shared_via_trait::<u32, _>(&store, "n").unwrap();
assert_eq!(*got, 42);
struct NotClone(#[allow(dead_code)] u32);
match shared_via_trait::<NotClone, _>(&store, "missing") {
Err(StoreError::KeyNotFound(_)) => (),
Err(e) => panic!("expected KeyNotFound for non-Clone type, got: {e:?}"),
Ok(_) => panic!("expected error for missing key"),
}
}
}