1use crate::DefaultTokenizer;
2use crate::{
3 embedder::{DefaultTokenEmbedder, Embedder, EmbedderBuilder, TokenEmbedder},
4 scorer::{ScoredDocument, Scorer},
5 Tokenizer,
6};
7use std::{
8 collections::HashMap,
9 fmt::{self, Debug, Display},
10 hash::Hash,
11 marker::PhantomData,
12};
13
14#[derive(Eq, PartialEq, Debug, Clone, PartialOrd, Hash)]
17pub struct Document<K> {
18 pub id: K,
20 pub contents: String,
22}
23
24impl<K> Display for Document<K> {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 write!(f, "{}", self.contents)
27 }
28}
29
30impl<K> Document<K> {
31 pub fn new(id: K, contents: impl Into<String>) -> Document<K> {
33 Document {
34 id,
35 contents: contents.into(),
36 }
37 }
38}
39
40#[derive(PartialEq, Debug, Clone)]
42pub struct SearchResult<K> {
43 pub document: Document<K>,
45 pub score: f32,
48}
49
50pub struct SearchEngine<K, D: TokenEmbedder = DefaultTokenEmbedder, T = DefaultTokenizer> {
53 embedder: Embedder<D, T>,
55 scorer: Scorer<K, D::EmbeddingSpace>,
57 documents: HashMap<K, String>,
59}
60
61impl<K: Debug, D: TokenEmbedder + Debug, T: Debug> Debug for SearchEngine<K, D, T> {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 write!(
64 f,
65 "SearchEngine {{ embedder: {:?}, documents: {:?} }}",
66 self.embedder, self.documents
67 )
68 }
69}
70
71impl<K, D, T> SearchEngine<K, D, T>
72where
73 K: Hash + Eq + Clone,
74 D: TokenEmbedder,
75 D::EmbeddingSpace: Eq + Hash + Clone,
76 T: Tokenizer,
77{
78 pub fn upsert(&mut self, document: impl Into<Document<K>>) {
82 let document = document.into();
83 let embedding = self.embedder.embed(document.contents.as_str());
84
85 if self.documents.contains_key(&document.id) {
86 self.remove(&document.id);
87 }
88 self.documents
89 .insert(document.id.clone(), document.contents);
90
91 self.scorer.upsert(&document.id, embedding);
92 }
93
94 pub fn remove(&mut self, document_id: &K) {
96 self.documents.remove(document_id);
97 self.scorer.remove(document_id);
98 }
99
100 pub fn get(&self, document_id: &K) -> Option<Document<K>> {
102 self.documents.get(document_id).map(|contents| Document {
103 id: document_id.clone(),
104 contents: contents.clone(),
105 })
106 }
107
108 pub fn iter(&self) -> impl Iterator<Item = Document<K>> + '_ {
110 self.documents.iter().map(|(id, contents)| Document {
111 id: id.clone(),
112 contents: contents.clone(),
113 })
114 }
115
116 pub fn search(&self, query: &str, limit: impl Into<Option<usize>>) -> Vec<SearchResult<K>> {
119 let query_embedding = self.embedder.embed(query);
120
121 let matches = self.scorer.matches(&query_embedding);
123
124 matches
125 .into_iter()
126 .take(limit.into().unwrap_or(usize::MAX))
127 .filter_map(|ScoredDocument { id, score }| {
128 self.get(&id)
129 .map(|document| SearchResult { document, score })
130 })
131 .collect()
132 }
133}
134
135pub struct SearchEngineBuilder<K, D = DefaultTokenEmbedder, T = DefaultTokenizer> {
138 embedder_builder: EmbedderBuilder<D, T>,
139 documents: Vec<Document<K>>,
140 document_id_type: PhantomData<K>,
141 token_embedder_type: PhantomData<D>,
142}
143
144impl<K, D, T> SearchEngineBuilder<K, D, T>
145where
146 K: Hash + Eq + Clone,
147 D: TokenEmbedder,
148 D::EmbeddingSpace: Eq + Hash + Clone,
149 T: Tokenizer + Sync,
150{
151 pub fn with_avgdl(avgdl: f32) -> SearchEngineBuilder<K, D, T>
162 where
163 T: Default,
164 {
165 SearchEngineBuilder {
166 embedder_builder: EmbedderBuilder::<D, T>::with_avgdl(avgdl),
167 documents: Vec::new(),
168 document_id_type: PhantomData,
169 token_embedder_type: PhantomData,
170 }
171 }
172
173 pub fn with_tokenizer_and_documents(
177 tokenizer: T,
178 documents: impl IntoIterator<Item = impl Into<Document<K>>>,
179 ) -> SearchEngineBuilder<K, D, T> {
180 let documents = documents.into_iter().map(|d| d.into()).collect::<Vec<_>>();
181 SearchEngineBuilder {
182 embedder_builder: EmbedderBuilder::<D, T>::with_tokenizer_and_fit_to_corpus(
183 tokenizer,
184 &documents
185 .iter()
186 .map(|d| d.contents.as_str())
187 .collect::<Vec<_>>(),
188 ),
189 documents,
190 document_id_type: PhantomData,
191 token_embedder_type: PhantomData,
192 }
193 }
194
195 pub fn with_tokenizer_and_corpus(
200 tokenizer: T,
201 corpus: impl IntoIterator<Item = impl Into<String>>,
202 ) -> SearchEngineBuilder<u32, D, T> {
203 let documents = corpus
204 .into_iter()
205 .enumerate()
206 .map(|(id, document)| Document::new(id as u32, document.into()))
207 .collect::<Vec<_>>();
208 SearchEngineBuilder::<u32, D, T>::with_tokenizer_and_documents(tokenizer, documents)
209 }
210
211 pub fn tokenizer(self, tokenizer: T) -> Self {
213 Self {
214 embedder_builder: self.embedder_builder.tokenizer(tokenizer),
215 ..self
216 }
217 }
218
219 pub fn k1(self, k1: f32) -> Self {
221 Self {
222 embedder_builder: self.embedder_builder.k1(k1),
223 ..self
224 }
225 }
226
227 pub fn b(self, b: f32) -> Self {
229 Self {
230 embedder_builder: self.embedder_builder.b(b),
231 ..self
232 }
233 }
234
235 pub fn avgdl(self, avgdl: f32) -> Self {
237 Self {
238 embedder_builder: self.embedder_builder.avgdl(avgdl),
239 ..self
240 }
241 }
242
243 pub fn build(self) -> SearchEngine<K, D, T> {
245 let mut search_engine = SearchEngine::<K, D, T> {
246 embedder: self.embedder_builder.build(),
247 scorer: Scorer::<K, D::EmbeddingSpace>::new(),
248 documents: HashMap::new(),
249 };
250 for document in self.documents {
251 search_engine.upsert(document);
252 }
253 search_engine
254 }
255}
256
257#[cfg(feature = "default_tokenizer")]
258impl<K, D> SearchEngineBuilder<K, D, DefaultTokenizer>
259where
260 K: Hash + Eq + Clone,
261 D: TokenEmbedder,
262 D::EmbeddingSpace: Eq + Hash + Clone,
263{
264 pub fn with_documents(
269 language_mode: impl Into<crate::LanguageMode>,
270 documents: impl IntoIterator<Item = impl Into<Document<K>>>,
271 ) -> Self {
272 Self::with_tokenizer_and_documents(DefaultTokenizer::new(language_mode), documents)
273 }
274
275 pub fn with_corpus(
281 language_mode: impl Into<crate::LanguageMode>,
282 corpus: impl IntoIterator<Item = impl Into<String>>,
283 ) -> SearchEngineBuilder<u32, D, DefaultTokenizer> {
284 SearchEngineBuilder::<u32, D, DefaultTokenizer>::with_tokenizer_and_corpus(
285 DefaultTokenizer::new(language_mode),
286 corpus,
287 )
288 }
289
290 pub fn language_mode(self, language_mode: impl Into<crate::LanguageMode>) -> Self {
292 Self::tokenizer(self, DefaultTokenizer::new(language_mode))
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use insta::assert_debug_snapshot;
299
300 use super::*;
301 use crate::{
302 test_data_loader::tests::{read_recipes, Recipe},
303 Language, LanguageMode,
304 };
305
306 impl From<Recipe> for Document<String> {
307 fn from(value: Recipe) -> Self {
308 Document::new(value.title, value.recipe)
309 }
310 }
311
312 fn create_recipe_search_engine(
313 recipe_file: &str,
314 language_mode: impl Into<LanguageMode>,
315 ) -> SearchEngine<String, u32> {
316 let recipes = read_recipes(recipe_file);
317
318 SearchEngineBuilder::with_documents(language_mode, recipes).build()
319 }
320
321 #[test]
322 fn search_returns_relevant_documents() {
323 let corpus = vec!["space station", "bacon and avocado sandwich"];
324 let search_engine =
325 SearchEngineBuilder::<u32>::with_corpus(Language::English, corpus).build();
326
327 let results = search_engine.search("sandwich with bacon", 5);
328 assert!(results.len() == 1);
329 assert!(results[0].document.contents == "bacon and avocado sandwich");
330 assert!(results[0].score > 0.0);
331 }
332
333 #[test]
334 fn search_does_not_return_unrelated_documents() {
335 let corpus = vec!["space station", "bacon and avocado sandwich"];
336 let search_engine =
337 SearchEngineBuilder::<u32>::with_corpus(Language::English, corpus).build();
338
339 let results = search_engine.search("maths and computer science", 5);
340 assert!(results.is_empty());
341 }
342
343 #[test]
344 fn it_can_insert_a_document() {
345 let mut search_engine = SearchEngineBuilder::<&str>::with_avgdl(2.0).build();
346 let document = Document::new("hello world", "bananas and apples");
347 let document_id = document.id;
348
349 search_engine.upsert(document.clone());
350 let result = search_engine.get(&document_id);
351
352 assert!(result.unwrap() == document);
353 }
354
355 #[test]
356 fn it_can_remove_a_document() {
357 let mut search_engine = SearchEngineBuilder::<usize>::with_avgdl(2.0).build();
358 let document = Document::new(123, "bananas and apples");
359 let document_id = document.id.clone();
360
361 search_engine.upsert(document);
362 search_engine.remove(&document_id);
363
364 assert!(search_engine.get(&document_id).is_none());
365 }
366
367 #[test]
368 fn it_can_update_a_document() {
369 let document_id = "hello_world";
370 let document = Document::new(document_id, "bananas and apples");
371 let mut search_engine =
372 SearchEngineBuilder::<&str>::with_documents(Language::English, vec![document]).build();
373 let new_document = Document::new(document_id, "oranges and papayas");
374
375 search_engine.upsert(new_document.clone());
376 let result = search_engine.get(&document_id);
377
378 assert!(result.unwrap() == new_document);
379 }
380
381 #[test]
382 fn handles_empty_input() {
383 let mut search_engine = SearchEngineBuilder::<u32>::with_avgdl(2.0).build();
384 let document = Document::new(123, "");
385
386 search_engine.upsert(document);
387
388 let results = search_engine.search("bacon sandwich", 5);
389 assert!(results.is_empty());
390 }
391
392 #[test]
393 fn handles_empty_search() {
394 let mut search_engine = SearchEngineBuilder::<u32>::with_avgdl(2.0).build();
395 let document = Document::new(123, "pencil and paper");
396
397 search_engine.upsert(document);
398
399 let results = search_engine.search("", 5);
400 assert!(results.is_empty());
401 }
402
403 #[test]
404 fn it_returns_exact_matches_with_highest_score() {
405 let search_engine = create_recipe_search_engine("recipes_en.csv", Language::English);
406
407 let results = search_engine.search(
408 "To make guacamole, start by mashing 2 ripe avocados in a bowl.",
409 None,
410 );
411
412 assert!(!results.is_empty());
413 assert_eq!(results[0].document.id, "Guacamole");
414 }
415
416 #[test]
417 fn it_only_returns_results_containing_query() {
418 let search_engine = create_recipe_search_engine("recipes_en.csv", Language::English);
419
420 let results = search_engine.search("vegetable", 5);
421
422 assert_eq!(results.len(), 5);
424 assert!(results
425 .iter()
426 .all(|result| result.document.contents.contains("vegetable")));
427 }
428
429 #[test]
430 fn it_returns_results_sorted_by_score() {
431 let search_engine = create_recipe_search_engine("recipes_en.csv", Language::English);
432
433 let results = search_engine.search("chicken", 1000);
434
435 assert!(!results.is_empty());
436 assert!(results
437 .windows(2)
438 .all(|result_pair| { result_pair[0].score >= result_pair[1].score }));
439 }
440
441 #[test]
442 fn it_ranks_shorter_documents_higher() {
443 let documents = [
444 Document {
445 id: 0,
446 contents: "Correct horse battery staple bacon bacon bacon".to_string(),
447 },
448 Document {
449 id: 1,
450 contents: "Correct horse battery staple".to_string(),
451 },
452 ];
453 let search_engine =
454 SearchEngineBuilder::<u32>::with_documents(Language::English, documents).build();
455
456 let results = search_engine.search("staple", 2);
457
458 assert_eq!(results.len(), 2);
459 assert_eq!(results[0].document.id, 1);
460 assert_eq!(results[1].document.id, 0);
461 assert!(results[0].score > results[1].score);
462 }
463
464 #[test]
465 fn it_matches_common_unicode_equivalents() {
466 let corpus = vec!["étude"];
467 let search_engine =
468 SearchEngineBuilder::<u32>::with_corpus(Language::French, corpus).build();
469
470 let results_1 = search_engine.search("etude", None);
471 let results_2 = search_engine.search("étude", None);
472
473 assert_eq!(results_1.len(), 1);
474 assert_eq!(results_2.len(), 1);
475 assert_eq!(results_1, results_2);
476 }
477
478 #[test]
479 fn it_can_search_for_emoji() {
480 let corpus = vec!["🔥"];
481 let search_engine =
482 SearchEngineBuilder::<u32>::with_corpus(Language::English, corpus).build();
483
484 let results_1 = search_engine.search("🔥", None);
485 let results_2 = search_engine.search("fire", None);
486
487 assert_eq!(results_1.len(), 1);
488 assert_eq!(results_2.len(), 1);
489 assert_eq!(results_1, results_2);
490 }
491
492 #[test]
493 fn it_matches_snapshot_en() {
494 let search_engine = create_recipe_search_engine("recipes_en.csv", Language::English);
495
496 let mut results = search_engine.search("bake", None);
497 results.sort_by_key(|result| result.document.id.clone());
499
500 insta::with_settings!({snapshot_path => "../snapshots"}, {
501 assert_debug_snapshot!(results);
502 });
503 }
504
505 #[test]
506 fn it_matches_snapshot_de() {
507 let search_engine = create_recipe_search_engine("recipes_de.csv", Language::German);
508
509 let mut results = search_engine.search("backen", None);
510
511 results.sort_by_key(|result| result.document.id.clone());
513
514 insta::with_settings!({snapshot_path => "../snapshots"}, {
515 assert_debug_snapshot!(results);
516 });
517 }
518}