use std::collections::HashMap;
use std::sync::RwLock;
use crate::error::Result;
use crate::record::{Record, RecordId};
#[derive(Debug, Default)]
pub(crate) struct MemoryStore {
records: RwLock<HashMap<RecordId, Record>>,
}
impl MemoryStore {
pub(crate) fn new() -> Self {
Self::default()
}
pub(crate) fn upsert(&self, record: Record) -> Result<()> {
let id = record.id();
let mut guard = self.write();
let _previous = guard.insert(id, record);
Ok(())
}
pub(crate) fn get(&self, id: RecordId) -> Result<Option<Record>> {
let guard = self.read();
Ok(guard.get(&id).cloned())
}
pub(crate) fn delete(&self, id: RecordId) -> Result<bool> {
let mut guard = self.write();
Ok(guard.remove(&id).is_some())
}
pub(crate) fn len(&self) -> usize {
self.read().len()
}
pub(crate) fn is_empty(&self) -> bool {
self.read().is_empty()
}
pub(crate) fn with_records<F, R>(&self, f: F) -> R
where
F: FnOnce(&HashMap<RecordId, Record>) -> R,
{
let guard = self.read();
f(&guard)
}
fn read(&self) -> std::sync::RwLockReadGuard<'_, HashMap<RecordId, Record>> {
match self.records.read() {
Ok(guard) => guard,
Err(poison) => poison.into_inner(),
}
}
fn write(&self) -> std::sync::RwLockWriteGuard<'_, HashMap<RecordId, Record>> {
match self.records.write() {
Ok(guard) => guard,
Err(poison) => poison.into_inner(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::vector::Vector;
fn record(id: u64, components: Vec<f32>) -> Record {
Record::new(RecordId::new(id), Vector::new(components).unwrap())
}
#[test]
fn new_store_is_empty() {
let s = MemoryStore::new();
assert!(s.is_empty());
assert_eq!(s.len(), 0);
}
#[test]
fn upsert_inserts_new_record() {
let s = MemoryStore::new();
s.upsert(record(1, vec![0.1, 0.2])).unwrap();
assert_eq!(s.len(), 1);
let got = s.get(RecordId::new(1)).unwrap().expect("present");
assert_eq!(got.vector().as_slice(), &[0.1, 0.2]);
}
#[test]
fn upsert_replaces_existing_record() {
let s = MemoryStore::new();
s.upsert(record(1, vec![0.1, 0.2])).unwrap();
s.upsert(record(1, vec![0.9, 0.8])).unwrap();
assert_eq!(s.len(), 1);
let got = s.get(RecordId::new(1)).unwrap().expect("present");
assert_eq!(got.vector().as_slice(), &[0.9, 0.8]);
}
#[test]
fn get_returns_none_for_missing_id() {
let s = MemoryStore::new();
assert!(s.get(RecordId::new(99)).unwrap().is_none());
}
#[test]
fn delete_removes_existing_record() {
let s = MemoryStore::new();
s.upsert(record(1, vec![0.1, 0.2])).unwrap();
assert!(s.delete(RecordId::new(1)).unwrap());
assert!(s.is_empty());
}
#[test]
fn delete_returns_false_when_absent() {
let s = MemoryStore::new();
assert!(!s.delete(RecordId::new(99)).unwrap());
}
#[test]
fn len_reflects_distinct_ids() {
let s = MemoryStore::new();
s.upsert(record(1, vec![0.0])).unwrap();
s.upsert(record(2, vec![0.0])).unwrap();
s.upsert(record(3, vec![0.0])).unwrap();
s.upsert(record(2, vec![1.0])).unwrap();
assert_eq!(s.len(), 3);
}
#[test]
fn store_is_thread_safe_across_arc() {
use std::sync::Arc;
use std::thread;
let store = Arc::new(MemoryStore::new());
let mut handles = Vec::new();
for id in 0..16_u64 {
let s = Arc::clone(&store);
handles.push(thread::spawn(move || {
s.upsert(record(id, vec![id as f32; 4])).unwrap();
}));
}
for h in handles {
h.join().expect("worker thread panicked");
}
assert_eq!(store.len(), 16);
for id in 0..16 {
let got = store
.get(RecordId::new(id))
.unwrap()
.expect("record present");
assert_eq!(got.id().get(), id);
}
}
#[test]
fn store_recovers_from_poisoned_lock() {
use std::sync::Arc;
use std::thread;
let store = Arc::new(MemoryStore::new());
store.upsert(record(1, vec![0.1, 0.2])).unwrap();
let poisoner = Arc::clone(&store);
let _ = thread::spawn(move || {
let _guard = poisoner.records.write().unwrap();
panic!("intentional");
})
.join();
assert!(store.records.is_poisoned());
let got = store.get(RecordId::new(1)).unwrap();
assert!(got.is_some());
}
}