use std::collections::HashMap;
use qdrant_client::Qdrant;
use qdrant_client::qdrant::{
Condition, CreateCollectionBuilder, DeletePointsBuilder, Distance, Filter, PointStruct, QueryPointsBuilder,
ScrollPointsBuilder, UpsertPointsBuilder, Value, VectorParamsBuilder,
};
use uuid::Uuid;
use super::{MemoryFilter, VectorError, VectorIndex};
use crate::memory::{KindSelector, Memory, Scope};
const DEFAULT_COLLECTION: &str = "memoir_memories";
const PID_PAYLOAD_KEY: &str = "pid";
const CREATED_AT_PAYLOAD_KEY: &str = "created_at";
const EVENT_AT_PAYLOAD_KEY: &str = "event_at";
const CONFIDENCE_PAYLOAD_KEY: &str = "confidence";
const CATEGORY_PAYLOAD_KEY: &str = "category";
pub(crate) const RESERVED_PAYLOAD_KEYS: &[&str] = &[
PID_PAYLOAD_KEY,
"agent_id",
"org_id",
"user_id",
"kind",
CREATED_AT_PAYLOAD_KEY,
EVENT_AT_PAYLOAD_KEY,
CONFIDENCE_PAYLOAD_KEY,
CATEGORY_PAYLOAD_KEY,
];
#[derive(Clone)]
pub struct QdrantIndex {
qdrant: Qdrant,
collection: String,
}
impl std::fmt::Debug for QdrantIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("QdrantIndex")
.field("collection", &self.collection)
.finish_non_exhaustive()
}
}
impl QdrantIndex {
pub fn new(qdrant: Qdrant) -> Self {
Self {
qdrant,
collection: DEFAULT_COLLECTION.to_string(),
}
}
pub fn connect(url: impl Into<String>) -> Result<Self, VectorError> {
let qdrant = Qdrant::from_url(&url.into())
.build()
.map_err(|err| VectorError::Connection(err.to_string()))?;
Ok(Self::new(qdrant))
}
pub fn with_collection(mut self, collection: impl Into<String>) -> Self {
self.collection = collection.into();
self
}
pub fn collection_name(&self) -> &str {
&self.collection
}
}
impl VectorIndex for QdrantIndex {
async fn ensure_collection(&self, vector_dim: usize) -> Result<(), VectorError> {
let exists = self
.qdrant
.collection_exists(&self.collection)
.await
.map_err(connection)?;
if exists {
return Ok(());
}
self.qdrant
.create_collection(
CreateCollectionBuilder::new(&self.collection)
.vectors_config(VectorParamsBuilder::new(vector_dim as u64, Distance::Cosine)),
)
.await
.map_err(connection)?;
Ok(())
}
async fn upsert(&self, memory: &Memory, vector: Vec<f32>) -> Result<(), VectorError> {
self.delete_by_pids(&[&memory.pid]).await?;
let mut payload: HashMap<String, Value> = HashMap::new();
payload.insert(PID_PAYLOAD_KEY.to_string(), Value::from(memory.pid.clone()));
payload.insert("agent_id".to_string(), Value::from(memory.scope.agent_id.clone()));
payload.insert("org_id".to_string(), Value::from(memory.scope.org_id.clone()));
payload.insert("user_id".to_string(), Value::from(memory.scope.user_id.clone()));
payload.insert("kind".to_string(), Value::from(memory.kind.to_string()));
payload.insert(
CREATED_AT_PAYLOAD_KEY.to_string(),
Value::from(memory.created_at.timestamp_millis()),
);
if let Some(event_at) = memory.event_at {
payload.insert(
EVENT_AT_PAYLOAD_KEY.to_string(),
Value::from(event_at.timestamp_millis()),
);
}
payload.insert(
CONFIDENCE_PAYLOAD_KEY.to_string(),
Value::from(i64::from(memory.confidence.get())),
);
if let Some(category) = &memory.category {
payload.insert(CATEGORY_PAYLOAD_KEY.to_string(), Value::from(category.clone()));
}
if let Some(obj) = memory.metadata.as_object() {
for (k, v) in obj {
if RESERVED_PAYLOAD_KEYS.iter().any(|reserved| reserved == k) {
continue;
}
payload.insert(k.clone(), Value::from(v.clone()));
}
}
let point = PointStruct::new(Uuid::new_v4().to_string(), vector, payload);
self.qdrant
.upsert_points(UpsertPointsBuilder::new(&self.collection, vec![point]))
.await
.map_err(connection)?;
Ok(())
}
async fn search(
&self,
scope: Scope,
query_embedding: Vec<f32>,
limit: usize,
kinds: KindSelector,
extra_filter: Option<MemoryFilter>,
min_similarity: Option<f32>,
) -> Result<Vec<(String, f32)>, VectorError> {
if kinds.is_empty() {
return Ok(Vec::new());
}
let mut must = vec![
Condition::matches("agent_id", scope.agent_id),
Condition::matches("org_id", scope.org_id),
Condition::matches("user_id", scope.user_id),
];
if !kinds.includes_all() {
let names: Vec<String> = kinds.included_kinds().into_iter().map(|k| k.to_string()).collect();
must.push(Condition::matches("kind", names));
}
let mut must_not = Vec::new();
let mut should = Vec::new();
if let Some(extra) = extra_filter {
let translated: Filter = extra.into();
must.extend(translated.must);
must_not.extend(translated.must_not);
should.extend(translated.should);
}
let filter = Filter {
must,
must_not,
should,
min_should: None,
};
let mut request = QueryPointsBuilder::new(&self.collection)
.query(query_embedding)
.limit(limit as u64)
.filter(filter)
.with_payload(true);
if let Some(threshold) = min_similarity {
request = request.score_threshold(threshold);
}
let response = self.qdrant.query(request).await.map_err(connection)?;
let mut hits = Vec::with_capacity(response.result.len());
for scored in response.result {
if let Some(pid) = pid_from_payload(&scored.payload) {
hits.push((pid, scored.score));
}
}
Ok(hits)
}
async fn delete_by_pids(&self, pids: &[&str]) -> Result<(), VectorError> {
if pids.is_empty() {
return Ok(());
}
let conditions: Vec<Condition> = pids
.iter()
.map(|p| Condition::matches(PID_PAYLOAD_KEY, (*p).to_string()))
.collect();
let filter = Filter::should(conditions);
self.qdrant
.delete_points(DeletePointsBuilder::new(&self.collection).points(filter))
.await
.map_err(connection)?;
Ok(())
}
async fn list_pids_in_scope(&self, scope: Scope, page_size: usize) -> Result<Vec<String>, VectorError> {
let filter = Filter::must(vec![
Condition::matches("agent_id", scope.agent_id),
Condition::matches("org_id", scope.org_id),
Condition::matches("user_id", scope.user_id),
]);
let mut pids = Vec::new();
let mut offset: Option<qdrant_client::qdrant::PointId> = None;
loop {
let mut request = ScrollPointsBuilder::new(&self.collection)
.filter(filter.clone())
.limit(page_size as u32)
.with_payload(true)
.with_vectors(false);
if let Some(o) = offset.take() {
request = request.offset(o);
}
let response = self.qdrant.scroll(request).await.map_err(connection)?;
for point in response.result {
if let Some(pid) = pid_from_payload(&point.payload) {
pids.push(pid);
}
}
match response.next_page_offset {
Some(next) => offset = Some(next),
None => break,
}
}
Ok(pids)
}
}
fn connection<E: std::fmt::Display>(err: E) -> VectorError {
VectorError::Connection(err.to_string())
}
fn pid_from_payload(payload: &HashMap<String, Value>) -> Option<String> {
payload
.get(PID_PAYLOAD_KEY)
.and_then(|v| v.as_str().map(|s| s.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn should_extract_pid_from_payload_when_present() {
let payload = HashMap::from([(PID_PAYLOAD_KEY.to_string(), Value::from("my-pid".to_string()))]);
assert_eq!(pid_from_payload(&payload), Some("my-pid".to_string()));
}
#[test]
fn should_return_none_when_pid_absent_from_payload() {
let payload = HashMap::from([("other".to_string(), Value::from("x".to_string()))]);
assert_eq!(pid_from_payload(&payload), None);
}
#[test]
fn should_return_none_when_pid_value_is_not_a_string() {
let payload = HashMap::from([(PID_PAYLOAD_KEY.to_string(), Value::from(42i64))]);
assert_eq!(pid_from_payload(&payload), None);
}
}