1use std::ops::Range;
2
3use lancedb::{
4 DistanceType,
5 query::{QueryBase, VectorQuery},
6};
7use rig::{
8 embeddings::embedding::EmbeddingModel,
9 vector_store::{
10 VectorStoreError, VectorStoreIndex,
11 request::{FilterError, SearchFilter, VectorSearchRequest},
12 },
13};
14use serde::Deserialize;
15use serde_json::Value;
16use utils::{FilterTableColumns, QueryToJson};
17
18mod utils;
19
20fn lancedb_to_rig_error(e: lancedb::Error) -> VectorStoreError {
21 VectorStoreError::DatastoreError(Box::new(e))
22}
23
24fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError {
25 VectorStoreError::JsonError(e)
26}
27
28pub struct LanceDbVectorIndex<M: EmbeddingModel> {
41 model: M,
43 table: lancedb::Table,
45 id_field: String,
47 search_params: SearchParams,
49}
50
51impl<M> LanceDbVectorIndex<M>
52where
53 M: EmbeddingModel,
54{
55 pub async fn new(
59 table: lancedb::Table,
60 model: M,
61 id_field: &str,
62 search_params: SearchParams,
63 ) -> Result<Self, lancedb::Error> {
64 Ok(Self {
65 table,
66 model,
67 id_field: id_field.to_string(),
68 search_params,
69 })
70 }
71
72 fn build_query(&self, mut query: VectorQuery) -> VectorQuery {
75 let SearchParams {
76 distance_type,
77 search_type,
78 nprobes,
79 refine_factor,
80 post_filter,
81 column,
82 } = self.search_params.clone();
83
84 if let Some(distance_type) = distance_type {
85 query = query.distance_type(distance_type);
86 }
87
88 if let Some(SearchType::Flat) = search_type {
89 query = query.bypass_vector_index();
90 }
91
92 if let Some(SearchType::Approximate) = search_type {
93 if let Some(nprobes) = nprobes {
94 query = query.nprobes(nprobes);
95 }
96 if let Some(refine_factor) = refine_factor {
97 query = query.refine_factor(refine_factor);
98 }
99 }
100
101 if let Some(true) = post_filter {
102 query = query.postfilter();
103 }
104
105 if let Some(column) = column {
106 query = query.column(column.as_str())
107 }
108
109 query
110 }
111}
112
113#[derive(Debug, Clone)]
115pub enum SearchType {
116 Flat,
118 Approximate,
120}
121
122#[derive(Debug, Clone)]
124pub struct LanceDBFilter(Result<String, FilterError>);
125
126impl serde::Serialize for LanceDBFilter {
127 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
128 where
129 S: serde::Serializer,
130 {
131 match &self.0 {
132 Ok(s) => serializer.serialize_str(s),
133 Err(e) => serializer.collect_str(e),
134 }
135 }
136}
137
138impl<'de> serde::Deserialize<'de> for LanceDBFilter {
139 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
140 where
141 D: serde::Deserializer<'de>,
142 {
143 let s = String::deserialize(deserializer)?;
144 Ok(LanceDBFilter(Ok(s)))
146 }
147}
148
149fn zip_result(
150 l: Result<String, FilterError>,
151 r: Result<String, FilterError>,
152) -> Result<(String, String), FilterError> {
153 l.and_then(|l| r.map(|r| (l, r)))
154}
155
156impl SearchFilter for LanceDBFilter {
157 type Value = serde_json::Value;
158
159 fn eq(key: String, value: Self::Value) -> Self {
160 Self(escape_value(value).map(|s| format!("{key} = {s}")))
161 }
162
163 fn gt(key: String, value: Self::Value) -> Self {
164 Self(escape_value(value).map(|s| format!("{key} > {s}")))
165 }
166
167 fn lt(key: String, value: Self::Value) -> Self {
168 Self(escape_value(value).map(|s| format!("{key} < {s}")))
169 }
170
171 fn and(self, rhs: Self) -> Self {
172 Self(zip_result(self.0, rhs.0).map(|(l, r)| format!("({l}) AND ({r})")))
173 }
174
175 fn or(self, rhs: Self) -> Self {
176 Self(zip_result(self.0, rhs.0).map(|(l, r)| format!("({l}) OR ({r})")))
177 }
178}
179
180fn escape_value(value: serde_json::Value) -> Result<String, FilterError> {
181 use serde_json::Value::*;
182
183 match value {
184 Null => Ok("NULL".into()),
185 Bool(b) => Ok(b.to_string()),
186 Number(n) => Ok(n.to_string()),
187 String(s) => Ok(format!("'{}'", s.replace("'", "''"))),
188 Array(xs) => Ok(format!(
189 "({})",
190 xs.into_iter()
191 .map(escape_value)
192 .collect::<Result<Vec<_>, _>>()?
193 .join(", ")
194 )),
195 Object(_) => Err(FilterError::TypeError(
196 "objects not supported in SQLite backend".into(),
197 )),
198 }
199}
200
201impl LanceDBFilter {
202 pub fn into_inner(self) -> Result<String, FilterError> {
203 self.0
204 }
205
206 #[allow(clippy::should_implement_trait)]
207 pub fn not(self) -> Self {
208 Self(self.0.map(|s| format!("NOT ({s})")))
209 }
210
211 pub fn in_values(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
213 Self(
214 values
215 .into_iter()
216 .map(escape_value)
217 .collect::<Result<Vec<_>, FilterError>>()
218 .map(|xs| xs.join(","))
219 .map(|xs| format!("{key} IN ({xs})")),
220 )
221 }
222
223 pub fn like<S>(key: String, pattern: S) -> Self
225 where
226 S: AsRef<str>,
227 {
228 Self(
229 escape_value(serde_json::Value::String(pattern.as_ref().into()))
230 .map(|pat| format!("{key} LIKE {pat}")),
231 )
232 }
233
234 pub fn ilike<S>(key: String, pattern: S) -> Self
236 where
237 S: AsRef<str>,
238 {
239 Self(
240 escape_value(serde_json::Value::String(pattern.as_ref().into()))
241 .map(|pat| format!("{key} ILIKE {pat}")),
242 )
243 }
244
245 pub fn is_null(key: String) -> Self {
247 Self(Ok(format!("{key} IS NULL")))
248 }
249
250 pub fn is_not_null(key: String) -> Self {
252 Self(Ok(format!("{key} IS NOT NULL")))
253 }
254
255 pub fn array_has_any(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
257 Self(
258 values
259 .into_iter()
260 .map(escape_value)
261 .collect::<Result<Vec<_>, FilterError>>()
262 .map(|xs| xs.join(","))
263 .map(|xs| format!("array_has_any({key}, ARRAY[{xs}])")),
264 )
265 }
266
267 pub fn array_has_all(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
269 Self(
270 values
271 .into_iter()
272 .map(escape_value)
273 .collect::<Result<Vec<_>, FilterError>>()
274 .map(|xs| xs.join(","))
275 .map(|xs| format!("array_has_all({key}, ARRAY[{xs}])")),
276 )
277 }
278
279 pub fn array_length(key: String, length: i32) -> Self {
281 Self(Ok(format!("array_length({key}) = {length}")))
282 }
283
284 pub fn between<T>(key: String, Range { start, end }: Range<T>) -> Self
286 where
287 T: PartialOrd + std::fmt::Display + Into<serde_json::Number>,
288 {
289 Self(Ok(format!("{key} BETWEEN {start} AND {end}")))
290 }
291}
292
293#[derive(Debug, Clone, Default)]
299pub struct SearchParams {
300 distance_type: Option<DistanceType>,
301 search_type: Option<SearchType>,
302 nprobes: Option<usize>,
303 refine_factor: Option<u32>,
304 post_filter: Option<bool>,
305 column: Option<String>,
306}
307
308impl SearchParams {
309 pub fn distance_type(mut self, distance_type: DistanceType) -> Self {
313 self.distance_type = Some(distance_type);
314 self
315 }
316
317 pub fn search_type(mut self, search_type: SearchType) -> Self {
321 self.search_type = Some(search_type);
322 self
323 }
324
325 pub fn nprobes(mut self, nprobes: usize) -> Self {
329 self.nprobes = Some(nprobes);
330 self
331 }
332
333 pub fn refine_factor(mut self, refine_factor: u32) -> Self {
337 self.refine_factor = Some(refine_factor);
338 self
339 }
340
341 pub fn post_filter(mut self, post_filter: bool) -> Self {
345 self.post_filter = Some(post_filter);
346 self
347 }
348
349 pub fn column(mut self, column: &str) -> Self {
353 self.column = Some(column.to_string());
354 self
355 }
356}
357
358impl<M> VectorStoreIndex for LanceDbVectorIndex<M>
359where
360 M: EmbeddingModel + Sync + Send,
361{
362 type Filter = LanceDBFilter;
363
364 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
382 &self,
383 req: VectorSearchRequest<LanceDBFilter>,
384 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
385 let prompt_embedding = self.model.embed_text(req.query()).await?;
386
387 let mut query = self
388 .table
389 .vector_search(prompt_embedding.vec.clone())
390 .map_err(lancedb_to_rig_error)?
391 .limit(req.samples() as usize)
392 .distance_range(None, req.threshold().map(|x| x as f32))
393 .select(lancedb::query::Select::Columns(
394 self.table
395 .schema()
396 .await
397 .map_err(lancedb_to_rig_error)?
398 .filter_embeddings(),
399 ));
400
401 if let Some(filter) = req.filter() {
402 query = query.only_if(filter.clone().into_inner()?)
403 }
404
405 self.build_query(query)
406 .execute_query()
407 .await?
408 .into_iter()
409 .enumerate()
410 .map(|(i, value)| {
411 Ok((
412 match value.get("_distance") {
413 Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(),
414 _ => 0.0,
415 },
416 match value.get(self.id_field.clone()) {
417 Some(Value::String(id)) => id.to_string(),
418 _ => format!("unknown{i}"),
419 },
420 serde_json::from_value(value).map_err(serde_to_rig_error)?,
421 ))
422 })
423 .collect()
424 }
425
426 async fn top_n_ids(
444 &self,
445 req: VectorSearchRequest<LanceDBFilter>,
446 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
447 let prompt_embedding = self.model.embed_text(req.query()).await?;
448
449 let mut query = self
450 .table
451 .query()
452 .select(lancedb::query::Select::Columns(vec![self.id_field.clone()]))
453 .nearest_to(prompt_embedding.vec.clone())
454 .map_err(lancedb_to_rig_error)?
455 .distance_range(None, req.threshold().map(|x| x as f32))
456 .limit(req.samples() as usize);
457
458 if let Some(filter) = req.filter() {
459 query = query.only_if(filter.clone().into_inner()?)
460 }
461
462 self.build_query(query)
463 .execute_query()
464 .await?
465 .into_iter()
466 .map(|value| {
467 Ok((
468 match value.get("distance") {
469 Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(),
470 _ => 0.0,
471 },
472 match value.get(self.id_field.clone()) {
473 Some(Value::String(id)) => id.to_string(),
474 _ => "".to_string(),
475 },
476 ))
477 })
478 .collect()
479 }
480}