mod common;
pub mod log_ratio;
mod merge_join;
mod scoring;
mod sharded;
#[allow(unused_imports)]
pub(crate) use common::collect_negative_minimizers_sharded;
pub use sharded::{
classify_batch_sharded_merge_join, classify_batch_sharded_parallel_rg,
classify_from_extracted_minimizers, classify_from_extracted_minimizers_parallel_rg,
classify_from_query_index, classify_from_query_index_parallel_rg,
classify_with_sharded_negative, extract_batch_minimizers,
};
pub use best_hit::filter_best_hits;
mod best_hit;
#[cfg(test)]
mod best_hit_tests {
use super::filter_best_hits;
use crate::types::HitResult;
#[test]
fn test_filter_best_hits_basic() {
let hits = vec![
HitResult {
query_id: 1,
bucket_id: 10,
score: 0.5,
},
HitResult {
query_id: 1,
bucket_id: 20,
score: 0.9,
}, HitResult {
query_id: 2,
bucket_id: 10,
score: 0.8,
}, ];
let best = filter_best_hits(hits);
assert_eq!(best.len(), 2);
let q1 = best.iter().find(|h| h.query_id == 1).unwrap();
assert_eq!(q1.bucket_id, 20);
assert!((q1.score - 0.9).abs() < 1e-10);
let q2 = best.iter().find(|h| h.query_id == 2).unwrap();
assert_eq!(q2.bucket_id, 10);
assert!((q2.score - 0.8).abs() < 1e-10);
}
#[test]
fn test_filter_best_hits_tie_breaking() {
let hits = vec![
HitResult {
query_id: 1,
bucket_id: 10,
score: 0.5,
},
HitResult {
query_id: 1,
bucket_id: 20,
score: 0.5,
}, ];
let best = filter_best_hits(hits);
assert_eq!(best.len(), 1);
assert_eq!(best[0].bucket_id, 10);
}
#[test]
fn test_filter_best_hits_empty() {
let hits: Vec<HitResult> = vec![];
let best = filter_best_hits(hits);
assert!(best.is_empty());
}
#[test]
fn test_filter_best_hits_unique_queries() {
let hits = vec![
HitResult {
query_id: 1,
bucket_id: 10,
score: 0.5,
},
HitResult {
query_id: 2,
bucket_id: 20,
score: 0.6,
},
HitResult {
query_id: 3,
bucket_id: 30,
score: 0.7,
},
];
let best = filter_best_hits(hits);
assert_eq!(best.len(), 3);
}
#[test]
fn test_filter_best_hits_many_buckets_per_query() {
let hits = vec![
HitResult {
query_id: 1,
bucket_id: 1,
score: 0.1,
},
HitResult {
query_id: 1,
bucket_id: 2,
score: 0.3,
},
HitResult {
query_id: 1,
bucket_id: 3,
score: 0.7,
}, HitResult {
query_id: 1,
bucket_id: 4,
score: 0.5,
},
HitResult {
query_id: 1,
bucket_id: 5,
score: 0.2,
},
];
let best = filter_best_hits(hits);
assert_eq!(best.len(), 1);
assert_eq!(best[0].bucket_id, 3);
assert!((best[0].score - 0.7).abs() < 1e-10);
}
}