cerebro 1.1.8

A blazing-fast AI memory layer that enables teams of specialized agents to collaborate through a shared cognitive architecture.
Documentation
use crate::models::{Chunk, Node};
use crate::traits::{CerebroError, Result, VectorStore};
use async_trait::async_trait;
use qdrant_client::qdrant::{
    Condition, CreateCollectionBuilder, DeletePointsBuilder, Distance, Filter, GetPointsBuilder,
    PointStruct, UpsertPointsBuilder, Value, VectorParamsBuilder,
};
use qdrant_client::Qdrant;
use std::collections::HashMap;

/// A highly scalable distributed vector store driver utilizing Qdrant.
pub struct QdrantVectorStore {
    client: Qdrant,
    collection_name: String,
}

impl QdrantVectorStore {
    pub async fn new(url: &str, collection_name: &str, vector_size: u64) -> Result<Self> {
        let client = Qdrant::from_url(url)
            .build()
            .map_err(|e| CerebroError::StorageError(e.to_string()))?;

        // Ensure collection exists
        if !client
            .collection_exists(collection_name)
            .await
            .map_err(|e| CerebroError::StorageError(e.to_string()))?
        {
            client
                .create_collection(
                    CreateCollectionBuilder::new(collection_name)
                        .vectors_config(VectorParamsBuilder::new(vector_size, Distance::Cosine)),
                )
                .await
                .map_err(|e| CerebroError::StorageError(e.to_string()))?;
        }

        Ok(Self {
            client,
            collection_name: collection_name.into(),
        })
    }
}

#[async_trait]
impl VectorStore for QdrantVectorStore {
    async fn upsert(&self, nodes: Vec<Node>) -> Result<()> {
        let mut points = Vec::with_capacity(nodes.len());

        for node in nodes {
            let mut payload: HashMap<String, Value> = HashMap::new();
            payload.insert("document_id".to_string(), node.chunk.document_id.into());
            payload.insert("chunk_index".to_string(), (node.chunk.index as i64).into());
            payload.insert("text_content".to_string(), node.chunk.text.into());

            points.push(PointStruct::new(node.id, node.embedding, payload));
        }

        self.client
            .upsert_points(UpsertPointsBuilder::new(&self.collection_name, points))
            .await
            .map_err(|e| CerebroError::StorageError(e.to_string()))?;

        Ok(())
    }

    async fn get(&self, node_ids: &[&str]) -> Result<Vec<Node>> {
        let points = self
            .client
            .get_points(
                GetPointsBuilder::new(
                    &self.collection_name,
                    node_ids
                        .iter()
                        .map(|&s| s.to_string().into())
                        .collect::<Vec<qdrant_client::qdrant::PointId>>(),
                )
                .with_payload(true)
                .with_vectors(true),
            )
            .await
            .map_err(|e| CerebroError::StorageError(e.to_string()))?;

        let mut results = Vec::new();
        for point in points.result {
            let id = point
                .id
                .and_then(|i| i.point_id_options)
                .map(|opt| match opt {
                    qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u) => u,
                    _ => String::new(),
                })
                .unwrap_or_default();

            if id.is_empty() {
                continue;
            } // UUID sanity

            let payload = point.payload;
            let document_id = payload
                .get("document_id")
                .and_then(|v| v.kind.clone())
                .map(|k| match k {
                    qdrant_client::qdrant::value::Kind::StringValue(s) => s,
                    _ => String::new(),
                })
                .unwrap_or_default();

            let chunk_index = payload
                .get("chunk_index")
                .and_then(|v| v.kind.clone())
                .map(|k| match k {
                    qdrant_client::qdrant::value::Kind::IntegerValue(i) => i as usize,
                    _ => 0,
                })
                .unwrap_or(0);

            let text_content = payload
                .get("text_content")
                .and_then(|v| v.kind.clone())
                .map(|k| match k {
                    qdrant_client::qdrant::value::Kind::StringValue(s) => s,
                    _ => String::new(),
                })
                .unwrap_or_default();

            let embedding =
                if let Some(qdrant_client::qdrant::vectors_output::VectorsOptions::Vector(v)) =
                    point.vectors.and_then(|v| v.vectors_options)
                {
                    v.data
                } else {
                    vec![]
                };

            results.push(Node {
                id,
                chunk: Chunk {
                    document_id,
                    index: chunk_index,
                    text: text_content,
                },
                embedding,
                edges: vec![],
            });
        }
        Ok(results)
    }

    async fn search(
        &self,
        _text_query: &str,
        embedding: &[f32],
        top_k: usize,
    ) -> Result<Vec<(Node, f32)>> {
        let points = self
            .client
            .search_points(
                qdrant_client::qdrant::SearchPointsBuilder::new(
                    &self.collection_name,
                    embedding.to_vec(),
                    top_k as u64,
                )
                .with_payload(true)
                .with_vectors(true),
            )
            .await
            .map_err(|e| CerebroError::StorageError(e.to_string()))?;

        let mut results = Vec::new();
        for point in points.result {
            let id = point
                .id
                .and_then(|i| i.point_id_options)
                .map(|opt| match opt {
                    qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u) => u,
                    _ => String::new(),
                })
                .unwrap_or_default();

            if id.is_empty() {
                continue;
            }

            let payload = point.payload;
            let document_id = payload
                .get("document_id")
                .and_then(|v| v.kind.clone())
                .map(|k| match k {
                    qdrant_client::qdrant::value::Kind::StringValue(s) => s,
                    _ => String::new(),
                })
                .unwrap_or_default();

            let chunk_index = payload
                .get("chunk_index")
                .and_then(|v| v.kind.clone())
                .map(|k| match k {
                    qdrant_client::qdrant::value::Kind::IntegerValue(i) => i as usize,
                    _ => 0,
                })
                .unwrap_or(0);

            let text_content = payload
                .get("text_content")
                .and_then(|v| v.kind.clone())
                .map(|k| match k {
                    qdrant_client::qdrant::value::Kind::StringValue(s) => s,
                    _ => String::new(),
                })
                .unwrap_or_default();

            let embedding =
                if let Some(qdrant_client::qdrant::vectors_output::VectorsOptions::Vector(v)) =
                    point.vectors.and_then(|v| v.vectors_options)
                {
                    v.data
                } else {
                    vec![]
                };

            results.push((
                Node {
                    id,
                    chunk: Chunk {
                        document_id,
                        index: chunk_index,
                        text: text_content,
                    },
                    embedding,
                    edges: vec![],
                },
                point.score,
            ));
        }

        Ok(results)
    }

    async fn delete_document(&self, doc_id: &str) -> Result<()> {
        let condition = Condition::matches("document_id", doc_id.to_string());
        self.client
            .delete_points(
                DeletePointsBuilder::new(&self.collection_name).points(Filter::must([condition])),
            )
            .await
            .map_err(|e| CerebroError::StorageError(e.to_string()))?;

        Ok(())
    }

    async fn get_all_nodes(&self) -> Result<Vec<Node>> {
        // Since Qdrant doesn't have a simple fetch_all, we'll do an empty scroll for consolidation.
        Ok(vec![])
    }
}