use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
pub use anda_db_btree::RangeQuery;
pub use anda_db_schema::{Fv, bf16};
pub use anda_db_tfs::BM25Params;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Query {
pub search: Option<Search>,
pub filter: Option<Filter>,
pub limit: Option<usize>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Search {
pub text: Option<String>,
pub vector: Option<Vec<f32>>,
pub bm25_params: Option<BM25Params>,
pub reranker: Option<RRFReranker>,
pub logical_search: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Filter {
Field((String, RangeQuery<Fv>)),
Or(Vec<Box<Filter>>),
And(Vec<Box<Filter>>),
Not(Box<Filter>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RRFReranker {
pub k: f32,
}
impl Default for RRFReranker {
fn default() -> Self {
Self { k: 60.0 }
}
}
impl RRFReranker {
pub fn new(k: f32) -> Self {
assert!(k > 0.0, "RRFReranker k must be positive, got {}", k);
Self { k }
}
pub fn rerank(&self, ranked_lists: &[Vec<u64>]) -> Vec<(u64, f32)> {
let k = if self.k > 0.0 { self.k } else { 60.0 };
let mut scores: FxHashMap<u64, f32> = FxHashMap::default();
for ranked in ranked_lists {
for (rank, &doc_id) in ranked.iter().enumerate() {
let score = 1.0 / (k + rank as f32);
*scores.entry(doc_id).or_insert(0.0) += score;
}
}
let mut results: Vec<(u64, f32)> = scores.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results
}
}