1use std::collections::HashMap;
11
12use qdrant_client::Qdrant;
13use qdrant_client::qdrant::{
14 Condition, CreateCollectionBuilder, DeletePointsBuilder, Distance, Filter, PointStruct, QueryPointsBuilder,
15 ScrollPointsBuilder, UpsertPointsBuilder, Value, VectorParamsBuilder,
16};
17use uuid::Uuid;
18
19use super::{MemoryFilter, VectorError, VectorIndex};
20use crate::memory::{KindSelector, Memory, Scope};
21
22const DEFAULT_COLLECTION: &str = "memoir_memories";
23
24const PID_PAYLOAD_KEY: &str = "pid";
26
27const CREATED_AT_PAYLOAD_KEY: &str = "created_at";
33
34const EVENT_AT_PAYLOAD_KEY: &str = "event_at";
41
42const CONFIDENCE_PAYLOAD_KEY: &str = "confidence";
48
49const CATEGORY_PAYLOAD_KEY: &str = "category";
55
56pub(crate) const RESERVED_PAYLOAD_KEYS: &[&str] = &[
65 PID_PAYLOAD_KEY,
66 "agent_id",
67 "org_id",
68 "user_id",
69 "kind",
70 CREATED_AT_PAYLOAD_KEY,
71 EVENT_AT_PAYLOAD_KEY,
72 CONFIDENCE_PAYLOAD_KEY,
73 CATEGORY_PAYLOAD_KEY,
74];
75
76#[derive(Clone)]
81pub struct QdrantIndex {
82 qdrant: Qdrant,
83 collection: String,
84}
85
86impl std::fmt::Debug for QdrantIndex {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 f.debug_struct("QdrantIndex")
89 .field("collection", &self.collection)
90 .finish_non_exhaustive()
91 }
92}
93
94impl QdrantIndex {
95 pub fn new(qdrant: Qdrant) -> Self {
97 Self {
98 qdrant,
99 collection: DEFAULT_COLLECTION.to_string(),
100 }
101 }
102
103 pub fn connect(url: impl Into<String>) -> Result<Self, VectorError> {
115 let qdrant = Qdrant::from_url(&url.into())
116 .build()
117 .map_err(|err| VectorError::Connection(err.to_string()))?;
118 Ok(Self::new(qdrant))
119 }
120
121 pub fn with_collection(mut self, collection: impl Into<String>) -> Self {
123 self.collection = collection.into();
124 self
125 }
126
127 pub fn collection_name(&self) -> &str {
129 &self.collection
130 }
131}
132
133impl VectorIndex for QdrantIndex {
134 async fn ensure_collection(&self, vector_dim: usize) -> Result<(), VectorError> {
135 let exists = self
136 .qdrant
137 .collection_exists(&self.collection)
138 .await
139 .map_err(connection)?;
140 if exists {
141 return Ok(());
142 }
143
144 self.qdrant
145 .create_collection(
146 CreateCollectionBuilder::new(&self.collection)
147 .vectors_config(VectorParamsBuilder::new(vector_dim as u64, Distance::Cosine)),
148 )
149 .await
150 .map_err(connection)?;
151 Ok(())
152 }
153
154 async fn upsert(&self, memory: &Memory, vector: Vec<f32>) -> Result<(), VectorError> {
155 self.delete_by_pids(&[&memory.pid]).await?;
159
160 let mut payload: HashMap<String, Value> = HashMap::new();
165 payload.insert(PID_PAYLOAD_KEY.to_string(), Value::from(memory.pid.clone()));
166 payload.insert("agent_id".to_string(), Value::from(memory.scope.agent_id.clone()));
167 payload.insert("org_id".to_string(), Value::from(memory.scope.org_id.clone()));
168 payload.insert("user_id".to_string(), Value::from(memory.scope.user_id.clone()));
169 payload.insert("kind".to_string(), Value::from(memory.kind.to_string()));
170 payload.insert(
171 CREATED_AT_PAYLOAD_KEY.to_string(),
172 Value::from(memory.created_at.timestamp_millis()),
173 );
174 if let Some(event_at) = memory.event_at {
175 payload.insert(
179 EVENT_AT_PAYLOAD_KEY.to_string(),
180 Value::from(event_at.timestamp_millis()),
181 );
182 }
183 payload.insert(
186 CONFIDENCE_PAYLOAD_KEY.to_string(),
187 Value::from(i64::from(memory.confidence.get())),
188 );
189 if let Some(category) = &memory.category {
190 payload.insert(CATEGORY_PAYLOAD_KEY.to_string(), Value::from(category.clone()));
193 }
194
195 if let Some(obj) = memory.metadata.as_object() {
203 for (k, v) in obj {
204 if RESERVED_PAYLOAD_KEYS.iter().any(|reserved| reserved == k) {
205 continue;
206 }
207 payload.insert(k.clone(), Value::from(v.clone()));
208 }
209 }
210
211 let point = PointStruct::new(Uuid::new_v4().to_string(), vector, payload);
212
213 self.qdrant
214 .upsert_points(UpsertPointsBuilder::new(&self.collection, vec![point]))
215 .await
216 .map_err(connection)?;
217
218 Ok(())
219 }
220
221 async fn search(
222 &self,
223 scope: Scope,
224 query_embedding: Vec<f32>,
225 limit: usize,
226 kinds: KindSelector,
227 extra_filter: Option<MemoryFilter>,
228 min_similarity: Option<f32>,
229 ) -> Result<Vec<(String, f32)>, VectorError> {
230 if kinds.is_empty() {
231 return Ok(Vec::new());
232 }
233
234 let mut must = vec![
240 Condition::matches("agent_id", scope.agent_id),
241 Condition::matches("org_id", scope.org_id),
242 Condition::matches("user_id", scope.user_id),
243 ];
244 if !kinds.includes_all() {
245 let names: Vec<String> = kinds.included_kinds().into_iter().map(|k| k.to_string()).collect();
246 must.push(Condition::matches("kind", names));
247 }
248
249 let mut must_not = Vec::new();
250 let mut should = Vec::new();
251 if let Some(extra) = extra_filter {
252 let translated: Filter = extra.into();
253 must.extend(translated.must);
254 must_not.extend(translated.must_not);
255 should.extend(translated.should);
256 }
257
258 let filter = Filter {
259 must,
260 must_not,
261 should,
262 min_should: None,
263 };
264
265 let mut request = QueryPointsBuilder::new(&self.collection)
266 .query(query_embedding)
267 .limit(limit as u64)
268 .filter(filter)
269 .with_payload(true);
270 if let Some(threshold) = min_similarity {
271 request = request.score_threshold(threshold);
272 }
273
274 let response = self.qdrant.query(request).await.map_err(connection)?;
275
276 let mut hits = Vec::with_capacity(response.result.len());
277 for scored in response.result {
278 if let Some(pid) = pid_from_payload(&scored.payload) {
279 hits.push((pid, scored.score));
280 }
281 }
282 Ok(hits)
283 }
284
285 async fn delete_by_pids(&self, pids: &[&str]) -> Result<(), VectorError> {
286 if pids.is_empty() {
287 return Ok(());
288 }
289
290 let conditions: Vec<Condition> = pids
294 .iter()
295 .map(|p| Condition::matches(PID_PAYLOAD_KEY, (*p).to_string()))
296 .collect();
297 let filter = Filter::should(conditions);
298
299 self.qdrant
300 .delete_points(DeletePointsBuilder::new(&self.collection).points(filter))
301 .await
302 .map_err(connection)?;
303 Ok(())
304 }
305
306 async fn list_pids_in_scope(&self, scope: Scope, page_size: usize) -> Result<Vec<String>, VectorError> {
307 let filter = Filter::must(vec![
308 Condition::matches("agent_id", scope.agent_id),
309 Condition::matches("org_id", scope.org_id),
310 Condition::matches("user_id", scope.user_id),
311 ]);
312
313 let mut pids = Vec::new();
314 let mut offset: Option<qdrant_client::qdrant::PointId> = None;
315
316 loop {
317 let mut request = ScrollPointsBuilder::new(&self.collection)
318 .filter(filter.clone())
319 .limit(page_size as u32)
320 .with_payload(true)
321 .with_vectors(false);
322 if let Some(o) = offset.take() {
323 request = request.offset(o);
324 }
325
326 let response = self.qdrant.scroll(request).await.map_err(connection)?;
327
328 for point in response.result {
329 if let Some(pid) = pid_from_payload(&point.payload) {
330 pids.push(pid);
331 }
332 }
333
334 match response.next_page_offset {
335 Some(next) => offset = Some(next),
336 None => break,
337 }
338 }
339
340 Ok(pids)
341 }
342}
343
344fn connection<E: std::fmt::Display>(err: E) -> VectorError {
345 VectorError::Connection(err.to_string())
346}
347
348fn pid_from_payload(payload: &HashMap<String, Value>) -> Option<String> {
354 payload
355 .get(PID_PAYLOAD_KEY)
356 .and_then(|v| v.as_str().map(|s| s.to_string()))
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[test]
364 fn should_extract_pid_from_payload_when_present() {
365 let payload = HashMap::from([(PID_PAYLOAD_KEY.to_string(), Value::from("my-pid".to_string()))]);
366 assert_eq!(pid_from_payload(&payload), Some("my-pid".to_string()));
367 }
368
369 #[test]
370 fn should_return_none_when_pid_absent_from_payload() {
371 let payload = HashMap::from([("other".to_string(), Value::from("x".to_string()))]);
372 assert_eq!(pid_from_payload(&payload), None);
373 }
374
375 #[test]
376 fn should_return_none_when_pid_value_is_not_a_string() {
377 let payload = HashMap::from([(PID_PAYLOAD_KEY.to_string(), Value::from(42i64))]);
378 assert_eq!(pid_from_payload(&payload), None);
379 }
380}