use crate::models::{Chunk, Node};
use crate::traits::{CerebroError, Result, VectorStore};
use async_trait::async_trait;
use pgvector::Vector;
use sqlx::{postgres::PgPoolOptions, Pool, Postgres, Row};
pub struct PgVectorStore {
pool: Pool<Postgres>,
}
impl PgVectorStore {
pub async fn new(database_url: &str) -> Result<Self> {
let pool = PgPoolOptions::new()
.max_connections(5)
.connect(database_url)
.await
.map_err(|e| CerebroError::StorageError(e.to_string()))?;
sqlx::query("CREATE EXTENSION IF NOT EXISTS vector")
.execute(&pool)
.await
.map_err(|e| CerebroError::StorageError(e.to_string()))?;
sqlx::query(
"CREATE TABLE IF NOT EXISTS cerebro_nodes (
id UUID PRIMARY KEY,
document_id TEXT NOT NULL,
chunk_index INT NOT NULL,
text_content TEXT NOT NULL,
embedding VECTOR(1536),
ts tsvector GENERATED ALWAYS AS (to_tsvector('english', text_content)) STORED
)",
)
.execute(&pool)
.await
.map_err(|e| CerebroError::StorageError(e.to_string()))?;
Ok(Self { pool })
}
}
#[async_trait]
impl VectorStore for PgVectorStore {
async fn upsert(&self, nodes: Vec<Node>) -> Result<()> {
let mut tx = self
.pool
.begin()
.await
.map_err(|e| CerebroError::StorageError(e.to_string()))?;
for node in nodes {
let vec_embedding = Vector::from(node.embedding);
let uuid = uuid::Uuid::parse_str(&node.id).unwrap_or_else(|_| uuid::Uuid::new_v4());
sqlx::query(
"INSERT INTO cerebro_nodes (id, document_id, chunk_index, text_content, embedding)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (id) DO UPDATE SET
text_content = EXCLUDED.text_content,
embedding = EXCLUDED.embedding",
)
.bind(uuid)
.bind(&node.chunk.document_id)
.bind(node.chunk.index as i32)
.bind(&node.chunk.text)
.bind(vec_embedding)
.execute(&mut *tx)
.await
.map_err(|e| CerebroError::StorageError(e.to_string()))?;
}
tx.commit()
.await
.map_err(|e| CerebroError::StorageError(e.to_string()))?;
Ok(())
}
async fn get(&self, node_ids: &[&str]) -> Result<Vec<Node>> {
let mut result = Vec::new();
for id in node_ids {
if let Ok(uuid) = uuid::Uuid::parse_str(id) {
let row = sqlx::query("SELECT document_id, chunk_index, text_content, embedding FROM cerebro_nodes WHERE id = $1")
.bind(uuid)
.fetch_optional(&self.pool).await
.map_err(|e| CerebroError::StorageError(e.to_string()))?;
if let Some(r) = row {
let vec: Vector = r.get("embedding");
result.push(Node {
id: id.to_string(),
chunk: Chunk {
document_id: r.get("document_id"),
index: r.get::<i32, _>("chunk_index") as usize,
text: r.get("text_content"),
},
embedding: vec.to_vec(),
edges: vec![],
});
}
}
}
Ok(result)
}
async fn search(
&self,
text_query: &str,
embedding: &[f32],
top_k: usize,
) -> Result<Vec<(Node, f32)>> {
let query_vec = Vector::from(embedding.to_vec());
let rows = sqlx::query(
"WITH vector_search AS (
SELECT id, ROW_NUMBER() OVER (ORDER BY embedding <=> $1) as vector_rank
FROM cerebro_nodes
ORDER BY embedding <=> $1
LIMIT 100
),
text_search AS (
SELECT id, ROW_NUMBER() OVER (ORDER BY ts_rank(ts, websearch_to_tsquery('english', $2)) DESC) as text_rank
FROM cerebro_nodes
WHERE ts @@ websearch_to_tsquery('english', $2)
LIMIT 100
),
rrf AS (
SELECT COALESCE(v.id, t.id) as id,
COALESCE(1.0 / (60 + v.vector_rank), 0.0) + COALESCE(1.0 / (60 + t.text_rank), 0.0) AS rrf_score
FROM vector_search v
FULL OUTER JOIN text_search t ON v.id = t.id
ORDER BY rrf_score DESC
LIMIT $3
)
SELECT c.id, c.document_id, c.chunk_index, c.text_content, c.embedding, r.rrf_score
FROM rrf r
JOIN cerebro_nodes c ON r.id = c.id
ORDER BY r.rrf_score DESC"
)
.bind(query_vec)
.bind(text_query)
.bind(top_k as i32)
.fetch_all(&self.pool).await
.map_err(|e| CerebroError::StorageError(format!("Search failed: {}", e)))?;
let mut results = Vec::new();
for r in rows {
let uid: uuid::Uuid = r.get("id");
let vec: Vector = r.get("embedding");
let sim: f64 = r.get("rrf_score");
results.push((
Node {
id: uid.to_string(),
chunk: Chunk {
document_id: r.get("document_id"),
index: r.get::<i32, _>("chunk_index") as usize,
text: r.get("text_content"),
},
embedding: vec.to_vec(),
edges: vec![],
},
sim as f32,
));
}
Ok(results)
}
async fn delete_document(&self, doc_id: &str) -> Result<()> {
sqlx::query("DELETE FROM cerebro_nodes WHERE document_id = $1")
.bind(doc_id)
.execute(&self.pool)
.await
.map_err(|e| CerebroError::StorageError(e.to_string()))?;
Ok(())
}
async fn get_all_nodes(&self) -> Result<Vec<Node>> {
let rows = sqlx::query(
"SELECT id, document_id, chunk_index, text_content, embedding FROM cerebro_nodes",
)
.fetch_all(&self.pool)
.await
.map_err(|e| CerebroError::StorageError(format!("Search failed: {}", e)))?;
let mut results = Vec::new();
for r in rows {
let uid: uuid::Uuid = r.get("id");
let vec: Vector = r.get("embedding");
results.push(Node {
id: uid.to_string(),
chunk: Chunk {
document_id: r.get("document_id"),
index: r.get::<i32, _>("chunk_index") as usize,
text: r.get("text_content"),
},
embedding: vec.to_vec(),
edges: vec![],
});
}
Ok(results)
}
}