bm_25/
search.rs

1use crate::DefaultTokenizer;
2use crate::{
3    embedder::{DefaultTokenEmbedder, Embedder, EmbedderBuilder, TokenEmbedder},
4    scorer::{ScoredDocument, Scorer},
5    Tokenizer,
6};
7use std::{
8    collections::HashMap,
9    fmt::{self, Debug, Display},
10    hash::Hash,
11    marker::PhantomData,
12};
13
14/// A document that you can insert into a search engine. K is the type of the document id. Note
15/// that it is more effient to use a numeric type.
16#[derive(Eq, PartialEq, Debug, Clone, PartialOrd, Hash)]
17pub struct Document<K> {
18    /// A unique identifier for the document.
19    pub id: K,
20    /// The contents of the document.
21    pub contents: String,
22}
23
24impl<K> Display for Document<K> {
25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26        write!(f, "{}", self.contents)
27    }
28}
29
30impl<K> Document<K> {
31    /// Creates a new document with the given id and contents.
32    pub fn new(id: K, contents: impl Into<String>) -> Document<K> {
33        Document {
34            id,
35            contents: contents.into(),
36        }
37    }
38}
39
40/// A search result, containing a document and its BM25 score.
41#[derive(PartialEq, Debug, Clone)]
42pub struct SearchResult<K> {
43    /// The document that was found.
44    pub document: Document<K>,
45    /// The BM25 score of the document. A higher score means the document is more relevant to the
46    /// query.
47    pub score: f32,
48}
49
50/// A search engine that ranks documents with BM25. K is the type of the document id, D is the
51/// type of the token embedder and T is the type of the tokenizer.
52pub struct SearchEngine<K, D: TokenEmbedder = DefaultTokenEmbedder, T = DefaultTokenizer> {
53    // The embedder used to convert documents into embeddings.
54    embedder: Embedder<D, T>,
55    // A scorer for document embeddings.
56    scorer: Scorer<K, D::EmbeddingSpace>,
57    // A mapping from document ids to document contents.
58    documents: HashMap<K, String>,
59}
60
61impl<K: Debug, D: TokenEmbedder + Debug, T: Debug> Debug for SearchEngine<K, D, T> {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        write!(
64            f,
65            "SearchEngine {{ embedder: {:?}, documents: {:?} }}",
66            self.embedder, self.documents
67        )
68    }
69}
70
71impl<K, D, T> SearchEngine<K, D, T>
72where
73    K: Hash + Eq + Clone,
74    D: TokenEmbedder,
75    D::EmbeddingSpace: Eq + Hash + Clone,
76    T: Tokenizer,
77{
78    /// Upserts a document into the search engine. If a document with the same id already exists,
79    /// it will be replaced. Note that upserting a document will change the true value of `avgdl`.
80    /// The more `avgdl` drifts from its true value, the less accurate the BM25 scores will be.
81    pub fn upsert(&mut self, document: impl Into<Document<K>>) {
82        let document = document.into();
83        let embedding = self.embedder.embed(document.contents.as_str());
84
85        if self.documents.contains_key(&document.id) {
86            self.remove(&document.id);
87        }
88        self.documents
89            .insert(document.id.clone(), document.contents);
90
91        self.scorer.upsert(&document.id, embedding);
92    }
93
94    /// Removes a document from the search engine if it exists.
95    pub fn remove(&mut self, document_id: &K) {
96        self.documents.remove(document_id);
97        self.scorer.remove(document_id);
98    }
99
100    /// Gets the contents of a document by its id.
101    pub fn get(&self, document_id: &K) -> Option<Document<K>> {
102        self.documents.get(document_id).map(|contents| Document {
103            id: document_id.clone(),
104            contents: contents.clone(),
105        })
106    }
107
108    /// Returns an iterator over the documents in the search engine.
109    pub fn iter(&self) -> impl Iterator<Item = Document<K>> + '_ {
110        self.documents.iter().map(|(id, contents)| Document {
111            id: id.clone(),
112            contents: contents.clone(),
113        })
114    }
115
116    /// Searches the documents for the given query and returns the top `limit` results.
117    /// Only the document contents are searched, not the document ids.
118    pub fn search(&self, query: &str, limit: impl Into<Option<usize>>) -> Vec<SearchResult<K>> {
119        let query_embedding = self.embedder.embed(query);
120
121        // Reduce search space by filtering out all documents whose score would be 0
122        let matches = self.scorer.matches(&query_embedding);
123
124        matches
125            .into_iter()
126            .take(limit.into().unwrap_or(usize::MAX))
127            .filter_map(|ScoredDocument { id, score }| {
128                self.get(&id)
129                    .map(|document| SearchResult { document, score })
130            })
131            .collect()
132    }
133}
134
135/// A consuming builder for SearchEngine. K is the type of the document id, D is the type of the
136/// token embedder and T is the type of the tokenizer.
137pub struct SearchEngineBuilder<K, D = DefaultTokenEmbedder, T = DefaultTokenizer> {
138    embedder_builder: EmbedderBuilder<D, T>,
139    documents: Vec<Document<K>>,
140    document_id_type: PhantomData<K>,
141    token_embedder_type: PhantomData<D>,
142}
143
144impl<K, D, T> SearchEngineBuilder<K, D, T>
145where
146    K: Hash + Eq + Clone,
147    D: TokenEmbedder,
148    D::EmbeddingSpace: Eq + Hash + Clone,
149    T: Tokenizer + Sync,
150{
151    /// Constructs a new SearchEngineBuilder with the given average document length. Use this if you
152    /// know the average document length in advance. If you don't, but you have your full corpus
153    /// ahead of time, use `with_documents` or `with_corpus` instead.
154    ///
155    /// If you have neither the full corpus nor a sample of it, you can configure the embedder to
156    /// disregard document length by setting `b` to 0.0. In this case, it doesn't matter what
157    /// value you pass to `with_avgdl`.
158    ///
159    /// The average document length is the average number of tokens in a document from your corpus;
160    /// if you need access to this value, you can construct an Embedder and call `avgdl` on it.
161    pub fn with_avgdl(avgdl: f32) -> SearchEngineBuilder<K, D, T>
162    where
163        T: Default,
164    {
165        SearchEngineBuilder {
166            embedder_builder: EmbedderBuilder::<D, T>::with_avgdl(avgdl),
167            documents: Vec::new(),
168            document_id_type: PhantomData,
169            token_embedder_type: PhantomData,
170        }
171    }
172
173    /// Constructs a new SearchEngineBuilder with the given documents. The search engine will fit
174    /// to the given documents, using the given tokenizer. When you call `build`, the builder
175    /// will pre-populate the search engine with the given documents, and pass on the tokenizer.
176    pub fn with_tokenizer_and_documents(
177        tokenizer: T,
178        documents: impl IntoIterator<Item = impl Into<Document<K>>>,
179    ) -> SearchEngineBuilder<K, D, T> {
180        let documents = documents.into_iter().map(|d| d.into()).collect::<Vec<_>>();
181        SearchEngineBuilder {
182            embedder_builder: EmbedderBuilder::<D, T>::with_tokenizer_and_fit_to_corpus(
183                tokenizer,
184                &documents
185                    .iter()
186                    .map(|d| d.contents.as_str())
187                    .collect::<Vec<_>>(),
188            ),
189            documents,
190            document_id_type: PhantomData,
191            token_embedder_type: PhantomData,
192        }
193    }
194
195    /// Constructs a new SearchEngineBuilder with the corpus. The search engine will fit
196    /// to the given corpus, using the given tokenizer. When you call `build`, the builder
197    /// will pre-populate the search engine with the given corpus, and pass on the tokenizer.
198    /// This function will automatically generate u32 ids for each entry in your corpus.
199    pub fn with_tokenizer_and_corpus(
200        tokenizer: T,
201        corpus: impl IntoIterator<Item = impl Into<String>>,
202    ) -> SearchEngineBuilder<u32, D, T> {
203        let documents = corpus
204            .into_iter()
205            .enumerate()
206            .map(|(id, document)| Document::new(id as u32, document.into()))
207            .collect::<Vec<_>>();
208        SearchEngineBuilder::<u32, D, T>::with_tokenizer_and_documents(tokenizer, documents)
209    }
210
211    /// Sets the tokenizer of the embedder.
212    pub fn tokenizer(self, tokenizer: T) -> Self {
213        Self {
214            embedder_builder: self.embedder_builder.tokenizer(tokenizer),
215            ..self
216        }
217    }
218
219    /// Sets the k1 parameter of the embedder.
220    pub fn k1(self, k1: f32) -> Self {
221        Self {
222            embedder_builder: self.embedder_builder.k1(k1),
223            ..self
224        }
225    }
226
227    /// Sets the b parameter of the embedder.
228    pub fn b(self, b: f32) -> Self {
229        Self {
230            embedder_builder: self.embedder_builder.b(b),
231            ..self
232        }
233    }
234
235    /// Overrides the average document length of the embedder.
236    pub fn avgdl(self, avgdl: f32) -> Self {
237        Self {
238            embedder_builder: self.embedder_builder.avgdl(avgdl),
239            ..self
240        }
241    }
242
243    /// Builds the search engine.
244    pub fn build(self) -> SearchEngine<K, D, T> {
245        let mut search_engine = SearchEngine::<K, D, T> {
246            embedder: self.embedder_builder.build(),
247            scorer: Scorer::<K, D::EmbeddingSpace>::new(),
248            documents: HashMap::new(),
249        };
250        for document in self.documents {
251            search_engine.upsert(document);
252        }
253        search_engine
254    }
255}
256
257#[cfg(feature = "default_tokenizer")]
258impl<K, D> SearchEngineBuilder<K, D, DefaultTokenizer>
259where
260    K: Hash + Eq + Clone,
261    D: TokenEmbedder,
262    D::EmbeddingSpace: Eq + Hash + Clone,
263{
264    /// Constructs a new SearchEngineBuilder with the given documents. The search engine will fit
265    /// to the given documents, using the default tokenizer configured with the given language mode.
266    /// When you call `build`, the builder will pre-populate the search engine with the given
267    /// documents, and pass on the tokenizer.
268    pub fn with_documents(
269        language_mode: impl Into<crate::LanguageMode>,
270        documents: impl IntoIterator<Item = impl Into<Document<K>>>,
271    ) -> Self {
272        Self::with_tokenizer_and_documents(DefaultTokenizer::new(language_mode), documents)
273    }
274
275    /// Constructs a new SearchEngineBuilder with the corpus. The search engine will fit
276    /// to the given corpus, using the default tokenizer configured with the given language mode.
277    /// When you call `build`, the builder will pre-populate the search engine with the given
278    /// corpus and pass on the tokenizer. This function will automatically generate u32 ids for
279    /// each entry in your corpus.
280    pub fn with_corpus(
281        language_mode: impl Into<crate::LanguageMode>,
282        corpus: impl IntoIterator<Item = impl Into<String>>,
283    ) -> SearchEngineBuilder<u32, D, DefaultTokenizer> {
284        SearchEngineBuilder::<u32, D, DefaultTokenizer>::with_tokenizer_and_corpus(
285            DefaultTokenizer::new(language_mode),
286            corpus,
287        )
288    }
289
290    /// Sets the tokenizer to the default tokenizer with the given language mode.
291    pub fn language_mode(self, language_mode: impl Into<crate::LanguageMode>) -> Self {
292        Self::tokenizer(self, DefaultTokenizer::new(language_mode))
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use insta::assert_debug_snapshot;
299
300    use super::*;
301    use crate::{
302        test_data_loader::tests::{read_recipes, Recipe},
303        Language, LanguageMode,
304    };
305
306    impl From<Recipe> for Document<String> {
307        fn from(value: Recipe) -> Self {
308            Document::new(value.title, value.recipe)
309        }
310    }
311
312    fn create_recipe_search_engine(
313        recipe_file: &str,
314        language_mode: impl Into<LanguageMode>,
315    ) -> SearchEngine<String, u32> {
316        let recipes = read_recipes(recipe_file);
317
318        SearchEngineBuilder::with_documents(language_mode, recipes).build()
319    }
320
321    #[test]
322    fn search_returns_relevant_documents() {
323        let corpus = vec!["space station", "bacon and avocado sandwich"];
324        let search_engine =
325            SearchEngineBuilder::<u32>::with_corpus(Language::English, corpus).build();
326
327        let results = search_engine.search("sandwich with bacon", 5);
328        assert!(results.len() == 1);
329        assert!(results[0].document.contents == "bacon and avocado sandwich");
330        assert!(results[0].score > 0.0);
331    }
332
333    #[test]
334    fn search_does_not_return_unrelated_documents() {
335        let corpus = vec!["space station", "bacon and avocado sandwich"];
336        let search_engine =
337            SearchEngineBuilder::<u32>::with_corpus(Language::English, corpus).build();
338
339        let results = search_engine.search("maths and computer science", 5);
340        assert!(results.is_empty());
341    }
342
343    #[test]
344    fn it_can_insert_a_document() {
345        let mut search_engine = SearchEngineBuilder::<&str>::with_avgdl(2.0).build();
346        let document = Document::new("hello world", "bananas and apples");
347        let document_id = document.id;
348
349        search_engine.upsert(document.clone());
350        let result = search_engine.get(&document_id);
351
352        assert!(result.unwrap() == document);
353    }
354
355    #[test]
356    fn it_can_remove_a_document() {
357        let mut search_engine = SearchEngineBuilder::<usize>::with_avgdl(2.0).build();
358        let document = Document::new(123, "bananas and apples");
359        let document_id = document.id.clone();
360
361        search_engine.upsert(document);
362        search_engine.remove(&document_id);
363
364        assert!(search_engine.get(&document_id).is_none());
365    }
366
367    #[test]
368    fn it_can_update_a_document() {
369        let document_id = "hello_world";
370        let document = Document::new(document_id, "bananas and apples");
371        let mut search_engine =
372            SearchEngineBuilder::<&str>::with_documents(Language::English, vec![document]).build();
373        let new_document = Document::new(document_id, "oranges and papayas");
374
375        search_engine.upsert(new_document.clone());
376        let result = search_engine.get(&document_id);
377
378        assert!(result.unwrap() == new_document);
379    }
380
381    #[test]
382    fn handles_empty_input() {
383        let mut search_engine = SearchEngineBuilder::<u32>::with_avgdl(2.0).build();
384        let document = Document::new(123, "");
385
386        search_engine.upsert(document);
387
388        let results = search_engine.search("bacon sandwich", 5);
389        assert!(results.is_empty());
390    }
391
392    #[test]
393    fn handles_empty_search() {
394        let mut search_engine = SearchEngineBuilder::<u32>::with_avgdl(2.0).build();
395        let document = Document::new(123, "pencil and paper");
396
397        search_engine.upsert(document);
398
399        let results = search_engine.search("", 5);
400        assert!(results.is_empty());
401    }
402
403    #[test]
404    fn it_returns_exact_matches_with_highest_score() {
405        let search_engine = create_recipe_search_engine("recipes_en.csv", Language::English);
406
407        let results = search_engine.search(
408            "To make guacamole, start by mashing 2 ripe avocados in a bowl.",
409            None,
410        );
411
412        assert!(!results.is_empty());
413        assert_eq!(results[0].document.id, "Guacamole");
414    }
415
416    #[test]
417    fn it_only_returns_results_containing_query() {
418        let search_engine = create_recipe_search_engine("recipes_en.csv", Language::English);
419
420        let results = search_engine.search("vegetable", 5);
421
422        // At least 5 recipes contain the word "vegetable"
423        assert_eq!(results.len(), 5);
424        assert!(results
425            .iter()
426            .all(|result| result.document.contents.contains("vegetable")));
427    }
428
429    #[test]
430    fn it_returns_results_sorted_by_score() {
431        let search_engine = create_recipe_search_engine("recipes_en.csv", Language::English);
432
433        let results = search_engine.search("chicken", 1000);
434
435        assert!(!results.is_empty());
436        assert!(results
437            .windows(2)
438            .all(|result_pair| { result_pair[0].score >= result_pair[1].score }));
439    }
440
441    #[test]
442    fn it_ranks_shorter_documents_higher() {
443        let documents = [
444            Document {
445                id: 0,
446                contents: "Correct horse battery staple bacon bacon bacon".to_string(),
447            },
448            Document {
449                id: 1,
450                contents: "Correct horse battery staple".to_string(),
451            },
452        ];
453        let search_engine =
454            SearchEngineBuilder::<u32>::with_documents(Language::English, documents).build();
455
456        let results = search_engine.search("staple", 2);
457
458        assert_eq!(results.len(), 2);
459        assert_eq!(results[0].document.id, 1);
460        assert_eq!(results[1].document.id, 0);
461        assert!(results[0].score > results[1].score);
462    }
463
464    #[test]
465    fn it_matches_common_unicode_equivalents() {
466        let corpus = vec!["étude"];
467        let search_engine =
468            SearchEngineBuilder::<u32>::with_corpus(Language::French, corpus).build();
469
470        let results_1 = search_engine.search("etude", None);
471        let results_2 = search_engine.search("étude", None);
472
473        assert_eq!(results_1.len(), 1);
474        assert_eq!(results_2.len(), 1);
475        assert_eq!(results_1, results_2);
476    }
477
478    #[test]
479    fn it_can_search_for_emoji() {
480        let corpus = vec!["🔥"];
481        let search_engine =
482            SearchEngineBuilder::<u32>::with_corpus(Language::English, corpus).build();
483
484        let results_1 = search_engine.search("🔥", None);
485        let results_2 = search_engine.search("fire", None);
486
487        assert_eq!(results_1.len(), 1);
488        assert_eq!(results_2.len(), 1);
489        assert_eq!(results_1, results_2);
490    }
491
492    #[test]
493    fn it_matches_snapshot_en() {
494        let search_engine = create_recipe_search_engine("recipes_en.csv", Language::English);
495
496        let mut results = search_engine.search("bake", None);
497        // sort the results by document id to make the snapshot deterministic
498        results.sort_by_key(|result| result.document.id.clone());
499
500        insta::with_settings!({snapshot_path => "../snapshots"}, {
501            assert_debug_snapshot!(results);
502        });
503    }
504
505    #[test]
506    fn it_matches_snapshot_de() {
507        let search_engine = create_recipe_search_engine("recipes_de.csv", Language::German);
508
509        let mut results = search_engine.search("backen", None);
510
511        // sort the results by document id to make the snapshot deterministic
512        results.sort_by_key(|result| result.document.id.clone());
513
514        insta::with_settings!({snapshot_path => "../snapshots"}, {
515            assert_debug_snapshot!(results);
516        });
517    }
518}