use crate::DefaultTokenizer;
use crate::{
embedder::{DefaultTokenEmbedder, Embedder, EmbedderBuilder, TokenEmbedder},
scorer::{ScoredDocument, Scorer},
Tokenizer,
};
use std::{
collections::HashMap,
fmt::{self, Debug, Display},
hash::Hash,
marker::PhantomData,
};
#[derive(Eq, PartialEq, Debug, Clone, PartialOrd, Hash)]
pub struct Document<K> {
pub id: K,
pub contents: String,
}
impl<K> Display for Document<K> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.contents)
}
}
impl<K> Document<K> {
pub fn new(id: K, contents: impl Into<String>) -> Document<K> {
Document {
id,
contents: contents.into(),
}
}
}
#[derive(PartialEq, Debug, Clone)]
pub struct SearchResult<K> {
pub document: Document<K>,
pub score: f32,
}
pub struct SearchEngine<K, D: TokenEmbedder = DefaultTokenEmbedder, T = DefaultTokenizer> {
embedder: Embedder<D, T>,
scorer: Scorer<K, D::EmbeddingSpace>,
documents: HashMap<K, String>,
}
impl<K: Debug, D: TokenEmbedder + Debug, T: Debug> Debug for SearchEngine<K, D, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"SearchEngine {{ embedder: {:?}, documents: {:?} }}",
self.embedder, self.documents
)
}
}
impl<K, D, T> SearchEngine<K, D, T>
where
K: Hash + Eq + Clone,
D: TokenEmbedder,
D::EmbeddingSpace: Eq + Hash + Clone,
T: Tokenizer,
{
pub fn upsert(&mut self, document: impl Into<Document<K>>) {
let document = document.into();
let embedding = self.embedder.embed(document.contents.as_str());
if self.documents.contains_key(&document.id) {
self.remove(&document.id);
}
self.documents
.insert(document.id.clone(), document.contents);
self.scorer.upsert(&document.id, embedding);
}
pub fn remove(&mut self, document_id: &K) {
self.documents.remove(document_id);
self.scorer.remove(document_id);
}
pub fn get(&self, document_id: &K) -> Option<Document<K>> {
self.documents.get(document_id).map(|contents| Document {
id: document_id.clone(),
contents: contents.clone(),
})
}
pub fn iter(&self) -> impl Iterator<Item = Document<K>> + '_ {
self.documents.iter().map(|(id, contents)| Document {
id: id.clone(),
contents: contents.clone(),
})
}
pub fn search(&self, query: &str, limit: impl Into<Option<usize>>) -> Vec<SearchResult<K>> {
let query_embedding = self.embedder.embed(query);
let matches = self.scorer.matches(&query_embedding);
matches
.into_iter()
.take(limit.into().unwrap_or(usize::MAX))
.filter_map(|ScoredDocument { id, score }| {
self.get(&id)
.map(|document| SearchResult { document, score })
})
.collect()
}
}
pub struct SearchEngineBuilder<K, D = DefaultTokenEmbedder, T = DefaultTokenizer> {
embedder_builder: EmbedderBuilder<D, T>,
documents: Vec<Document<K>>,
document_id_type: PhantomData<K>,
token_embedder_type: PhantomData<D>,
}
impl<K, D, T> SearchEngineBuilder<K, D, T>
where
K: Hash + Eq + Clone,
D: TokenEmbedder,
D::EmbeddingSpace: Eq + Hash + Clone,
T: Tokenizer + Sync,
{
pub fn with_avgdl(avgdl: f32) -> SearchEngineBuilder<K, D, T>
where
T: Default,
{
SearchEngineBuilder {
embedder_builder: EmbedderBuilder::<D, T>::with_avgdl(avgdl),
documents: Vec::new(),
document_id_type: PhantomData,
token_embedder_type: PhantomData,
}
}
pub fn with_tokenizer_and_documents(
tokenizer: T,
documents: impl IntoIterator<Item = impl Into<Document<K>>>,
) -> SearchEngineBuilder<K, D, T> {
let documents = documents.into_iter().map(|d| d.into()).collect::<Vec<_>>();
SearchEngineBuilder {
embedder_builder: EmbedderBuilder::<D, T>::with_tokenizer_and_fit_to_corpus(
tokenizer,
&documents
.iter()
.map(|d| d.contents.as_str())
.collect::<Vec<_>>(),
),
documents,
document_id_type: PhantomData,
token_embedder_type: PhantomData,
}
}
pub fn with_tokenizer_and_corpus(
tokenizer: T,
corpus: impl IntoIterator<Item = impl Into<String>>,
) -> SearchEngineBuilder<u32, D, T> {
let documents = corpus
.into_iter()
.enumerate()
.map(|(id, document)| Document::new(id as u32, document.into()))
.collect::<Vec<_>>();
SearchEngineBuilder::<u32, D, T>::with_tokenizer_and_documents(tokenizer, documents)
}
pub fn tokenizer(self, tokenizer: T) -> Self {
Self {
embedder_builder: self.embedder_builder.tokenizer(tokenizer),
..self
}
}
pub fn k1(self, k1: f32) -> Self {
Self {
embedder_builder: self.embedder_builder.k1(k1),
..self
}
}
pub fn b(self, b: f32) -> Self {
Self {
embedder_builder: self.embedder_builder.b(b),
..self
}
}
pub fn avgdl(self, avgdl: f32) -> Self {
Self {
embedder_builder: self.embedder_builder.avgdl(avgdl),
..self
}
}
pub fn build(self) -> SearchEngine<K, D, T> {
let mut search_engine = SearchEngine::<K, D, T> {
embedder: self.embedder_builder.build(),
scorer: Scorer::<K, D::EmbeddingSpace>::new(),
documents: HashMap::new(),
};
for document in self.documents {
search_engine.upsert(document);
}
search_engine
}
}
#[cfg(feature = "default_tokenizer")]
impl<K, D> SearchEngineBuilder<K, D, DefaultTokenizer>
where
K: Hash + Eq + Clone,
D: TokenEmbedder,
D::EmbeddingSpace: Eq + Hash + Clone,
{
pub fn with_documents(
language_mode: impl Into<crate::LanguageMode>,
documents: impl IntoIterator<Item = impl Into<Document<K>>>,
) -> Self {
Self::with_tokenizer_and_documents(DefaultTokenizer::new(language_mode), documents)
}
pub fn with_corpus(
language_mode: impl Into<crate::LanguageMode>,
corpus: impl IntoIterator<Item = impl Into<String>>,
) -> SearchEngineBuilder<u32, D, DefaultTokenizer> {
SearchEngineBuilder::<u32, D, DefaultTokenizer>::with_tokenizer_and_corpus(
DefaultTokenizer::new(language_mode),
corpus,
)
}
pub fn language_mode(self, language_mode: impl Into<crate::LanguageMode>) -> Self {
Self::tokenizer(self, DefaultTokenizer::new(language_mode))
}
}
#[cfg(test)]
mod tests {
use insta::assert_debug_snapshot;
use super::*;
use crate::{
test_data_loader::tests::{read_recipes, Recipe},
Language, LanguageMode,
};
impl From<Recipe> for Document<String> {
fn from(value: Recipe) -> Self {
Document::new(value.title, value.recipe)
}
}
fn create_recipe_search_engine(
recipe_file: &str,
language_mode: impl Into<LanguageMode>,
) -> SearchEngine<String, u32> {
let recipes = read_recipes(recipe_file);
SearchEngineBuilder::with_documents(language_mode, recipes).build()
}
#[test]
fn search_returns_relevant_documents() {
let corpus = vec!["space station", "bacon and avocado sandwich"];
let search_engine =
SearchEngineBuilder::<u32>::with_corpus(Language::English, corpus).build();
let results = search_engine.search("sandwich with bacon", 5);
assert!(results.len() == 1);
assert!(results[0].document.contents == "bacon and avocado sandwich");
assert!(results[0].score > 0.0);
}
#[test]
fn search_does_not_return_unrelated_documents() {
let corpus = vec!["space station", "bacon and avocado sandwich"];
let search_engine =
SearchEngineBuilder::<u32>::with_corpus(Language::English, corpus).build();
let results = search_engine.search("maths and computer science", 5);
assert!(results.is_empty());
}
#[test]
fn it_can_insert_a_document() {
let mut search_engine = SearchEngineBuilder::<&str>::with_avgdl(2.0).build();
let document = Document::new("hello world", "bananas and apples");
let document_id = document.id;
search_engine.upsert(document.clone());
let result = search_engine.get(&document_id);
assert!(result.unwrap() == document);
}
#[test]
fn it_can_remove_a_document() {
let mut search_engine = SearchEngineBuilder::<usize>::with_avgdl(2.0).build();
let document = Document::new(123, "bananas and apples");
let document_id = document.id.clone();
search_engine.upsert(document);
search_engine.remove(&document_id);
assert!(search_engine.get(&document_id).is_none());
}
#[test]
fn it_can_update_a_document() {
let document_id = "hello_world";
let document = Document::new(document_id, "bananas and apples");
let mut search_engine =
SearchEngineBuilder::<&str>::with_documents(Language::English, vec![document]).build();
let new_document = Document::new(document_id, "oranges and papayas");
search_engine.upsert(new_document.clone());
let result = search_engine.get(&document_id);
assert!(result.unwrap() == new_document);
}
#[test]
fn handles_empty_input() {
let mut search_engine = SearchEngineBuilder::<u32>::with_avgdl(2.0).build();
let document = Document::new(123, "");
search_engine.upsert(document);
let results = search_engine.search("bacon sandwich", 5);
assert!(results.is_empty());
}
#[test]
fn handles_empty_search() {
let mut search_engine = SearchEngineBuilder::<u32>::with_avgdl(2.0).build();
let document = Document::new(123, "pencil and paper");
search_engine.upsert(document);
let results = search_engine.search("", 5);
assert!(results.is_empty());
}
#[test]
fn it_returns_exact_matches_with_highest_score() {
let search_engine = create_recipe_search_engine("recipes_en.csv", Language::English);
let results = search_engine.search(
"To make guacamole, start by mashing 2 ripe avocados in a bowl.",
None,
);
assert!(!results.is_empty());
assert_eq!(results[0].document.id, "Guacamole");
}
#[test]
fn it_only_returns_results_containing_query() {
let search_engine = create_recipe_search_engine("recipes_en.csv", Language::English);
let results = search_engine.search("vegetable", 5);
assert_eq!(results.len(), 5);
assert!(results
.iter()
.all(|result| result.document.contents.contains("vegetable")));
}
#[test]
fn it_returns_results_sorted_by_score() {
let search_engine = create_recipe_search_engine("recipes_en.csv", Language::English);
let results = search_engine.search("chicken", 1000);
assert!(!results.is_empty());
assert!(results
.windows(2)
.all(|result_pair| { result_pair[0].score >= result_pair[1].score }));
}
#[test]
fn it_ranks_shorter_documents_higher() {
let documents = [
Document {
id: 0,
contents: "Correct horse battery staple bacon bacon bacon".to_string(),
},
Document {
id: 1,
contents: "Correct horse battery staple".to_string(),
},
];
let search_engine =
SearchEngineBuilder::<u32>::with_documents(Language::English, documents).build();
let results = search_engine.search("staple", 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].document.id, 1);
assert_eq!(results[1].document.id, 0);
assert!(results[0].score > results[1].score);
}
#[test]
fn it_matches_common_unicode_equivalents() {
let corpus = vec!["étude"];
let search_engine =
SearchEngineBuilder::<u32>::with_corpus(Language::French, corpus).build();
let results_1 = search_engine.search("etude", None);
let results_2 = search_engine.search("étude", None);
assert_eq!(results_1.len(), 1);
assert_eq!(results_2.len(), 1);
assert_eq!(results_1, results_2);
}
#[test]
fn it_can_search_for_emoji() {
let corpus = vec!["🔥"];
let search_engine =
SearchEngineBuilder::<u32>::with_corpus(Language::English, corpus).build();
let results_1 = search_engine.search("🔥", None);
let results_2 = search_engine.search("fire", None);
assert_eq!(results_1.len(), 1);
assert_eq!(results_2.len(), 1);
assert_eq!(results_1, results_2);
}
#[test]
fn it_matches_snapshot_en() {
let search_engine = create_recipe_search_engine("recipes_en.csv", Language::English);
let mut results = search_engine.search("bake", None);
results.sort_by_key(|result| result.document.id.clone());
insta::with_settings!({snapshot_path => "../snapshots"}, {
assert_debug_snapshot!(results);
});
}
#[test]
fn it_matches_snapshot_de() {
let search_engine = create_recipe_search_engine("recipes_de.csv", Language::German);
let mut results = search_engine.search("backen", None);
results.sort_by_key(|result| result.document.id.clone());
insta::with_settings!({snapshot_path => "../snapshots"}, {
assert_debug_snapshot!(results);
});
}
}