1#![cfg_attr(
2 test,
3 allow(
4 clippy::expect_used,
5 clippy::indexing_slicing,
6 clippy::panic,
7 clippy::unwrap_used,
8 clippy::unreachable
9 )
10)]
11
12use std::ops::Range;
13
14use lancedb::{
15 DistanceType,
16 query::{QueryBase, VectorQuery},
17};
18use rig::{
19 embeddings::embedding::EmbeddingModel,
20 vector_store::{
21 VectorStoreError, VectorStoreIndex,
22 request::{FilterError, SearchFilter, VectorSearchRequest},
23 },
24};
25use serde::Deserialize;
26use serde_json::Value;
27use utils::{FilterTableColumns, QueryToJson};
28
29mod utils;
30
31fn lancedb_to_rig_error(e: lancedb::Error) -> VectorStoreError {
32 VectorStoreError::DatastoreError(Box::new(e))
33}
34
35fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError {
36 VectorStoreError::JsonError(e)
37}
38
39pub struct LanceDbVectorIndex<M: EmbeddingModel> {
53 model: M,
55 table: lancedb::Table,
57 id_field: String,
59 search_params: SearchParams,
61}
62
63impl<M> LanceDbVectorIndex<M>
64where
65 M: EmbeddingModel,
66{
67 pub async fn new(
71 table: lancedb::Table,
72 model: M,
73 id_field: &str,
74 search_params: SearchParams,
75 ) -> Result<Self, lancedb::Error> {
76 Ok(Self {
77 table,
78 model,
79 id_field: id_field.to_string(),
80 search_params,
81 })
82 }
83
84 fn build_query(&self, mut query: VectorQuery) -> VectorQuery {
87 let SearchParams {
88 distance_type,
89 search_type,
90 nprobes,
91 refine_factor,
92 post_filter,
93 column,
94 } = self.search_params.clone();
95
96 if let Some(distance_type) = distance_type {
97 query = query.distance_type(distance_type);
98 }
99
100 if let Some(SearchType::Flat) = search_type {
101 query = query.bypass_vector_index();
102 }
103
104 if let Some(SearchType::Approximate) = search_type {
105 if let Some(nprobes) = nprobes {
106 query = query.nprobes(nprobes);
107 }
108 if let Some(refine_factor) = refine_factor {
109 query = query.refine_factor(refine_factor);
110 }
111 }
112
113 if let Some(true) = post_filter {
114 query = query.postfilter();
115 }
116
117 if let Some(column) = column {
118 query = query.column(column.as_str())
119 }
120
121 query
122 }
123}
124
125#[derive(Debug, Clone)]
127pub enum SearchType {
128 Flat,
130 Approximate,
132}
133
134#[derive(Debug, Clone)]
136pub struct LanceDBFilter(Result<String, FilterError>);
137
138impl serde::Serialize for LanceDBFilter {
139 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
140 where
141 S: serde::Serializer,
142 {
143 match &self.0 {
144 Ok(s) => serializer.serialize_str(s),
145 Err(e) => serializer.collect_str(e),
146 }
147 }
148}
149
150impl<'de> serde::Deserialize<'de> for LanceDBFilter {
151 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
152 where
153 D: serde::Deserializer<'de>,
154 {
155 let s = String::deserialize(deserializer)?;
156 Ok(LanceDBFilter(Ok(s)))
158 }
159}
160
161fn zip_result(
162 l: Result<String, FilterError>,
163 r: Result<String, FilterError>,
164) -> Result<(String, String), FilterError> {
165 l.and_then(|l| r.map(|r| (l, r)))
166}
167
168impl SearchFilter for LanceDBFilter {
169 type Value = serde_json::Value;
170
171 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
172 Self(escape_value(value).map(|s| format!("{} = {s}", key.as_ref())))
173 }
174
175 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
176 Self(escape_value(value).map(|s| format!("{} > {s}", key.as_ref())))
177 }
178
179 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
180 Self(escape_value(value).map(|s| format!("{} < {s}", key.as_ref())))
181 }
182
183 fn and(self, rhs: Self) -> Self {
184 Self(zip_result(self.0, rhs.0).map(|(l, r)| format!("({l}) AND ({r})")))
185 }
186
187 fn or(self, rhs: Self) -> Self {
188 Self(zip_result(self.0, rhs.0).map(|(l, r)| format!("({l}) OR ({r})")))
189 }
190}
191
192fn escape_value(value: serde_json::Value) -> Result<String, FilterError> {
193 use serde_json::Value::*;
194
195 match value {
196 Null => Ok("NULL".into()),
197 Bool(b) => Ok(b.to_string()),
198 Number(n) => Ok(n.to_string()),
199 String(s) => Ok(format!("'{}'", s.replace("'", "''"))),
200 Array(xs) => Ok(format!(
201 "({})",
202 xs.into_iter()
203 .map(escape_value)
204 .collect::<Result<Vec<_>, _>>()?
205 .join(", ")
206 )),
207 Object(_) => Err(FilterError::TypeError(
208 "objects not supported in SQLite backend".into(),
209 )),
210 }
211}
212
213impl LanceDBFilter {
214 pub fn into_inner(self) -> Result<String, FilterError> {
215 self.0
216 }
217
218 #[allow(clippy::should_implement_trait)]
219 pub fn not(self) -> Self {
220 Self(self.0.map(|s| format!("NOT ({s})")))
221 }
222
223 pub fn in_values(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
225 Self(
226 values
227 .into_iter()
228 .map(escape_value)
229 .collect::<Result<Vec<_>, FilterError>>()
230 .map(|xs| xs.join(","))
231 .map(|xs| format!("{key} IN ({xs})")),
232 )
233 }
234
235 pub fn like<S>(key: String, pattern: S) -> Self
237 where
238 S: AsRef<str>,
239 {
240 Self(
241 escape_value(serde_json::Value::String(pattern.as_ref().into()))
242 .map(|pat| format!("{key} LIKE {pat}")),
243 )
244 }
245
246 pub fn ilike<S>(key: String, pattern: S) -> Self
248 where
249 S: AsRef<str>,
250 {
251 Self(
252 escape_value(serde_json::Value::String(pattern.as_ref().into()))
253 .map(|pat| format!("{key} ILIKE {pat}")),
254 )
255 }
256
257 pub fn is_null(key: String) -> Self {
259 Self(Ok(format!("{key} IS NULL")))
260 }
261
262 pub fn is_not_null(key: String) -> Self {
264 Self(Ok(format!("{key} IS NOT NULL")))
265 }
266
267 pub fn array_has_any(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
269 Self(
270 values
271 .into_iter()
272 .map(escape_value)
273 .collect::<Result<Vec<_>, FilterError>>()
274 .map(|xs| xs.join(","))
275 .map(|xs| format!("array_has_any({key}, ARRAY[{xs}])")),
276 )
277 }
278
279 pub fn array_has_all(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
281 Self(
282 values
283 .into_iter()
284 .map(escape_value)
285 .collect::<Result<Vec<_>, FilterError>>()
286 .map(|xs| xs.join(","))
287 .map(|xs| format!("array_has_all({key}, ARRAY[{xs}])")),
288 )
289 }
290
291 pub fn array_length(key: String, length: i32) -> Self {
293 Self(Ok(format!("array_length({key}) = {length}")))
294 }
295
296 pub fn between<T>(key: String, Range { start, end }: Range<T>) -> Self
298 where
299 T: PartialOrd + std::fmt::Display + Into<serde_json::Number>,
300 {
301 Self(Ok(format!("{key} BETWEEN {start} AND {end}")))
302 }
303}
304
305#[derive(Debug, Clone, Default)]
311pub struct SearchParams {
312 distance_type: Option<DistanceType>,
313 search_type: Option<SearchType>,
314 nprobes: Option<usize>,
315 refine_factor: Option<u32>,
316 post_filter: Option<bool>,
317 column: Option<String>,
318}
319
320impl SearchParams {
321 pub fn distance_type(mut self, distance_type: DistanceType) -> Self {
325 self.distance_type = Some(distance_type);
326 self
327 }
328
329 pub fn search_type(mut self, search_type: SearchType) -> Self {
333 self.search_type = Some(search_type);
334 self
335 }
336
337 pub fn nprobes(mut self, nprobes: usize) -> Self {
341 self.nprobes = Some(nprobes);
342 self
343 }
344
345 pub fn refine_factor(mut self, refine_factor: u32) -> Self {
349 self.refine_factor = Some(refine_factor);
350 self
351 }
352
353 pub fn post_filter(mut self, post_filter: bool) -> Self {
357 self.post_filter = Some(post_filter);
358 self
359 }
360
361 pub fn column(mut self, column: &str) -> Self {
365 self.column = Some(column.to_string());
366 self
367 }
368}
369
370impl<M> VectorStoreIndex for LanceDbVectorIndex<M>
371where
372 M: EmbeddingModel + Sync + Send,
373{
374 type Filter = LanceDBFilter;
375
376 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
395 &self,
396 req: VectorSearchRequest<LanceDBFilter>,
397 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
398 let prompt_embedding = self.model.embed_text(req.query()).await?;
399
400 let mut query = self
401 .table
402 .vector_search(prompt_embedding.vec.clone())
403 .map_err(lancedb_to_rig_error)?
404 .limit(req.samples() as usize)
405 .distance_range(None, req.threshold().map(|x| x as f32))
406 .select(lancedb::query::Select::Columns(
407 self.table
408 .schema()
409 .await
410 .map_err(lancedb_to_rig_error)?
411 .filter_embeddings(),
412 ));
413
414 if let Some(filter) = req.filter() {
415 query = query.only_if(filter.clone().into_inner()?)
416 }
417
418 self.build_query(query)
419 .execute_query()
420 .await?
421 .into_iter()
422 .enumerate()
423 .map(|(i, value)| {
424 Ok((
425 match value.get("_distance") {
426 Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(),
427 _ => 0.0,
428 },
429 match value.get(self.id_field.clone()) {
430 Some(Value::String(id)) => id.to_string(),
431 _ => format!("unknown{i}"),
432 },
433 serde_json::from_value(value).map_err(serde_to_rig_error)?,
434 ))
435 })
436 .collect()
437 }
438
439 async fn top_n_ids(
458 &self,
459 req: VectorSearchRequest<LanceDBFilter>,
460 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
461 let prompt_embedding = self.model.embed_text(req.query()).await?;
462
463 let mut query = self
464 .table
465 .query()
466 .select(lancedb::query::Select::Columns(vec![self.id_field.clone()]))
467 .nearest_to(prompt_embedding.vec.clone())
468 .map_err(lancedb_to_rig_error)?
469 .distance_range(None, req.threshold().map(|x| x as f32))
470 .limit(req.samples() as usize);
471
472 if let Some(filter) = req.filter() {
473 query = query.only_if(filter.clone().into_inner()?)
474 }
475
476 self.build_query(query)
477 .execute_query()
478 .await?
479 .into_iter()
480 .map(|value| {
481 Ok((
482 match value.get("distance") {
483 Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(),
484 _ => 0.0,
485 },
486 match value.get(self.id_field.clone()) {
487 Some(Value::String(id)) => id.to_string(),
488 _ => "".to_string(),
489 },
490 ))
491 })
492 .collect()
493 }
494}