use converge_core::capability::{
CapabilityError, VectorMatch, VectorQuery, VectorRecall, VectorRecord,
};
use std::collections::HashMap;
use std::sync::RwLock;
use super::cosine_similarity;
pub struct InMemoryVectorStore {
records: RwLock<HashMap<String, VectorRecord>>,
}
impl Default for InMemoryVectorStore {
fn default() -> Self {
Self::new()
}
}
impl InMemoryVectorStore {
#[must_use]
pub fn new() -> Self {
Self {
records: RwLock::new(HashMap::new()),
}
}
#[must_use]
pub fn with_records(records: Vec<VectorRecord>) -> Self {
let store = Self::new();
for record in records {
let _ = store.upsert(&record);
}
store
}
pub fn all_records(&self) -> Vec<VectorRecord> {
self.records
.read()
.expect("Lock poisoned")
.values()
.cloned()
.collect()
}
}
impl VectorRecall for InMemoryVectorStore {
fn name(&self) -> &'static str {
"in-memory"
}
fn upsert(&self, record: &VectorRecord) -> Result<(), CapabilityError> {
let mut records = self.records.write().expect("Lock poisoned");
records.insert(record.id.clone(), record.clone());
Ok(())
}
fn query(&self, query: &VectorQuery) -> Result<Vec<VectorMatch>, CapabilityError> {
let records = self.records.read().expect("Lock poisoned");
let mut matches: Vec<VectorMatch> = records
.values()
.map(|record| {
let score = f64::from(cosine_similarity(&query.vector, &record.vector));
VectorMatch {
id: record.id.clone(),
score,
payload: record.payload.clone(),
}
})
.filter(|m| {
query.min_score.is_none_or(|min| m.score >= min)
})
.collect();
matches.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
matches.truncate(query.top_k);
Ok(matches)
}
fn delete(&self, id: &str) -> Result<(), CapabilityError> {
let mut records = self.records.write().expect("Lock poisoned");
records.remove(id);
Ok(())
}
fn clear(&self) -> Result<(), CapabilityError> {
let mut records = self.records.write().expect("Lock poisoned");
records.clear();
Ok(())
}
fn count(&self) -> Result<usize, CapabilityError> {
let records = self.records.read().expect("Lock poisoned");
Ok(records.len())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn upsert_and_query() {
let store = InMemoryVectorStore::new();
store
.upsert(&VectorRecord {
id: "doc-1".into(),
vector: vec![1.0, 0.0, 0.0],
payload: json!({"title": "Document 1"}),
})
.unwrap();
store
.upsert(&VectorRecord {
id: "doc-2".into(),
vector: vec![0.9, 0.1, 0.0],
payload: json!({"title": "Document 2"}),
})
.unwrap();
store
.upsert(&VectorRecord {
id: "doc-3".into(),
vector: vec![0.0, 1.0, 0.0],
payload: json!({"title": "Document 3"}),
})
.unwrap();
assert_eq!(store.count().unwrap(), 3);
let matches = store
.query(&VectorQuery::new(vec![1.0, 0.0, 0.0], 2))
.unwrap();
assert_eq!(matches.len(), 2);
assert_eq!(matches[0].id, "doc-1"); assert_eq!(matches[1].id, "doc-2"); }
#[test]
fn query_with_min_score() {
let store = InMemoryVectorStore::new();
store
.upsert(&VectorRecord {
id: "close".into(),
vector: vec![0.95, 0.05, 0.0],
payload: json!({}),
})
.unwrap();
store
.upsert(&VectorRecord {
id: "far".into(),
vector: vec![0.0, 0.0, 1.0],
payload: json!({}),
})
.unwrap();
let matches = store
.query(&VectorQuery::new(vec![1.0, 0.0, 0.0], 10).with_min_score(0.5))
.unwrap();
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].id, "close");
}
#[test]
fn upsert_overwrites() {
let store = InMemoryVectorStore::new();
store
.upsert(&VectorRecord {
id: "doc-1".into(),
vector: vec![1.0, 0.0, 0.0],
payload: json!({"version": 1}),
})
.unwrap();
store
.upsert(&VectorRecord {
id: "doc-1".into(),
vector: vec![0.0, 1.0, 0.0],
payload: json!({"version": 2}),
})
.unwrap();
assert_eq!(store.count().unwrap(), 1);
let records = store.all_records();
assert_eq!(records[0].payload["version"], 2);
}
#[test]
fn delete_and_clear() {
let store = InMemoryVectorStore::new();
store
.upsert(&VectorRecord {
id: "doc-1".into(),
vector: vec![1.0, 0.0, 0.0],
payload: json!({}),
})
.unwrap();
store
.upsert(&VectorRecord {
id: "doc-2".into(),
vector: vec![0.0, 1.0, 0.0],
payload: json!({}),
})
.unwrap();
assert_eq!(store.count().unwrap(), 2);
store.delete("doc-1").unwrap();
assert_eq!(store.count().unwrap(), 1);
store.clear().unwrap();
assert_eq!(store.count().unwrap(), 0);
}
}