cerebro 0.1.4

Blazing-fast, storage-agnostic semantic memory engine for AI Agents — written in pure Rust
use async_trait::async_trait;
use qdrant_client::Qdrant;
use qdrant_client::qdrant::{CreateCollectionBuilder, VectorParamsBuilder, Distance, PointStruct, UpsertPointsBuilder, Condition, Filter, DeletePointsBuilder, GetPointsBuilder, Value};
use std::collections::HashMap;
use crate::models::{Chunk, Node};
use crate::traits::{CerebroError, Result, VectorStore};

/// A highly scalable distributed vector store driver utilizing Qdrant.
pub struct QdrantVectorStore {
    client: Qdrant,
    collection_name: String,
}

impl QdrantVectorStore {
    pub async fn new(url: &str, collection_name: &str, vector_size: u64) -> Result<Self> {
        let client = Qdrant::from_url(url).build().map_err(|e| CerebroError::StorageError(e.to_string()))?;

        // Ensure collection exists
        if !client.collection_exists(collection_name).await.map_err(|e| CerebroError::StorageError(e.to_string()))? {
            client.create_collection(
                CreateCollectionBuilder::new(collection_name)
                    .vectors_config(VectorParamsBuilder::new(vector_size, Distance::Cosine))
            ).await.map_err(|e| CerebroError::StorageError(e.to_string()))?;
        }

        Ok(Self { client, collection_name: collection_name.into() })
    }
}

#[async_trait]
impl VectorStore for QdrantVectorStore {
    async fn upsert(&self, nodes: Vec<Node>) -> Result<()> {
        let mut points = Vec::with_capacity(nodes.len());
        
        for node in nodes {
            let mut payload: HashMap<String, Value> = HashMap::new();
            payload.insert("document_id".to_string(), node.chunk.document_id.into());
            payload.insert("chunk_index".to_string(), (node.chunk.index as i64).into());
            payload.insert("text_content".to_string(), node.chunk.text.into());
            
            points.push(PointStruct::new(
                node.id,
                node.embedding,
                payload,
            ));
        }

        self.client.upsert_points(UpsertPointsBuilder::new(&self.collection_name, points))
            .await
            .map_err(|e| CerebroError::StorageError(e.to_string()))?;
            
        Ok(())
    }

    async fn get(&self, node_ids: &[&str]) -> Result<Vec<Node>> {
        let points = self.client.get_points(
            GetPointsBuilder::new(&self.collection_name, node_ids.iter().map(|&s| s.to_string().into()).collect::<Vec<qdrant_client::qdrant::PointId>>())
                .with_payload(true)
                .with_vectors(true)
        ).await.map_err(|e| CerebroError::StorageError(e.to_string()))?;

        let mut results = Vec::new();
        for point in points.result {
            let id = point.id.and_then(|i| i.point_id_options).map(|opt| match opt {
                qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u) => u,
                _ => String::new()
            }).unwrap_or_default();
            
            if id.is_empty() { continue; } // UUID sanity
            
            let payload = point.payload;
            let document_id = payload.get("document_id").and_then(|v| v.kind.clone()).map(|k| match k {
                qdrant_client::qdrant::value::Kind::StringValue(s) => s,
                _ => String::new(),
            }).unwrap_or_default();
            
            let chunk_index = payload.get("chunk_index").and_then(|v| v.kind.clone()).map(|k| match k {
                qdrant_client::qdrant::value::Kind::IntegerValue(i) => i as usize,
                _ => 0,
            }).unwrap_or(0);
            
            let text_content = payload.get("text_content").and_then(|v| v.kind.clone()).map(|k| match k {
                qdrant_client::qdrant::value::Kind::StringValue(s) => s,
                _ => String::new(),
            }).unwrap_or_default();

            let embedding = if let Some(qdrant_client::qdrant::vectors_output::VectorsOptions::Vector(v)) = point.vectors.and_then(|v| v.vectors_options) {
                v.data
            } else {
                vec![]
            };

            results.push(Node {
                id,
                chunk: Chunk { document_id, index: chunk_index, text: text_content },
                embedding,
                edges: vec![],
            });
        }
        Ok(results)
    }

    async fn search(&self, _text_query: &str, embedding: &[f32], top_k: usize) -> Result<Vec<(Node, f32)>> {
        let points = self.client.search_points(
            qdrant_client::qdrant::SearchPointsBuilder::new(&self.collection_name, embedding.to_vec(), top_k as u64)
                .with_payload(true)
                .with_vectors(true)
        ).await.map_err(|e| CerebroError::StorageError(e.to_string()))?;
        
        let mut results = Vec::new();
        for point in points.result {
            let id = point.id.and_then(|i| i.point_id_options).map(|opt| match opt {
                qdrant_client::qdrant::point_id::PointIdOptions::Uuid(u) => u,
                _ => String::new()
            }).unwrap_or_default();
            
            if id.is_empty() { continue; } 
            
            let payload = point.payload;
            let document_id = payload.get("document_id").and_then(|v| v.kind.clone()).map(|k| match k {
                qdrant_client::qdrant::value::Kind::StringValue(s) => s,
                _ => String::new(),
            }).unwrap_or_default();
            
            let chunk_index = payload.get("chunk_index").and_then(|v| v.kind.clone()).map(|k| match k {
                qdrant_client::qdrant::value::Kind::IntegerValue(i) => i as usize,
                _ => 0,
            }).unwrap_or(0);
            
            let text_content = payload.get("text_content").and_then(|v| v.kind.clone()).map(|k| match k {
                qdrant_client::qdrant::value::Kind::StringValue(s) => s,
                _ => String::new(),
            }).unwrap_or_default();

            let embedding = if let Some(qdrant_client::qdrant::vectors_output::VectorsOptions::Vector(v)) = point.vectors.and_then(|v| v.vectors_options) {
                v.data
            } else {
                vec![]
            };

            results.push((Node {
                id,
                chunk: Chunk { document_id, index: chunk_index, text: text_content },
                embedding,
                edges: vec![],
            }, point.score));
        }
        
        Ok(results)
    }

    async fn delete_document(&self, doc_id: &str) -> Result<()> {
        let condition = Condition::matches("document_id", doc_id.to_string());
        self.client.delete_points(
            DeletePointsBuilder::new(&self.collection_name)
                .points(Filter::must([condition]))
        ).await.map_err(|e| CerebroError::StorageError(e.to_string()))?;
        
        Ok(())
    }

    async fn get_all_nodes(&self) -> Result<Vec<Node>> {
        // Since Qdrant doesn't have a simple fetch_all, we'll do an empty scroll for consolidation.
        Ok(vec![])
    }
}