bep_lancedb/
lib.rs

1use lancedb::{
2    query::{QueryBase, VectorQuery},
3    DistanceType,
4};
5use bep::{
6    embeddings::embedding::EmbeddingModel,
7    vector_store::{VectorStoreError, VectorStoreIndex},
8};
9use serde::Deserialize;
10use serde_json::Value;
11use utils::{FilterTableColumns, QueryToJson};
12
13mod utils;
14
15fn lancedb_to_bep_error(e: lancedb::Error) -> VectorStoreError {
16    VectorStoreError::DatastoreError(Box::new(e))
17}
18
19fn serde_to_bep_error(e: serde_json::Error) -> VectorStoreError {
20    VectorStoreError::JsonError(e)
21}
22
23/// Type on which vector searches can be performed for a lanceDb table.
24/// # Example
25/// ```
26/// use bep_lancedb::{LanceDbVectorIndex, SearchParams};
27/// use bep::providers::openai::{Client, TEXT_EMBEDDING_ADA_002, EmbeddingModel};
28///
29/// let openai_client = Client::from_env();
30///
31/// let table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here.
32/// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here.
33/// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?;
34/// ```
35pub struct LanceDbVectorIndex<M: EmbeddingModel> {
36    /// Defines which model is used to generate embeddings for the vector store.
37    model: M,
38    /// LanceDB table containing embeddings.
39    table: lancedb::Table,
40    /// Column name in `table` that contains the id of a record.
41    id_field: String,
42    /// Vector search params that are used during vector search operations.
43    search_params: SearchParams,
44}
45
46impl<M: EmbeddingModel> LanceDbVectorIndex<M> {
47    /// Create an instance of `LanceDbVectorIndex` with an existing table and model.
48    /// Define the id field name of the table.
49    /// Define search parameters that will be used to perform vector searches on the table.
50    pub async fn new(
51        table: lancedb::Table,
52        model: M,
53        id_field: &str,
54        search_params: SearchParams,
55    ) -> Result<Self, lancedb::Error> {
56        Ok(Self {
57            table,
58            model,
59            id_field: id_field.to_string(),
60            search_params,
61        })
62    }
63
64    /// Apply the search_params to the vector query.
65    /// This is a helper function used by the methods `top_n` and `top_n_ids` of the `VectorStoreIndex` trait.
66    fn build_query(&self, mut query: VectorQuery) -> VectorQuery {
67        let SearchParams {
68            distance_type,
69            search_type,
70            nprobes,
71            refine_factor,
72            post_filter,
73            column,
74        } = self.search_params.clone();
75
76        if let Some(distance_type) = distance_type {
77            query = query.distance_type(distance_type);
78        }
79
80        if let Some(SearchType::Flat) = search_type {
81            query = query.bypass_vector_index();
82        }
83
84        if let Some(SearchType::Approximate) = search_type {
85            if let Some(nprobes) = nprobes {
86                query = query.nprobes(nprobes);
87            }
88            if let Some(refine_factor) = refine_factor {
89                query = query.refine_factor(refine_factor);
90            }
91        }
92
93        if let Some(true) = post_filter {
94            query = query.postfilter();
95        }
96
97        if let Some(column) = column {
98            query = query.column(column.as_str())
99        }
100
101        query
102    }
103}
104
105/// See [LanceDB vector search](https://lancedb.github.io/lancedb/search/) for more information.
106#[derive(Debug, Clone)]
107pub enum SearchType {
108    // Flat search, also called ENN or kNN.
109    Flat,
110    /// Approximal Nearest Neighbor search, also called ANN.
111    Approximate,
112}
113
114/// Parameters used to perform a vector search on a LanceDb table.
115/// # Example
116/// ```
117/// let search_params = bep_lancedb::SearchParams::default().distance_type(lancedb::DistanceType::Cosine);
118/// ```
119#[derive(Debug, Clone, Default)]
120pub struct SearchParams {
121    distance_type: Option<DistanceType>,
122    search_type: Option<SearchType>,
123    nprobes: Option<usize>,
124    refine_factor: Option<u32>,
125    post_filter: Option<bool>,
126    column: Option<String>,
127}
128
129impl SearchParams {
130    /// Sets the distance type of the search params.
131    /// Always set the distance_type to match the value used to train the index.
132    /// The default is DistanceType::L2.
133    pub fn distance_type(mut self, distance_type: DistanceType) -> Self {
134        self.distance_type = Some(distance_type);
135        self
136    }
137
138    /// Sets the search type of the search params.
139    /// By default, ANN will be used if there is an index on the table and kNN will be used if there is NO index on the table.
140    /// To use the mentioned defaults, do not set the search type.
141    pub fn search_type(mut self, search_type: SearchType) -> Self {
142        self.search_type = Some(search_type);
143        self
144    }
145
146    /// Sets the nprobes of the search params.
147    /// Only set this value only when the search type is ANN.
148    /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information.
149    pub fn nprobes(mut self, nprobes: usize) -> Self {
150        self.nprobes = Some(nprobes);
151        self
152    }
153
154    /// Sets the refine factor of the search params.
155    /// Only set this value only when search type is ANN.
156    /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information.
157    pub fn refine_factor(mut self, refine_factor: u32) -> Self {
158        self.refine_factor = Some(refine_factor);
159        self
160    }
161
162    /// Sets the post filter of the search params.
163    /// If set to true, filtering will happen after the vector search instead of before.
164    /// See [LanceDb pre/post filtering](https://lancedb.github.io/lancedb/sql/#pre-and-post-filtering) for more information.
165    pub fn post_filter(mut self, post_filter: bool) -> Self {
166        self.post_filter = Some(post_filter);
167        self
168    }
169
170    /// Sets the column of the search params.
171    /// Only set this value if there is more than one column that contains lists of floats.
172    /// If there is only one column of list of floats, this column will be chosen for the vector search automatically.
173    pub fn column(mut self, column: &str) -> Self {
174        self.column = Some(column.to_string());
175        self
176    }
177}
178
179impl<M: EmbeddingModel + Sync + Send> VectorStoreIndex for LanceDbVectorIndex<M> {
180    /// Implement the `top_n` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`.
181    /// # Example
182    /// ```
183    /// use bep_lancedb::{LanceDbVectorIndex, SearchParams};
184    /// use bep::providers::openai::{EmbeddingModel, Client, TEXT_EMBEDDING_ADA_002};
185    ///
186    /// let openai_client = Client::from_env();
187    ///
188    /// let table: lancedb::Table = db.create_table("fake_definitions"); // <-- Replace with your lancedb table here.
189    /// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here.
190    /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?;
191    ///
192    /// // Query the index
193    /// let result = vector_store_index
194    ///     .top_n::<String>("My boss says I zindle too much, what does that mean?", 1)
195    ///     .await?;
196    /// ```
197    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
198        &self,
199        query: &str,
200        n: usize,
201    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
202        let prompt_embedding = self.model.embed_text(query).await?;
203
204        let query = self
205            .table
206            .vector_search(prompt_embedding.vec.clone())
207            .map_err(lancedb_to_bep_error)?
208            .limit(n)
209            .select(lancedb::query::Select::Columns(
210                self.table
211                    .schema()
212                    .await
213                    .map_err(lancedb_to_bep_error)?
214                    .filter_embeddings(),
215            ));
216
217        self.build_query(query)
218            .execute_query()
219            .await?
220            .into_iter()
221            .enumerate()
222            .map(|(i, value)| {
223                Ok((
224                    match value.get("_distance") {
225                        Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(),
226                        _ => 0.0,
227                    },
228                    match value.get(self.id_field.clone()) {
229                        Some(Value::String(id)) => id.to_string(),
230                        _ => format!("unknown{i}"),
231                    },
232                    serde_json::from_value(value).map_err(serde_to_bep_error)?,
233                ))
234            })
235            .collect()
236    }
237
238    /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`.
239    /// # Example
240    /// ```
241    /// use bep_lancedb::{LanceDbVectorIndex, SearchParams};
242    /// use bep::providers::openai::{Client, TEXT_EMBEDDING_ADA_002, EmbeddingModel};
243    ///
244    /// let openai_client = Client::from_env();
245    ///
246    /// let table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here.
247    /// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here.
248    /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?;
249    ///
250    /// // Query the index
251    /// let result = vector_store_index
252    ///     .top_n_ids("My boss says I zindle too much, what does that mean?", 1)
253    ///     .await?;
254    /// ```
255    async fn top_n_ids(
256        &self,
257        query: &str,
258        n: usize,
259    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
260        let prompt_embedding = self.model.embed_text(query).await?;
261
262        let query = self
263            .table
264            .query()
265            .select(lancedb::query::Select::Columns(vec![self.id_field.clone()]))
266            .nearest_to(prompt_embedding.vec.clone())
267            .map_err(lancedb_to_bep_error)?
268            .limit(n);
269
270        self.build_query(query)
271            .execute_query()
272            .await?
273            .into_iter()
274            .map(|value| {
275                Ok((
276                    match value.get("distance") {
277                        Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(),
278                        _ => 0.0,
279                    },
280                    match value.get(self.id_field.clone()) {
281                        Some(Value::String(id)) => id.to_string(),
282                        _ => "".to_string(),
283                    },
284                ))
285            })
286            .collect()
287    }
288}