use crate::collection::Collection;
use crate::error::Result;
use crate::query::VectorQuery;
use crate::rerank::{Hit, RrfReRanker, WeightedReRanker};
pub struct HybridSearch {
queries: Vec<VectorQuery>,
fuser: FuserKind,
top_k: Option<usize>,
}
enum FuserKind {
Rrf(RrfReRanker),
Weighted(WeightedReRanker),
}
impl Default for HybridSearch {
fn default() -> Self {
Self::new()
}
}
impl HybridSearch {
pub fn new() -> Self {
Self {
queries: Vec::new(),
fuser: FuserKind::Rrf(RrfReRanker::default()),
top_k: None,
}
}
pub fn query(mut self, query: VectorQuery) -> Self {
self.queries.push(query);
self
}
pub fn reranker(mut self, r: RrfReRanker) -> Self {
self.fuser = FuserKind::Rrf(r);
self
}
pub fn weighted_reranker(mut self, r: WeightedReRanker) -> Self {
self.fuser = FuserKind::Weighted(r);
self
}
pub fn top_k(mut self, n: usize) -> Self {
self.top_k = Some(n);
self
}
pub fn execute(self, collection: &Collection) -> Result<Vec<Hit>> {
let mut per_query: Vec<Vec<Hit>> = Vec::with_capacity(self.queries.len());
for q in &self.queries {
per_query.push(collection.query(q)?.to_hits());
}
let mut fused = match self.fuser {
FuserKind::Rrf(r) => r.fuse(per_query),
FuserKind::Weighted(r) => r.fuse(per_query),
};
if let Some(k) = self.top_k {
fused.truncate(k);
}
Ok(fused)
}
}
#[cfg(test)]
mod tests {
}