use std::collections::HashMap;
use async_trait::async_trait;
use surrealdb::Surreal;
use surrealdb::engine::any::{Any, connect};
use surrealdb_types::{RecordId, RecordIdKey, SurrealValue};
use tracing::debug;
use crate::document::{Chunk, SearchResult};
use crate::error::{RagError, Result};
use crate::vectorstore::VectorStore;
const DEFAULT_NS: &str = "adk_rag";
const DEFAULT_DB: &str = "vectors";
#[derive(Debug, SurrealValue)]
struct ChunkRow {
text: String,
embedding: Vec<f32>,
metadata: HashMap<String, String>,
document_id: String,
}
#[derive(Debug, SurrealValue)]
struct SearchRow {
id: RecordId,
text: String,
metadata: HashMap<String, String>,
document_id: String,
distance: f32,
}
fn record_id_key_to_string(key: &RecordIdKey) -> String {
match key {
RecordIdKey::String(s) => s.clone(),
RecordIdKey::Number(n) => n.to_string(),
RecordIdKey::Uuid(u) => u.to_string(),
_ => format!("{key:?}"),
}
}
pub struct SurrealVectorStore {
db: Surreal<Any>,
}
impl SurrealVectorStore {
pub async fn in_memory() -> Result<Self> {
let db = connect("mem://").await.map_err(Self::map_err)?;
db.use_ns(DEFAULT_NS).use_db(DEFAULT_DB).await.map_err(Self::map_err)?;
Ok(Self { db })
}
pub async fn rocksdb(path: &str) -> Result<Self> {
let db = connect(format!("rocksdb://{path}")).await.map_err(Self::map_err)?;
db.use_ns(DEFAULT_NS).use_db(DEFAULT_DB).await.map_err(Self::map_err)?;
Ok(Self { db })
}
pub async fn remote(url: &str) -> Result<Self> {
let db = connect(url).await.map_err(Self::map_err)?;
db.use_ns(DEFAULT_NS).use_db(DEFAULT_DB).await.map_err(Self::map_err)?;
Ok(Self { db })
}
pub fn from_connection(db: Surreal<Any>) -> Self {
Self { db }
}
fn map_err(e: surrealdb::Error) -> RagError {
RagError::VectorStoreError { backend: "surrealdb".to_string(), message: e.to_string() }
}
fn sanitize_table_name(name: &str) -> Result<String> {
let sanitized: String =
name.chars().map(|c| if c.is_alphanumeric() || c == '_' { c } else { '_' }).collect();
if sanitized.is_empty() {
return Err(RagError::VectorStoreError {
backend: "surrealdb".to_string(),
message: "collection name is empty after sanitization".to_string(),
});
}
Ok(sanitized)
}
}
#[async_trait]
impl VectorStore for SurrealVectorStore {
async fn create_collection(&self, name: &str, dimensions: usize) -> Result<()> {
let table = Self::sanitize_table_name(name)?;
let sql = format!(
"DEFINE TABLE IF NOT EXISTS {table}; \
DEFINE FIELD IF NOT EXISTS text ON {table} TYPE string; \
DEFINE FIELD IF NOT EXISTS embedding ON {table} TYPE array<float>; \
DEFINE FIELD IF NOT EXISTS metadata ON {table} FLEXIBLE TYPE object; \
DEFINE FIELD IF NOT EXISTS document_id ON {table} TYPE string; \
DEFINE INDEX IF NOT EXISTS idx_{table}_hnsw ON {table} \
FIELDS embedding HNSW DIMENSION {dimensions} DIST COSINE;"
);
self.db.query(&sql).await.map_err(Self::map_err)?;
debug!(collection = name, table = %table, dimensions, "created surrealdb collection");
Ok(())
}
async fn delete_collection(&self, name: &str) -> Result<()> {
let table = Self::sanitize_table_name(name)?;
self.db.query(format!("REMOVE TABLE IF EXISTS {table};")).await.map_err(Self::map_err)?;
debug!(collection = name, table = %table, "deleted surrealdb collection");
Ok(())
}
async fn upsert(&self, collection: &str, chunks: &[Chunk]) -> Result<()> {
if chunks.is_empty() {
return Ok(());
}
let table = Self::sanitize_table_name(collection)?;
for chunk in chunks {
let row = ChunkRow {
text: chunk.text.clone(),
embedding: chunk.embedding.clone(),
metadata: chunk.metadata.clone(),
document_id: chunk.document_id.clone(),
};
let _: Option<ChunkRow> = self
.db
.upsert((&table as &str, &chunk.id as &str))
.content(row)
.await
.map_err(Self::map_err)?;
}
debug!(collection, count = chunks.len(), "upserted chunks to surrealdb");
Ok(())
}
async fn delete(&self, collection: &str, ids: &[&str]) -> Result<()> {
if ids.is_empty() {
return Ok(());
}
let table = Self::sanitize_table_name(collection)?;
for id in ids {
let _: Option<ChunkRow> =
self.db.delete((&table as &str, *id)).await.map_err(Self::map_err)?;
}
debug!(collection, count = ids.len(), "deleted chunks from surrealdb");
Ok(())
}
async fn search(
&self,
collection: &str,
embedding: &[f32],
top_k: usize,
) -> Result<Vec<SearchResult>> {
let table = Self::sanitize_table_name(collection)?;
let sql = format!(
"SELECT id, text, metadata, document_id, \
vector::distance::knn() AS distance \
FROM {table} \
WHERE embedding <|{top_k},COSINE|> $embedding \
ORDER BY distance;"
);
let embedding_vec: Vec<f32> = embedding.to_vec();
let mut response =
self.db.query(&sql).bind(("embedding", embedding_vec)).await.map_err(Self::map_err)?;
let rows: Vec<SearchRow> = response.take(0).map_err(Self::map_err)?;
let results = rows
.into_iter()
.map(|row| {
let id = record_id_key_to_string(&row.id.key);
let score = 1.0 - row.distance;
SearchResult {
chunk: Chunk {
id,
text: row.text,
embedding: vec![],
metadata: row.metadata,
document_id: row.document_id,
},
score,
}
})
.collect();
Ok(results)
}
}