cerebro 1.1.8

A blazing-fast AI memory layer that enables teams of specialized agents to collaborate through a shared cognitive architecture.
Documentation
use crate::models::{Chunk, Node};
use crate::traits::{CerebroError, Result, VectorStore};
use async_trait::async_trait;
use pgvector::Vector;
use sqlx::{postgres::PgPoolOptions, Pool, Postgres, Row};

/// A highly scalable PostgreSQL vector store utilizing the pgvector extension.
pub struct PgVectorStore {
    pool: Pool<Postgres>,
}

impl PgVectorStore {
    pub async fn new(database_url: &str) -> Result<Self> {
        let pool = PgPoolOptions::new()
            .max_connections(5)
            .connect(database_url)
            .await
            .map_err(|e| CerebroError::StorageError(e.to_string()))?;

        sqlx::query("CREATE EXTENSION IF NOT EXISTS vector")
            .execute(&pool)
            .await
            .map_err(|e| CerebroError::StorageError(e.to_string()))?;

        sqlx::query(
            "CREATE TABLE IF NOT EXISTS cerebro_nodes (
                id UUID PRIMARY KEY,
                document_id TEXT NOT NULL,
                chunk_index INT NOT NULL,
                text_content TEXT NOT NULL,
                embedding VECTOR(1536),
                ts tsvector GENERATED ALWAYS AS (to_tsvector('english', text_content)) STORED
            )",
        )
        .execute(&pool)
        .await
        .map_err(|e| CerebroError::StorageError(e.to_string()))?;

        Ok(Self { pool })
    }
}

#[async_trait]
impl VectorStore for PgVectorStore {
    async fn upsert(&self, nodes: Vec<Node>) -> Result<()> {
        let mut tx = self
            .pool
            .begin()
            .await
            .map_err(|e| CerebroError::StorageError(e.to_string()))?;

        for node in nodes {
            let vec_embedding = Vector::from(node.embedding);
            let uuid = uuid::Uuid::parse_str(&node.id).unwrap_or_else(|_| uuid::Uuid::new_v4());

            sqlx::query(
                "INSERT INTO cerebro_nodes (id, document_id, chunk_index, text_content, embedding)
                 VALUES ($1, $2, $3, $4, $5)
                 ON CONFLICT (id) DO UPDATE SET
                 text_content = EXCLUDED.text_content,
                 embedding = EXCLUDED.embedding",
            )
            .bind(uuid)
            .bind(&node.chunk.document_id)
            .bind(node.chunk.index as i32)
            .bind(&node.chunk.text)
            .bind(vec_embedding)
            .execute(&mut *tx)
            .await
            .map_err(|e| CerebroError::StorageError(e.to_string()))?;
        }

        tx.commit()
            .await
            .map_err(|e| CerebroError::StorageError(e.to_string()))?;
        Ok(())
    }

    async fn get(&self, node_ids: &[&str]) -> Result<Vec<Node>> {
        let mut result = Vec::new();
        for id in node_ids {
            if let Ok(uuid) = uuid::Uuid::parse_str(id) {
                let row = sqlx::query("SELECT document_id, chunk_index, text_content, embedding FROM cerebro_nodes WHERE id = $1")
                    .bind(uuid)
                    .fetch_optional(&self.pool).await
                    .map_err(|e| CerebroError::StorageError(e.to_string()))?;

                if let Some(r) = row {
                    let vec: Vector = r.get("embedding");
                    result.push(Node {
                        id: id.to_string(),
                        chunk: Chunk {
                            document_id: r.get("document_id"),
                            index: r.get::<i32, _>("chunk_index") as usize,
                            text: r.get("text_content"),
                        },
                        embedding: vec.to_vec(),
                        edges: vec![],
                    });
                }
            }
        }
        Ok(result)
    }

    async fn search(
        &self,
        text_query: &str,
        embedding: &[f32],
        top_k: usize,
    ) -> Result<Vec<(Node, f32)>> {
        let query_vec = Vector::from(embedding.to_vec());

        let rows = sqlx::query(
            "WITH vector_search AS (
                SELECT id, ROW_NUMBER() OVER (ORDER BY embedding <=> $1) as vector_rank
                FROM cerebro_nodes
                ORDER BY embedding <=> $1
                LIMIT 100
             ),
             text_search AS (
                SELECT id, ROW_NUMBER() OVER (ORDER BY ts_rank(ts, websearch_to_tsquery('english', $2)) DESC) as text_rank
                FROM cerebro_nodes
                WHERE ts @@ websearch_to_tsquery('english', $2)
                LIMIT 100
             ),
             rrf AS (
                SELECT COALESCE(v.id, t.id) as id,
                       COALESCE(1.0 / (60 + v.vector_rank), 0.0) + COALESCE(1.0 / (60 + t.text_rank), 0.0) AS rrf_score
                FROM vector_search v
                FULL OUTER JOIN text_search t ON v.id = t.id
                ORDER BY rrf_score DESC
                LIMIT $3
             )
             SELECT c.id, c.document_id, c.chunk_index, c.text_content, c.embedding, r.rrf_score
             FROM rrf r
             JOIN cerebro_nodes c ON r.id = c.id
             ORDER BY r.rrf_score DESC"
        )
        .bind(query_vec)
        .bind(text_query)
        .bind(top_k as i32)
        .fetch_all(&self.pool).await
        .map_err(|e| CerebroError::StorageError(format!("Search failed: {}", e)))?;

        let mut results = Vec::new();
        for r in rows {
            let uid: uuid::Uuid = r.get("id");
            let vec: Vector = r.get("embedding");
            let sim: f64 = r.get("rrf_score");
            results.push((
                Node {
                    id: uid.to_string(),
                    chunk: Chunk {
                        document_id: r.get("document_id"),
                        index: r.get::<i32, _>("chunk_index") as usize,
                        text: r.get("text_content"),
                    },
                    embedding: vec.to_vec(),
                    edges: vec![],
                },
                sim as f32,
            ));
        }
        Ok(results)
    }

    async fn delete_document(&self, doc_id: &str) -> Result<()> {
        sqlx::query("DELETE FROM cerebro_nodes WHERE document_id = $1")
            .bind(doc_id)
            .execute(&self.pool)
            .await
            .map_err(|e| CerebroError::StorageError(e.to_string()))?;
        Ok(())
    }

    async fn get_all_nodes(&self) -> Result<Vec<Node>> {
        let rows = sqlx::query(
            "SELECT id, document_id, chunk_index, text_content, embedding FROM cerebro_nodes",
        )
        .fetch_all(&self.pool)
        .await
        .map_err(|e| CerebroError::StorageError(format!("Search failed: {}", e)))?;

        let mut results = Vec::new();
        for r in rows {
            let uid: uuid::Uuid = r.get("id");
            let vec: Vector = r.get("embedding");
            results.push(Node {
                id: uid.to_string(),
                chunk: Chunk {
                    document_id: r.get("document_id"),
                    index: r.get::<i32, _>("chunk_index") as usize,
                    text: r.get("text_content"),
                },
                embedding: vec.to_vec(),
                edges: vec![],
            });
        }
        Ok(results)
    }
}