Skip to main content

memoir_core/vector/
mod.rs

1//! Vector-index abstraction for similarity search.
2//!
3//! Defines [`VectorIndex`], implemented by [`QdrantIndex`] (the default) and
4//! by callers who want to plug in a different vector backend. Memoir's
5//! source-of-truth row storage is a separate concern handled by
6//! [`crate::store::MemoryStore`]; this trait covers only the vector index.
7
8mod error;
9mod filter;
10pub mod qdrant;
11
12pub use error::VectorError;
13pub use filter::{FilterCondition, MatchValue, MatchValues, MemoryFilter, NumericRange};
14pub use qdrant::QdrantIndex;
15
16use std::future::Future;
17
18use crate::memory::{KindSelector, Memory, Scope};
19
20#[cfg(test)]
21use crate::memory::MemoryKind;
22
23/// Stores and queries vectors keyed by memory pid.
24///
25/// Implementations own the vector-backend connection. The trait methods are
26/// async and `Send`-bound so callers can drive them from any tokio runtime.
27pub trait VectorIndex: Send + Sync + 'static {
28    /// Ensures the backing collection exists with the configured dimension.
29    ///
30    /// Idempotent: callers invoke this on startup; second-call is a no-op.
31    /// `vector_dim` must match the dimension produced by the embedding model
32    /// the consumer pairs with this index.
33    ///
34    /// # Errors
35    ///
36    /// Returns [`VectorError::Connection`] if the backend is unreachable,
37    /// [`VectorError::BadRequest`] if the collection exists with a
38    /// different vector dimension than requested.
39    fn ensure_collection(&self, vector_dim: usize) -> impl Future<Output = Result<(), VectorError>> + Send;
40
41    /// Upserts a memory's vector + payload for similarity search and filtering.
42    ///
43    /// The payload carries enough of the memory's state to support payload-
44    /// level filters at search time: scope (`agent_id`, `org_id`, `user_id`),
45    /// `kind`, `created_at`, `event_at` (when known), and the memory's
46    /// arbitrary JSON `metadata` flattened to top-level payload keys. The
47    /// source-of-truth row in Postgres still holds the canonical copy; the
48    /// payload is a derived index. Callers are responsible for ensuring the
49    /// Postgres row exists before this completes
50    /// ([`crate::store::IndexStatus::Pending`] covers the gap).
51    ///
52    /// # Errors
53    ///
54    /// Returns [`VectorError::Connection`] on backend errors and
55    /// [`VectorError::BadRequest`] when the vector's dimension does not
56    /// match the collection's.
57    fn upsert(&self, memory: &Memory, vector: Vec<f32>) -> impl Future<Output = Result<(), VectorError>> + Send;
58
59    /// Returns the top similarity hits within `scope`, filtered by kind.
60    ///
61    /// Returns pid+score tuples ordered by descending score. The caller
62    /// hydrates these into full [`crate::memory::Memory`] values via
63    /// [`crate::store::MemoryStore::find_by_pids`].
64    ///
65    /// `extra_filter` AND-joins with the scope + kind filter — caller-supplied
66    /// conditions cannot widen scope. An empty filter (or `None`) is inert.
67    /// `min_similarity` sets a score floor; hits below it are dropped by the
68    /// backend before they reach the result. `None` applies no floor.
69    ///
70    /// # Errors
71    ///
72    /// Returns [`VectorError::Connection`] on backend errors and
73    /// [`VectorError::BadRequest`] when the query vector's dimension does
74    /// not match the collection's.
75    fn search(
76        &self,
77        scope: Scope,
78        query_embedding: Vec<f32>,
79        limit: usize,
80        kinds: KindSelector,
81        extra_filter: Option<MemoryFilter>,
82        min_similarity: Option<f32>,
83    ) -> impl Future<Output = Result<Vec<(String, f32)>, VectorError>> + Send;
84
85    /// Deletes vectors for the given pids.
86    ///
87    /// Best-effort: failures are not propagated up to user-facing requests
88    /// in the canonical Forget flow. The caller decides whether to surface
89    /// errors (e.g. reconciliation propagates; user-facing Forget logs).
90    ///
91    /// # Errors
92    ///
93    /// Returns [`VectorError::Connection`] on backend errors.
94    fn delete_by_pids(&self, pids: &[&str]) -> impl Future<Output = Result<(), VectorError>> + Send;
95
96    /// Returns every pid in the index that matches `scope`.
97    ///
98    /// Used by the reconciliation sweep's orphan-cleanup pass. Implementations
99    /// paginate internally using `page_size` and concatenate the result.
100    ///
101    /// # Errors
102    ///
103    /// Returns [`VectorError::Connection`] on backend errors.
104    fn list_pids_in_scope(
105        &self,
106        scope: Scope,
107        page_size: usize,
108    ) -> impl Future<Output = Result<Vec<String>, VectorError>> + Send;
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use std::collections::HashMap;
115    use std::sync::Mutex;
116
117    #[derive(Default)]
118    struct StubIndex {
119        points: Mutex<HashMap<String, (Scope, MemoryKind, Vec<f32>)>>,
120    }
121
122    impl VectorIndex for StubIndex {
123        async fn ensure_collection(&self, _vector_dim: usize) -> Result<(), VectorError> {
124            Ok(())
125        }
126
127        async fn upsert(&self, memory: &Memory, vector: Vec<f32>) -> Result<(), VectorError> {
128            self.points
129                .lock()
130                .unwrap()
131                .insert(memory.pid.clone(), (memory.scope.clone(), memory.kind, vector));
132            Ok(())
133        }
134
135        async fn search(
136            &self,
137            _scope: Scope,
138            _query_embedding: Vec<f32>,
139            limit: usize,
140            _kinds: KindSelector,
141            _extra_filter: Option<MemoryFilter>,
142            _min_similarity: Option<f32>,
143        ) -> Result<Vec<(String, f32)>, VectorError> {
144            Ok(self
145                .points
146                .lock()
147                .unwrap()
148                .keys()
149                .take(limit)
150                .map(|pid| (pid.clone(), 0.5))
151                .collect())
152        }
153
154        async fn delete_by_pids(&self, pids: &[&str]) -> Result<(), VectorError> {
155            let mut points = self.points.lock().unwrap();
156            for pid in pids {
157                points.remove(*pid);
158            }
159            Ok(())
160        }
161
162        async fn list_pids_in_scope(&self, scope: Scope, _page_size: usize) -> Result<Vec<String>, VectorError> {
163            Ok(self
164                .points
165                .lock()
166                .unwrap()
167                .iter()
168                .filter(|(_, (s, _, _))| s == &scope)
169                .map(|(pid, _)| pid.clone())
170                .collect())
171        }
172    }
173
174    #[tokio::test(flavor = "current_thread")]
175    async fn should_implement_trait_with_in_test_stub() {
176        use chrono::Utc;
177
178        let index = StubIndex::default();
179        let scope = Scope {
180            agent_id: "a".to_string(),
181            org_id: "o".to_string(),
182            user_id: "u".to_string(),
183        };
184        let now: chrono::DateTime<chrono::FixedOffset> = Utc::now().into();
185        let memory = Memory {
186            pid: "pid1".to_string(),
187            scope: scope.clone(),
188            content: "hello".to_string(),
189            metadata: serde_json::json!({}),
190            kind: MemoryKind::Episodic,
191            source_pid: None,
192            supersession: None,
193            created_at: now,
194            updated_at: now,
195            event_at: None,
196            score: None,
197            status: crate::store::IndexStatus::Pending,
198            confidence: crate::memory::Confidence::default(),
199            category: None,
200            retirement: None,
201        };
202
203        index.ensure_collection(4).await.unwrap();
204        index.upsert(&memory, vec![0.1, 0.2, 0.3, 0.4]).await.unwrap();
205
206        let hits = index
207            .search(scope, vec![0.1, 0.2, 0.3, 0.4], 5, KindSelector::default(), None, None)
208            .await
209            .unwrap();
210        assert_eq!(hits.len(), 1);
211        assert_eq!(hits[0].0, "pid1");
212
213        index.delete_by_pids(&["pid1"]).await.unwrap();
214    }
215}