agent_memory/
vector_store_api.rs

1//! Vector store traits and a local in-memory implementation.
2
3use std::collections::HashMap;
4use std::num::NonZeroUsize;
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use tokio::sync::RwLock;
10use uuid::Uuid;
11
12use crate::MemoryResult;
13use crate::embeddings::EmbeddingVector;
14
15/// Record stored in a vector database.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct VectorPoint {
18    id: Uuid,
19    embedding: EmbeddingVector,
20    #[serde(default)]
21    metadata: Value,
22    #[serde(default)]
23    tags: Vec<String>,
24}
25
26impl VectorPoint {
27    /// Creates a new vector point with optional metadata.
28    #[must_use]
29    pub fn new(id: Uuid, embedding: EmbeddingVector) -> Self {
30        Self {
31            id,
32            embedding,
33            metadata: Value::Null,
34            tags: Vec::new(),
35        }
36    }
37
38    /// Assigns metadata to the point.
39    #[must_use]
40    pub fn with_metadata(mut self, metadata: Value) -> Self {
41        self.metadata = metadata;
42        self
43    }
44
45    /// Assigns tags to the point.
46    #[must_use]
47    pub fn with_tags<I, S>(mut self, tags: I) -> Self
48    where
49        I: IntoIterator<Item = S>,
50        S: Into<String>,
51    {
52        self.tags = tags.into_iter().map(Into::into).collect();
53        self
54    }
55
56    /// Returns the identifier.
57    #[must_use]
58    pub fn id(&self) -> Uuid {
59        self.id
60    }
61
62    /// Returns the embedding reference.
63    #[must_use]
64    pub fn embedding(&self) -> &EmbeddingVector {
65        &self.embedding
66    }
67
68    /// Returns tags associated with the point.
69    #[must_use]
70    pub fn tags(&self) -> &[String] {
71        &self.tags
72    }
73
74    /// Returns the metadata payload.
75    #[must_use]
76    pub fn metadata(&self) -> &Value {
77        &self.metadata
78    }
79}
80
81/// Query parameters for retrieving similar vectors.
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct VectorQuery {
84    embedding: EmbeddingVector,
85    top_k: NonZeroUsize,
86    #[serde(default)]
87    tags: Vec<String>,
88}
89
90impl VectorQuery {
91    /// Creates a new query request.
92    #[must_use]
93    pub fn new(embedding: EmbeddingVector, top_k: NonZeroUsize) -> Self {
94        Self {
95            embedding,
96            top_k,
97            tags: Vec::new(),
98        }
99    }
100
101    /// Restricts results to vectors tagged with all provided labels.
102    #[must_use]
103    pub fn with_tags<I, S>(mut self, tags: I) -> Self
104    where
105        I: IntoIterator<Item = S>,
106        S: Into<String>,
107    {
108        self.tags = tags.into_iter().map(Into::into).collect();
109        self
110    }
111
112    /// Returns the embedding driving the query.
113    #[must_use]
114    pub fn embedding(&self) -> &EmbeddingVector {
115        &self.embedding
116    }
117
118    /// Returns the desired number of results.
119    #[must_use]
120    pub fn top_k(&self) -> usize {
121        self.top_k.get()
122    }
123
124    /// Returns tags to enforce during search.
125    #[must_use]
126    pub fn tags(&self) -> &[String] {
127        &self.tags
128    }
129}
130
131/// Match returned from a vector store query.
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct VectorMatch {
134    id: Uuid,
135    score: f32,
136    #[serde(default)]
137    metadata: Value,
138    #[serde(default)]
139    tags: Vec<String>,
140}
141
142impl VectorMatch {
143    /// Creates a match structure.
144    #[must_use]
145    pub fn new(id: Uuid, score: f32, metadata: Value, tags: Vec<String>) -> Self {
146        Self {
147            id,
148            score,
149            metadata,
150            tags,
151        }
152    }
153
154    /// Returns the identifier.
155    #[must_use]
156    pub fn id(&self) -> Uuid {
157        self.id
158    }
159
160    /// Returns cosine similarity score.
161    #[must_use]
162    pub fn score(&self) -> f32 {
163        self.score
164    }
165
166    /// Returns metadata payload.
167    #[must_use]
168    pub fn metadata(&self) -> &Value {
169        &self.metadata
170    }
171
172    /// Returns tags associated with the match.
173    #[must_use]
174    pub fn tags(&self) -> &[String] {
175        &self.tags
176    }
177}
178
179/// Interface for vector store clients.
180#[async_trait]
181pub trait VectorStoreClient: Send + Sync {
182    /// Inserts or updates a vector point.
183    async fn upsert(&self, point: VectorPoint) -> MemoryResult<()>;
184
185    /// Removes a vector point if present.
186    async fn remove(&self, id: Uuid) -> MemoryResult<()>;
187
188    /// Executes a similarity query and returns matches ordered by descending score.
189    async fn query(&self, query: VectorQuery) -> MemoryResult<Vec<VectorMatch>>;
190}
191
192/// Simple in-memory vector store using cosine similarity.
193pub struct LocalVectorStore {
194    points: RwLock<HashMap<Uuid, VectorPoint>>,
195}
196
197impl LocalVectorStore {
198    /// Creates an empty store.
199    #[must_use]
200    pub fn new() -> Self {
201        Self {
202            points: RwLock::new(HashMap::new()),
203        }
204    }
205}
206
207impl Default for LocalVectorStore {
208    fn default() -> Self {
209        Self::new()
210    }
211}
212
213#[async_trait]
214impl VectorStoreClient for LocalVectorStore {
215    async fn upsert(&self, point: VectorPoint) -> MemoryResult<()> {
216        let mut guard = self.points.write().await;
217        guard.insert(point.id(), point);
218        Ok(())
219    }
220
221    async fn remove(&self, id: Uuid) -> MemoryResult<()> {
222        let mut guard = self.points.write().await;
223        guard.remove(&id);
224        Ok(())
225    }
226
227    async fn query(&self, query: VectorQuery) -> MemoryResult<Vec<VectorMatch>> {
228        let guard = self.points.read().await;
229        let mut matches = Vec::new();
230
231        let query_embedding = query.embedding();
232        let query_tags = query.tags();
233
234        for point in guard.values() {
235            if !query_tags.is_empty()
236                && !query_tags
237                    .iter()
238                    .all(|tag| point.tags().iter().any(|candidate| candidate == tag))
239            {
240                continue;
241            }
242
243            if point.embedding().len() != query_embedding.len() {
244                continue;
245            }
246
247            let score = cosine_similarity(point.embedding(), query_embedding);
248            matches.push(VectorMatch::new(
249                point.id(),
250                score,
251                point.metadata().clone(),
252                point.tags().to_vec(),
253            ));
254        }
255
256        matches.sort_by(|a, b| {
257            b.score
258                .partial_cmp(&a.score)
259                .unwrap_or(std::cmp::Ordering::Equal)
260        });
261        matches.truncate(query.top_k());
262        Ok(matches)
263    }
264}
265
266fn cosine_similarity(lhs: &EmbeddingVector, rhs: &EmbeddingVector) -> f32 {
267    let numerator = lhs.dot(rhs);
268    let denominator = lhs.magnitude() * rhs.magnitude();
269    if denominator == 0.0 {
270        0.0
271    } else {
272        numerator / denominator
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[tokio::test]
281    async fn basic_query() {
282        let store = LocalVectorStore::new();
283
284        store
285            .upsert(
286                VectorPoint::new(
287                    Uuid::new_v4(),
288                    EmbeddingVector::new(vec![1.0, 0.0, 0.0]).unwrap(),
289                )
290                .with_tags(["alpha"]),
291            )
292            .await
293            .unwrap();
294
295        store
296            .upsert(
297                VectorPoint::new(
298                    Uuid::new_v4(),
299                    EmbeddingVector::new(vec![0.0, 1.0, 0.0]).unwrap(),
300                )
301                .with_tags(["beta"]),
302            )
303            .await
304            .unwrap();
305
306        let query = VectorQuery::new(
307            EmbeddingVector::new(vec![1.0, 0.0, 0.0]).unwrap(),
308            NonZeroUsize::new(1).unwrap(),
309        );
310        let matches = store.query(query).await.unwrap();
311        assert_eq!(matches.len(), 1);
312        assert_eq!(matches[0].tags(), ["alpha"]);
313        assert!((matches[0].score() - 1.0).abs() < f32::EPSILON);
314    }
315
316    #[tokio::test]
317    async fn respects_tag_filter() {
318        let store = LocalVectorStore::new();
319        let id = Uuid::new_v4();
320        store
321            .upsert(
322                VectorPoint::new(id, EmbeddingVector::new(vec![1.0, 1.0]).unwrap())
323                    .with_tags(["alpha", "beta"]),
324            )
325            .await
326            .unwrap();
327
328        let query = VectorQuery::new(
329            EmbeddingVector::new(vec![1.0, 1.0]).unwrap(),
330            NonZeroUsize::new(5).unwrap(),
331        )
332        .with_tags(["beta", "alpha"]);
333        let matches = store.query(query).await.unwrap();
334        assert_eq!(matches.len(), 1);
335        assert_eq!(matches[0].id(), id);
336    }
337}