use crate::embedder::{DefaultEmbeddingSpace, Embedding};
use std::{
cmp::Ordering,
collections::{HashMap, HashSet},
fmt::Debug,
hash::Hash,
};
#[derive(PartialEq, Debug)]
pub struct ScoredDocument<K> {
pub id: K,
pub score: f32,
}
#[derive(Default)]
pub struct Scorer<K, D = DefaultEmbeddingSpace> {
embeddings: HashMap<K, Embedding<D>>,
inverted_token_index: HashMap<D, HashSet<K>>,
}
impl<K, D> Scorer<K, D>
where
D: Eq + Hash + Clone,
K: Eq + Hash + Clone,
{
pub fn new() -> Scorer<K, D> {
Scorer {
embeddings: HashMap::new(),
inverted_token_index: HashMap::new(),
}
}
pub fn upsert(&mut self, document_id: &K, embedding: Embedding<D>) {
if self.embeddings.contains_key(document_id) {
self.remove(document_id);
}
for token_index in embedding.indices() {
let documents_containing_token = self
.inverted_token_index
.entry(token_index.clone())
.or_default();
documents_containing_token.insert(document_id.clone());
}
self.embeddings.insert(document_id.clone(), embedding);
}
pub fn remove(&mut self, document_id: &K) {
if let Some(embedding) = self.embeddings.remove(document_id) {
for token_index in embedding.indices() {
if let Some(matches) = self.inverted_token_index.get_mut(token_index) {
matches.remove(document_id);
}
}
}
}
pub fn score(&self, document_id: &K, query_embedding: &Embedding<D>) -> Option<f32> {
let document_embedding = self.embeddings.get(document_id)?;
Some(self.score_(document_embedding, query_embedding))
}
pub fn matches(&self, query_embedding: &Embedding<D>) -> Vec<ScoredDocument<K>> {
let relevant_embeddings_it = query_embedding
.indices()
.filter_map(|token_index| self.inverted_token_index.get(token_index))
.flat_map(|document_set| document_set.iter())
.collect::<HashSet<_>>()
.into_iter()
.filter_map(|document_id| self.embeddings.get(document_id).map(|e| (document_id, e)));
let mut scores: Vec<_> = relevant_embeddings_it
.map(|(document_id, document_embedding)| ScoredDocument {
id: document_id.clone(),
score: self.score_(document_embedding, query_embedding),
})
.collect();
scores.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
scores
}
fn idf(&self, token_index: &D) -> f32 {
let token_frequency = self
.inverted_token_index
.get(token_index)
.map_or(0, |documents| documents.len()) as f32;
let numerator = self.embeddings.len() as f32 - token_frequency + 0.5;
let denominator = token_frequency + 0.5;
(1f32 + (numerator / denominator)).ln()
}
fn score_(&self, document_embedding: &Embedding<D>, query_embedding: &Embedding<D>) -> f32 {
let mut document_score = 0f32;
for token_index in query_embedding.indices() {
let token_idf = self.idf(token_index);
let token_index_value = document_embedding
.iter()
.find(|token_embedding| token_embedding.index == *token_index)
.map(|token_embedding| token_embedding.value)
.unwrap_or(0f32);
let token_score = token_idf * token_index_value;
document_score += token_score;
}
document_score
}
}
#[cfg(test)]
mod tests {
use crate::TokenEmbedding;
use super::*;
fn scorer_with_embeddings(embeddings: &Vec<Embedding>) -> Scorer<usize> {
let mut scorer = Scorer::<usize>::new();
for (i, document_embedding) in embeddings.iter().enumerate() {
scorer.upsert(&i, document_embedding.clone());
}
scorer
}
#[test]
fn it_scores_missing_document_as_none() {
let scorer = Scorer::<usize>::new();
let query_embedding = Embedding::any();
let score = scorer.score(&12345, &query_embedding);
let matches = scorer.matches(&query_embedding);
assert_eq!(score, None);
assert!(matches.is_empty());
}
#[test]
fn it_scores_mutually_exclusive_indices_as_zero() {
let document_embeddings = vec![Embedding(vec![TokenEmbedding::new(1, 1.0)])];
let scorer = scorer_with_embeddings(&document_embeddings);
let query_embedding = Embedding(vec![TokenEmbedding::new(0, 1.0)]);
let score = scorer.score(&0, &query_embedding);
assert_eq!(score, Some(0.0));
}
#[test]
fn it_scores_rare_indices_higher_than_common_ones() {
let document_embeddings = vec![
Embedding(vec![TokenEmbedding::new(0, 1.0)]),
Embedding(vec![TokenEmbedding::new(0, 1.0)]),
Embedding(vec![TokenEmbedding::new(1, 1.0)]),
];
let scorer = scorer_with_embeddings(&document_embeddings);
let score_1 = scorer.score(&0, &Embedding(vec![TokenEmbedding::new(0, 1.0)]));
let score_2 = scorer.score(&2, &Embedding(vec![TokenEmbedding::new(1, 1.0)]));
assert!(score_1.unwrap() < score_2.unwrap());
}
#[test]
fn it_scores_longer_embeddings_lower_than_shorter_ones() {
let document_embeddings = vec![
Embedding(vec![
TokenEmbedding::new(0, 0.9),
TokenEmbedding::new(1, 0.9),
]),
Embedding(vec![TokenEmbedding::new(0, 1.0)]),
];
let scorer = scorer_with_embeddings(&document_embeddings);
let score_1 = scorer.score(&0, &Embedding(vec![TokenEmbedding::new(0, 1.0)]));
let score_2 = scorer.score(&1, &Embedding(vec![TokenEmbedding::new(0, 1.0)]));
assert!(score_1.unwrap() < score_2.unwrap());
}
#[test]
fn it_only_matches_embeddings_with_non_zero_score() {
let document_embeddings = vec![
Embedding(vec![TokenEmbedding::new(0, 1.0)]),
Embedding(vec![TokenEmbedding::new(1, 1.0)]),
];
let scorer = scorer_with_embeddings(&document_embeddings);
let query_embedding = Embedding(vec![TokenEmbedding::new(0, 1.0)]);
let matches = scorer.matches(&query_embedding);
assert_eq!(
matches,
vec![ScoredDocument {
id: 0,
score: 0.6931472
}]
);
}
#[test]
fn it_does_not_score_frequent_terms_negatively() {
let document_embeddings = vec![Embedding(vec![
TokenEmbedding::new(0, 1.5),
TokenEmbedding::new(0, 1.5),
])];
let scorer = scorer_with_embeddings(&document_embeddings);
let query_embedding = Embedding(vec![TokenEmbedding::new(0, 1.0)]);
let matches = scorer.matches(&query_embedding);
assert!(matches[0].score >= 0.0);
}
#[test]
fn it_sorts_matches_by_score() {
let document_embeddings = vec![
Embedding(vec![
TokenEmbedding::new(0, 0.9),
TokenEmbedding::new(1, 0.9),
]),
Embedding(vec![TokenEmbedding::new(0, 1.0)]),
];
let scorer = scorer_with_embeddings(&document_embeddings);
let query_embedding = Embedding(vec![TokenEmbedding::new(0, 1.0)]);
let matches = scorer.matches(&query_embedding);
assert_eq!(
matches,
vec![
ScoredDocument {
id: 1,
score: 0.1823216
},
ScoredDocument {
id: 0,
score: 0.16408943
}
]
);
}
}