1use mem_types::{VecSearchHit, VecStore, VecStoreError, VecStoreItem};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
9 if a.len() != b.len() || a.is_empty() {
10 return 0.0;
11 }
12 let dot: f64 = a
13 .iter()
14 .zip(b.iter())
15 .map(|(x, y)| (*x as f64) * (*y as f64))
16 .sum();
17 let na: f64 = a.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
18 let nb: f64 = b.iter().map(|x| (*x as f64).powi(2)).sum::<f64>().sqrt();
19 if na == 0.0 || nb == 0.0 {
20 return 0.0;
21 }
22 dot / (na * nb)
23}
24
25pub struct InMemoryVecStore {
27 store: Arc<RwLock<HashMap<String, HashMap<String, VecStoreItem>>>>,
29 default_collection: String,
30}
31
32impl InMemoryVecStore {
33 pub fn new(default_collection: Option<&str>) -> Self {
34 Self {
35 store: Arc::new(RwLock::new(HashMap::new())),
36 default_collection: default_collection.unwrap_or("memos_memories").to_string(),
37 }
38 }
39
40 fn coll(&self, collection: Option<&str>) -> String {
41 collection.unwrap_or(&self.default_collection).to_string()
42 }
43}
44
45#[async_trait::async_trait]
46impl VecStore for InMemoryVecStore {
47 async fn add(
48 &self,
49 items: &[VecStoreItem],
50 collection: Option<&str>,
51 ) -> Result<(), VecStoreError> {
52 let coll = self.coll(collection);
53 let mut guard = self.store.write().await;
54 let map = guard.entry(coll).or_default();
55 for item in items {
56 map.insert(item.id.clone(), item.clone());
57 }
58 Ok(())
59 }
60
61 async fn search(
62 &self,
63 query_vector: &[f32],
64 top_k: usize,
65 filter: Option<&HashMap<String, serde_json::Value>>,
66 collection: Option<&str>,
67 ) -> Result<Vec<VecSearchHit>, VecStoreError> {
68 let coll = self.coll(collection);
69 let guard = self.store.read().await;
70 let map = guard
71 .get(&coll)
72 .map(|m| m.values().cloned().collect::<Vec<_>>());
73 let items = map.unwrap_or_default();
74 let mut candidates: Vec<(VecStoreItem, f64)> = items
75 .into_iter()
76 .filter(|i| {
77 if let Some(f) = filter {
78 for (k, v) in f.iter() {
79 if let Some(pv) = i.payload.get(k) {
80 if pv != v {
81 return false;
82 }
83 } else {
84 return false;
85 }
86 }
87 }
88 true
89 })
90 .map(|i| {
91 let score = cosine_similarity(query_vector, &i.vector);
92 (i, score)
93 })
94 .collect();
95 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
96 let hits = candidates
97 .into_iter()
98 .take(top_k)
99 .map(|(i, score)| VecSearchHit { id: i.id, score })
100 .collect();
101 Ok(hits)
102 }
103
104 async fn get_by_ids(
105 &self,
106 ids: &[String],
107 collection: Option<&str>,
108 ) -> Result<Vec<VecStoreItem>, VecStoreError> {
109 let coll = self.coll(collection);
110 let guard = self.store.read().await;
111 let map = guard.get(&coll);
112 let mut out = Vec::new();
113 if let Some(m) = map {
114 for id in ids {
115 if let Some(item) = m.get(id) {
116 out.push(item.clone());
117 }
118 }
119 }
120 Ok(out)
121 }
122
123 async fn delete(&self, ids: &[String], collection: Option<&str>) -> Result<(), VecStoreError> {
124 let coll = self.coll(collection);
125 let mut guard = self.store.write().await;
126 if let Some(m) = guard.get_mut(&coll) {
127 for id in ids {
128 m.remove(id);
129 }
130 }
131 Ok(())
132 }
133
134 async fn upsert(
135 &self,
136 items: &[VecStoreItem],
137 collection: Option<&str>,
138 ) -> Result<(), VecStoreError> {
139 let coll = self.coll(collection);
140 let mut guard = self.store.write().await;
141 let map = guard.entry(coll).or_default();
142 for item in items {
143 map.insert(item.id.clone(), item.clone());
144 }
145 Ok(())
146 }
147}