use std::marker::PhantomData;
use spire_proto::spiredb::cluster::{
ColumnDef, ColumnType, CreateTableRequest, schema_service_client::SchemaServiceClient,
};
use spiresql::vector::types::{Algorithm, IndexParams};
use crate::client::Spire;
use crate::document::Doc;
use crate::error::{Error, Result};
use crate::search::{Filter, Search};
use crate::watch::WatchStream;
fn doc_cache_key(collection: &str, id: &str) -> u64 {
ahash::RandomState::with_seeds(0, 0, 0, 0).hash_one((collection, id))
}
pub struct Collection<T: Doc> {
pub(crate) spire: Spire,
pub(crate) name: String,
pub(crate) _phantom: PhantomData<T>,
}
impl<T: Doc> Clone for Collection<T> {
fn clone(&self) -> Self {
Self {
spire: self.spire.clone(),
name: self.name.clone(),
_phantom: PhantomData,
}
}
}
impl<T: Doc> Collection<T> {
pub(crate) fn new(spire: Spire, name: String) -> Self {
Self {
spire,
name,
_phantom: PhantomData,
}
}
pub fn table_name(&self) -> String {
format!("_ai_{}", self.name)
}
pub fn index_name(&self) -> String {
format!("_ai_{}_vec", self.name)
}
pub async fn ensure(&self) -> Result<()> {
let table = self.table_name();
let index = self.index_name();
let dims = self.spire.inner.embedder.dimensions() as u32;
let mut schema_client = SchemaServiceClient::new(self.spire.inner.pd_channel.clone());
let columns = vec![
ColumnDef {
name: "id".to_string(),
r#type: ColumnType::TypeString.into(),
nullable: false,
..Default::default()
},
ColumnDef {
name: "doc".to_string(),
r#type: ColumnType::TypeBytes.into(),
nullable: false,
..Default::default()
},
ColumnDef {
name: "embed_text".to_string(),
r#type: ColumnType::TypeString.into(),
nullable: true,
..Default::default()
},
ColumnDef {
name: "created_at".to_string(),
r#type: ColumnType::TypeTimestamp.into(),
nullable: true,
..Default::default()
},
];
let request = CreateTableRequest {
name: table.clone(),
columns,
primary_key: vec!["id".to_string()],
};
match schema_client.create_table(request).await {
Ok(_) => {}
Err(status) if status.code() == tonic::Code::AlreadyExists => {
}
Err(e) => return Err(Error::Grpc(e)),
}
if dims > 0 {
let params = IndexParams::new(&index, &table, "embedding")
.algorithm(Algorithm::Manode)
.dimensions(dims);
match self.spire.inner.vector.create_index(params).await {
Ok(_) => {}
Err(spiresql::vector::error::VectorError::IndexAlreadyExists(_)) => {}
Err(e) => return Err(Error::Vector(e)),
}
}
Ok(())
}
pub async fn insert(&self, doc: &T) -> Result<String> {
let id = doc.id().to_string();
let doc_json = serde_json::to_vec(doc)?;
let embed_text = doc.embed_text();
let cache_key = doc_cache_key(&self.name, &id);
self.spire
.inner
.doc_cache
.insert(cache_key, doc_json.clone());
let embedding = if !embed_text.is_empty() {
Some(self.spire.inner.embedder.embed(&embed_text).await?)
} else {
None
};
if let Some(ref vec) = embedding {
self.vector_insert(id.as_bytes(), vec, &doc_json).await?;
}
Ok(id)
}
pub async fn insert_many(&self, docs: &[T]) -> Result<Vec<String>> {
if docs.is_empty() {
return Ok(Vec::new());
}
let ids: Vec<String> = docs.iter().map(|d| d.id().to_string()).collect();
let texts: Vec<String> = docs.iter().map(|d| d.embed_text()).collect();
let non_empty: Vec<String> = texts.iter().filter(|t| !t.is_empty()).cloned().collect();
let embeddings = if !non_empty.is_empty() {
self.spire.inner.embedder.embed_batch(&non_empty).await?
} else {
Vec::new()
};
let mut embed_iter = embeddings.into_iter();
for (i, doc) in docs.iter().enumerate() {
let doc_json = serde_json::to_vec(doc)?;
let cache_key = doc_cache_key(&self.name, &ids[i]);
self.spire
.inner
.doc_cache
.insert(cache_key, doc_json.clone());
if !texts[i].is_empty()
&& let Some(vec) = embed_iter.next()
{
self.vector_insert(ids[i].as_bytes(), &vec, &doc_json)
.await?;
}
}
Ok(ids)
}
async fn vector_insert(&self, doc_id: &[u8], vec: &[f32], payload: &[u8]) -> Result<u64> {
let index_name = self.index_name();
match self
.spire
.inner
.vector
.insert(&index_name, doc_id, vec, Some(payload))
.await
{
Ok(id) => Ok(id),
Err(spiresql::vector::error::VectorError::IndexNotFound(_)) => {
self.ensure().await?;
Ok(self
.spire
.inner
.vector
.insert(&index_name, doc_id, vec, Some(payload))
.await?)
}
Err(e) => Err(Error::Vector(e)),
}
}
pub async fn upsert(&self, doc: &T) -> Result<String> {
let id = doc.id().to_string();
let _ = self
.spire
.inner
.vector
.delete(&self.index_name(), id.as_bytes())
.await;
self.insert(doc).await
}
pub async fn delete(&self, id: &str) -> Result<bool> {
let cache_key = doc_cache_key(&self.name, id);
self.spire.inner.doc_cache.remove(&cache_key);
match self
.spire
.inner
.vector
.delete(&self.index_name(), id.as_bytes())
.await
{
Ok(_) => Ok(true),
Err(spiresql::vector::error::VectorError::IndexNotFound(_)) => Ok(false),
Err(e) => Err(Error::Vector(e)),
}
}
pub async fn get(&self, id: &str) -> Result<Option<T>> {
let cache_key = doc_cache_key(&self.name, id);
if let Some(bytes) = self.spire.inner.doc_cache.get(&cache_key)
&& let Ok(doc) = serde_json::from_slice::<T>(&bytes)
{
return Ok(Some(doc));
}
match self
.spire
.inner
.vector
.get_payload(&self.index_name(), id.as_bytes())
.await
{
Ok(Some(payload)) => {
self.spire
.inner
.doc_cache
.insert(cache_key, payload.clone());
match serde_json::from_slice::<T>(&payload) {
Ok(doc) => Ok(Some(doc)),
Err(_) => Ok(None),
}
}
Ok(None) => Ok(None),
Err(_) => Ok(None),
}
}
pub async fn get_many(&self, ids: &[&str]) -> Result<Vec<T>> {
let mut docs = Vec::new();
for id in ids {
if let Some(doc) = self.get(id).await? {
docs.push(doc);
}
}
Ok(docs)
}
pub async fn all(&self) -> Result<Vec<T>> {
let dims = self.spire.inner.embedder.dimensions();
if dims == 0 {
return Ok(Vec::new());
}
let val = 1.0 / (dims as f32).sqrt();
let query_vec = vec![val; dims];
let index_name = self.index_name();
let opts = spiresql::vector::types::SearchOptions::default()
.k(10_000)
.with_payload();
let results = match self
.spire
.inner
.vector
.search(&index_name, &query_vec, opts.clone())
.await
{
Ok(r) => r,
Err(spiresql::vector::error::VectorError::IndexNotFound(_)) => {
self.ensure().await?;
self.spire
.inner
.vector
.search(&index_name, &query_vec, opts)
.await?
}
Err(e) => return Err(Error::Vector(e)),
};
let mut docs = Vec::with_capacity(results.len());
for result in results {
if let Some(payload) = &result.payload
&& let Ok(doc) = serde_json::from_slice::<T>(payload)
{
docs.push(doc);
}
}
Ok(docs)
}
pub fn search(&self, query: &str) -> Search<T> {
Search::query(self.clone(), query.to_string())
}
pub fn similar(&self, id: &str) -> Search<T> {
Search::similar_id(self.clone(), id.to_string())
}
pub fn similar_vec(&self, vec: &[f32]) -> Search<T> {
Search::similar_vec(self.clone(), vec.to_vec())
}
pub fn filter(&self, sql_where: &str) -> Filter<T> {
Filter::new(self.clone(), sql_where.to_string())
}
pub async fn watch(&self) -> Result<WatchStream<T>> {
WatchStream::new(&self.spire.inner.stream_addr, &self.table_name()).await
}
}