use anyhow::Result;
use async_trait::async_trait;
use serde::Deserialize;
use surrealdb::types::SurrealValue;
use tracing::debug;
use post_cortex_embeddings::{SearchMatch, VectorMetadata};
use crate::traits::VectorStorage;
use super::SurrealDBStorage;
use super::MIN_VECTOR_LEN;
use super::records::{EmbeddingRecord, KnnResult};
#[async_trait]
impl VectorStorage for SurrealDBStorage {
async fn add_vector(&self, vector: Vec<f32>, metadata: VectorMetadata) -> Result<String> {
if vector.len() < MIN_VECTOR_LEN {
return Err(anyhow::anyhow!(
"Vector too short: got {} dims, need at least {}",
vector.len(),
MIN_VECTOR_LEN,
));
}
debug!(
"SurrealDBStorage: Adding vector for content {}",
metadata.id
);
let record = EmbeddingRecord {
content_id: metadata.id.clone(),
session_id: metadata.source.clone(),
vector,
text: metadata.text,
content_type: metadata.content_type,
timestamp: metadata.timestamp.to_rfc3339(),
metadata: metadata.metadata,
};
let _: Option<EmbeddingRecord> = self
.db
.upsert(("embedding", metadata.id.clone()))
.content(record)
.await?;
Ok(metadata.id)
}
async fn add_vectors_batch(
&self,
vectors: Vec<(Vec<f32>, VectorMetadata)>,
) -> Result<Vec<String>> {
let mut ids = Vec::new();
for (vector, metadata) in vectors {
let id = self.add_vector(vector, metadata).await?;
ids.push(id);
}
Ok(ids)
}
async fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchMatch>> {
if query.len() < MIN_VECTOR_LEN {
return Err(anyhow::anyhow!(
"Query vector dimension mismatch: at least {}, got {}",
MIN_VECTOR_LEN,
query.len()
));
}
debug!("SurrealDBStorage: HNSW search for {} nearest vectors", k);
let query_vec: Vec<f32> = query.to_vec();
let mut response = self
.db
.query("SELECT *, vector::distance::knn() AS distance FROM embedding WHERE vector <|$k|> $query_vec")
.bind(("k", k as i64))
.bind(("query_vec", query_vec))
.await?;
let records: Vec<KnnResult> = response.take(0)?;
let matches: Vec<SearchMatch> = records
.into_iter()
.map(|r| {
let similarity = 1.0 - r.distance.min(1.0);
SearchMatch {
vector_id: 0,
similarity,
metadata: VectorMetadata {
id: r.content_id,
text: r.text,
source: r.session_id,
content_type: r.content_type,
timestamp: Self::parse_datetime(&r.timestamp),
metadata: r.metadata,
},
}
})
.collect();
debug!("SurrealDBStorage: HNSW found {} matches", matches.len());
Ok(matches)
}
async fn search_in_session(
&self,
query: &[f32],
k: usize,
session_id: &str,
) -> Result<Vec<SearchMatch>> {
if query.len() < MIN_VECTOR_LEN {
return Err(anyhow::anyhow!(
"Query vector dimension mismatch: at least {}, got {}",
MIN_VECTOR_LEN,
query.len()
));
}
debug!(
"SurrealDBStorage: HNSW search for {} nearest vectors in session {}",
k, session_id
);
let fetch_count = (k * 5).max(50); let query_vec: Vec<f32> = query.to_vec();
let mut response = self
.db
.query("SELECT *, vector::distance::knn() AS distance FROM embedding WHERE vector <|$fetch_count|> $query_vec")
.bind(("fetch_count", fetch_count as i64))
.bind(("query_vec", query_vec))
.await?;
let records: Vec<KnnResult> = response.take(0)?;
let matches: Vec<SearchMatch> = records
.into_iter()
.filter(|r| r.session_id == session_id)
.take(k)
.map(|r| {
let similarity = 1.0 - r.distance.min(1.0);
SearchMatch {
vector_id: 0,
similarity,
metadata: VectorMetadata {
id: r.content_id,
text: r.text,
source: r.session_id,
content_type: r.content_type,
timestamp: Self::parse_datetime(&r.timestamp),
metadata: r.metadata,
},
}
})
.collect();
debug!(
"SurrealDBStorage: HNSW found {} matches in session",
matches.len()
);
Ok(matches)
}
async fn search_by_content_type(
&self,
query: &[f32],
k: usize,
content_type: &str,
) -> Result<Vec<SearchMatch>> {
if query.len() < MIN_VECTOR_LEN {
return Err(anyhow::anyhow!(
"Query vector dimension mismatch: at least {}, got {}",
MIN_VECTOR_LEN,
query.len()
));
}
debug!(
"SurrealDBStorage: HNSW search for {} nearest vectors of type {}",
k, content_type
);
let fetch_count = (k * 5).max(50);
let query_vec: Vec<f32> = query.to_vec();
let mut response = self
.db
.query("SELECT *, vector::distance::knn() AS distance FROM embedding WHERE vector <|$fetch_count|> $query_vec")
.bind(("fetch_count", fetch_count as i64))
.bind(("query_vec", query_vec))
.await?;
let records: Vec<KnnResult> = response.take(0)?;
let matches: Vec<SearchMatch> = records
.into_iter()
.filter(|r| r.content_type == content_type)
.take(k)
.map(|r| {
let similarity = 1.0 - r.distance.min(1.0);
SearchMatch {
vector_id: 0,
similarity,
metadata: VectorMetadata {
id: r.content_id,
text: r.text,
source: r.session_id,
content_type: r.content_type,
timestamp: Self::parse_datetime(&r.timestamp),
metadata: r.metadata,
},
}
})
.collect();
debug!(
"SurrealDBStorage: HNSW found {} matches of type {}",
matches.len(),
content_type
);
Ok(matches)
}
async fn remove_vector(&self, id: &str) -> Result<bool> {
let result: Option<EmbeddingRecord> = self.delete("embedding", id).await?;
Ok(result.is_some())
}
async fn has_session_embeddings(&self, session_id: &str) -> bool {
let count = self.count_session_embeddings(session_id).await;
count > 0
}
async fn count_session_embeddings(&self, session_id: &str) -> usize {
let result = self
.db
.query("SELECT count() FROM embedding WHERE session_id = $session_id GROUP ALL")
.bind(("session_id", session_id.to_string()))
.await;
if let Ok(mut response) = result {
#[derive(Deserialize, SurrealValue)]
struct CountResult {
count: i64,
}
if let Ok(Some(count)) = response.take::<Option<CountResult>>(0) {
return count.count as usize;
}
}
0
}
async fn total_count(&self) -> usize {
let result = self.db.query("SELECT count() FROM embedding GROUP ALL").await;
if let Ok(mut response) = result {
#[derive(Deserialize, SurrealValue)]
struct CountResult {
count: i64,
}
if let Ok(Some(count)) = response.take::<Option<CountResult>>(0) {
return count.count as usize;
}
}
0
}
async fn get_session_vectors(
&self,
session_id: &str,
) -> Result<Vec<(Vec<f32>, VectorMetadata)>> {
let mut response = self
.db
.query("SELECT * FROM embedding WHERE session_id = $session_id")
.bind(("session_id", session_id.to_string()))
.await?;
let records: Vec<EmbeddingRecord> = response.take(0)?;
Ok(records
.into_iter()
.map(|r| {
(
r.vector,
VectorMetadata {
id: r.content_id,
text: r.text,
source: r.session_id,
content_type: r.content_type,
timestamp: Self::parse_datetime(&r.timestamp),
metadata: r.metadata,
},
)
})
.collect())
}
async fn get_all_vectors(&self) -> Result<Vec<(Vec<f32>, VectorMetadata)>> {
let mut all_vectors = Vec::new();
let limit = 1000;
let mut start = 0;
loop {
let mut response = self
.db
.query("SELECT * FROM embedding LIMIT $limit START $start")
.bind(("limit", limit))
.bind(("start", start))
.await?;
let records: Vec<EmbeddingRecord> = response.take(0)?;
if records.is_empty() {
break;
}
for r in records {
all_vectors.push((
r.vector,
VectorMetadata {
id: r.content_id,
text: r.text,
source: r.session_id,
content_type: r.content_type,
timestamp: Self::parse_datetime(&r.timestamp),
metadata: r.metadata,
},
));
}
start += limit;
}
Ok(all_vectors)
}
}