1use std::collections::HashMap;
4use std::num::NonZeroUsize;
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use tokio::sync::RwLock;
10use uuid::Uuid;
11
12use crate::MemoryResult;
13use crate::embeddings::EmbeddingVector;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct VectorPoint {
18 id: Uuid,
19 embedding: EmbeddingVector,
20 #[serde(default)]
21 metadata: Value,
22 #[serde(default)]
23 tags: Vec<String>,
24}
25
26impl VectorPoint {
27 #[must_use]
29 pub fn new(id: Uuid, embedding: EmbeddingVector) -> Self {
30 Self {
31 id,
32 embedding,
33 metadata: Value::Null,
34 tags: Vec::new(),
35 }
36 }
37
38 #[must_use]
40 pub fn with_metadata(mut self, metadata: Value) -> Self {
41 self.metadata = metadata;
42 self
43 }
44
45 #[must_use]
47 pub fn with_tags<I, S>(mut self, tags: I) -> Self
48 where
49 I: IntoIterator<Item = S>,
50 S: Into<String>,
51 {
52 self.tags = tags.into_iter().map(Into::into).collect();
53 self
54 }
55
56 #[must_use]
58 pub fn id(&self) -> Uuid {
59 self.id
60 }
61
62 #[must_use]
64 pub fn embedding(&self) -> &EmbeddingVector {
65 &self.embedding
66 }
67
68 #[must_use]
70 pub fn tags(&self) -> &[String] {
71 &self.tags
72 }
73
74 #[must_use]
76 pub fn metadata(&self) -> &Value {
77 &self.metadata
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct VectorQuery {
84 embedding: EmbeddingVector,
85 top_k: NonZeroUsize,
86 #[serde(default)]
87 tags: Vec<String>,
88}
89
90impl VectorQuery {
91 #[must_use]
93 pub fn new(embedding: EmbeddingVector, top_k: NonZeroUsize) -> Self {
94 Self {
95 embedding,
96 top_k,
97 tags: Vec::new(),
98 }
99 }
100
101 #[must_use]
103 pub fn with_tags<I, S>(mut self, tags: I) -> Self
104 where
105 I: IntoIterator<Item = S>,
106 S: Into<String>,
107 {
108 self.tags = tags.into_iter().map(Into::into).collect();
109 self
110 }
111
112 #[must_use]
114 pub fn embedding(&self) -> &EmbeddingVector {
115 &self.embedding
116 }
117
118 #[must_use]
120 pub fn top_k(&self) -> usize {
121 self.top_k.get()
122 }
123
124 #[must_use]
126 pub fn tags(&self) -> &[String] {
127 &self.tags
128 }
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct VectorMatch {
134 id: Uuid,
135 score: f32,
136 #[serde(default)]
137 metadata: Value,
138 #[serde(default)]
139 tags: Vec<String>,
140}
141
142impl VectorMatch {
143 #[must_use]
145 pub fn new(id: Uuid, score: f32, metadata: Value, tags: Vec<String>) -> Self {
146 Self {
147 id,
148 score,
149 metadata,
150 tags,
151 }
152 }
153
154 #[must_use]
156 pub fn id(&self) -> Uuid {
157 self.id
158 }
159
160 #[must_use]
162 pub fn score(&self) -> f32 {
163 self.score
164 }
165
166 #[must_use]
168 pub fn metadata(&self) -> &Value {
169 &self.metadata
170 }
171
172 #[must_use]
174 pub fn tags(&self) -> &[String] {
175 &self.tags
176 }
177}
178
179#[async_trait]
181pub trait VectorStoreClient: Send + Sync {
182 async fn upsert(&self, point: VectorPoint) -> MemoryResult<()>;
184
185 async fn remove(&self, id: Uuid) -> MemoryResult<()>;
187
188 async fn query(&self, query: VectorQuery) -> MemoryResult<Vec<VectorMatch>>;
190}
191
192pub struct LocalVectorStore {
194 points: RwLock<HashMap<Uuid, VectorPoint>>,
195}
196
197impl LocalVectorStore {
198 #[must_use]
200 pub fn new() -> Self {
201 Self {
202 points: RwLock::new(HashMap::new()),
203 }
204 }
205}
206
207impl Default for LocalVectorStore {
208 fn default() -> Self {
209 Self::new()
210 }
211}
212
213#[async_trait]
214impl VectorStoreClient for LocalVectorStore {
215 async fn upsert(&self, point: VectorPoint) -> MemoryResult<()> {
216 let mut guard = self.points.write().await;
217 guard.insert(point.id(), point);
218 Ok(())
219 }
220
221 async fn remove(&self, id: Uuid) -> MemoryResult<()> {
222 let mut guard = self.points.write().await;
223 guard.remove(&id);
224 Ok(())
225 }
226
227 async fn query(&self, query: VectorQuery) -> MemoryResult<Vec<VectorMatch>> {
228 let guard = self.points.read().await;
229 let mut matches = Vec::new();
230
231 let query_embedding = query.embedding();
232 let query_tags = query.tags();
233
234 for point in guard.values() {
235 if !query_tags.is_empty()
236 && !query_tags
237 .iter()
238 .all(|tag| point.tags().iter().any(|candidate| candidate == tag))
239 {
240 continue;
241 }
242
243 if point.embedding().len() != query_embedding.len() {
244 continue;
245 }
246
247 let score = cosine_similarity(point.embedding(), query_embedding);
248 matches.push(VectorMatch::new(
249 point.id(),
250 score,
251 point.metadata().clone(),
252 point.tags().to_vec(),
253 ));
254 }
255
256 matches.sort_by(|a, b| {
257 b.score
258 .partial_cmp(&a.score)
259 .unwrap_or(std::cmp::Ordering::Equal)
260 });
261 matches.truncate(query.top_k());
262 Ok(matches)
263 }
264}
265
266fn cosine_similarity(lhs: &EmbeddingVector, rhs: &EmbeddingVector) -> f32 {
267 let numerator = lhs.dot(rhs);
268 let denominator = lhs.magnitude() * rhs.magnitude();
269 if denominator == 0.0 {
270 0.0
271 } else {
272 numerator / denominator
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[tokio::test]
281 async fn basic_query() {
282 let store = LocalVectorStore::new();
283
284 store
285 .upsert(
286 VectorPoint::new(
287 Uuid::new_v4(),
288 EmbeddingVector::new(vec![1.0, 0.0, 0.0]).unwrap(),
289 )
290 .with_tags(["alpha"]),
291 )
292 .await
293 .unwrap();
294
295 store
296 .upsert(
297 VectorPoint::new(
298 Uuid::new_v4(),
299 EmbeddingVector::new(vec![0.0, 1.0, 0.0]).unwrap(),
300 )
301 .with_tags(["beta"]),
302 )
303 .await
304 .unwrap();
305
306 let query = VectorQuery::new(
307 EmbeddingVector::new(vec![1.0, 0.0, 0.0]).unwrap(),
308 NonZeroUsize::new(1).unwrap(),
309 );
310 let matches = store.query(query).await.unwrap();
311 assert_eq!(matches.len(), 1);
312 assert_eq!(matches[0].tags(), ["alpha"]);
313 assert!((matches[0].score() - 1.0).abs() < f32::EPSILON);
314 }
315
316 #[tokio::test]
317 async fn respects_tag_filter() {
318 let store = LocalVectorStore::new();
319 let id = Uuid::new_v4();
320 store
321 .upsert(
322 VectorPoint::new(id, EmbeddingVector::new(vec![1.0, 1.0]).unwrap())
323 .with_tags(["alpha", "beta"]),
324 )
325 .await
326 .unwrap();
327
328 let query = VectorQuery::new(
329 EmbeddingVector::new(vec![1.0, 1.0]).unwrap(),
330 NonZeroUsize::new(5).unwrap(),
331 )
332 .with_tags(["beta", "alpha"]);
333 let matches = store.query(query).await.unwrap();
334 assert_eq!(matches.len(), 1);
335 assert_eq!(matches[0].id(), id);
336 }
337}