use anyhow::anyhow;
use tokenizers::Tokenizer;
use crate::backend::{Encoding, RerankBackend};
pub const DEFAULT_RERANK_MODEL: &str = "cross-encoder/ms-marco-MiniLM-L-12-v2";
pub const DEFAULT_RERANK_CANDIDATES: usize = 100;
pub struct Reranker {
backend: Box<dyn RerankBackend>,
tokenizer: Tokenizer,
}
impl Reranker {
pub fn from_pretrained(model_repo: &str) -> crate::Result<Self> {
let backend = crate::backend::load_reranker_cpu(model_repo)?;
let tokenizer = crate::tokenize::load_tokenizer(model_repo)?;
Ok(Self { backend, tokenizer })
}
pub fn score_pairs(&self, pairs: &[(&str, &str)]) -> crate::Result<Vec<f32>> {
if pairs.is_empty() {
return Ok(Vec::new());
}
let max_tokens = self.backend.max_tokens();
let encodings: crate::Result<Vec<Encoding>> = pairs
.iter()
.map(|(q, d)| {
let enc = self
.tokenizer
.encode((*q, *d), true)
.map_err(|e| crate::Error::Other(anyhow!("rerank tokenize failed: {e}")))?;
let len = enc.get_ids().len().min(max_tokens);
Ok(Encoding {
input_ids: enc.get_ids()[..len].iter().map(|&x| i64::from(x)).collect(),
attention_mask: enc.get_attention_mask()[..len]
.iter()
.map(|&x| i64::from(x))
.collect(),
token_type_ids: enc.get_type_ids()[..len]
.iter()
.map(|&x| i64::from(x))
.collect(),
})
})
.collect();
let encodings = encodings?;
self.backend.score_batch(&encodings)
}
#[must_use]
pub fn max_tokens(&self) -> usize {
self.backend.max_tokens()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore = "requires network + model download (~22MB)"]
fn loads_and_ranks_default_cross_encoder() {
let rr = Reranker::from_pretrained(DEFAULT_RERANK_MODEL)
.expect("default cross-encoder should load");
let scores = rr
.score_pairs(&[
(
"how to make pasta",
"Boil water, add salt, cook pasta for 8 minutes.",
),
(
"how to make pasta",
"The mitochondria is the powerhouse of the cell.",
),
])
.expect("scoring should succeed");
assert_eq!(scores.len(), 2);
assert!(scores.iter().all(|&s| (0.0..=1.0).contains(&s)));
assert!(
scores[0] > scores[1],
"relevant doc ({}) should beat irrelevant ({})",
scores[0],
scores[1]
);
}
}