rig_lancedb/
lib.rs

1use std::ops::Range;
2
3use lancedb::{
4    DistanceType,
5    query::{QueryBase, VectorQuery},
6};
7use rig::{
8    embeddings::embedding::EmbeddingModel,
9    vector_store::{
10        VectorStoreError, VectorStoreIndex,
11        request::{FilterError, SearchFilter, VectorSearchRequest},
12    },
13};
14use serde::Deserialize;
15use serde_json::Value;
16use utils::{FilterTableColumns, QueryToJson};
17
18mod utils;
19
20fn lancedb_to_rig_error(e: lancedb::Error) -> VectorStoreError {
21    VectorStoreError::DatastoreError(Box::new(e))
22}
23
24fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError {
25    VectorStoreError::JsonError(e)
26}
27
28/// Type on which vector searches can be performed for a lanceDb table.
29/// # Example
30/// ```
31/// use rig_lancedb::{LanceDbVectorIndex, SearchParams};
32/// use rig::providers::openai::{Client, TEXT_EMBEDDING_ADA_002, EmbeddingModel};
33///
34/// let openai_client = Client::from_env();
35///
36/// let table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here.
37/// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here.
38/// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?;
39/// ```
40pub struct LanceDbVectorIndex<M: EmbeddingModel> {
41    /// Defines which model is used to generate embeddings for the vector store.
42    model: M,
43    /// LanceDB table containing embeddings.
44    table: lancedb::Table,
45    /// Column name in `table` that contains the id of a record.
46    id_field: String,
47    /// Vector search params that are used during vector search operations.
48    search_params: SearchParams,
49}
50
51impl<M> LanceDbVectorIndex<M>
52where
53    M: EmbeddingModel,
54{
55    /// Create an instance of `LanceDbVectorIndex` with an existing table and model.
56    /// Define the id field name of the table.
57    /// Define search parameters that will be used to perform vector searches on the table.
58    pub async fn new(
59        table: lancedb::Table,
60        model: M,
61        id_field: &str,
62        search_params: SearchParams,
63    ) -> Result<Self, lancedb::Error> {
64        Ok(Self {
65            table,
66            model,
67            id_field: id_field.to_string(),
68            search_params,
69        })
70    }
71
72    /// Apply the search_params to the vector query.
73    /// This is a helper function used by the methods `top_n` and `top_n_ids` of the `VectorStoreIndex` trait.
74    fn build_query(&self, mut query: VectorQuery) -> VectorQuery {
75        let SearchParams {
76            distance_type,
77            search_type,
78            nprobes,
79            refine_factor,
80            post_filter,
81            column,
82        } = self.search_params.clone();
83
84        if let Some(distance_type) = distance_type {
85            query = query.distance_type(distance_type);
86        }
87
88        if let Some(SearchType::Flat) = search_type {
89            query = query.bypass_vector_index();
90        }
91
92        if let Some(SearchType::Approximate) = search_type {
93            if let Some(nprobes) = nprobes {
94                query = query.nprobes(nprobes);
95            }
96            if let Some(refine_factor) = refine_factor {
97                query = query.refine_factor(refine_factor);
98            }
99        }
100
101        if let Some(true) = post_filter {
102            query = query.postfilter();
103        }
104
105        if let Some(column) = column {
106            query = query.column(column.as_str())
107        }
108
109        query
110    }
111}
112
113/// See [LanceDB vector search](https://lancedb.github.io/lancedb/search/) for more information.
114#[derive(Debug, Clone)]
115pub enum SearchType {
116    // Flat search, also called ENN or kNN.
117    Flat,
118    /// Approximal Nearest Neighbor search, also called ANN.
119    Approximate,
120}
121
122/// An eDSL for filtering expressions, is rendered as a `WHERE` clause
123#[derive(Debug, Clone)]
124pub struct LanceDBFilter(Result<String, FilterError>);
125
126impl serde::Serialize for LanceDBFilter {
127    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
128    where
129        S: serde::Serializer,
130    {
131        match &self.0 {
132            Ok(s) => serializer.serialize_str(s),
133            Err(e) => serializer.collect_str(e),
134        }
135    }
136}
137
138impl<'de> serde::Deserialize<'de> for LanceDBFilter {
139    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
140    where
141        D: serde::Deserializer<'de>,
142    {
143        let s = String::deserialize(deserializer)?;
144        // We can't deserialize to Error, so just create an Ok variant
145        Ok(LanceDBFilter(Ok(s)))
146    }
147}
148
149fn zip_result(
150    l: Result<String, FilterError>,
151    r: Result<String, FilterError>,
152) -> Result<(String, String), FilterError> {
153    l.and_then(|l| r.map(|r| (l, r)))
154}
155
156impl SearchFilter for LanceDBFilter {
157    type Value = serde_json::Value;
158
159    fn eq(key: String, value: Self::Value) -> Self {
160        Self(escape_value(value).map(|s| format!("{key} = {s}")))
161    }
162
163    fn gt(key: String, value: Self::Value) -> Self {
164        Self(escape_value(value).map(|s| format!("{key} > {s}")))
165    }
166
167    fn lt(key: String, value: Self::Value) -> Self {
168        Self(escape_value(value).map(|s| format!("{key} < {s}")))
169    }
170
171    fn and(self, rhs: Self) -> Self {
172        Self(zip_result(self.0, rhs.0).map(|(l, r)| format!("({l}) AND ({r})")))
173    }
174
175    fn or(self, rhs: Self) -> Self {
176        Self(zip_result(self.0, rhs.0).map(|(l, r)| format!("({l}) OR ({r})")))
177    }
178}
179
180fn escape_value(value: serde_json::Value) -> Result<String, FilterError> {
181    use serde_json::Value::*;
182
183    match value {
184        Null => Ok("NULL".into()),
185        Bool(b) => Ok(b.to_string()),
186        Number(n) => Ok(n.to_string()),
187        String(s) => Ok(format!("'{}'", s.replace("'", "''"))),
188        Array(xs) => Ok(format!(
189            "({})",
190            xs.into_iter()
191                .map(escape_value)
192                .collect::<Result<Vec<_>, _>>()?
193                .join(", ")
194        )),
195        Object(_) => Err(FilterError::TypeError(
196            "objects not supported in SQLite backend".into(),
197        )),
198    }
199}
200
201impl LanceDBFilter {
202    pub fn into_inner(self) -> Result<String, FilterError> {
203        self.0
204    }
205
206    #[allow(clippy::should_implement_trait)]
207    pub fn not(self) -> Self {
208        Self(self.0.map(|s| format!("NOT ({s})")))
209    }
210
211    /// IN operator
212    pub fn in_values(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
213        Self(
214            values
215                .into_iter()
216                .map(escape_value)
217                .collect::<Result<Vec<_>, FilterError>>()
218                .map(|xs| xs.join(","))
219                .map(|xs| format!("{key} IN ({xs})")),
220        )
221    }
222
223    /// LIKE operator (string pattern matching)
224    pub fn like<S>(key: String, pattern: S) -> Self
225    where
226        S: AsRef<str>,
227    {
228        Self(
229            escape_value(serde_json::Value::String(pattern.as_ref().into()))
230                .map(|pat| format!("{key} LIKE {pat}")),
231        )
232    }
233
234    /// ILIKE operator (case-insensitive pattern matching)
235    pub fn ilike<S>(key: String, pattern: S) -> Self
236    where
237        S: AsRef<str>,
238    {
239        Self(
240            escape_value(serde_json::Value::String(pattern.as_ref().into()))
241                .map(|pat| format!("{key} ILIKE {pat}")),
242        )
243    }
244
245    /// IS NULL check
246    pub fn is_null(key: String) -> Self {
247        Self(Ok(format!("{key} IS NULL")))
248    }
249
250    /// IS NOT NULL check
251    pub fn is_not_null(key: String) -> Self {
252        Self(Ok(format!("{key} IS NOT NULL")))
253    }
254
255    /// Array has any (for LIST columns with scalar index)
256    pub fn array_has_any(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
257        Self(
258            values
259                .into_iter()
260                .map(escape_value)
261                .collect::<Result<Vec<_>, FilterError>>()
262                .map(|xs| xs.join(","))
263                .map(|xs| format!("array_has_any({key}, ARRAY[{xs}])")),
264        )
265    }
266
267    /// Array has all (for LIST columns with scalar index)
268    pub fn array_has_all(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
269        Self(
270            values
271                .into_iter()
272                .map(escape_value)
273                .collect::<Result<Vec<_>, FilterError>>()
274                .map(|xs| xs.join(","))
275                .map(|xs| format!("array_has_all({key}, ARRAY[{xs}])")),
276        )
277    }
278
279    /// Array length comparison
280    pub fn array_length(key: String, length: i32) -> Self {
281        Self(Ok(format!("array_length({key}) = {length}")))
282    }
283
284    /// BETWEEN operator
285    pub fn between<T>(key: String, Range { start, end }: Range<T>) -> Self
286    where
287        T: PartialOrd + std::fmt::Display + Into<serde_json::Number>,
288    {
289        Self(Ok(format!("{key} BETWEEN {start} AND {end}")))
290    }
291}
292
293/// Parameters used to perform a vector search on a LanceDb table.
294/// # Example
295/// ```
296/// let search_params = rig_lancedb::SearchParams::default().distance_type(lancedb::DistanceType::Cosine);
297/// ```
298#[derive(Debug, Clone, Default)]
299pub struct SearchParams {
300    distance_type: Option<DistanceType>,
301    search_type: Option<SearchType>,
302    nprobes: Option<usize>,
303    refine_factor: Option<u32>,
304    post_filter: Option<bool>,
305    column: Option<String>,
306}
307
308impl SearchParams {
309    /// Sets the distance type of the search params.
310    /// Always set the distance_type to match the value used to train the index.
311    /// The default is DistanceType::L2.
312    pub fn distance_type(mut self, distance_type: DistanceType) -> Self {
313        self.distance_type = Some(distance_type);
314        self
315    }
316
317    /// Sets the search type of the search params.
318    /// By default, ANN will be used if there is an index on the table and kNN will be used if there is NO index on the table.
319    /// To use the mentioned defaults, do not set the search type.
320    pub fn search_type(mut self, search_type: SearchType) -> Self {
321        self.search_type = Some(search_type);
322        self
323    }
324
325    /// Sets the nprobes of the search params.
326    /// Only set this value only when the search type is ANN.
327    /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information.
328    pub fn nprobes(mut self, nprobes: usize) -> Self {
329        self.nprobes = Some(nprobes);
330        self
331    }
332
333    /// Sets the refine factor of the search params.
334    /// Only set this value only when search type is ANN.
335    /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information.
336    pub fn refine_factor(mut self, refine_factor: u32) -> Self {
337        self.refine_factor = Some(refine_factor);
338        self
339    }
340
341    /// Sets the post filter of the search params.
342    /// If set to true, filtering will happen after the vector search instead of before.
343    /// See [LanceDb pre/post filtering](https://lancedb.github.io/lancedb/sql/#pre-and-post-filtering) for more information.
344    pub fn post_filter(mut self, post_filter: bool) -> Self {
345        self.post_filter = Some(post_filter);
346        self
347    }
348
349    /// Sets the column of the search params.
350    /// Only set this value if there is more than one column that contains lists of floats.
351    /// If there is only one column of list of floats, this column will be chosen for the vector search automatically.
352    pub fn column(mut self, column: &str) -> Self {
353        self.column = Some(column.to_string());
354        self
355    }
356}
357
358impl<M> VectorStoreIndex for LanceDbVectorIndex<M>
359where
360    M: EmbeddingModel + Sync + Send,
361{
362    type Filter = LanceDBFilter;
363
364    /// Implement the `top_n` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`.
365    /// # Example
366    /// ```
367    /// use rig_lancedb::{LanceDbVectorIndex, SearchParams};
368    /// use rig::providers::openai::{EmbeddingModel, Client, TEXT_EMBEDDING_ADA_002};
369    ///
370    /// let openai_client = Client::from_env();
371    ///
372    /// let table: lancedb::Table = db.create_table("fake_definitions"); // <-- Replace with your lancedb table here.
373    /// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here.
374    /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?;
375    ///
376    /// // Query the index
377    /// let result = vector_store_index
378    ///     .top_n::<String>("My boss says I zindle too much, what does that mean?", 1)
379    ///     .await?;
380    /// ```
381    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
382        &self,
383        req: VectorSearchRequest<LanceDBFilter>,
384    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
385        let prompt_embedding = self.model.embed_text(req.query()).await?;
386
387        let mut query = self
388            .table
389            .vector_search(prompt_embedding.vec.clone())
390            .map_err(lancedb_to_rig_error)?
391            .limit(req.samples() as usize)
392            .distance_range(None, req.threshold().map(|x| x as f32))
393            .select(lancedb::query::Select::Columns(
394                self.table
395                    .schema()
396                    .await
397                    .map_err(lancedb_to_rig_error)?
398                    .filter_embeddings(),
399            ));
400
401        if let Some(filter) = req.filter() {
402            query = query.only_if(filter.clone().into_inner()?)
403        }
404
405        self.build_query(query)
406            .execute_query()
407            .await?
408            .into_iter()
409            .enumerate()
410            .map(|(i, value)| {
411                Ok((
412                    match value.get("_distance") {
413                        Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(),
414                        _ => 0.0,
415                    },
416                    match value.get(self.id_field.clone()) {
417                        Some(Value::String(id)) => id.to_string(),
418                        _ => format!("unknown{i}"),
419                    },
420                    serde_json::from_value(value).map_err(serde_to_rig_error)?,
421                ))
422            })
423            .collect()
424    }
425
426    /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`.
427    /// # Example
428    /// ```
429    /// use rig_lancedb::{LanceDbVectorIndex, SearchParams};
430    /// use rig::providers::openai::{Client, TEXT_EMBEDDING_ADA_002, EmbeddingModel};
431    ///
432    /// let openai_client = Client::from_env();
433    ///
434    /// let table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here.
435    /// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here.
436    /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?;
437    ///
438    /// // Query the index
439    /// let result = vector_store_index
440    ///     .top_n_ids("My boss says I zindle too much, what does that mean?", 1)
441    ///     .await?;
442    /// ```
443    async fn top_n_ids(
444        &self,
445        req: VectorSearchRequest<LanceDBFilter>,
446    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
447        let prompt_embedding = self.model.embed_text(req.query()).await?;
448
449        let mut query = self
450            .table
451            .query()
452            .select(lancedb::query::Select::Columns(vec![self.id_field.clone()]))
453            .nearest_to(prompt_embedding.vec.clone())
454            .map_err(lancedb_to_rig_error)?
455            .distance_range(None, req.threshold().map(|x| x as f32))
456            .limit(req.samples() as usize);
457
458        if let Some(filter) = req.filter() {
459            query = query.only_if(filter.clone().into_inner()?)
460        }
461
462        self.build_query(query)
463            .execute_query()
464            .await?
465            .into_iter()
466            .map(|value| {
467                Ok((
468                    match value.get("distance") {
469                        Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(),
470                        _ => 0.0,
471                    },
472                    match value.get(self.id_field.clone()) {
473                        Some(Value::String(id)) => id.to_string(),
474                        _ => "".to_string(),
475                    },
476                ))
477            })
478            .collect()
479    }
480}