1use std::collections::HashMap;
19
20use serde::{Deserialize, Serialize};
21
22use ailake_core::{AilakeError, AilakeResult};
23
24const K1: f32 = 1.2;
26const B: f32 = 0.75;
27const MAX_VOCAB: usize = 50_000;
29const MIN_TERM_LEN: usize = 2;
31
32pub fn tokenize(text: &str) -> Vec<String> {
34 text.split(|c: char| !c.is_alphanumeric())
35 .filter(|t| t.len() >= MIN_TERM_LEN)
36 .map(|t| t.to_lowercase())
37 .collect()
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize, Default)]
44pub struct IdfStats {
45 pub doc_count: u64,
47 pub total_tokens: u64,
49 pub term_df: HashMap<String, u64>,
51}
52
53impl IdfStats {
54 pub fn avg_doc_len(&self) -> f32 {
55 if self.doc_count == 0 {
56 1.0
57 } else {
58 self.total_tokens as f32 / self.doc_count as f32
59 }
60 }
61
62 pub fn idf(&self, term: &str) -> f32 {
64 let df = self.term_df.get(term).copied().unwrap_or(0) as f32;
65 let n = self.doc_count as f32;
66 ((n - df + 0.5) / (df + 0.5) + 1.0).ln()
69 }
70
71 pub fn merge_batch(&mut self, texts: &[&str]) {
76 for &text in texts {
77 let terms = tokenize(text);
78 self.doc_count += 1;
79 self.total_tokens += terms.len() as u64;
80
81 let mut seen = HashMap::<&str, ()>::new();
83 for term in &terms {
84 if seen.insert(term.as_str(), ()).is_none() {
85 *self.term_df.entry(term.clone()).or_insert(0) += 1;
86 }
87 }
88 }
89
90 if self.term_df.len() > MAX_VOCAB {
91 let mut pairs: Vec<(String, u64)> = self.term_df.drain().collect();
94 pairs.sort_unstable_by_key(|b| std::cmp::Reverse(b.1));
95 pairs.truncate(MAX_VOCAB);
96 self.term_df = pairs.into_iter().collect();
97 }
98 }
99
100 pub fn to_bytes(&self) -> AilakeResult<Vec<u8>> {
102 let raw = bincode::serialize(self).map_err(|e| AilakeError::Bincode(e.to_string()))?;
103 zstd::encode_all(&raw[..], 3).map_err(AilakeError::Io)
104 }
105
106 pub fn from_bytes(bytes: &[u8]) -> AilakeResult<Self> {
108 let raw = zstd::decode_all(bytes).map_err(AilakeError::Io)?;
109 bincode::deserialize(&raw).map_err(|e| AilakeError::Bincode(e.to_string()))
110 }
111}
112
113pub struct BM25Scorer<'a> {
115 stats: &'a IdfStats,
116}
117
118impl<'a> BM25Scorer<'a> {
119 pub fn new(stats: &'a IdfStats) -> Self {
120 Self { stats }
121 }
122
123 pub fn score(&self, query_text: &str, doc_text: &str) -> f32 {
125 let query_terms = tokenize(query_text);
126 if query_terms.is_empty() {
127 return 0.0;
128 }
129
130 let doc_terms = tokenize(doc_text);
131 let doc_len = doc_terms.len() as f32;
132 let avgdl = self.stats.avg_doc_len();
133
134 let mut tf_map: HashMap<&str, u32> = HashMap::new();
135 for term in &doc_terms {
136 *tf_map.entry(term.as_str()).or_insert(0) += 1;
137 }
138
139 let mut score = 0.0f32;
140 for term in &query_terms {
141 let tf = tf_map.get(term.as_str()).copied().unwrap_or(0) as f32;
142 if tf == 0.0 {
143 continue;
144 }
145 let idf = self.stats.idf(term);
146 let tf_norm = tf * (K1 + 1.0) / (tf + K1 * (1.0 - B + B * doc_len / avgdl));
148 score += idf * tf_norm;
149 }
150 score
151 }
152
153 pub fn score_batch(&self, query_text: &str, docs: &[&str]) -> Vec<f32> {
155 docs.iter().map(|doc| self.score(query_text, doc)).collect()
156 }
157}
158
159#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
161pub enum HybridFusion {
162 #[default]
165 Rrf,
166 Linear,
169}
170
171#[derive(Debug, Clone)]
173pub struct HybridConfig {
174 pub query_text: String,
176 pub text_columns: Vec<String>,
179 pub bm25_weight: f32,
182 pub fusion: HybridFusion,
184 pub candidate_pool: Option<usize>,
188}
189
190impl Default for HybridConfig {
191 fn default() -> Self {
192 Self {
193 query_text: String::new(),
194 text_columns: vec!["chunk_text".to_string()],
195 bm25_weight: 0.5,
196 fusion: HybridFusion::Rrf,
197 candidate_pool: None,
198 }
199 }
200}
201
202impl HybridConfig {
203 pub fn new(query_text: impl Into<String>) -> Self {
204 Self {
205 query_text: query_text.into(),
206 ..Default::default()
207 }
208 }
209
210 pub fn with_text_column(mut self, col: impl Into<String>) -> Self {
211 self.text_columns = vec![col.into()];
212 self
213 }
214
215 pub fn with_text_columns(mut self, cols: Vec<String>) -> Self {
216 self.text_columns = cols;
217 self
218 }
219
220 pub fn with_bm25_weight(mut self, w: f32) -> Self {
221 self.bm25_weight = w.clamp(0.0, 1.0);
222 self
223 }
224
225 pub fn with_fusion(mut self, fusion: HybridFusion) -> Self {
226 self.fusion = fusion;
227 self
228 }
229
230 pub fn with_candidate_pool(mut self, n: usize) -> Self {
231 self.candidate_pool = Some(n);
232 self
233 }
234}
235
236pub fn rrf_score(vec_rank: usize, bm25_rank: usize, bm25_weight: f32) -> f32 {
240 const RRF_K: f32 = 60.0;
241 let vec_weight = 1.0 - bm25_weight;
242 let rrf = vec_weight / (RRF_K + vec_rank as f32) + bm25_weight / (RRF_K + bm25_rank as f32);
243 -rrf
244}
245
246pub fn linear_score(
251 vec_dist: f32,
252 min_vec: f32,
253 max_vec: f32,
254 bm25: f32,
255 min_bm25: f32,
256 max_bm25: f32,
257 bm25_weight: f32,
258) -> f32 {
259 let norm_vec = if (max_vec - min_vec).abs() < f32::EPSILON {
260 0.0
261 } else {
262 (vec_dist - min_vec) / (max_vec - min_vec)
263 };
264 let norm_bm25 = if (max_bm25 - min_bm25).abs() < f32::EPSILON {
265 0.5
266 } else {
267 (bm25 - min_bm25) / (max_bm25 - min_bm25)
268 };
269 let vec_weight = 1.0 - bm25_weight;
270 vec_weight * norm_vec + bm25_weight * (1.0 - norm_bm25)
272}
273
274pub const BM25_STATS_PATH_PROP: &str = "ailake.bm25.stats-path";
276pub const BM25_STATS_FILE: &str = "metadata/ailake_bm25_stats.bin";
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[test]
284 fn tokenize_basic() {
285 let tokens = tokenize("Hello, World! This is a test.");
286 assert!(tokens.contains(&"hello".to_string()));
287 assert!(tokens.contains(&"world".to_string()));
288 assert!(tokens.contains(&"test".to_string()));
289 assert!(!tokens.contains(&"a".to_string()));
291 }
292
293 #[test]
294 fn idf_empty_corpus_returns_positive() {
295 let stats = IdfStats::default();
296 let idf = stats.idf("unknown_term");
297 assert!(idf > 0.0, "IDF should be positive for unseen term");
298 }
299
300 #[test]
301 fn merge_batch_accumulates_df() {
302 let mut stats = IdfStats::default();
303 stats.merge_batch(&["the quick brown fox", "the lazy dog"]);
304 assert_eq!(stats.doc_count, 2);
305 assert_eq!(stats.term_df["the"], 2, "the appears in both docs");
306 assert_eq!(stats.term_df["fox"], 1);
307 assert_eq!(stats.term_df["dog"], 1);
308 }
309
310 #[test]
311 fn bm25_scorer_ranks_relevant_doc_higher() {
312 let mut stats = IdfStats::default();
313 let docs = [
314 "rust programming language systems",
315 "python machine learning data science",
316 "rust memory safety zero cost abstractions",
317 ];
318 stats.merge_batch(&docs);
319
320 let scorer = BM25Scorer::new(&stats);
321 let query = "rust systems programming";
322 let s0 = scorer.score(query, docs[0]);
323 let s1 = scorer.score(query, docs[1]);
324 let s2 = scorer.score(query, docs[2]);
325
326 assert!(
328 s0 > s1,
329 "rust doc scores higher than python doc: s0={s0}, s1={s1}"
330 );
331 assert!(
332 s2 > s1,
333 "rust doc scores higher than python doc: s2={s2}, s1={s1}"
334 );
335 }
336
337 #[test]
338 fn idf_stats_roundtrip() {
339 let mut stats = IdfStats::default();
340 stats.merge_batch(&["hello world foo bar", "foo baz qux"]);
341 let bytes = stats.to_bytes().unwrap();
342 let restored = IdfStats::from_bytes(&bytes).unwrap();
343 assert_eq!(restored.doc_count, stats.doc_count);
344 assert_eq!(restored.term_df["foo"], 2);
345 assert_eq!(restored.term_df["hello"], 1);
346 }
347
348 #[test]
349 fn vocab_cap_prunes_to_max() {
350 let mut stats = IdfStats::default();
351 let doc: String = (0..=MAX_VOCAB + 100)
353 .map(|i| format!("term{i}"))
354 .collect::<Vec<_>>()
355 .join(" ");
356 stats.merge_batch(&[doc.as_str()]);
357 assert!(
358 stats.term_df.len() <= MAX_VOCAB,
359 "vocab should be capped at {MAX_VOCAB}"
360 );
361 }
362
363 #[test]
364 fn rrf_score_is_negative() {
365 let s = rrf_score(0, 0, 0.5);
366 assert!(
367 s < 0.0,
368 "RRF score should be negated for sort-ascending convention"
369 );
370 }
371
372 #[test]
373 fn linear_score_in_range() {
374 let s = linear_score(0.5, 0.0, 1.0, 0.8, 0.0, 1.0, 0.5);
375 assert!(
376 (0.0..=1.0).contains(&s),
377 "linear score should be in [0,1]: {s}"
378 );
379 }
380}