probe_code/
ranking.rs

1use ahash::{AHashMap, AHashSet};
2use probe_code::search::elastic_query::Expr;
3use probe_code::search::tokenization;
4use rust_stemmers::{Algorithm, Stemmer};
5use std::sync::OnceLock;
6
7// Replace standard collections with ahash versions for better performance
8type HashMap<K, V> = AHashMap<K, V>;
9type HashSet<T> = AHashSet<T>;
10
11/// Maps unique query tokens (Strings) to unique u8 indices
12pub type QueryTokenMap = HashMap<String, u8>;
13
14/// Represents the result of term frequency and document frequency computation
15pub struct TfDfResult {
16    /// Term frequencies for each document, using u8 index for query tokens
17    pub term_frequencies: Vec<HashMap<u8, usize>>,
18    /// Document frequencies for each term (remains String-based for IDF)
19    pub document_frequencies: HashMap<String, usize>,
20    /// Document lengths (number of tokens in each document)
21    pub document_lengths: Vec<usize>,
22}
23
24/// Parameters for document ranking
25pub struct RankingParams<'a> {
26    /// Documents to rank
27    pub documents: &'a [&'a str],
28    /// Query string
29    pub query: &'a str,
30    /// Pre-tokenized content (optional)
31    pub pre_tokenized: Option<&'a [Vec<String>]>,
32}
33
34/// Returns a reference to the global stemmer instance
35pub fn get_stemmer() -> &'static Stemmer {
36    static STEMMER: OnceLock<Stemmer> = OnceLock::new();
37    STEMMER.get_or_init(|| Stemmer::create(Algorithm::English))
38}
39
40/// Tokenizes text into lowercase words by splitting on whitespace and non-alphanumeric characters,
41/// removes stop words, and applies stemming. Also splits camelCase/PascalCase identifiers.
42pub fn tokenize(text: &str) -> Vec<String> {
43    tokenization::tokenize(text)
44}
45
46/// Preprocesses text with filename for search by tokenizing and removing duplicates
47/// This is used for filename matching - it adds the filename and its directory structure to the tokens
48pub fn preprocess_text_with_filename(text: &str, filename: &str) -> Vec<String> {
49    let mut tokens = tokenize(text);
50    let filename_tokens = tokenize(filename);
51    tokens.extend(filename_tokens);
52    tokens
53}
54
55/// Computes the average document length.
56pub fn compute_avgdl(lengths: &[usize]) -> f64 {
57    if lengths.is_empty() {
58        return 0.0;
59    }
60    // Convert to f64 before summing to prevent potential integer overflow
61    // when dealing with very large documents or a large number of documents
62    let sum: f64 = lengths.iter().map(|&x| x as f64).sum();
63    sum / lengths.len() as f64
64}
65
66// -------------------------------------------------------------------------
67// BM25 EXACT (like Elasticsearch) with "bool" logic for must/should/must_not
68// -------------------------------------------------------------------------
69
70/// Parameters for BM25 calculation with precomputed IDF values
71pub struct PrecomputedBm25Params<'a> {
72    /// Document term frequencies using u8 indices
73    pub doc_tf: &'a HashMap<u8, usize>,
74    /// Document length
75    pub doc_len: usize,
76    /// Average document length
77    pub avgdl: f64,
78    /// Precomputed IDF values for query terms (remains String-based)
79    pub idfs: &'a HashMap<String, f64>,
80    /// Map from query term string to u8 index
81    pub query_token_map: &'a QueryTokenMap,
82    /// BM25 k1 parameter
83    pub k1: f64,
84    /// BM25 b parameter
85    pub b: f64,
86}
87
88/// Extracts unique terms from a query expression
89pub fn extract_query_terms(expr: &Expr) -> HashSet<String> {
90    use Expr::*;
91    let mut terms = HashSet::new();
92
93    match expr {
94        Term { keywords, .. } => {
95            terms.extend(keywords.iter().cloned());
96        }
97        And(left, right) | Or(left, right) => {
98            terms.extend(extract_query_terms(left));
99            terms.extend(extract_query_terms(right));
100        }
101    }
102
103    terms
104}
105
106/// Precomputes IDF values for a set of terms
107pub fn precompute_idfs(
108    terms: &HashSet<String>,
109    dfs: &HashMap<String, usize>,
110    n_docs: usize,
111) -> HashMap<String, f64> {
112    let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
113
114    if debug_mode {
115        println!(
116            "DEBUG: Precomputing IDF values for {terms_len} terms",
117            terms_len = terms.len()
118        );
119    }
120
121    terms
122        .iter()
123        .filter_map(|term| {
124            let df = *dfs.get(term).unwrap_or(&0);
125            if df > 0 {
126                let numerator = (n_docs as f64 - df as f64) + 0.5;
127                let denominator = df as f64 + 0.5;
128                let idf = (1.0 + (numerator / denominator)).ln();
129                Some((term.as_str(), idf))
130            } else {
131                None
132            }
133        })
134        .map(|(term, idf)| (term.to_string(), idf))
135        .collect()
136}
137
138/// Generates a mapping from unique query tokens to unique u8 indices.
139///
140/// This function takes a set of query terms and assigns a unique u8 index to each term.
141/// The mapping is used to efficiently represent query terms in memory-constrained contexts.
142///
143/// # Arguments
144///
145/// * `query_terms` - A set of unique query terms to map to indices
146///
147/// # Returns
148///
149/// * `Result<QueryTokenMap, &'static str>` - A mapping from terms to indices, or an error if there are too many terms
150///
151/// # Errors
152///
153/// Returns an error if the number of unique query terms exceeds 256 (the maximum value for u8).
154fn generate_query_token_map(query_terms: &HashSet<String>) -> Result<QueryTokenMap, &'static str> {
155    // Check if we have too many terms for u8 mapping
156    if query_terms.len() > 256 {
157        return Err("Query exceeds the 256 unique token limit for u8 mapping");
158    }
159
160    let mut token_map = QueryTokenMap::new();
161    let mut index: u8 = 0;
162
163    // Sort terms for deterministic mapping (HashMap iteration order isn't guaranteed)
164    let mut sorted_terms: Vec<&str> = query_terms.iter().map(|s| s.as_str()).collect();
165    sorted_terms.sort();
166
167    // Assign each term a unique index
168    for term in sorted_terms {
169        token_map.insert(term.to_string(), index);
170        index = index.wrapping_add(1); // Use wrapping_add to handle potential overflow safely
171    }
172
173    Ok(token_map)
174}
175
176/// Optimized BM25 single-token function using precomputed IDF values:
177/// tf_part = freq * (k1+1) / (freq + k1*(1 - b + b*(docLen/avgdl)))
178fn bm25_single_token_optimized(token: &str, params: &PrecomputedBm25Params) -> f64 {
179    // Look up the token's index in the query_token_map
180    let Some(&token_index) = params.query_token_map.get(token) else {
181        // This query term string doesn't have a u8 mapping? Should not happen if map is correct.
182        // Or maybe the term was stemmed differently? Return 0 for safety.
183        return 0.0;
184    };
185
186    // Get frequency using the u8 index
187    let freq_in_doc = *params.doc_tf.get(&token_index).unwrap_or(&0) as f64;
188    if freq_in_doc <= 0.0 {
189        return 0.0;
190    }
191
192    // Use precomputed IDF value (still using string for IDF lookup)
193    let idf = *params.idfs.get(token).unwrap_or(&0.0);
194
195    let tf_part = (freq_in_doc * (params.k1 + 1.0))
196        / (freq_in_doc
197            + params.k1 * (1.0 - params.b + params.b * (params.doc_len as f64 / params.avgdl)));
198
199    idf * tf_part
200}
201
202/// Sum BM25 for all keywords in a single "Term" node using precomputed IDF values
203fn score_term_bm25_optimized(keywords: &[String], params: &PrecomputedBm25Params) -> f64 {
204    let mut total = 0.0;
205    for kw in keywords {
206        total += bm25_single_token_optimized(kw, params);
207    }
208    total
209}
210
211/// Recursively compute a doc's "ES-like BM25 bool query" score from the AST using precomputed IDF values:
212/// - If it fails a must or matches a must_not => return None (exclude doc)
213/// - Otherwise sum up matched subclause scores
214/// - For "OR," doc must match at least one side
215/// - For "AND," doc must match both sides
216/// - For a "should" term, we add the BM25 if it matches; if the entire query has no must, then
217///   at least one "should" must match in order to include the doc.
218pub fn score_expr_bm25_optimized(expr: &Expr, params: &PrecomputedBm25Params) -> Option<f64> {
219    use Expr::*;
220    match expr {
221        Term {
222            keywords,
223            required,
224            excluded,
225            ..
226        } => {
227            let score = score_term_bm25_optimized(keywords, params);
228
229            if *excluded {
230                // must_not => doc out if doc_score > 0
231                if score > 0.0 {
232                    None
233                } else {
234                    Some(0.0)
235                }
236            } else if *required {
237                // must => doc out if doc_score=0
238                if score > 0.0 {
239                    Some(score)
240                } else {
241                    None
242                }
243            } else {
244                // "should" => we don't exclude doc if score=0 here, because maybe it matches
245                // something else in an OR. Return Some(0.0 or some positive).
246                // The top-level logic ensures if no must in the entire query, we need at least one should>0.
247                Some(score)
248            }
249        }
250        And(left, right) => {
251            let lscore = score_expr_bm25_optimized(left, params)?;
252            let rscore = score_expr_bm25_optimized(right, params)?;
253            Some(lscore + rscore)
254        }
255        Or(left, right) => {
256            let l = score_expr_bm25_optimized(left, params);
257            let r = score_expr_bm25_optimized(right, params);
258            match (l, r) {
259                (None, None) => None,
260                (None, Some(rs)) => Some(rs),
261                (Some(ls), None) => Some(ls),
262                (Some(ls), Some(rs)) => Some(ls + rs),
263            }
264        }
265    }
266}
267
268// -------------------------------------------------------------------------
269// This is your main entry point for ranking. It now does "pure BM25 like ES."
270// -------------------------------------------------------------------------
271pub fn rank_documents(params: &RankingParams) -> Vec<(usize, f64)> {
272    use rayon::prelude::*;
273    use std::cmp::Ordering;
274
275    let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
276
277    // 1) Parse the user query into an AST (Expr)
278    //    If your code uses parse_query(...) from `elastic_query.rs`, do:
279    let parsed_expr = match crate::search::elastic_query::parse_query(params.query, false) {
280        Ok(expr) => expr,
281        Err(e) => {
282            if debug_mode {
283                eprintln!("DEBUG: parse_query failed: {e:?}");
284            }
285            // Instead of silently returning empty results, log a warning even in non-debug mode
286            // to ensure errors are visible and can be addressed
287            eprintln!("WARNING: Query parsing failed: {e:?}. Returning empty results.");
288            // In a future version, consider changing the return type to Result<Vec<(usize, f64)>, QueryError>
289            // to properly propagate errors to the caller
290            return vec![];
291        }
292    };
293
294    // 3) Extract query terms, create token mapping
295    let query_terms = extract_query_terms(&parsed_expr);
296
297    // Generate query token map (maps each unique query term to a unique u8 index)
298    let query_token_map = match generate_query_token_map(&query_terms) {
299        Ok(map) => map,
300        Err(e) => {
301            if debug_mode {
302                eprintln!("DEBUG: Failed to generate query token map: {e}");
303            }
304            eprintln!("WARNING: {e}");
305            return vec![];
306        }
307    };
308
309    if debug_mode {
310        println!(
311            "DEBUG: Generated query token map with {} entries",
312            query_token_map.len()
313        );
314    }
315
316    // 2) Precompute TF/DF for docs
317    let tf_df_result = if let Some(pre_tokenized) = &params.pre_tokenized {
318        // Use pre-tokenized content if available
319        if debug_mode {
320            println!("DEBUG: Using pre-tokenized content for ranking");
321        }
322        compute_tf_df_from_tokenized(pre_tokenized, &query_token_map)
323    } else {
324        // Fallback to tokenizing the documents
325        if debug_mode {
326            println!("DEBUG: Tokenizing documents for ranking");
327        }
328        // Tokenize documents on the fly
329        let tokenized_docs: Vec<Vec<String>> =
330            params.documents.iter().map(|doc| tokenize(doc)).collect();
331        compute_tf_df_from_tokenized(&tokenized_docs, &query_token_map)
332    };
333
334    let n_docs = params.documents.len();
335    let avgdl = compute_avgdl(&tf_df_result.document_lengths);
336
337    // Precompute IDF values
338
339    let precomputed_idfs =
340        precompute_idfs(&query_terms, &tf_df_result.document_frequencies, n_docs);
341
342    if debug_mode {
343        println!(
344            "DEBUG: Precomputed IDF values for {} unique query terms",
345            precomputed_idfs.len()
346        );
347    }
348
349    // 4) BM25 parameters
350    // These values are standard defaults for BM25 as established in academic literature:
351    // k1=1.2 controls term frequency saturation (higher values give more weight to term frequency)
352    // b=0.75 controls document length normalization (higher values give more penalty to longer documents)
353    // See: Robertson, S. E., & Zaragoza, H. (2009). The Probabilistic Relevance Framework: BM25 and Beyond
354    let k1 = 1.2;
355    let b = 0.75;
356
357    if debug_mode {
358        println!("DEBUG: Starting parallel document scoring for {n_docs} documents");
359    }
360
361    // 5) Compute BM25 bool logic score for each doc in parallel
362    // Use a stable collection method to ensure deterministic ordering
363    let scored_docs: Vec<(usize, Option<f64>)> = (0..tf_df_result.term_frequencies.len())
364        .collect::<Vec<_>>() // Collect indices first to ensure stable ordering
365        .par_iter() // Then parallelize
366        .map(|&i| {
367            let doc_tf = &tf_df_result.term_frequencies[i];
368            let doc_len = tf_df_result.document_lengths[i];
369
370            // Create optimized BM25 parameters with precomputed IDF values
371            let precomputed_bm25_params = PrecomputedBm25Params {
372                doc_tf,
373                doc_len,
374                avgdl,
375                idfs: &precomputed_idfs,
376                query_token_map: &query_token_map,
377                k1,
378                b,
379            };
380
381            // Evaluate doc's BM25 sum or None if excluded using optimized function
382            let bm25_score_opt = score_expr_bm25_optimized(&parsed_expr, &precomputed_bm25_params);
383
384            (i, bm25_score_opt)
385        })
386        .collect();
387
388    if debug_mode {
389        println!("DEBUG: Parallel document scoring completed");
390    }
391
392    // Filter out documents that didn't match and collect scores
393    let mut filtered_docs: Vec<(usize, f64)> = scored_docs
394        .into_iter()
395        .filter_map(|(i, score_opt)| score_opt.map(|score| (i, score)))
396        .collect();
397
398    // 6) Sort in descending order by BM25 score, with a stable secondary sort by document index
399    filtered_docs.sort_by(|a, b| {
400        // First compare by score (descending)
401        // Note: unwrap_or(Ordering::Equal) handles NaN cases by treating them as equal
402        // This ensures stable sorting even if a score calculation resulted in NaN
403        // (which shouldn't happen with our implementation, but provides robustness)
404        match b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal) {
405            Ordering::Equal => {
406                // If scores are equal, sort by document index (ascending) for stability
407                a.0.cmp(&b.0)
408            }
409            other => other,
410        }
411    });
412
413    if debug_mode {
414        println!(
415            "DEBUG: Sorted {} matching documents by score",
416            filtered_docs.len()
417        );
418    }
419
420    filtered_docs
421}
422
423/// Computes term frequencies (TF) for each document, document frequencies (DF) for each term,
424/// and document lengths from pre-tokenized content.
425///
426/// Uses the query_token_map to convert string tokens to u8 indices for term frequencies.
427pub fn compute_tf_df_from_tokenized(
428    tokenized_docs: &[Vec<String>],
429    query_token_map: &QueryTokenMap,
430) -> TfDfResult {
431    use rayon::prelude::*;
432
433    let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
434
435    if debug_mode {
436        println!("DEBUG: Starting parallel TF-DF computation from pre-tokenized content for {docs_len} documents", docs_len = tokenized_docs.len());
437    }
438
439    // Process documents in parallel to compute term frequencies and document lengths
440    #[allow(clippy::type_complexity)]
441    let doc_results: Vec<(
442        HashMap<u8, usize>,
443        HashMap<String, usize>,
444        usize,
445        HashSet<String>,
446    )> = tokenized_docs
447        .par_iter()
448        .map(|tokens| {
449            let mut tf_u8 = HashMap::new(); // Term frequencies using u8 indices
450            let mut tf_str = HashMap::new(); // Term frequencies using strings (for DF calculation)
451
452            // Compute term frequency for the current document
453            for token in tokens.iter() {
454                // Update string-based term frequency (for document frequency calculation)
455                *tf_str.entry(token.clone()).or_insert(0) += 1;
456
457                // Update u8-based term frequency (only for tokens in the query)
458                if let Some(&token_index) = query_token_map.get(token) {
459                    *tf_u8.entry(token_index).or_insert(0) += 1;
460                }
461            }
462
463            // Collect unique terms for document frequency calculation
464            let unique_terms: HashSet<String> = tf_str.keys().cloned().collect();
465
466            (tf_u8, tf_str, tokens.len(), unique_terms)
467        })
468        .collect();
469
470    // Extract term frequencies and document lengths
471    let mut term_frequencies = Vec::with_capacity(tokenized_docs.len());
472    let mut document_lengths = Vec::with_capacity(tokenized_docs.len());
473
474    // Compute document frequencies in parallel using adaptive chunking
475    // This balances parallelism with reduced contention
476    // The chunk size calculation:
477    // - Divides total documents by available threads to distribute work evenly
478    // - Ensures at least 1 document per chunk to prevent empty chunks
479    // - Larger chunks reduce thread coordination overhead but may lead to load imbalance
480    // - Smaller chunks improve load balancing but increase synchronization costs
481    // Use checked_div to safely handle the case where there are no threads (which shouldn't happen)
482    // and ensure we always have at least one item per chunk
483    let min_chunk_size = tokenized_docs
484        .len()
485        .checked_div(rayon::current_num_threads())
486        .unwrap_or(1)
487        .max(1);
488    let document_frequencies = doc_results
489        .par_iter()
490        .with_min_len(min_chunk_size) // Adaptive chunking based on document count
491        .map(|(_, _, _, unique_terms)| {
492            // Create a local document frequency map for this chunk
493            let mut local_df = HashMap::new();
494            for term in unique_terms {
495                *local_df.entry(term.clone()).or_insert(0) += 1;
496            }
497            local_df
498        })
499        .reduce(HashMap::new, |mut acc, local_df| {
500            // Merge local document frequency maps
501            for (term, count) in local_df {
502                *acc.entry(term).or_insert(0) += count;
503            }
504            acc
505        });
506
507    if debug_mode {
508        println!(
509            "DEBUG: Parallel DF computation completed with {} unique terms",
510            document_frequencies.len()
511        );
512    }
513
514    // Collect results in a deterministic order
515    for (tf_u8, _, doc_len, _) in doc_results {
516        term_frequencies.push(tf_u8);
517        document_lengths.push(doc_len);
518    }
519
520    if debug_mode {
521        println!("DEBUG: Parallel TF-DF computation from pre-tokenized content completed");
522        println!("DEBUG: Using u8 indices for term frequencies (optimized storage)");
523    }
524
525    TfDfResult {
526        term_frequencies,
527        document_frequencies,
528        document_lengths,
529    }
530}
531
532// -------------------------------------------------------------------------
533// Unit tests (optional). Adapt or remove as you wish.
534// -------------------------------------------------------------------------
535#[cfg(test)]
536mod tests {
537    use super::*;
538
539    #[test]
540    fn test_basic_bm25_scoring() {
541        // A trivial test: 2 docs, 1 query
542        let docs = vec!["api process load", "another random text with process"];
543        let query = "+api +process +load"; // must have "api", must have "process", must have "load"
544
545        let params = RankingParams {
546            documents: &docs,
547            query,
548            pre_tokenized: None,
549        };
550
551        let results = rank_documents(&params);
552        // Only the first doc should match, because it has all 3 required words
553        assert_eq!(results.len(), 1);
554        assert_eq!(results[0].0, 0); // doc index 0
555
556        // Verify score is positive and within expected range for BM25
557        // BM25 scores typically fall within certain ranges based on the algorithm's properties
558        assert!(results[0].1 > 0.0);
559        assert!(results[0].1 < 10.0); // Upper bound based on typical BM25 behavior with small documents
560    }
561
562    #[test]
563    fn test_bm25_scoring_with_pre_tokenized() {
564        // A trivial test: 2 docs, 1 query, with pre-tokenized content
565        let docs = vec!["api process load", "another random text with process"];
566        let query = "+api +process +load"; // must have "api", must have "process", must have "load"
567
568        // Pre-tokenized content
569        let pre_tokenized = vec![
570            vec!["api".to_string(), "process".to_string(), "load".to_string()],
571            vec![
572                "another".to_string(),
573                "random".to_string(),
574                "text".to_string(),
575                "with".to_string(),
576                "process".to_string(),
577            ],
578        ];
579
580        let params = RankingParams {
581            documents: &docs,
582            query,
583            pre_tokenized: Some(&pre_tokenized),
584        };
585
586        let results = rank_documents(&params);
587        // Only the first doc should match, because it has all 3 required words
588        assert_eq!(results.len(), 1);
589        assert_eq!(results[0].0, 0); // doc index 0
590
591        // Verify score is positive and within expected range for BM25
592        assert!(results[0].1 > 0.0);
593        assert!(results[0].1 < 10.0); // Upper bound based on typical BM25 behavior with small documents
594    }
595
596    #[test]
597    fn test_relative_bm25_scoring() {
598        // Test that documents with more matching terms get higher scores
599        let docs = vec![
600            "api process load data", // 4 matching terms
601            "api process load",      // 3 matching terms
602            "api process",           // 2 matching terms
603            "api",                   // 1 matching term
604        ];
605        let query = "api process load data"; // All terms are optional
606
607        let params = RankingParams {
608            documents: &docs,
609            query,
610            pre_tokenized: None,
611        };
612
613        let results = rank_documents(&params);
614        // All docs should match since all terms are optional
615        assert_eq!(results.len(), 4);
616
617        // Verify that scores decrease as fewer terms match
618        // Doc with 4 matches should be first, then 3, then 2, then 1
619        assert_eq!(results[0].0, 0); // First doc (4 matches)
620        assert_eq!(results[1].0, 1); // Second doc (3 matches)
621        assert_eq!(results[2].0, 2); // Third doc (2 matches)
622        assert_eq!(results[3].0, 3); // Fourth doc (1 match)
623
624        // Verify that scores decrease as expected
625        assert!(results[0].1 > results[1].1); // 4 matches > 3 matches
626        assert!(results[1].1 > results[2].1); // 3 matches > 2 matches
627        assert!(results[2].1 > results[3].1); // 2 matches > 1 match
628    }
629
630    #[test]
631    fn test_generate_query_token_map_basic() {
632        // Create a set of query terms
633        let mut query_terms = HashSet::new();
634        query_terms.insert("apple".to_string());
635        query_terms.insert("banana".to_string());
636        query_terms.insert("cherry".to_string());
637
638        // Generate the token map
639        let token_map = generate_query_token_map(&query_terms).unwrap();
640
641        // Verify the map contains all terms
642        assert_eq!(token_map.len(), 3);
643
644        // Verify each term has a unique index
645        let mut indices = HashSet::new();
646        for (_, &idx) in &token_map {
647            assert!(indices.insert(idx), "Duplicate index found");
648        }
649
650        // Verify the indices are in the expected range
651        assert_eq!(indices.len(), 3);
652        assert!(indices.contains(&0));
653        assert!(indices.contains(&1));
654        assert!(indices.contains(&2));
655    }
656
657    #[test]
658    fn test_generate_query_token_map_empty() {
659        // Test with empty query terms
660        let query_terms = HashSet::new();
661
662        // Generate the token map
663        let token_map = generate_query_token_map(&query_terms).unwrap();
664
665        // Verify the map is empty
666        assert!(token_map.is_empty());
667    }
668
669    #[test]
670    fn test_generate_query_token_map_deterministic() {
671        // Create two identical sets of query terms
672        let mut query_terms1 = HashSet::new();
673        query_terms1.insert("apple".to_string());
674        query_terms1.insert("banana".to_string());
675        query_terms1.insert("cherry".to_string());
676
677        let mut query_terms2 = HashSet::new();
678        query_terms2.insert("cherry".to_string());
679        query_terms2.insert("apple".to_string());
680        query_terms2.insert("banana".to_string());
681
682        // Generate token maps for both sets
683        let token_map1 = generate_query_token_map(&query_terms1).unwrap();
684        let token_map2 = generate_query_token_map(&query_terms2).unwrap();
685
686        // Verify both maps have the same mappings despite different insertion order
687        assert_eq!(token_map1.len(), token_map2.len());
688
689        for (term, &idx1) in &token_map1 {
690            assert_eq!(
691                Some(&idx1),
692                token_map2.get(term),
693                "Term '{term}' has different indices in the two maps"
694            );
695        }
696    }
697
698    #[test]
699    fn test_generate_query_token_map_too_many_terms() {
700        // Create a set with more than 256 terms
701        let query_terms: HashSet<String> = (0..257).map(|i| format!("term{i}")).collect();
702
703        // Attempt to generate the token map
704        let result = generate_query_token_map(&query_terms);
705
706        // Verify it returns an error
707        assert!(result.is_err());
708        assert_eq!(
709            result.unwrap_err(),
710            "Query exceeds the 256 unique token limit for u8 mapping"
711        );
712    }
713
714    #[test]
715    fn test_compute_tf_df_with_u8_indices() {
716        // Create a simple set of documents
717        let docs = vec![
718            vec![
719                "apple".to_string(),
720                "banana".to_string(),
721                "cherry".to_string(),
722            ],
723            vec!["apple".to_string(), "banana".to_string()],
724            vec!["apple".to_string()],
725        ];
726
727        // Create a query token map
728        let mut query_token_map = QueryTokenMap::new();
729        query_token_map.insert("apple".to_string(), 0);
730        query_token_map.insert("banana".to_string(), 1);
731        query_token_map.insert("cherry".to_string(), 2);
732
733        // Compute TF/DF
734        let tf_df_result = compute_tf_df_from_tokenized(&docs, &query_token_map);
735
736        // Check document lengths
737        assert_eq!(tf_df_result.document_lengths[0], 3);
738        assert_eq!(tf_df_result.document_lengths[1], 2);
739        assert_eq!(tf_df_result.document_lengths[2], 1);
740
741        // Check term frequencies using u8 indices
742        assert_eq!(*tf_df_result.term_frequencies[0].get(&0).unwrap(), 1); // "apple" in doc 0
743        assert_eq!(*tf_df_result.term_frequencies[0].get(&1).unwrap(), 1); // "banana" in doc 0
744        assert_eq!(*tf_df_result.term_frequencies[0].get(&2).unwrap(), 1); // "cherry" in doc 0
745
746        assert_eq!(*tf_df_result.term_frequencies[1].get(&0).unwrap(), 1); // "apple" in doc 1
747        assert_eq!(*tf_df_result.term_frequencies[1].get(&1).unwrap(), 1); // "banana" in doc 1
748        assert!(tf_df_result.term_frequencies[1].get(&2).is_none()); // no "cherry" in doc 1
749
750        assert_eq!(*tf_df_result.term_frequencies[2].get(&0).unwrap(), 1); // "apple" in doc 2
751        assert!(tf_df_result.term_frequencies[2].get(&1).is_none()); // no "banana" in doc 2
752        assert!(tf_df_result.term_frequencies[2].get(&2).is_none()); // no "cherry" in doc 2
753
754        // Check document frequencies (still string-based)
755        assert_eq!(*tf_df_result.document_frequencies.get("apple").unwrap(), 3); // in all 3 docs
756        assert_eq!(*tf_df_result.document_frequencies.get("banana").unwrap(), 2); // in 2 docs
757        assert_eq!(*tf_df_result.document_frequencies.get("cherry").unwrap(), 1);
758        // in 1 doc
759    }
760
761    #[test]
762    fn test_bm25_scoring_with_u8_indices() {
763        // Create a simple document
764        let _doc_content = "apple banana cherry"; // Kept for documentation purposes
765
766        // Create a query token map
767        let mut query_token_map = QueryTokenMap::new();
768        query_token_map.insert("apple".to_string(), 0);
769        query_token_map.insert("banana".to_string(), 1);
770        query_token_map.insert("cherry".to_string(), 2);
771
772        // Create term frequencies using u8 indices
773        let mut doc_tf = HashMap::new();
774        doc_tf.insert(0u8, 1); // "apple" appears once
775        doc_tf.insert(1u8, 1); // "banana" appears once
776        doc_tf.insert(2u8, 1); // "cherry" appears once
777
778        // Create document frequencies
779        let mut doc_freqs = HashMap::new();
780        doc_freqs.insert("apple".to_string(), 1);
781        doc_freqs.insert("banana".to_string(), 1);
782        doc_freqs.insert("cherry".to_string(), 1);
783
784        // Create IDF values (simplified for testing)
785        let mut idfs = HashMap::new();
786        idfs.insert("apple".to_string(), 1.0);
787        idfs.insert("banana".to_string(), 1.0);
788        idfs.insert("cherry".to_string(), 1.0);
789
790        // Create BM25 parameters
791        let params = PrecomputedBm25Params {
792            doc_tf: &doc_tf,
793            doc_len: 3,
794            avgdl: 3.0,
795            idfs: &idfs,
796            query_token_map: &query_token_map,
797            k1: 1.2,
798            b: 0.75,
799        };
800
801        // Test bm25_single_token_optimized
802        let apple_score = bm25_single_token_optimized("apple", &params);
803        let banana_score = bm25_single_token_optimized("banana", &params);
804        let cherry_score = bm25_single_token_optimized("cherry", &params);
805
806        // All terms have the same frequency, IDF, and document length,
807        // so they should have the same score
808        assert!(apple_score > 0.0);
809        assert_eq!(apple_score, banana_score);
810        assert_eq!(banana_score, cherry_score);
811
812        // Test with a term not in the query_token_map
813        let unknown_score = bm25_single_token_optimized("unknown", &params);
814        assert_eq!(unknown_score, 0.0);
815
816        // Test score_term_bm25_optimized
817        let keywords = vec!["apple".to_string(), "banana".to_string()];
818        let term_score = score_term_bm25_optimized(&keywords, &params);
819
820        // The score should be the sum of individual scores
821        assert_eq!(term_score, apple_score + banana_score);
822    }
823}