Skip to main content

memoir_core/vector/
qdrant.rs

1//! [`VectorIndex`] implementation backed by Qdrant.
2//!
3//! Qdrant only accepts `u64` or UUID values for point IDs (`PointId`), but
4//! memoir-core pids are nanoid strings — incompatible. Each upsert generates
5//! a fresh UUIDv4 as the point ID and stores the memoir pid in the point's
6//! payload under the `pid` key. Search, scroll, and delete paths all
7//! resolve the memoir pid via the payload; the UUID point ID is an
8//! implementation detail nobody outside this module sees.
9
10use std::collections::HashMap;
11
12use qdrant_client::Qdrant;
13use qdrant_client::qdrant::{
14    Condition, CreateCollectionBuilder, DeletePointsBuilder, Distance, Filter, PointStruct, QueryPointsBuilder,
15    ScrollPointsBuilder, UpsertPointsBuilder, Value, VectorParamsBuilder,
16};
17use uuid::Uuid;
18
19use super::{MemoryFilter, VectorError, VectorIndex};
20use crate::memory::{KindSelector, Memory, Scope};
21
22const DEFAULT_COLLECTION: &str = "memoir_memories";
23
24/// Payload key under which each point stores its memoir pid.
25const PID_PAYLOAD_KEY: &str = "pid";
26
27/// Payload key for wall-clock write time, encoded as i64 epoch milliseconds.
28///
29/// Filterable via [`super::FilterCondition::Range`] in millisecond units.
30/// Matches the encoding used elsewhere in the polypixel template (verified
31/// against rig-service's `models/messages.rs:139` use of `timestamp_millis`).
32const CREATED_AT_PAYLOAD_KEY: &str = "created_at";
33
34/// Payload key for event time, encoded as i64 epoch milliseconds.
35///
36/// Omitted entirely (not written as null) when the memory has no event-time
37/// known. Range filters against this key implicitly exclude memories whose
38/// event-time is unknown — the desired semantics for "find memories from
39/// last week" (memories without event-time can't satisfy the constraint).
40const EVENT_AT_PAYLOAD_KEY: &str = "event_at";
41
42/// Payload key for the memory's confidence, encoded as an i64 percentage 0-100.
43///
44/// Always present (every row has a confidence). Filterable via
45/// [`super::FilterCondition::Range`] — e.g. "only rows >= 80" for the
46/// selection layer (epic 0011).
47const CONFIDENCE_PAYLOAD_KEY: &str = "confidence";
48
49/// Payload key for the memory's category label.
50///
51/// Omitted entirely (not written as null) when the row has no category yet —
52/// an equality filter against this key implicitly excludes uncategorized
53/// rows, matching the `event_at` "missing key fails to match" semantics.
54const CATEGORY_PAYLOAD_KEY: &str = "category";
55
56/// Payload keys owned by memoir-core; consumer metadata cannot use these.
57///
58/// The memory's `metadata` JSON is flattened to top-level payload keys so
59/// caller-supplied [`super::FilterCondition`] entries can match against
60/// metadata fields directly (e.g. `field: "role"` matches `metadata.role`).
61/// Reserved keys are protected from clobbering by validation at write time
62/// — see [`crate::store::MemoryStore::remember`] / the remember client
63/// path — so callers can't smuggle a `pid` or scope value in via metadata.
64pub(crate) const RESERVED_PAYLOAD_KEYS: &[&str] = &[
65    PID_PAYLOAD_KEY,
66    "agent_id",
67    "org_id",
68    "user_id",
69    "kind",
70    CREATED_AT_PAYLOAD_KEY,
71    EVENT_AT_PAYLOAD_KEY,
72    CONFIDENCE_PAYLOAD_KEY,
73    CATEGORY_PAYLOAD_KEY,
74];
75
76/// Default [`VectorIndex`] backed by Qdrant.
77///
78/// Constructed via [`Self::new`]. Collection name defaults to
79/// `memoir_memories`; override with [`Self::with_collection`].
80#[derive(Clone)]
81pub struct QdrantIndex {
82    qdrant: Qdrant,
83    collection: String,
84}
85
86impl std::fmt::Debug for QdrantIndex {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        f.debug_struct("QdrantIndex")
89            .field("collection", &self.collection)
90            .finish_non_exhaustive()
91    }
92}
93
94impl QdrantIndex {
95    /// Builds an index from an existing Qdrant client.
96    pub fn new(qdrant: Qdrant) -> Self {
97        Self {
98            qdrant,
99            collection: DEFAULT_COLLECTION.to_string(),
100        }
101    }
102
103    /// Connects to Qdrant at `url`, building the client internally.
104    ///
105    /// The URL-in entry point the [`Client`](crate::client::Client) builder uses,
106    /// so a consumer configures the vector backend with a connection string and
107    /// never names the `qdrant_client::Qdrant` type. `url` is a Qdrant gRPC
108    /// endpoint (e.g. `http://localhost:6334`).
109    ///
110    /// # Errors
111    ///
112    /// Returns [`VectorError::Connection`] if `url` is malformed or the client
113    /// cannot be constructed.
114    pub fn connect(url: impl Into<String>) -> Result<Self, VectorError> {
115        let qdrant = Qdrant::from_url(&url.into())
116            .build()
117            .map_err(|err| VectorError::Connection(err.to_string()))?;
118        Ok(Self::new(qdrant))
119    }
120
121    /// Sets the Qdrant collection name used for vector storage.
122    pub fn with_collection(mut self, collection: impl Into<String>) -> Self {
123        self.collection = collection.into();
124        self
125    }
126
127    /// Returns the Qdrant collection name configured for this index.
128    pub fn collection_name(&self) -> &str {
129        &self.collection
130    }
131}
132
133impl VectorIndex for QdrantIndex {
134    async fn ensure_collection(&self, vector_dim: usize) -> Result<(), VectorError> {
135        let exists = self
136            .qdrant
137            .collection_exists(&self.collection)
138            .await
139            .map_err(connection)?;
140        if exists {
141            return Ok(());
142        }
143
144        self.qdrant
145            .create_collection(
146                CreateCollectionBuilder::new(&self.collection)
147                    .vectors_config(VectorParamsBuilder::new(vector_dim as u64, Distance::Cosine)),
148            )
149            .await
150            .map_err(connection)?;
151        Ok(())
152    }
153
154    async fn upsert(&self, memory: &Memory, vector: Vec<f32>) -> Result<(), VectorError> {
155        // First delete any prior points carrying this pid in their payload,
156        // since the Qdrant point ID is a fresh UUID per upsert and won't
157        // collide with a previous write's ID.
158        self.delete_by_pids(&[&memory.pid]).await?;
159
160        // First-class payload keys. Owned by memoir-core and protected
161        // against consumer-metadata clobbering by `RESERVED_PAYLOAD_KEYS`.
162        // Timestamps are i64 epoch milliseconds, matching the polypixel
163        // template convention (rig-service `models/messages.rs:139`).
164        let mut payload: HashMap<String, Value> = HashMap::new();
165        payload.insert(PID_PAYLOAD_KEY.to_string(), Value::from(memory.pid.clone()));
166        payload.insert("agent_id".to_string(), Value::from(memory.scope.agent_id.clone()));
167        payload.insert("org_id".to_string(), Value::from(memory.scope.org_id.clone()));
168        payload.insert("user_id".to_string(), Value::from(memory.scope.user_id.clone()));
169        payload.insert("kind".to_string(), Value::from(memory.kind.to_string()));
170        payload.insert(
171            CREATED_AT_PAYLOAD_KEY.to_string(),
172            Value::from(memory.created_at.timestamp_millis()),
173        );
174        if let Some(event_at) = memory.event_at {
175            // Omit (not write null): Qdrant range filters treat missing
176            // payload keys as "fail to match", which is the right semantic
177            // for "memories with known event-time in this window."
178            payload.insert(
179                EVENT_AT_PAYLOAD_KEY.to_string(),
180                Value::from(event_at.timestamp_millis()),
181            );
182        }
183        // Confidence is always present (every row has one). Stored as i64 so
184        // it filters via the same numeric Range path as the timestamps.
185        payload.insert(
186            CONFIDENCE_PAYLOAD_KEY.to_string(),
187            Value::from(i64::from(memory.confidence.get())),
188        );
189        if let Some(category) = &memory.category {
190            // Omit when absent (same rationale as event_at): an equality
191            // filter on category should exclude not-yet-categorized rows.
192            payload.insert(CATEGORY_PAYLOAD_KEY.to_string(), Value::from(category.clone()));
193        }
194
195        // Flatten metadata's top-level object into the payload alongside
196        // the first-class keys. Reserved-key collisions are prevented by
197        // validation at the write boundary (Client::remember /
198        // RememberBuilder); reaching this code with a colliding key would
199        // mean a bug upstream, so we drop the colliding entries
200        // defensively rather than panicking. The `From<serde_json::Value>`
201        // impl on qdrant_client `Value` handles every JSON variant.
202        if let Some(obj) = memory.metadata.as_object() {
203            for (k, v) in obj {
204                if RESERVED_PAYLOAD_KEYS.iter().any(|reserved| reserved == k) {
205                    continue;
206                }
207                payload.insert(k.clone(), Value::from(v.clone()));
208            }
209        }
210
211        let point = PointStruct::new(Uuid::new_v4().to_string(), vector, payload);
212
213        self.qdrant
214            .upsert_points(UpsertPointsBuilder::new(&self.collection, vec![point]))
215            .await
216            .map_err(connection)?;
217
218        Ok(())
219    }
220
221    async fn search(
222        &self,
223        scope: Scope,
224        query_embedding: Vec<f32>,
225        limit: usize,
226        kinds: KindSelector,
227        extra_filter: Option<MemoryFilter>,
228        min_similarity: Option<f32>,
229    ) -> Result<Vec<(String, f32)>, VectorError> {
230        if kinds.is_empty() {
231            return Ok(Vec::new());
232        }
233
234        // Scope conditions go in `must` first so an `extra_filter.must` cannot
235        // accidentally widen scope: a caller-supplied `must` adds to AND, not
236        // replaces. A caller-supplied `must_not` on `agent_id` (or any scope
237        // field) would only narrow further, not widen — Qdrant evaluates
238        // `must AND NOT must_not`.
239        let mut must = vec![
240            Condition::matches("agent_id", scope.agent_id),
241            Condition::matches("org_id", scope.org_id),
242            Condition::matches("user_id", scope.user_id),
243        ];
244        if !kinds.includes_all() {
245            let names: Vec<String> = kinds.included_kinds().into_iter().map(|k| k.to_string()).collect();
246            must.push(Condition::matches("kind", names));
247        }
248
249        let mut must_not = Vec::new();
250        let mut should = Vec::new();
251        if let Some(extra) = extra_filter {
252            let translated: Filter = extra.into();
253            must.extend(translated.must);
254            must_not.extend(translated.must_not);
255            should.extend(translated.should);
256        }
257
258        let filter = Filter {
259            must,
260            must_not,
261            should,
262            min_should: None,
263        };
264
265        let mut request = QueryPointsBuilder::new(&self.collection)
266            .query(query_embedding)
267            .limit(limit as u64)
268            .filter(filter)
269            .with_payload(true);
270        if let Some(threshold) = min_similarity {
271            request = request.score_threshold(threshold);
272        }
273
274        let response = self.qdrant.query(request).await.map_err(connection)?;
275
276        let mut hits = Vec::with_capacity(response.result.len());
277        for scored in response.result {
278            if let Some(pid) = pid_from_payload(&scored.payload) {
279                hits.push((pid, scored.score));
280            }
281        }
282        Ok(hits)
283    }
284
285    async fn delete_by_pids(&self, pids: &[&str]) -> Result<(), VectorError> {
286        if pids.is_empty() {
287            return Ok(());
288        }
289
290        // Pids live in payload, not in the point ID, so delete by payload
291        // filter. Each pid translates to a `match` condition; the wrapper
292        // `Filter::should` (logical OR) covers a batch of pids in one call.
293        let conditions: Vec<Condition> = pids
294            .iter()
295            .map(|p| Condition::matches(PID_PAYLOAD_KEY, (*p).to_string()))
296            .collect();
297        let filter = Filter::should(conditions);
298
299        self.qdrant
300            .delete_points(DeletePointsBuilder::new(&self.collection).points(filter))
301            .await
302            .map_err(connection)?;
303        Ok(())
304    }
305
306    async fn list_pids_in_scope(&self, scope: Scope, page_size: usize) -> Result<Vec<String>, VectorError> {
307        let filter = Filter::must(vec![
308            Condition::matches("agent_id", scope.agent_id),
309            Condition::matches("org_id", scope.org_id),
310            Condition::matches("user_id", scope.user_id),
311        ]);
312
313        let mut pids = Vec::new();
314        let mut offset: Option<qdrant_client::qdrant::PointId> = None;
315
316        loop {
317            let mut request = ScrollPointsBuilder::new(&self.collection)
318                .filter(filter.clone())
319                .limit(page_size as u32)
320                .with_payload(true)
321                .with_vectors(false);
322            if let Some(o) = offset.take() {
323                request = request.offset(o);
324            }
325
326            let response = self.qdrant.scroll(request).await.map_err(connection)?;
327
328            for point in response.result {
329                if let Some(pid) = pid_from_payload(&point.payload) {
330                    pids.push(pid);
331                }
332            }
333
334            match response.next_page_offset {
335                Some(next) => offset = Some(next),
336                None => break,
337            }
338        }
339
340        Ok(pids)
341    }
342}
343
344fn connection<E: std::fmt::Display>(err: E) -> VectorError {
345    VectorError::Connection(err.to_string())
346}
347
348/// Extracts the memoir pid from a Qdrant point's payload, if present.
349///
350/// Returns `None` when the payload lacks a `pid` key or carries a non-string
351/// value — both should be impossible for points written via [`QdrantIndex::upsert`],
352/// but defending against malformed remote state keeps the search side robust.
353fn pid_from_payload(payload: &HashMap<String, Value>) -> Option<String> {
354    payload
355        .get(PID_PAYLOAD_KEY)
356        .and_then(|v| v.as_str().map(|s| s.to_string()))
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    #[test]
364    fn should_extract_pid_from_payload_when_present() {
365        let payload = HashMap::from([(PID_PAYLOAD_KEY.to_string(), Value::from("my-pid".to_string()))]);
366        assert_eq!(pid_from_payload(&payload), Some("my-pid".to_string()));
367    }
368
369    #[test]
370    fn should_return_none_when_pid_absent_from_payload() {
371        let payload = HashMap::from([("other".to_string(), Value::from("x".to_string()))]);
372        assert_eq!(pid_from_payload(&payload), None);
373    }
374
375    #[test]
376    fn should_return_none_when_pid_value_is_not_a_string() {
377        let payload = HashMap::from([(PID_PAYLOAD_KEY.to_string(), Value::from(42i64))]);
378        assert_eq!(pid_from_payload(&payload), None);
379    }
380}