use std::marker::PhantomData;
use spiresql::vector::types::SearchOptions;
use crate::collection::Collection;
use crate::document::Doc;
use crate::error::{Error, Result};
#[derive(Debug, Clone)]
pub struct Hit<T> {
pub id: String,
pub score: f32,
pub doc: T,
}
pub struct Search<T: Doc> {
collection: Collection<T>,
mode: SearchMode,
filter_sql: Option<String>,
limit: usize,
min_score: f32,
}
pub(crate) enum SearchMode {
Query(String),
SimilarId(String),
SimilarVec(Vec<f32>),
}
impl<T: Doc> Search<T> {
pub(crate) fn query(collection: Collection<T>, query: String) -> Self {
Self {
collection,
mode: SearchMode::Query(query),
filter_sql: None,
limit: 10,
min_score: 0.0,
}
}
pub(crate) fn similar_id(collection: Collection<T>, id: String) -> Self {
Self {
collection,
mode: SearchMode::SimilarId(id),
filter_sql: None,
limit: 10,
min_score: 0.0,
}
}
pub(crate) fn similar_vec(collection: Collection<T>, vec: Vec<f32>) -> Self {
Self {
collection,
mode: SearchMode::SimilarVec(vec),
filter_sql: None,
limit: 10,
min_score: 0.0,
}
}
pub fn filter(mut self, sql: &str) -> Self {
self.filter_sql = Some(sql.to_string());
self
}
pub fn limit(mut self, n: usize) -> Self {
self.limit = n;
self
}
pub fn min_score(mut self, s: f32) -> Self {
self.min_score = s;
self
}
pub async fn run(self) -> Result<Vec<Hit<T>>> {
let query_vec = match &self.mode {
SearchMode::Query(text) => self.collection.spire.inner.embedder.embed(text).await?,
SearchMode::SimilarVec(vec) => vec.clone(),
SearchMode::SimilarId(_id) => {
return Err(Error::Other("similar_id not yet implemented".to_string()));
}
};
let index_name = self.collection.index_name();
let opts = SearchOptions::default().k(self.limit as u32).with_payload();
let results = match self
.collection
.spire
.inner
.vector
.search(&index_name, &query_vec, opts.clone())
.await
{
Ok(r) => r,
Err(spiresql::vector::error::VectorError::IndexNotFound(_)) => {
self.collection.ensure().await?;
self.collection
.spire
.inner
.vector
.search(&index_name, &query_vec, opts)
.await?
}
Err(e) => return Err(Error::Vector(e)),
};
let mut hits = Vec::with_capacity(results.len());
for result in results {
let score = 1.0 - result.distance;
if score < self.min_score {
continue;
}
let id = String::from_utf8_lossy(&result.id).to_string();
if let Some(payload) = &result.payload {
match serde_json::from_slice::<T>(payload) {
Ok(doc) => {
hits.push(Hit { id, score, doc });
}
Err(_) => {
continue;
}
}
}
}
Ok(hits)
}
pub async fn docs(self) -> Result<Vec<T>> {
Ok(self.run().await?.into_iter().map(|h| h.doc).collect())
}
pub async fn first(mut self) -> Result<Option<Hit<T>>> {
self.limit = 1;
Ok(self.run().await?.into_iter().next())
}
}
pub struct Filter<T: Doc> {
collection: Collection<T>,
sql_where: String,
order_by: Option<String>,
limit: Option<usize>,
_phantom: PhantomData<T>,
}
impl<T: Doc> Filter<T> {
pub(crate) fn new(collection: Collection<T>, sql_where: String) -> Self {
Self {
collection,
sql_where,
order_by: None,
limit: None,
_phantom: PhantomData,
}
}
pub fn order_by(mut self, col: &str, desc: bool) -> Self {
let dir = if desc { "DESC" } else { "ASC" };
self.order_by = Some(format!("{col} {dir}"));
self
}
pub fn limit(mut self, n: usize) -> Self {
self.limit = Some(n);
self
}
pub async fn run(self) -> Result<Vec<T>> {
let _ = (self.collection, self.sql_where, self.order_by, self.limit);
Ok(Vec::new())
}
pub async fn count(self) -> Result<u64> {
Ok(0)
}
}