rig_lancedb/
lib.rs

1use lancedb::{
2    DistanceType,
3    query::{QueryBase, VectorQuery},
4};
5use rig::{
6    embeddings::embedding::EmbeddingModel,
7    vector_store::{
8        VectorStoreError, VectorStoreIndex,
9        request::{FilterError, SearchFilter, VectorSearchRequest},
10    },
11};
12use serde::Deserialize;
13use serde_json::Value;
14use utils::{FilterTableColumns, QueryToJson};
15
16mod utils;
17
18fn lancedb_to_rig_error(e: lancedb::Error) -> VectorStoreError {
19    VectorStoreError::DatastoreError(Box::new(e))
20}
21
22fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError {
23    VectorStoreError::JsonError(e)
24}
25
26/// Type on which vector searches can be performed for a lanceDb table.
27/// # Example
28/// ```
29/// use rig_lancedb::{LanceDbVectorIndex, SearchParams};
30/// use rig::providers::openai::{Client, TEXT_EMBEDDING_ADA_002, EmbeddingModel};
31///
32/// let openai_client = Client::from_env();
33///
34/// let table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here.
35/// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here.
36/// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?;
37/// ```
38pub struct LanceDbVectorIndex<M: EmbeddingModel> {
39    /// Defines which model is used to generate embeddings for the vector store.
40    model: M,
41    /// LanceDB table containing embeddings.
42    table: lancedb::Table,
43    /// Column name in `table` that contains the id of a record.
44    id_field: String,
45    /// Vector search params that are used during vector search operations.
46    search_params: SearchParams,
47}
48
49impl<M> LanceDbVectorIndex<M>
50where
51    M: EmbeddingModel,
52{
53    /// Create an instance of `LanceDbVectorIndex` with an existing table and model.
54    /// Define the id field name of the table.
55    /// Define search parameters that will be used to perform vector searches on the table.
56    pub async fn new(
57        table: lancedb::Table,
58        model: M,
59        id_field: &str,
60        search_params: SearchParams,
61    ) -> Result<Self, lancedb::Error> {
62        Ok(Self {
63            table,
64            model,
65            id_field: id_field.to_string(),
66            search_params,
67        })
68    }
69
70    /// Apply the search_params to the vector query.
71    /// This is a helper function used by the methods `top_n` and `top_n_ids` of the `VectorStoreIndex` trait.
72    fn build_query(&self, mut query: VectorQuery) -> VectorQuery {
73        let SearchParams {
74            distance_type,
75            search_type,
76            nprobes,
77            refine_factor,
78            post_filter,
79            column,
80        } = self.search_params.clone();
81
82        if let Some(distance_type) = distance_type {
83            query = query.distance_type(distance_type);
84        }
85
86        if let Some(SearchType::Flat) = search_type {
87            query = query.bypass_vector_index();
88        }
89
90        if let Some(SearchType::Approximate) = search_type {
91            if let Some(nprobes) = nprobes {
92                query = query.nprobes(nprobes);
93            }
94            if let Some(refine_factor) = refine_factor {
95                query = query.refine_factor(refine_factor);
96            }
97        }
98
99        if let Some(true) = post_filter {
100            query = query.postfilter();
101        }
102
103        if let Some(column) = column {
104            query = query.column(column.as_str())
105        }
106
107        query
108    }
109}
110
111/// See [LanceDB vector search](https://lancedb.github.io/lancedb/search/) for more information.
112#[derive(Debug, Clone)]
113pub enum SearchType {
114    // Flat search, also called ENN or kNN.
115    Flat,
116    /// Approximal Nearest Neighbor search, also called ANN.
117    Approximate,
118}
119
120/// An eDSL for filtering expressions, is rendered as a `WHERE` clause
121#[derive(Debug, Clone)]
122pub struct LanceDBFilter(Result<String, FilterError>);
123
124fn zip_result(
125    l: Result<String, FilterError>,
126    r: Result<String, FilterError>,
127) -> Result<(String, String), FilterError> {
128    l.and_then(|l| r.map(|r| (l, r)))
129}
130
131impl SearchFilter for LanceDBFilter {
132    type Value = serde_json::Value;
133
134    fn eq(key: String, value: Self::Value) -> Self {
135        Self(escape_value(value).map(|s| format!("{key} = {s}")))
136    }
137
138    fn gt(key: String, value: Self::Value) -> Self {
139        Self(escape_value(value).map(|s| format!("{key} > {s}")))
140    }
141
142    fn lt(key: String, value: Self::Value) -> Self {
143        Self(escape_value(value).map(|s| format!("{key} < {s}")))
144    }
145
146    fn and(self, rhs: Self) -> Self {
147        Self(zip_result(self.0, rhs.0).map(|(l, r)| format!("({l}) AND ({r})")))
148    }
149
150    fn or(self, rhs: Self) -> Self {
151        Self(zip_result(self.0, rhs.0).map(|(l, r)| format!("({l}) OR ({r})")))
152    }
153}
154
155fn escape_value(value: serde_json::Value) -> Result<String, FilterError> {
156    use serde_json::Value::*;
157
158    match value {
159        Null => Ok("NULL".into()),
160        Bool(b) => Ok(b.to_string()),
161        Number(n) => Ok(n.to_string()),
162        String(s) => Ok(format!("'{}'", s.replace("'", "''"))),
163        Array(xs) => Ok(format!(
164            "({})",
165            xs.into_iter()
166                .map(escape_value)
167                .collect::<Result<Vec<_>, _>>()?
168                .join(", ")
169        )),
170        Object(_) => Err(FilterError::TypeError(
171            "objects not supported in SQLite backend".into(),
172        )),
173    }
174}
175
176impl LanceDBFilter {
177    pub fn into_inner(self) -> Result<String, FilterError> {
178        self.0
179    }
180
181    #[allow(clippy::should_implement_trait)]
182    pub fn not(self) -> Self {
183        Self(self.0.map(|s| format!("NOT ({s})")))
184    }
185}
186
187/// Parameters used to perform a vector search on a LanceDb table.
188/// # Example
189/// ```
190/// let search_params = rig_lancedb::SearchParams::default().distance_type(lancedb::DistanceType::Cosine);
191/// ```
192#[derive(Debug, Clone, Default)]
193pub struct SearchParams {
194    distance_type: Option<DistanceType>,
195    search_type: Option<SearchType>,
196    nprobes: Option<usize>,
197    refine_factor: Option<u32>,
198    post_filter: Option<bool>,
199    column: Option<String>,
200}
201
202impl SearchParams {
203    /// Sets the distance type of the search params.
204    /// Always set the distance_type to match the value used to train the index.
205    /// The default is DistanceType::L2.
206    pub fn distance_type(mut self, distance_type: DistanceType) -> Self {
207        self.distance_type = Some(distance_type);
208        self
209    }
210
211    /// Sets the search type of the search params.
212    /// By default, ANN will be used if there is an index on the table and kNN will be used if there is NO index on the table.
213    /// To use the mentioned defaults, do not set the search type.
214    pub fn search_type(mut self, search_type: SearchType) -> Self {
215        self.search_type = Some(search_type);
216        self
217    }
218
219    /// Sets the nprobes of the search params.
220    /// Only set this value only when the search type is ANN.
221    /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information.
222    pub fn nprobes(mut self, nprobes: usize) -> Self {
223        self.nprobes = Some(nprobes);
224        self
225    }
226
227    /// Sets the refine factor of the search params.
228    /// Only set this value only when search type is ANN.
229    /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information.
230    pub fn refine_factor(mut self, refine_factor: u32) -> Self {
231        self.refine_factor = Some(refine_factor);
232        self
233    }
234
235    /// Sets the post filter of the search params.
236    /// If set to true, filtering will happen after the vector search instead of before.
237    /// See [LanceDb pre/post filtering](https://lancedb.github.io/lancedb/sql/#pre-and-post-filtering) for more information.
238    pub fn post_filter(mut self, post_filter: bool) -> Self {
239        self.post_filter = Some(post_filter);
240        self
241    }
242
243    /// Sets the column of the search params.
244    /// Only set this value if there is more than one column that contains lists of floats.
245    /// If there is only one column of list of floats, this column will be chosen for the vector search automatically.
246    pub fn column(mut self, column: &str) -> Self {
247        self.column = Some(column.to_string());
248        self
249    }
250}
251
252impl<M> VectorStoreIndex for LanceDbVectorIndex<M>
253where
254    M: EmbeddingModel + Sync + Send,
255{
256    type Filter = LanceDBFilter;
257
258    /// Implement the `top_n` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`.
259    /// # Example
260    /// ```
261    /// use rig_lancedb::{LanceDbVectorIndex, SearchParams};
262    /// use rig::providers::openai::{EmbeddingModel, Client, TEXT_EMBEDDING_ADA_002};
263    ///
264    /// let openai_client = Client::from_env();
265    ///
266    /// let table: lancedb::Table = db.create_table("fake_definitions"); // <-- Replace with your lancedb table here.
267    /// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here.
268    /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?;
269    ///
270    /// // Query the index
271    /// let result = vector_store_index
272    ///     .top_n::<String>("My boss says I zindle too much, what does that mean?", 1)
273    ///     .await?;
274    /// ```
275    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
276        &self,
277        req: VectorSearchRequest<LanceDBFilter>,
278    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
279        let prompt_embedding = self.model.embed_text(req.query()).await?;
280
281        let mut query = self
282            .table
283            .vector_search(prompt_embedding.vec.clone())
284            .map_err(lancedb_to_rig_error)?
285            .limit(req.samples() as usize)
286            .distance_range(None, req.threshold().map(|x| x as f32))
287            .select(lancedb::query::Select::Columns(
288                self.table
289                    .schema()
290                    .await
291                    .map_err(lancedb_to_rig_error)?
292                    .filter_embeddings(),
293            ));
294
295        if let Some(filter) = req.filter() {
296            query = query.only_if(filter.clone().into_inner()?)
297        }
298
299        self.build_query(query)
300            .execute_query()
301            .await?
302            .into_iter()
303            .enumerate()
304            .map(|(i, value)| {
305                Ok((
306                    match value.get("_distance") {
307                        Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(),
308                        _ => 0.0,
309                    },
310                    match value.get(self.id_field.clone()) {
311                        Some(Value::String(id)) => id.to_string(),
312                        _ => format!("unknown{i}"),
313                    },
314                    serde_json::from_value(value).map_err(serde_to_rig_error)?,
315                ))
316            })
317            .collect()
318    }
319
320    /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`.
321    /// # Example
322    /// ```
323    /// use rig_lancedb::{LanceDbVectorIndex, SearchParams};
324    /// use rig::providers::openai::{Client, TEXT_EMBEDDING_ADA_002, EmbeddingModel};
325    ///
326    /// let openai_client = Client::from_env();
327    ///
328    /// let table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here.
329    /// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here.
330    /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?;
331    ///
332    /// // Query the index
333    /// let result = vector_store_index
334    ///     .top_n_ids("My boss says I zindle too much, what does that mean?", 1)
335    ///     .await?;
336    /// ```
337    async fn top_n_ids(
338        &self,
339        req: VectorSearchRequest<LanceDBFilter>,
340    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
341        let prompt_embedding = self.model.embed_text(req.query()).await?;
342
343        let mut query = self
344            .table
345            .query()
346            .select(lancedb::query::Select::Columns(vec![self.id_field.clone()]))
347            .nearest_to(prompt_embedding.vec.clone())
348            .map_err(lancedb_to_rig_error)?
349            .distance_range(None, req.threshold().map(|x| x as f32))
350            .limit(req.samples() as usize);
351
352        if let Some(filter) = req.filter() {
353            query = query.only_if(filter.clone().into_inner()?)
354        }
355
356        self.build_query(query)
357            .execute_query()
358            .await?
359            .into_iter()
360            .map(|value| {
361                Ok((
362                    match value.get("distance") {
363                        Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(),
364                        _ => 0.0,
365                    },
366                    match value.get(self.id_field.clone()) {
367                        Some(Value::String(id)) => id.to_string(),
368                        _ => "".to_string(),
369                    },
370                ))
371            })
372            .collect()
373    }
374}