rig_lancedb/
lib.rs

1use lancedb::{
2    DistanceType,
3    query::{QueryBase, VectorQuery},
4};
5use rig::{
6    embeddings::embedding::EmbeddingModel,
7    vector_store::{VectorStoreError, VectorStoreIndex, request::VectorSearchRequest},
8};
9use serde::Deserialize;
10use serde_json::Value;
11use utils::{FilterTableColumns, QueryToJson};
12
13mod utils;
14
15fn lancedb_to_rig_error(e: lancedb::Error) -> VectorStoreError {
16    VectorStoreError::DatastoreError(Box::new(e))
17}
18
19fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError {
20    VectorStoreError::JsonError(e)
21}
22
23/// Type on which vector searches can be performed for a lanceDb table.
24/// # Example
25/// ```
26/// use rig_lancedb::{LanceDbVectorIndex, SearchParams};
27/// use rig::providers::openai::{Client, TEXT_EMBEDDING_ADA_002, EmbeddingModel};
28///
29/// let openai_client = Client::from_env();
30///
31/// let table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here.
32/// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here.
33/// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?;
34/// ```
35pub struct LanceDbVectorIndex<M: EmbeddingModel> {
36    /// Defines which model is used to generate embeddings for the vector store.
37    model: M,
38    /// LanceDB table containing embeddings.
39    table: lancedb::Table,
40    /// Column name in `table` that contains the id of a record.
41    id_field: String,
42    /// Vector search params that are used during vector search operations.
43    search_params: SearchParams,
44}
45
46impl<M> LanceDbVectorIndex<M>
47where
48    M: EmbeddingModel,
49{
50    /// Create an instance of `LanceDbVectorIndex` with an existing table and model.
51    /// Define the id field name of the table.
52    /// Define search parameters that will be used to perform vector searches on the table.
53    pub async fn new(
54        table: lancedb::Table,
55        model: M,
56        id_field: &str,
57        search_params: SearchParams,
58    ) -> Result<Self, lancedb::Error> {
59        Ok(Self {
60            table,
61            model,
62            id_field: id_field.to_string(),
63            search_params,
64        })
65    }
66
67    /// Apply the search_params to the vector query.
68    /// This is a helper function used by the methods `top_n` and `top_n_ids` of the `VectorStoreIndex` trait.
69    fn build_query(&self, mut query: VectorQuery) -> VectorQuery {
70        let SearchParams {
71            distance_type,
72            search_type,
73            nprobes,
74            refine_factor,
75            post_filter,
76            column,
77        } = self.search_params.clone();
78
79        if let Some(distance_type) = distance_type {
80            query = query.distance_type(distance_type);
81        }
82
83        if let Some(SearchType::Flat) = search_type {
84            query = query.bypass_vector_index();
85        }
86
87        if let Some(SearchType::Approximate) = search_type {
88            if let Some(nprobes) = nprobes {
89                query = query.nprobes(nprobes);
90            }
91            if let Some(refine_factor) = refine_factor {
92                query = query.refine_factor(refine_factor);
93            }
94        }
95
96        if let Some(true) = post_filter {
97            query = query.postfilter();
98        }
99
100        if let Some(column) = column {
101            query = query.column(column.as_str())
102        }
103
104        query
105    }
106}
107
108/// See [LanceDB vector search](https://lancedb.github.io/lancedb/search/) for more information.
109#[derive(Debug, Clone)]
110pub enum SearchType {
111    // Flat search, also called ENN or kNN.
112    Flat,
113    /// Approximal Nearest Neighbor search, also called ANN.
114    Approximate,
115}
116
117/// Parameters used to perform a vector search on a LanceDb table.
118/// # Example
119/// ```
120/// let search_params = rig_lancedb::SearchParams::default().distance_type(lancedb::DistanceType::Cosine);
121/// ```
122#[derive(Debug, Clone, Default)]
123pub struct SearchParams {
124    distance_type: Option<DistanceType>,
125    search_type: Option<SearchType>,
126    nprobes: Option<usize>,
127    refine_factor: Option<u32>,
128    post_filter: Option<bool>,
129    column: Option<String>,
130}
131
132impl SearchParams {
133    /// Sets the distance type of the search params.
134    /// Always set the distance_type to match the value used to train the index.
135    /// The default is DistanceType::L2.
136    pub fn distance_type(mut self, distance_type: DistanceType) -> Self {
137        self.distance_type = Some(distance_type);
138        self
139    }
140
141    /// Sets the search type of the search params.
142    /// By default, ANN will be used if there is an index on the table and kNN will be used if there is NO index on the table.
143    /// To use the mentioned defaults, do not set the search type.
144    pub fn search_type(mut self, search_type: SearchType) -> Self {
145        self.search_type = Some(search_type);
146        self
147    }
148
149    /// Sets the nprobes of the search params.
150    /// Only set this value only when the search type is ANN.
151    /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information.
152    pub fn nprobes(mut self, nprobes: usize) -> Self {
153        self.nprobes = Some(nprobes);
154        self
155    }
156
157    /// Sets the refine factor of the search params.
158    /// Only set this value only when search type is ANN.
159    /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information.
160    pub fn refine_factor(mut self, refine_factor: u32) -> Self {
161        self.refine_factor = Some(refine_factor);
162        self
163    }
164
165    /// Sets the post filter of the search params.
166    /// If set to true, filtering will happen after the vector search instead of before.
167    /// See [LanceDb pre/post filtering](https://lancedb.github.io/lancedb/sql/#pre-and-post-filtering) for more information.
168    pub fn post_filter(mut self, post_filter: bool) -> Self {
169        self.post_filter = Some(post_filter);
170        self
171    }
172
173    /// Sets the column of the search params.
174    /// Only set this value if there is more than one column that contains lists of floats.
175    /// If there is only one column of list of floats, this column will be chosen for the vector search automatically.
176    pub fn column(mut self, column: &str) -> Self {
177        self.column = Some(column.to_string());
178        self
179    }
180}
181
182impl<M> VectorStoreIndex for LanceDbVectorIndex<M>
183where
184    M: EmbeddingModel + Sync + Send,
185{
186    /// Implement the `top_n` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`.
187    /// # Example
188    /// ```
189    /// use rig_lancedb::{LanceDbVectorIndex, SearchParams};
190    /// use rig::providers::openai::{EmbeddingModel, Client, TEXT_EMBEDDING_ADA_002};
191    ///
192    /// let openai_client = Client::from_env();
193    ///
194    /// let table: lancedb::Table = db.create_table("fake_definitions"); // <-- Replace with your lancedb table here.
195    /// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here.
196    /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?;
197    ///
198    /// // Query the index
199    /// let result = vector_store_index
200    ///     .top_n::<String>("My boss says I zindle too much, what does that mean?", 1)
201    ///     .await?;
202    /// ```
203    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
204        &self,
205        req: VectorSearchRequest,
206    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
207        let prompt_embedding = self.model.embed_text(req.query()).await?;
208
209        let query = self
210            .table
211            .vector_search(prompt_embedding.vec.clone())
212            .map_err(lancedb_to_rig_error)?
213            .limit(req.samples() as usize)
214            .distance_range(None, req.threshold().map(|x| x as f32))
215            .select(lancedb::query::Select::Columns(
216                self.table
217                    .schema()
218                    .await
219                    .map_err(lancedb_to_rig_error)?
220                    .filter_embeddings(),
221            ));
222
223        self.build_query(query)
224            .execute_query()
225            .await?
226            .into_iter()
227            .enumerate()
228            .map(|(i, value)| {
229                Ok((
230                    match value.get("_distance") {
231                        Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(),
232                        _ => 0.0,
233                    },
234                    match value.get(self.id_field.clone()) {
235                        Some(Value::String(id)) => id.to_string(),
236                        _ => format!("unknown{i}"),
237                    },
238                    serde_json::from_value(value).map_err(serde_to_rig_error)?,
239                ))
240            })
241            .collect()
242    }
243
244    /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`.
245    /// # Example
246    /// ```
247    /// use rig_lancedb::{LanceDbVectorIndex, SearchParams};
248    /// use rig::providers::openai::{Client, TEXT_EMBEDDING_ADA_002, EmbeddingModel};
249    ///
250    /// let openai_client = Client::from_env();
251    ///
252    /// let table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here.
253    /// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here.
254    /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?;
255    ///
256    /// // Query the index
257    /// let result = vector_store_index
258    ///     .top_n_ids("My boss says I zindle too much, what does that mean?", 1)
259    ///     .await?;
260    /// ```
261    async fn top_n_ids(
262        &self,
263        req: VectorSearchRequest,
264    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
265        let prompt_embedding = self.model.embed_text(req.query()).await?;
266
267        let query = self
268            .table
269            .query()
270            .select(lancedb::query::Select::Columns(vec![self.id_field.clone()]))
271            .nearest_to(prompt_embedding.vec.clone())
272            .map_err(lancedb_to_rig_error)?
273            .distance_range(None, req.threshold().map(|x| x as f32))
274            .limit(req.samples() as usize);
275
276        self.build_query(query)
277            .execute_query()
278            .await?
279            .into_iter()
280            .map(|value| {
281                Ok((
282                    match value.get("distance") {
283                        Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(),
284                        _ => 0.0,
285                    },
286                    match value.get(self.id_field.clone()) {
287                        Some(Value::String(id)) => id.to_string(),
288                        _ => "".to_string(),
289                    },
290                ))
291            })
292            .collect()
293    }
294}