use std::ops::Range;
use lancedb::{
DistanceType,
query::{QueryBase, VectorQuery},
};
use rig::{
embeddings::embedding::EmbeddingModel,
vector_store::{
VectorStoreError, VectorStoreIndex,
request::{FilterError, SearchFilter, VectorSearchRequest},
},
};
use serde::Deserialize;
use serde_json::Value;
use utils::{FilterTableColumns, QueryToJson};
mod utils;
fn lancedb_to_rig_error(e: lancedb::Error) -> VectorStoreError {
VectorStoreError::DatastoreError(Box::new(e))
}
fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError {
VectorStoreError::JsonError(e)
}
pub struct LanceDbVectorIndex<M: EmbeddingModel> {
model: M,
table: lancedb::Table,
id_field: String,
search_params: SearchParams,
}
impl<M> LanceDbVectorIndex<M>
where
M: EmbeddingModel,
{
pub async fn new(
table: lancedb::Table,
model: M,
id_field: &str,
search_params: SearchParams,
) -> Result<Self, lancedb::Error> {
Ok(Self {
table,
model,
id_field: id_field.to_string(),
search_params,
})
}
fn build_query(&self, mut query: VectorQuery) -> VectorQuery {
let SearchParams {
distance_type,
search_type,
nprobes,
refine_factor,
post_filter,
column,
} = self.search_params.clone();
if let Some(distance_type) = distance_type {
query = query.distance_type(distance_type);
}
if let Some(SearchType::Flat) = search_type {
query = query.bypass_vector_index();
}
if let Some(SearchType::Approximate) = search_type {
if let Some(nprobes) = nprobes {
query = query.nprobes(nprobes);
}
if let Some(refine_factor) = refine_factor {
query = query.refine_factor(refine_factor);
}
}
if let Some(true) = post_filter {
query = query.postfilter();
}
if let Some(column) = column {
query = query.column(column.as_str())
}
query
}
}
#[derive(Debug, Clone)]
pub enum SearchType {
Flat,
Approximate,
}
#[derive(Debug, Clone)]
pub struct LanceDBFilter(Result<String, FilterError>);
impl serde::Serialize for LanceDBFilter {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match &self.0 {
Ok(s) => serializer.serialize_str(s),
Err(e) => serializer.collect_str(e),
}
}
}
impl<'de> serde::Deserialize<'de> for LanceDBFilter {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Ok(LanceDBFilter(Ok(s)))
}
}
fn zip_result(
l: Result<String, FilterError>,
r: Result<String, FilterError>,
) -> Result<(String, String), FilterError> {
l.and_then(|l| r.map(|r| (l, r)))
}
impl SearchFilter for LanceDBFilter {
type Value = serde_json::Value;
fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
Self(escape_value(value).map(|s| format!("{} = {s}", key.as_ref())))
}
fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
Self(escape_value(value).map(|s| format!("{} > {s}", key.as_ref())))
}
fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
Self(escape_value(value).map(|s| format!("{} < {s}", key.as_ref())))
}
fn and(self, rhs: Self) -> Self {
Self(zip_result(self.0, rhs.0).map(|(l, r)| format!("({l}) AND ({r})")))
}
fn or(self, rhs: Self) -> Self {
Self(zip_result(self.0, rhs.0).map(|(l, r)| format!("({l}) OR ({r})")))
}
}
fn escape_value(value: serde_json::Value) -> Result<String, FilterError> {
use serde_json::Value::*;
match value {
Null => Ok("NULL".into()),
Bool(b) => Ok(b.to_string()),
Number(n) => Ok(n.to_string()),
String(s) => Ok(format!("'{}'", s.replace("'", "''"))),
Array(xs) => Ok(format!(
"({})",
xs.into_iter()
.map(escape_value)
.collect::<Result<Vec<_>, _>>()?
.join(", ")
)),
Object(_) => Err(FilterError::TypeError(
"objects not supported in SQLite backend".into(),
)),
}
}
impl LanceDBFilter {
pub fn into_inner(self) -> Result<String, FilterError> {
self.0
}
#[allow(clippy::should_implement_trait)]
pub fn not(self) -> Self {
Self(self.0.map(|s| format!("NOT ({s})")))
}
pub fn in_values(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
Self(
values
.into_iter()
.map(escape_value)
.collect::<Result<Vec<_>, FilterError>>()
.map(|xs| xs.join(","))
.map(|xs| format!("{key} IN ({xs})")),
)
}
pub fn like<S>(key: String, pattern: S) -> Self
where
S: AsRef<str>,
{
Self(
escape_value(serde_json::Value::String(pattern.as_ref().into()))
.map(|pat| format!("{key} LIKE {pat}")),
)
}
pub fn ilike<S>(key: String, pattern: S) -> Self
where
S: AsRef<str>,
{
Self(
escape_value(serde_json::Value::String(pattern.as_ref().into()))
.map(|pat| format!("{key} ILIKE {pat}")),
)
}
pub fn is_null(key: String) -> Self {
Self(Ok(format!("{key} IS NULL")))
}
pub fn is_not_null(key: String) -> Self {
Self(Ok(format!("{key} IS NOT NULL")))
}
pub fn array_has_any(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
Self(
values
.into_iter()
.map(escape_value)
.collect::<Result<Vec<_>, FilterError>>()
.map(|xs| xs.join(","))
.map(|xs| format!("array_has_any({key}, ARRAY[{xs}])")),
)
}
pub fn array_has_all(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
Self(
values
.into_iter()
.map(escape_value)
.collect::<Result<Vec<_>, FilterError>>()
.map(|xs| xs.join(","))
.map(|xs| format!("array_has_all({key}, ARRAY[{xs}])")),
)
}
pub fn array_length(key: String, length: i32) -> Self {
Self(Ok(format!("array_length({key}) = {length}")))
}
pub fn between<T>(key: String, Range { start, end }: Range<T>) -> Self
where
T: PartialOrd + std::fmt::Display + Into<serde_json::Number>,
{
Self(Ok(format!("{key} BETWEEN {start} AND {end}")))
}
}
#[derive(Debug, Clone, Default)]
pub struct SearchParams {
distance_type: Option<DistanceType>,
search_type: Option<SearchType>,
nprobes: Option<usize>,
refine_factor: Option<u32>,
post_filter: Option<bool>,
column: Option<String>,
}
impl SearchParams {
pub fn distance_type(mut self, distance_type: DistanceType) -> Self {
self.distance_type = Some(distance_type);
self
}
pub fn search_type(mut self, search_type: SearchType) -> Self {
self.search_type = Some(search_type);
self
}
pub fn nprobes(mut self, nprobes: usize) -> Self {
self.nprobes = Some(nprobes);
self
}
pub fn refine_factor(mut self, refine_factor: u32) -> Self {
self.refine_factor = Some(refine_factor);
self
}
pub fn post_filter(mut self, post_filter: bool) -> Self {
self.post_filter = Some(post_filter);
self
}
pub fn column(mut self, column: &str) -> Self {
self.column = Some(column.to_string());
self
}
}
impl<M> VectorStoreIndex for LanceDbVectorIndex<M>
where
M: EmbeddingModel + Sync + Send,
{
type Filter = LanceDBFilter;
async fn top_n<T: for<'a> Deserialize<'a> + Send>(
&self,
req: VectorSearchRequest<LanceDBFilter>,
) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
let prompt_embedding = self.model.embed_text(req.query()).await?;
let mut query = self
.table
.vector_search(prompt_embedding.vec.clone())
.map_err(lancedb_to_rig_error)?
.limit(req.samples() as usize)
.distance_range(None, req.threshold().map(|x| x as f32))
.select(lancedb::query::Select::Columns(
self.table
.schema()
.await
.map_err(lancedb_to_rig_error)?
.filter_embeddings(),
));
if let Some(filter) = req.filter() {
query = query.only_if(filter.clone().into_inner()?)
}
self.build_query(query)
.execute_query()
.await?
.into_iter()
.enumerate()
.map(|(i, value)| {
Ok((
match value.get("_distance") {
Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(),
_ => 0.0,
},
match value.get(self.id_field.clone()) {
Some(Value::String(id)) => id.to_string(),
_ => format!("unknown{i}"),
},
serde_json::from_value(value).map_err(serde_to_rig_error)?,
))
})
.collect()
}
async fn top_n_ids(
&self,
req: VectorSearchRequest<LanceDBFilter>,
) -> Result<Vec<(f64, String)>, VectorStoreError> {
let prompt_embedding = self.model.embed_text(req.query()).await?;
let mut query = self
.table
.query()
.select(lancedb::query::Select::Columns(vec![self.id_field.clone()]))
.nearest_to(prompt_embedding.vec.clone())
.map_err(lancedb_to_rig_error)?
.distance_range(None, req.threshold().map(|x| x as f32))
.limit(req.samples() as usize);
if let Some(filter) = req.filter() {
query = query.only_if(filter.clone().into_inner()?)
}
self.build_query(query)
.execute_query()
.await?
.into_iter()
.map(|value| {
Ok((
match value.get("distance") {
Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(),
_ => 0.0,
},
match value.get(self.id_field.clone()) {
Some(Value::String(id)) => id.to_string(),
_ => "".to_string(),
},
))
})
.collect()
}
}