1use ahash::{AHashMap, AHashSet};
2use probe_code::search::elastic_query::Expr;
3use probe_code::search::tokenization;
4use rust_stemmers::{Algorithm, Stemmer};
5use std::sync::OnceLock;
6
7type HashMap<K, V> = AHashMap<K, V>;
9type HashSet<T> = AHashSet<T>;
10
11pub type QueryTokenMap = HashMap<String, u8>;
13
14pub struct TfDfResult {
16 pub term_frequencies: Vec<HashMap<u8, usize>>,
18 pub document_frequencies: HashMap<String, usize>,
20 pub document_lengths: Vec<usize>,
22}
23
24pub struct RankingParams<'a> {
26 pub documents: &'a [&'a str],
28 pub query: &'a str,
30 pub pre_tokenized: Option<&'a [Vec<String>]>,
32}
33
34pub fn get_stemmer() -> &'static Stemmer {
36 static STEMMER: OnceLock<Stemmer> = OnceLock::new();
37 STEMMER.get_or_init(|| Stemmer::create(Algorithm::English))
38}
39
40pub fn tokenize(text: &str) -> Vec<String> {
43 tokenization::tokenize(text)
44}
45
46pub fn preprocess_text_with_filename(text: &str, filename: &str) -> Vec<String> {
49 let mut tokens = tokenize(text);
50 let filename_tokens = tokenize(filename);
51 tokens.extend(filename_tokens);
52 tokens
53}
54
55pub fn compute_avgdl(lengths: &[usize]) -> f64 {
57 if lengths.is_empty() {
58 return 0.0;
59 }
60 let sum: f64 = lengths.iter().map(|&x| x as f64).sum();
63 sum / lengths.len() as f64
64}
65
66pub struct PrecomputedBm25Params<'a> {
72 pub doc_tf: &'a HashMap<u8, usize>,
74 pub doc_len: usize,
76 pub avgdl: f64,
78 pub idfs: &'a HashMap<String, f64>,
80 pub query_token_map: &'a QueryTokenMap,
82 pub k1: f64,
84 pub b: f64,
86}
87
88pub fn extract_query_terms(expr: &Expr) -> HashSet<String> {
90 use Expr::*;
91 let mut terms = HashSet::new();
92
93 match expr {
94 Term { keywords, .. } => {
95 terms.extend(keywords.iter().cloned());
96 }
97 And(left, right) | Or(left, right) => {
98 terms.extend(extract_query_terms(left));
99 terms.extend(extract_query_terms(right));
100 }
101 }
102
103 terms
104}
105
106pub fn precompute_idfs(
108 terms: &HashSet<String>,
109 dfs: &HashMap<String, usize>,
110 n_docs: usize,
111) -> HashMap<String, f64> {
112 let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
113
114 if debug_mode {
115 println!(
116 "DEBUG: Precomputing IDF values for {terms_len} terms",
117 terms_len = terms.len()
118 );
119 }
120
121 terms
122 .iter()
123 .filter_map(|term| {
124 let df = *dfs.get(term).unwrap_or(&0);
125 if df > 0 {
126 let numerator = (n_docs as f64 - df as f64) + 0.5;
127 let denominator = df as f64 + 0.5;
128 let idf = (1.0 + (numerator / denominator)).ln();
129 Some((term.as_str(), idf))
130 } else {
131 None
132 }
133 })
134 .map(|(term, idf)| (term.to_string(), idf))
135 .collect()
136}
137
138fn generate_query_token_map(query_terms: &HashSet<String>) -> Result<QueryTokenMap, &'static str> {
155 if query_terms.len() > 256 {
157 return Err("Query exceeds the 256 unique token limit for u8 mapping");
158 }
159
160 let mut token_map = QueryTokenMap::new();
161 let mut index: u8 = 0;
162
163 let mut sorted_terms: Vec<&str> = query_terms.iter().map(|s| s.as_str()).collect();
165 sorted_terms.sort();
166
167 for term in sorted_terms {
169 token_map.insert(term.to_string(), index);
170 index = index.wrapping_add(1); }
172
173 Ok(token_map)
174}
175
176fn bm25_single_token_optimized(token: &str, params: &PrecomputedBm25Params) -> f64 {
179 let Some(&token_index) = params.query_token_map.get(token) else {
181 return 0.0;
184 };
185
186 let freq_in_doc = *params.doc_tf.get(&token_index).unwrap_or(&0) as f64;
188 if freq_in_doc <= 0.0 {
189 return 0.0;
190 }
191
192 let idf = *params.idfs.get(token).unwrap_or(&0.0);
194
195 let tf_part = (freq_in_doc * (params.k1 + 1.0))
196 / (freq_in_doc
197 + params.k1 * (1.0 - params.b + params.b * (params.doc_len as f64 / params.avgdl)));
198
199 idf * tf_part
200}
201
202fn score_term_bm25_optimized(keywords: &[String], params: &PrecomputedBm25Params) -> f64 {
204 let mut total = 0.0;
205 for kw in keywords {
206 total += bm25_single_token_optimized(kw, params);
207 }
208 total
209}
210
211pub fn score_expr_bm25_optimized(expr: &Expr, params: &PrecomputedBm25Params) -> Option<f64> {
219 use Expr::*;
220 match expr {
221 Term {
222 keywords,
223 required,
224 excluded,
225 ..
226 } => {
227 let score = score_term_bm25_optimized(keywords, params);
228
229 if *excluded {
230 if score > 0.0 {
232 None
233 } else {
234 Some(0.0)
235 }
236 } else if *required {
237 if score > 0.0 {
239 Some(score)
240 } else {
241 None
242 }
243 } else {
244 Some(score)
248 }
249 }
250 And(left, right) => {
251 let lscore = score_expr_bm25_optimized(left, params)?;
252 let rscore = score_expr_bm25_optimized(right, params)?;
253 Some(lscore + rscore)
254 }
255 Or(left, right) => {
256 let l = score_expr_bm25_optimized(left, params);
257 let r = score_expr_bm25_optimized(right, params);
258 match (l, r) {
259 (None, None) => None,
260 (None, Some(rs)) => Some(rs),
261 (Some(ls), None) => Some(ls),
262 (Some(ls), Some(rs)) => Some(ls + rs),
263 }
264 }
265 }
266}
267
268pub fn rank_documents(params: &RankingParams) -> Vec<(usize, f64)> {
272 use rayon::prelude::*;
273 use std::cmp::Ordering;
274
275 let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
276
277 let parsed_expr = match crate::search::elastic_query::parse_query(params.query, false) {
280 Ok(expr) => expr,
281 Err(e) => {
282 if debug_mode {
283 eprintln!("DEBUG: parse_query failed: {e:?}");
284 }
285 eprintln!("WARNING: Query parsing failed: {e:?}. Returning empty results.");
288 return vec![];
291 }
292 };
293
294 let query_terms = extract_query_terms(&parsed_expr);
296
297 let query_token_map = match generate_query_token_map(&query_terms) {
299 Ok(map) => map,
300 Err(e) => {
301 if debug_mode {
302 eprintln!("DEBUG: Failed to generate query token map: {e}");
303 }
304 eprintln!("WARNING: {e}");
305 return vec![];
306 }
307 };
308
309 if debug_mode {
310 println!(
311 "DEBUG: Generated query token map with {} entries",
312 query_token_map.len()
313 );
314 }
315
316 let tf_df_result = if let Some(pre_tokenized) = ¶ms.pre_tokenized {
318 if debug_mode {
320 println!("DEBUG: Using pre-tokenized content for ranking");
321 }
322 compute_tf_df_from_tokenized(pre_tokenized, &query_token_map)
323 } else {
324 if debug_mode {
326 println!("DEBUG: Tokenizing documents for ranking");
327 }
328 let tokenized_docs: Vec<Vec<String>> =
330 params.documents.iter().map(|doc| tokenize(doc)).collect();
331 compute_tf_df_from_tokenized(&tokenized_docs, &query_token_map)
332 };
333
334 let n_docs = params.documents.len();
335 let avgdl = compute_avgdl(&tf_df_result.document_lengths);
336
337 let precomputed_idfs =
340 precompute_idfs(&query_terms, &tf_df_result.document_frequencies, n_docs);
341
342 if debug_mode {
343 println!(
344 "DEBUG: Precomputed IDF values for {} unique query terms",
345 precomputed_idfs.len()
346 );
347 }
348
349 let k1 = 1.2;
355 let b = 0.75;
356
357 if debug_mode {
358 println!("DEBUG: Starting parallel document scoring for {n_docs} documents");
359 }
360
361 let scored_docs: Vec<(usize, Option<f64>)> = (0..tf_df_result.term_frequencies.len())
364 .collect::<Vec<_>>() .par_iter() .map(|&i| {
367 let doc_tf = &tf_df_result.term_frequencies[i];
368 let doc_len = tf_df_result.document_lengths[i];
369
370 let precomputed_bm25_params = PrecomputedBm25Params {
372 doc_tf,
373 doc_len,
374 avgdl,
375 idfs: &precomputed_idfs,
376 query_token_map: &query_token_map,
377 k1,
378 b,
379 };
380
381 let bm25_score_opt = score_expr_bm25_optimized(&parsed_expr, &precomputed_bm25_params);
383
384 (i, bm25_score_opt)
385 })
386 .collect();
387
388 if debug_mode {
389 println!("DEBUG: Parallel document scoring completed");
390 }
391
392 let mut filtered_docs: Vec<(usize, f64)> = scored_docs
394 .into_iter()
395 .filter_map(|(i, score_opt)| score_opt.map(|score| (i, score)))
396 .collect();
397
398 filtered_docs.sort_by(|a, b| {
400 match b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal) {
405 Ordering::Equal => {
406 a.0.cmp(&b.0)
408 }
409 other => other,
410 }
411 });
412
413 if debug_mode {
414 println!(
415 "DEBUG: Sorted {} matching documents by score",
416 filtered_docs.len()
417 );
418 }
419
420 filtered_docs
421}
422
423pub fn compute_tf_df_from_tokenized(
428 tokenized_docs: &[Vec<String>],
429 query_token_map: &QueryTokenMap,
430) -> TfDfResult {
431 use rayon::prelude::*;
432
433 let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
434
435 if debug_mode {
436 println!("DEBUG: Starting parallel TF-DF computation from pre-tokenized content for {docs_len} documents", docs_len = tokenized_docs.len());
437 }
438
439 #[allow(clippy::type_complexity)]
441 let doc_results: Vec<(
442 HashMap<u8, usize>,
443 HashMap<String, usize>,
444 usize,
445 HashSet<String>,
446 )> = tokenized_docs
447 .par_iter()
448 .map(|tokens| {
449 let mut tf_u8 = HashMap::new(); let mut tf_str = HashMap::new(); for token in tokens.iter() {
454 *tf_str.entry(token.clone()).or_insert(0) += 1;
456
457 if let Some(&token_index) = query_token_map.get(token) {
459 *tf_u8.entry(token_index).or_insert(0) += 1;
460 }
461 }
462
463 let unique_terms: HashSet<String> = tf_str.keys().cloned().collect();
465
466 (tf_u8, tf_str, tokens.len(), unique_terms)
467 })
468 .collect();
469
470 let mut term_frequencies = Vec::with_capacity(tokenized_docs.len());
472 let mut document_lengths = Vec::with_capacity(tokenized_docs.len());
473
474 let min_chunk_size = tokenized_docs
484 .len()
485 .checked_div(rayon::current_num_threads())
486 .unwrap_or(1)
487 .max(1);
488 let document_frequencies = doc_results
489 .par_iter()
490 .with_min_len(min_chunk_size) .map(|(_, _, _, unique_terms)| {
492 let mut local_df = HashMap::new();
494 for term in unique_terms {
495 *local_df.entry(term.clone()).or_insert(0) += 1;
496 }
497 local_df
498 })
499 .reduce(HashMap::new, |mut acc, local_df| {
500 for (term, count) in local_df {
502 *acc.entry(term).or_insert(0) += count;
503 }
504 acc
505 });
506
507 if debug_mode {
508 println!(
509 "DEBUG: Parallel DF computation completed with {} unique terms",
510 document_frequencies.len()
511 );
512 }
513
514 for (tf_u8, _, doc_len, _) in doc_results {
516 term_frequencies.push(tf_u8);
517 document_lengths.push(doc_len);
518 }
519
520 if debug_mode {
521 println!("DEBUG: Parallel TF-DF computation from pre-tokenized content completed");
522 println!("DEBUG: Using u8 indices for term frequencies (optimized storage)");
523 }
524
525 TfDfResult {
526 term_frequencies,
527 document_frequencies,
528 document_lengths,
529 }
530}
531
532#[cfg(test)]
536mod tests {
537 use super::*;
538
539 #[test]
540 fn test_basic_bm25_scoring() {
541 let docs = vec!["api process load", "another random text with process"];
543 let query = "+api +process +load"; let params = RankingParams {
546 documents: &docs,
547 query,
548 pre_tokenized: None,
549 };
550
551 let results = rank_documents(¶ms);
552 assert_eq!(results.len(), 1);
554 assert_eq!(results[0].0, 0); assert!(results[0].1 > 0.0);
559 assert!(results[0].1 < 10.0); }
561
562 #[test]
563 fn test_bm25_scoring_with_pre_tokenized() {
564 let docs = vec!["api process load", "another random text with process"];
566 let query = "+api +process +load"; let pre_tokenized = vec![
570 vec!["api".to_string(), "process".to_string(), "load".to_string()],
571 vec![
572 "another".to_string(),
573 "random".to_string(),
574 "text".to_string(),
575 "with".to_string(),
576 "process".to_string(),
577 ],
578 ];
579
580 let params = RankingParams {
581 documents: &docs,
582 query,
583 pre_tokenized: Some(&pre_tokenized),
584 };
585
586 let results = rank_documents(¶ms);
587 assert_eq!(results.len(), 1);
589 assert_eq!(results[0].0, 0); assert!(results[0].1 > 0.0);
593 assert!(results[0].1 < 10.0); }
595
596 #[test]
597 fn test_relative_bm25_scoring() {
598 let docs = vec![
600 "api process load data", "api process load", "api process", "api", ];
605 let query = "api process load data"; let params = RankingParams {
608 documents: &docs,
609 query,
610 pre_tokenized: None,
611 };
612
613 let results = rank_documents(¶ms);
614 assert_eq!(results.len(), 4);
616
617 assert_eq!(results[0].0, 0); assert_eq!(results[1].0, 1); assert_eq!(results[2].0, 2); assert_eq!(results[3].0, 3); assert!(results[0].1 > results[1].1); assert!(results[1].1 > results[2].1); assert!(results[2].1 > results[3].1); }
629
630 #[test]
631 fn test_generate_query_token_map_basic() {
632 let mut query_terms = HashSet::new();
634 query_terms.insert("apple".to_string());
635 query_terms.insert("banana".to_string());
636 query_terms.insert("cherry".to_string());
637
638 let token_map = generate_query_token_map(&query_terms).unwrap();
640
641 assert_eq!(token_map.len(), 3);
643
644 let mut indices = HashSet::new();
646 for (_, &idx) in &token_map {
647 assert!(indices.insert(idx), "Duplicate index found");
648 }
649
650 assert_eq!(indices.len(), 3);
652 assert!(indices.contains(&0));
653 assert!(indices.contains(&1));
654 assert!(indices.contains(&2));
655 }
656
657 #[test]
658 fn test_generate_query_token_map_empty() {
659 let query_terms = HashSet::new();
661
662 let token_map = generate_query_token_map(&query_terms).unwrap();
664
665 assert!(token_map.is_empty());
667 }
668
669 #[test]
670 fn test_generate_query_token_map_deterministic() {
671 let mut query_terms1 = HashSet::new();
673 query_terms1.insert("apple".to_string());
674 query_terms1.insert("banana".to_string());
675 query_terms1.insert("cherry".to_string());
676
677 let mut query_terms2 = HashSet::new();
678 query_terms2.insert("cherry".to_string());
679 query_terms2.insert("apple".to_string());
680 query_terms2.insert("banana".to_string());
681
682 let token_map1 = generate_query_token_map(&query_terms1).unwrap();
684 let token_map2 = generate_query_token_map(&query_terms2).unwrap();
685
686 assert_eq!(token_map1.len(), token_map2.len());
688
689 for (term, &idx1) in &token_map1 {
690 assert_eq!(
691 Some(&idx1),
692 token_map2.get(term),
693 "Term '{term}' has different indices in the two maps"
694 );
695 }
696 }
697
698 #[test]
699 fn test_generate_query_token_map_too_many_terms() {
700 let query_terms: HashSet<String> = (0..257).map(|i| format!("term{i}")).collect();
702
703 let result = generate_query_token_map(&query_terms);
705
706 assert!(result.is_err());
708 assert_eq!(
709 result.unwrap_err(),
710 "Query exceeds the 256 unique token limit for u8 mapping"
711 );
712 }
713
714 #[test]
715 fn test_compute_tf_df_with_u8_indices() {
716 let docs = vec![
718 vec![
719 "apple".to_string(),
720 "banana".to_string(),
721 "cherry".to_string(),
722 ],
723 vec!["apple".to_string(), "banana".to_string()],
724 vec!["apple".to_string()],
725 ];
726
727 let mut query_token_map = QueryTokenMap::new();
729 query_token_map.insert("apple".to_string(), 0);
730 query_token_map.insert("banana".to_string(), 1);
731 query_token_map.insert("cherry".to_string(), 2);
732
733 let tf_df_result = compute_tf_df_from_tokenized(&docs, &query_token_map);
735
736 assert_eq!(tf_df_result.document_lengths[0], 3);
738 assert_eq!(tf_df_result.document_lengths[1], 2);
739 assert_eq!(tf_df_result.document_lengths[2], 1);
740
741 assert_eq!(*tf_df_result.term_frequencies[0].get(&0).unwrap(), 1); assert_eq!(*tf_df_result.term_frequencies[0].get(&1).unwrap(), 1); assert_eq!(*tf_df_result.term_frequencies[0].get(&2).unwrap(), 1); assert_eq!(*tf_df_result.term_frequencies[1].get(&0).unwrap(), 1); assert_eq!(*tf_df_result.term_frequencies[1].get(&1).unwrap(), 1); assert!(tf_df_result.term_frequencies[1].get(&2).is_none()); assert_eq!(*tf_df_result.term_frequencies[2].get(&0).unwrap(), 1); assert!(tf_df_result.term_frequencies[2].get(&1).is_none()); assert!(tf_df_result.term_frequencies[2].get(&2).is_none()); assert_eq!(*tf_df_result.document_frequencies.get("apple").unwrap(), 3); assert_eq!(*tf_df_result.document_frequencies.get("banana").unwrap(), 2); assert_eq!(*tf_df_result.document_frequencies.get("cherry").unwrap(), 1);
758 }
760
761 #[test]
762 fn test_bm25_scoring_with_u8_indices() {
763 let _doc_content = "apple banana cherry"; let mut query_token_map = QueryTokenMap::new();
768 query_token_map.insert("apple".to_string(), 0);
769 query_token_map.insert("banana".to_string(), 1);
770 query_token_map.insert("cherry".to_string(), 2);
771
772 let mut doc_tf = HashMap::new();
774 doc_tf.insert(0u8, 1); doc_tf.insert(1u8, 1); doc_tf.insert(2u8, 1); let mut doc_freqs = HashMap::new();
780 doc_freqs.insert("apple".to_string(), 1);
781 doc_freqs.insert("banana".to_string(), 1);
782 doc_freqs.insert("cherry".to_string(), 1);
783
784 let mut idfs = HashMap::new();
786 idfs.insert("apple".to_string(), 1.0);
787 idfs.insert("banana".to_string(), 1.0);
788 idfs.insert("cherry".to_string(), 1.0);
789
790 let params = PrecomputedBm25Params {
792 doc_tf: &doc_tf,
793 doc_len: 3,
794 avgdl: 3.0,
795 idfs: &idfs,
796 query_token_map: &query_token_map,
797 k1: 1.2,
798 b: 0.75,
799 };
800
801 let apple_score = bm25_single_token_optimized("apple", ¶ms);
803 let banana_score = bm25_single_token_optimized("banana", ¶ms);
804 let cherry_score = bm25_single_token_optimized("cherry", ¶ms);
805
806 assert!(apple_score > 0.0);
809 assert_eq!(apple_score, banana_score);
810 assert_eq!(banana_score, cherry_score);
811
812 let unknown_score = bm25_single_token_optimized("unknown", ¶ms);
814 assert_eq!(unknown_score, 0.0);
815
816 let keywords = vec!["apple".to_string(), "banana".to_string()];
818 let term_score = score_term_bm25_optimized(&keywords, ¶ms);
819
820 assert_eq!(term_score, apple_score + banana_score);
822 }
823}