use std::{cmp::Ordering, collections::HashMap};
use crate::superfile::fts::tokenize::Tokenizer;
const K1: f32 = 1.2;
const B: f32 = 0.75;
struct DocStats {
doc_id: u64,
dl: u32,
tf: HashMap<String, u32>,
}
pub struct BruteForceBm25 {
docs: Vec<DocStats>,
df: HashMap<String, u32>,
avgdl: f32,
n: u32,
}
impl BruteForceBm25 {
pub fn index(corpus: &[(u64, &str)], tokenizer: &dyn Tokenizer) -> Self {
let mut docs: Vec<DocStats> = Vec::with_capacity(corpus.len());
let mut df: HashMap<String, u32> = HashMap::new();
let mut total_tokens: u64 = 0;
for (doc_id, text) in corpus {
let mut tf: HashMap<String, u32> = HashMap::new();
let mut dl: u32 = 0;
tokenizer.tokenize_each(text, &mut |tok| {
dl += 1;
*tf.entry(tok.to_owned()).or_insert(0) += 1;
});
for term in tf.keys() {
*df.entry(term.clone()).or_insert(0) += 1;
}
total_tokens += dl as u64;
docs.push(DocStats {
doc_id: *doc_id,
dl,
tf,
});
}
let n = docs.len() as u32;
let avgdl = if n == 0 {
0.0
} else {
total_tokens as f32 / n as f32
};
Self { docs, df, avgdl, n }
}
pub fn top_k(&self, query: &str, k: usize, tokenizer: &dyn Tokenizer) -> Vec<(u64, f32)> {
if k == 0 || self.n == 0 {
return Vec::new();
}
let mut q_terms: Vec<String> = Vec::new();
tokenizer.tokenize_each(query, &mut |tok| q_terms.push(tok.to_owned()));
self.top_k_terms(&q_terms, k)
}
pub fn top_k_terms(&self, terms: &[String], k: usize) -> Vec<(u64, f32)> {
if k == 0 || terms.is_empty() || self.n == 0 {
return Vec::new();
}
let n = self.n as f32;
let avgdl = self.avgdl;
let mut scored: Vec<(u64, f32)> = Vec::with_capacity(self.docs.len());
for doc in &self.docs {
let mut score: f32 = 0.0;
let dl = doc.dl as f32;
let dl_norm = K1 * (1.0 - B + B * dl / avgdl.max(f32::MIN_POSITIVE));
for term in terms {
let Some(&tf) = doc.tf.get(term) else {
continue;
};
let df = *self.df.get(term).unwrap_or(&0) as f32;
if df == 0.0 {
continue;
}
let idf = (1.0 + (n - df + 0.5) / (df + 0.5)).ln();
let tf_f = tf as f32;
score += idf * tf_f * (K1 + 1.0) / (tf_f + dl_norm);
}
if score > 0.0 {
scored.push((doc.doc_id, score));
}
}
scored.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(Ordering::Equal)
.then(a.0.cmp(&b.0))
});
scored.truncate(k);
scored
}
pub fn top_k_terms_and(&self, terms: &[String], k: usize) -> Vec<(u64, f32)> {
if k == 0 || terms.is_empty() || self.n == 0 {
return Vec::new();
}
let n = self.n as f32;
let avgdl = self.avgdl;
let mut scored: Vec<(u64, f32)> = Vec::with_capacity(self.docs.len());
'docs: for doc in &self.docs {
let dl = doc.dl as f32;
let dl_norm = K1 * (1.0 - B + B * dl / avgdl.max(f32::MIN_POSITIVE));
let mut score: f32 = 0.0;
for term in terms {
let Some(&tf) = doc.tf.get(term) else {
continue 'docs;
};
let df = *self.df.get(term).unwrap_or(&0) as f32;
if df == 0.0 {
continue 'docs;
}
let idf = (1.0 + (n - df + 0.5) / (df + 0.5)).ln();
let tf_f = tf as f32;
score += idf * tf_f * (K1 + 1.0) / (tf_f + dl_norm);
}
scored.push((doc.doc_id, score));
}
scored.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(Ordering::Equal)
.then(a.0.cmp(&b.0))
});
scored.truncate(k);
scored
}
pub fn n_docs(&self) -> u32 {
self.n
}
}