use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use ailake_core::{AilakeError, AilakeResult};
const K1: f32 = 1.2;
const B: f32 = 0.75;
const MAX_VOCAB: usize = 50_000;
const MIN_TERM_LEN: usize = 2;
pub fn tokenize(text: &str) -> Vec<String> {
text.split(|c: char| !c.is_alphanumeric())
.filter(|t| t.len() >= MIN_TERM_LEN)
.map(|t| t.to_lowercase())
.collect()
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct IdfStats {
pub doc_count: u64,
pub total_tokens: u64,
pub term_df: HashMap<String, u64>,
}
impl IdfStats {
pub fn avg_doc_len(&self) -> f32 {
if self.doc_count == 0 {
1.0
} else {
self.total_tokens as f32 / self.doc_count as f32
}
}
pub fn idf(&self, term: &str) -> f32 {
let df = self.term_df.get(term).copied().unwrap_or(0) as f32;
let n = self.doc_count as f32;
((n - df + 0.5) / (df + 0.5) + 1.0).ln()
}
pub fn merge_batch(&mut self, texts: &[&str]) {
for &text in texts {
let terms = tokenize(text);
self.doc_count += 1;
self.total_tokens += terms.len() as u64;
let mut seen = HashMap::<&str, ()>::new();
for term in &terms {
if seen.insert(term.as_str(), ()).is_none() {
*self.term_df.entry(term.clone()).or_insert(0) += 1;
}
}
}
if self.term_df.len() > MAX_VOCAB {
let mut pairs: Vec<(String, u64)> = self.term_df.drain().collect();
pairs.sort_unstable_by_key(|b| std::cmp::Reverse(b.1));
pairs.truncate(MAX_VOCAB);
self.term_df = pairs.into_iter().collect();
}
}
pub fn to_bytes(&self) -> AilakeResult<Vec<u8>> {
let raw = bincode::serialize(self).map_err(|e| AilakeError::Bincode(e.to_string()))?;
zstd::encode_all(&raw[..], 3).map_err(AilakeError::Io)
}
pub fn from_bytes(bytes: &[u8]) -> AilakeResult<Self> {
let raw = zstd::decode_all(bytes).map_err(AilakeError::Io)?;
bincode::deserialize(&raw).map_err(|e| AilakeError::Bincode(e.to_string()))
}
}
pub struct BM25Scorer<'a> {
stats: &'a IdfStats,
}
impl<'a> BM25Scorer<'a> {
pub fn new(stats: &'a IdfStats) -> Self {
Self { stats }
}
pub fn score(&self, query_text: &str, doc_text: &str) -> f32 {
let query_terms = tokenize(query_text);
if query_terms.is_empty() {
return 0.0;
}
let doc_terms = tokenize(doc_text);
let doc_len = doc_terms.len() as f32;
let avgdl = self.stats.avg_doc_len();
let mut tf_map: HashMap<&str, u32> = HashMap::new();
for term in &doc_terms {
*tf_map.entry(term.as_str()).or_insert(0) += 1;
}
let mut score = 0.0f32;
for term in &query_terms {
let tf = tf_map.get(term.as_str()).copied().unwrap_or(0) as f32;
if tf == 0.0 {
continue;
}
let idf = self.stats.idf(term);
let tf_norm = tf * (K1 + 1.0) / (tf + K1 * (1.0 - B + B * doc_len / avgdl));
score += idf * tf_norm;
}
score
}
pub fn score_batch(&self, query_text: &str, docs: &[&str]) -> Vec<f32> {
docs.iter().map(|doc| self.score(query_text, doc)).collect()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum HybridFusion {
#[default]
Rrf,
Linear,
}
#[derive(Debug, Clone)]
pub struct HybridConfig {
pub query_text: String,
pub text_columns: Vec<String>,
pub bm25_weight: f32,
pub fusion: HybridFusion,
pub candidate_pool: Option<usize>,
}
impl Default for HybridConfig {
fn default() -> Self {
Self {
query_text: String::new(),
text_columns: vec!["chunk_text".to_string()],
bm25_weight: 0.5,
fusion: HybridFusion::Rrf,
candidate_pool: None,
}
}
}
impl HybridConfig {
pub fn new(query_text: impl Into<String>) -> Self {
Self {
query_text: query_text.into(),
..Default::default()
}
}
pub fn with_text_column(mut self, col: impl Into<String>) -> Self {
self.text_columns = vec![col.into()];
self
}
pub fn with_text_columns(mut self, cols: Vec<String>) -> Self {
self.text_columns = cols;
self
}
pub fn with_bm25_weight(mut self, w: f32) -> Self {
self.bm25_weight = w.clamp(0.0, 1.0);
self
}
pub fn with_fusion(mut self, fusion: HybridFusion) -> Self {
self.fusion = fusion;
self
}
pub fn with_candidate_pool(mut self, n: usize) -> Self {
self.candidate_pool = Some(n);
self
}
}
pub fn rrf_score(vec_rank: usize, bm25_rank: usize, bm25_weight: f32) -> f32 {
const RRF_K: f32 = 60.0;
let vec_weight = 1.0 - bm25_weight;
let rrf = vec_weight / (RRF_K + vec_rank as f32) + bm25_weight / (RRF_K + bm25_rank as f32);
-rrf
}
pub fn linear_score(
vec_dist: f32,
min_vec: f32,
max_vec: f32,
bm25: f32,
min_bm25: f32,
max_bm25: f32,
bm25_weight: f32,
) -> f32 {
let norm_vec = if (max_vec - min_vec).abs() < f32::EPSILON {
0.0
} else {
(vec_dist - min_vec) / (max_vec - min_vec)
};
let norm_bm25 = if (max_bm25 - min_bm25).abs() < f32::EPSILON {
0.5
} else {
(bm25 - min_bm25) / (max_bm25 - min_bm25)
};
let vec_weight = 1.0 - bm25_weight;
vec_weight * norm_vec + bm25_weight * (1.0 - norm_bm25)
}
pub const BM25_STATS_PATH_PROP: &str = "ailake.bm25.stats-path";
pub const BM25_STATS_FILE: &str = "metadata/ailake_bm25_stats.bin";
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tokenize_basic() {
let tokens = tokenize("Hello, World! This is a test.");
assert!(tokens.contains(&"hello".to_string()));
assert!(tokens.contains(&"world".to_string()));
assert!(tokens.contains(&"test".to_string()));
assert!(!tokens.contains(&"a".to_string()));
}
#[test]
fn idf_empty_corpus_returns_positive() {
let stats = IdfStats::default();
let idf = stats.idf("unknown_term");
assert!(idf > 0.0, "IDF should be positive for unseen term");
}
#[test]
fn merge_batch_accumulates_df() {
let mut stats = IdfStats::default();
stats.merge_batch(&["the quick brown fox", "the lazy dog"]);
assert_eq!(stats.doc_count, 2);
assert_eq!(stats.term_df["the"], 2, "the appears in both docs");
assert_eq!(stats.term_df["fox"], 1);
assert_eq!(stats.term_df["dog"], 1);
}
#[test]
fn bm25_scorer_ranks_relevant_doc_higher() {
let mut stats = IdfStats::default();
let docs = [
"rust programming language systems",
"python machine learning data science",
"rust memory safety zero cost abstractions",
];
stats.merge_batch(&docs);
let scorer = BM25Scorer::new(&stats);
let query = "rust systems programming";
let s0 = scorer.score(query, docs[0]);
let s1 = scorer.score(query, docs[1]);
let s2 = scorer.score(query, docs[2]);
assert!(
s0 > s1,
"rust doc scores higher than python doc: s0={s0}, s1={s1}"
);
assert!(
s2 > s1,
"rust doc scores higher than python doc: s2={s2}, s1={s1}"
);
}
#[test]
fn idf_stats_roundtrip() {
let mut stats = IdfStats::default();
stats.merge_batch(&["hello world foo bar", "foo baz qux"]);
let bytes = stats.to_bytes().unwrap();
let restored = IdfStats::from_bytes(&bytes).unwrap();
assert_eq!(restored.doc_count, stats.doc_count);
assert_eq!(restored.term_df["foo"], 2);
assert_eq!(restored.term_df["hello"], 1);
}
#[test]
fn vocab_cap_prunes_to_max() {
let mut stats = IdfStats::default();
let doc: String = (0..=MAX_VOCAB + 100)
.map(|i| format!("term{i}"))
.collect::<Vec<_>>()
.join(" ");
stats.merge_batch(&[doc.as_str()]);
assert!(
stats.term_df.len() <= MAX_VOCAB,
"vocab should be capped at {MAX_VOCAB}"
);
}
#[test]
fn rrf_score_is_negative() {
let s = rrf_score(0, 0, 0.5);
assert!(
s < 0.0,
"RRF score should be negated for sort-ascending convention"
);
}
#[test]
fn linear_score_in_range() {
let s = linear_score(0.5, 0.0, 1.0, 0.8, 0.0, 1.0, 0.5);
assert!(
(0.0..=1.0).contains(&s),
"linear score should be in [0,1]: {s}"
);
}
}