1use lancedb::{
2 DistanceType,
3 query::{QueryBase, VectorQuery},
4};
5use rig::{
6 embeddings::embedding::EmbeddingModel,
7 vector_store::{
8 VectorStoreError, VectorStoreIndex,
9 request::{FilterError, SearchFilter, VectorSearchRequest},
10 },
11};
12use serde::Deserialize;
13use serde_json::Value;
14use utils::{FilterTableColumns, QueryToJson};
15
16mod utils;
17
18fn lancedb_to_rig_error(e: lancedb::Error) -> VectorStoreError {
19 VectorStoreError::DatastoreError(Box::new(e))
20}
21
22fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError {
23 VectorStoreError::JsonError(e)
24}
25
26pub struct LanceDbVectorIndex<M: EmbeddingModel> {
39 model: M,
41 table: lancedb::Table,
43 id_field: String,
45 search_params: SearchParams,
47}
48
49impl<M> LanceDbVectorIndex<M>
50where
51 M: EmbeddingModel,
52{
53 pub async fn new(
57 table: lancedb::Table,
58 model: M,
59 id_field: &str,
60 search_params: SearchParams,
61 ) -> Result<Self, lancedb::Error> {
62 Ok(Self {
63 table,
64 model,
65 id_field: id_field.to_string(),
66 search_params,
67 })
68 }
69
70 fn build_query(&self, mut query: VectorQuery) -> VectorQuery {
73 let SearchParams {
74 distance_type,
75 search_type,
76 nprobes,
77 refine_factor,
78 post_filter,
79 column,
80 } = self.search_params.clone();
81
82 if let Some(distance_type) = distance_type {
83 query = query.distance_type(distance_type);
84 }
85
86 if let Some(SearchType::Flat) = search_type {
87 query = query.bypass_vector_index();
88 }
89
90 if let Some(SearchType::Approximate) = search_type {
91 if let Some(nprobes) = nprobes {
92 query = query.nprobes(nprobes);
93 }
94 if let Some(refine_factor) = refine_factor {
95 query = query.refine_factor(refine_factor);
96 }
97 }
98
99 if let Some(true) = post_filter {
100 query = query.postfilter();
101 }
102
103 if let Some(column) = column {
104 query = query.column(column.as_str())
105 }
106
107 query
108 }
109}
110
111#[derive(Debug, Clone)]
113pub enum SearchType {
114 Flat,
116 Approximate,
118}
119
120#[derive(Debug, Clone)]
122pub struct LanceDBFilter(Result<String, FilterError>);
123
124fn zip_result(
125 l: Result<String, FilterError>,
126 r: Result<String, FilterError>,
127) -> Result<(String, String), FilterError> {
128 l.and_then(|l| r.map(|r| (l, r)))
129}
130
131impl SearchFilter for LanceDBFilter {
132 type Value = serde_json::Value;
133
134 fn eq(key: String, value: Self::Value) -> Self {
135 Self(escape_value(value).map(|s| format!("{key} = {s}")))
136 }
137
138 fn gt(key: String, value: Self::Value) -> Self {
139 Self(escape_value(value).map(|s| format!("{key} > {s}")))
140 }
141
142 fn lt(key: String, value: Self::Value) -> Self {
143 Self(escape_value(value).map(|s| format!("{key} < {s}")))
144 }
145
146 fn and(self, rhs: Self) -> Self {
147 Self(zip_result(self.0, rhs.0).map(|(l, r)| format!("({l}) AND ({r})")))
148 }
149
150 fn or(self, rhs: Self) -> Self {
151 Self(zip_result(self.0, rhs.0).map(|(l, r)| format!("({l}) OR ({r})")))
152 }
153}
154
155fn escape_value(value: serde_json::Value) -> Result<String, FilterError> {
156 use serde_json::Value::*;
157
158 match value {
159 Null => Ok("NULL".into()),
160 Bool(b) => Ok(b.to_string()),
161 Number(n) => Ok(n.to_string()),
162 String(s) => Ok(format!("'{}'", s.replace("'", "''"))),
163 Array(xs) => Ok(format!(
164 "({})",
165 xs.into_iter()
166 .map(escape_value)
167 .collect::<Result<Vec<_>, _>>()?
168 .join(", ")
169 )),
170 Object(_) => Err(FilterError::TypeError(
171 "objects not supported in SQLite backend".into(),
172 )),
173 }
174}
175
176impl LanceDBFilter {
177 pub fn into_inner(self) -> Result<String, FilterError> {
178 self.0
179 }
180
181 #[allow(clippy::should_implement_trait)]
182 pub fn not(self) -> Self {
183 Self(self.0.map(|s| format!("NOT ({s})")))
184 }
185}
186
187#[derive(Debug, Clone, Default)]
193pub struct SearchParams {
194 distance_type: Option<DistanceType>,
195 search_type: Option<SearchType>,
196 nprobes: Option<usize>,
197 refine_factor: Option<u32>,
198 post_filter: Option<bool>,
199 column: Option<String>,
200}
201
202impl SearchParams {
203 pub fn distance_type(mut self, distance_type: DistanceType) -> Self {
207 self.distance_type = Some(distance_type);
208 self
209 }
210
211 pub fn search_type(mut self, search_type: SearchType) -> Self {
215 self.search_type = Some(search_type);
216 self
217 }
218
219 pub fn nprobes(mut self, nprobes: usize) -> Self {
223 self.nprobes = Some(nprobes);
224 self
225 }
226
227 pub fn refine_factor(mut self, refine_factor: u32) -> Self {
231 self.refine_factor = Some(refine_factor);
232 self
233 }
234
235 pub fn post_filter(mut self, post_filter: bool) -> Self {
239 self.post_filter = Some(post_filter);
240 self
241 }
242
243 pub fn column(mut self, column: &str) -> Self {
247 self.column = Some(column.to_string());
248 self
249 }
250}
251
252impl<M> VectorStoreIndex for LanceDbVectorIndex<M>
253where
254 M: EmbeddingModel + Sync + Send,
255{
256 type Filter = LanceDBFilter;
257
258 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
276 &self,
277 req: VectorSearchRequest<LanceDBFilter>,
278 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
279 let prompt_embedding = self.model.embed_text(req.query()).await?;
280
281 let mut query = self
282 .table
283 .vector_search(prompt_embedding.vec.clone())
284 .map_err(lancedb_to_rig_error)?
285 .limit(req.samples() as usize)
286 .distance_range(None, req.threshold().map(|x| x as f32))
287 .select(lancedb::query::Select::Columns(
288 self.table
289 .schema()
290 .await
291 .map_err(lancedb_to_rig_error)?
292 .filter_embeddings(),
293 ));
294
295 if let Some(filter) = req.filter() {
296 query = query.only_if(filter.clone().into_inner()?)
297 }
298
299 self.build_query(query)
300 .execute_query()
301 .await?
302 .into_iter()
303 .enumerate()
304 .map(|(i, value)| {
305 Ok((
306 match value.get("_distance") {
307 Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(),
308 _ => 0.0,
309 },
310 match value.get(self.id_field.clone()) {
311 Some(Value::String(id)) => id.to_string(),
312 _ => format!("unknown{i}"),
313 },
314 serde_json::from_value(value).map_err(serde_to_rig_error)?,
315 ))
316 })
317 .collect()
318 }
319
320 async fn top_n_ids(
338 &self,
339 req: VectorSearchRequest<LanceDBFilter>,
340 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
341 let prompt_embedding = self.model.embed_text(req.query()).await?;
342
343 let mut query = self
344 .table
345 .query()
346 .select(lancedb::query::Select::Columns(vec![self.id_field.clone()]))
347 .nearest_to(prompt_embedding.vec.clone())
348 .map_err(lancedb_to_rig_error)?
349 .distance_range(None, req.threshold().map(|x| x as f32))
350 .limit(req.samples() as usize);
351
352 if let Some(filter) = req.filter() {
353 query = query.only_if(filter.clone().into_inner()?)
354 }
355
356 self.build_query(query)
357 .execute_query()
358 .await?
359 .into_iter()
360 .map(|value| {
361 Ok((
362 match value.get("distance") {
363 Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(),
364 _ => 0.0,
365 },
366 match value.get(self.id_field.clone()) {
367 Some(Value::String(id)) => id.to_string(),
368 _ => "".to_string(),
369 },
370 ))
371 })
372 .collect()
373 }
374}