Skip to main content

rig_lancedb/
lib.rs

1#![cfg_attr(
2    test,
3    allow(
4        clippy::expect_used,
5        clippy::indexing_slicing,
6        clippy::panic,
7        clippy::unwrap_used,
8        clippy::unreachable
9    )
10)]
11
12use std::ops::Range;
13
14use lancedb::{
15    DistanceType,
16    query::{QueryBase, VectorQuery},
17};
18use rig::{
19    embeddings::embedding::EmbeddingModel,
20    vector_store::{
21        VectorStoreError, VectorStoreIndex,
22        request::{FilterError, SearchFilter, VectorSearchRequest},
23    },
24};
25use serde::Deserialize;
26use serde_json::Value;
27use utils::{FilterTableColumns, QueryToJson};
28
29mod utils;
30
31fn lancedb_to_rig_error(e: lancedb::Error) -> VectorStoreError {
32    VectorStoreError::DatastoreError(Box::new(e))
33}
34
35fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError {
36    VectorStoreError::JsonError(e)
37}
38
39/// Type on which vector searches can be performed for a lanceDb table.
40/// # Example
41/// ```ignore
42/// use rig_lancedb::{LanceDbVectorIndex, SearchParams};
43/// use rig::client::ProviderClient;
44/// use rig::providers::openai::{Client, TEXT_EMBEDDING_ADA_002, EmbeddingModel};
45///
46/// let openai_client = Client::from_env()?;
47///
48/// let table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here.
49/// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here.
50/// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?;
51/// ```
52pub struct LanceDbVectorIndex<M: EmbeddingModel> {
53    /// Defines which model is used to generate embeddings for the vector store.
54    model: M,
55    /// LanceDB table containing embeddings.
56    table: lancedb::Table,
57    /// Column name in `table` that contains the id of a record.
58    id_field: String,
59    /// Vector search params that are used during vector search operations.
60    search_params: SearchParams,
61}
62
63impl<M> LanceDbVectorIndex<M>
64where
65    M: EmbeddingModel,
66{
67    /// Create an instance of `LanceDbVectorIndex` with an existing table and model.
68    /// Define the id field name of the table.
69    /// Define search parameters that will be used to perform vector searches on the table.
70    pub async fn new(
71        table: lancedb::Table,
72        model: M,
73        id_field: &str,
74        search_params: SearchParams,
75    ) -> Result<Self, lancedb::Error> {
76        Ok(Self {
77            table,
78            model,
79            id_field: id_field.to_string(),
80            search_params,
81        })
82    }
83
84    /// Apply the search_params to the vector query.
85    /// This is a helper function used by the methods `top_n` and `top_n_ids` of the `VectorStoreIndex` trait.
86    fn build_query(&self, mut query: VectorQuery) -> VectorQuery {
87        let SearchParams {
88            distance_type,
89            search_type,
90            nprobes,
91            refine_factor,
92            post_filter,
93            column,
94        } = self.search_params.clone();
95
96        if let Some(distance_type) = distance_type {
97            query = query.distance_type(distance_type);
98        }
99
100        if let Some(SearchType::Flat) = search_type {
101            query = query.bypass_vector_index();
102        }
103
104        if let Some(SearchType::Approximate) = search_type {
105            if let Some(nprobes) = nprobes {
106                query = query.nprobes(nprobes);
107            }
108            if let Some(refine_factor) = refine_factor {
109                query = query.refine_factor(refine_factor);
110            }
111        }
112
113        if let Some(true) = post_filter {
114            query = query.postfilter();
115        }
116
117        if let Some(column) = column {
118            query = query.column(column.as_str())
119        }
120
121        query
122    }
123}
124
125/// See [LanceDB vector search](https://lancedb.github.io/lancedb/search/) for more information.
126#[derive(Debug, Clone)]
127pub enum SearchType {
128    // Flat search, also called ENN or kNN.
129    Flat,
130    /// Approximal Nearest Neighbor search, also called ANN.
131    Approximate,
132}
133
134/// An eDSL for filtering expressions, is rendered as a `WHERE` clause
135#[derive(Debug, Clone)]
136pub struct LanceDBFilter(Result<String, FilterError>);
137
138impl serde::Serialize for LanceDBFilter {
139    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
140    where
141        S: serde::Serializer,
142    {
143        match &self.0 {
144            Ok(s) => serializer.serialize_str(s),
145            Err(e) => serializer.collect_str(e),
146        }
147    }
148}
149
150impl<'de> serde::Deserialize<'de> for LanceDBFilter {
151    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
152    where
153        D: serde::Deserializer<'de>,
154    {
155        let s = String::deserialize(deserializer)?;
156        // We can't deserialize to Error, so just create an Ok variant
157        Ok(LanceDBFilter(Ok(s)))
158    }
159}
160
161fn zip_result(
162    l: Result<String, FilterError>,
163    r: Result<String, FilterError>,
164) -> Result<(String, String), FilterError> {
165    l.and_then(|l| r.map(|r| (l, r)))
166}
167
168impl SearchFilter for LanceDBFilter {
169    type Value = serde_json::Value;
170
171    fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
172        Self(escape_value(value).map(|s| format!("{} = {s}", key.as_ref())))
173    }
174
175    fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
176        Self(escape_value(value).map(|s| format!("{} > {s}", key.as_ref())))
177    }
178
179    fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
180        Self(escape_value(value).map(|s| format!("{} < {s}", key.as_ref())))
181    }
182
183    fn and(self, rhs: Self) -> Self {
184        Self(zip_result(self.0, rhs.0).map(|(l, r)| format!("({l}) AND ({r})")))
185    }
186
187    fn or(self, rhs: Self) -> Self {
188        Self(zip_result(self.0, rhs.0).map(|(l, r)| format!("({l}) OR ({r})")))
189    }
190}
191
192fn escape_value(value: serde_json::Value) -> Result<String, FilterError> {
193    use serde_json::Value::*;
194
195    match value {
196        Null => Ok("NULL".into()),
197        Bool(b) => Ok(b.to_string()),
198        Number(n) => Ok(n.to_string()),
199        String(s) => Ok(format!("'{}'", s.replace("'", "''"))),
200        Array(xs) => Ok(format!(
201            "({})",
202            xs.into_iter()
203                .map(escape_value)
204                .collect::<Result<Vec<_>, _>>()?
205                .join(", ")
206        )),
207        Object(_) => Err(FilterError::TypeError(
208            "objects not supported in SQLite backend".into(),
209        )),
210    }
211}
212
213impl LanceDBFilter {
214    pub fn into_inner(self) -> Result<String, FilterError> {
215        self.0
216    }
217
218    #[allow(clippy::should_implement_trait)]
219    pub fn not(self) -> Self {
220        Self(self.0.map(|s| format!("NOT ({s})")))
221    }
222
223    /// IN operator
224    pub fn in_values(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
225        Self(
226            values
227                .into_iter()
228                .map(escape_value)
229                .collect::<Result<Vec<_>, FilterError>>()
230                .map(|xs| xs.join(","))
231                .map(|xs| format!("{key} IN ({xs})")),
232        )
233    }
234
235    /// LIKE operator (string pattern matching)
236    pub fn like<S>(key: String, pattern: S) -> Self
237    where
238        S: AsRef<str>,
239    {
240        Self(
241            escape_value(serde_json::Value::String(pattern.as_ref().into()))
242                .map(|pat| format!("{key} LIKE {pat}")),
243        )
244    }
245
246    /// ILIKE operator (case-insensitive pattern matching)
247    pub fn ilike<S>(key: String, pattern: S) -> Self
248    where
249        S: AsRef<str>,
250    {
251        Self(
252            escape_value(serde_json::Value::String(pattern.as_ref().into()))
253                .map(|pat| format!("{key} ILIKE {pat}")),
254        )
255    }
256
257    /// IS NULL check
258    pub fn is_null(key: String) -> Self {
259        Self(Ok(format!("{key} IS NULL")))
260    }
261
262    /// IS NOT NULL check
263    pub fn is_not_null(key: String) -> Self {
264        Self(Ok(format!("{key} IS NOT NULL")))
265    }
266
267    /// Array has any (for LIST columns with scalar index)
268    pub fn array_has_any(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
269        Self(
270            values
271                .into_iter()
272                .map(escape_value)
273                .collect::<Result<Vec<_>, FilterError>>()
274                .map(|xs| xs.join(","))
275                .map(|xs| format!("array_has_any({key}, ARRAY[{xs}])")),
276        )
277    }
278
279    /// Array has all (for LIST columns with scalar index)
280    pub fn array_has_all(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
281        Self(
282            values
283                .into_iter()
284                .map(escape_value)
285                .collect::<Result<Vec<_>, FilterError>>()
286                .map(|xs| xs.join(","))
287                .map(|xs| format!("array_has_all({key}, ARRAY[{xs}])")),
288        )
289    }
290
291    /// Array length comparison
292    pub fn array_length(key: String, length: i32) -> Self {
293        Self(Ok(format!("array_length({key}) = {length}")))
294    }
295
296    /// BETWEEN operator
297    pub fn between<T>(key: String, Range { start, end }: Range<T>) -> Self
298    where
299        T: PartialOrd + std::fmt::Display + Into<serde_json::Number>,
300    {
301        Self(Ok(format!("{key} BETWEEN {start} AND {end}")))
302    }
303}
304
305/// Parameters used to perform a vector search on a LanceDb table.
306/// # Example
307/// ```
308/// let search_params = rig_lancedb::SearchParams::default().distance_type(lancedb::DistanceType::Cosine);
309/// ```
310#[derive(Debug, Clone, Default)]
311pub struct SearchParams {
312    distance_type: Option<DistanceType>,
313    search_type: Option<SearchType>,
314    nprobes: Option<usize>,
315    refine_factor: Option<u32>,
316    post_filter: Option<bool>,
317    column: Option<String>,
318}
319
320impl SearchParams {
321    /// Sets the distance type of the search params.
322    /// Always set the distance_type to match the value used to train the index.
323    /// The default is DistanceType::L2.
324    pub fn distance_type(mut self, distance_type: DistanceType) -> Self {
325        self.distance_type = Some(distance_type);
326        self
327    }
328
329    /// Sets the search type of the search params.
330    /// By default, ANN will be used if there is an index on the table and kNN will be used if there is NO index on the table.
331    /// To use the mentioned defaults, do not set the search type.
332    pub fn search_type(mut self, search_type: SearchType) -> Self {
333        self.search_type = Some(search_type);
334        self
335    }
336
337    /// Sets the nprobes of the search params.
338    /// Only set this value only when the search type is ANN.
339    /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information.
340    pub fn nprobes(mut self, nprobes: usize) -> Self {
341        self.nprobes = Some(nprobes);
342        self
343    }
344
345    /// Sets the refine factor of the search params.
346    /// Only set this value only when search type is ANN.
347    /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information.
348    pub fn refine_factor(mut self, refine_factor: u32) -> Self {
349        self.refine_factor = Some(refine_factor);
350        self
351    }
352
353    /// Sets the post filter of the search params.
354    /// If set to true, filtering will happen after the vector search instead of before.
355    /// See [LanceDb pre/post filtering](https://lancedb.github.io/lancedb/sql/#pre-and-post-filtering) for more information.
356    pub fn post_filter(mut self, post_filter: bool) -> Self {
357        self.post_filter = Some(post_filter);
358        self
359    }
360
361    /// Sets the column of the search params.
362    /// Only set this value if there is more than one column that contains lists of floats.
363    /// If there is only one column of list of floats, this column will be chosen for the vector search automatically.
364    pub fn column(mut self, column: &str) -> Self {
365        self.column = Some(column.to_string());
366        self
367    }
368}
369
370impl<M> VectorStoreIndex for LanceDbVectorIndex<M>
371where
372    M: EmbeddingModel + Sync + Send,
373{
374    type Filter = LanceDBFilter;
375
376    /// Implement the `top_n` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`.
377    /// # Example
378    /// ```ignore
379    /// use rig_lancedb::{LanceDbVectorIndex, SearchParams};
380    /// use rig::client::ProviderClient;
381    /// use rig::providers::openai::{EmbeddingModel, Client, TEXT_EMBEDDING_ADA_002};
382    ///
383    /// let openai_client = Client::from_env()?;
384    ///
385    /// let table: lancedb::Table = db.create_table("fake_definitions"); // <-- Replace with your lancedb table here.
386    /// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here.
387    /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?;
388    ///
389    /// // Query the index
390    /// let result = vector_store_index
391    ///     .top_n::<String>("My boss says I zindle too much, what does that mean?", 1)
392    ///     .await?;
393    /// ```
394    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
395        &self,
396        req: VectorSearchRequest<LanceDBFilter>,
397    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
398        let prompt_embedding = self.model.embed_text(req.query()).await?;
399
400        let mut query = self
401            .table
402            .vector_search(prompt_embedding.vec.clone())
403            .map_err(lancedb_to_rig_error)?
404            .limit(req.samples() as usize)
405            .distance_range(None, req.threshold().map(|x| x as f32))
406            .select(lancedb::query::Select::Columns(
407                self.table
408                    .schema()
409                    .await
410                    .map_err(lancedb_to_rig_error)?
411                    .filter_embeddings(),
412            ));
413
414        if let Some(filter) = req.filter() {
415            query = query.only_if(filter.clone().into_inner()?)
416        }
417
418        self.build_query(query)
419            .execute_query()
420            .await?
421            .into_iter()
422            .enumerate()
423            .map(|(i, value)| {
424                Ok((
425                    match value.get("_distance") {
426                        Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(),
427                        _ => 0.0,
428                    },
429                    match value.get(self.id_field.clone()) {
430                        Some(Value::String(id)) => id.to_string(),
431                        _ => format!("unknown{i}"),
432                    },
433                    serde_json::from_value(value).map_err(serde_to_rig_error)?,
434                ))
435            })
436            .collect()
437    }
438
439    /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`.
440    /// # Example
441    /// ```ignore
442    /// use rig_lancedb::{LanceDbVectorIndex, SearchParams};
443    /// use rig::client::ProviderClient;
444    /// use rig::providers::openai::{Client, TEXT_EMBEDDING_ADA_002, EmbeddingModel};
445    ///
446    /// let openai_client = Client::from_env()?;
447    ///
448    /// let table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here.
449    /// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here.
450    /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?;
451    ///
452    /// // Query the index
453    /// let result = vector_store_index
454    ///     .top_n_ids("My boss says I zindle too much, what does that mean?", 1)
455    ///     .await?;
456    /// ```
457    async fn top_n_ids(
458        &self,
459        req: VectorSearchRequest<LanceDBFilter>,
460    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
461        let prompt_embedding = self.model.embed_text(req.query()).await?;
462
463        let mut query = self
464            .table
465            .query()
466            .select(lancedb::query::Select::Columns(vec![self.id_field.clone()]))
467            .nearest_to(prompt_embedding.vec.clone())
468            .map_err(lancedb_to_rig_error)?
469            .distance_range(None, req.threshold().map(|x| x as f32))
470            .limit(req.samples() as usize);
471
472        if let Some(filter) = req.filter() {
473            query = query.only_if(filter.clone().into_inner()?)
474        }
475
476        self.build_query(query)
477            .execute_query()
478            .await?
479            .into_iter()
480            .map(|value| {
481                Ok((
482                    match value.get("distance") {
483                        Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(),
484                        _ => 0.0,
485                    },
486                    match value.get(self.id_field.clone()) {
487                        Some(Value::String(id)) => id.to_string(),
488                        _ => "".to_string(),
489                    },
490                ))
491            })
492            .collect()
493    }
494}