use std::fmt::Debug;
use thiserror::Error;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum RerankError {
#[error("network error: {0}")]
Network(String),
#[error("authentication failed: {0}")]
Auth(String),
#[error("rate limited: {0}")]
RateLimited(String),
#[error("bad request ({status}): {body}")]
BadRequest {
status: u16,
body: String,
},
#[error("server error ({status}): {body}")]
Server {
status: u16,
body: String,
},
#[error("decode error: {0}")]
Decode(String),
#[error("config error: {0}")]
Config(String),
#[error("inference error: {0}")]
Inference(String),
#[error("score count mismatch: expected {expected}, got {got}")]
ScoreCountMismatch {
expected: usize,
got: usize,
},
}
pub trait Reranker: Send + Sync + Debug {
fn model(&self) -> &str;
fn rerank(&self, query: &str, candidates: &[&str]) -> Result<Vec<f32>, RerankError>;
}
#[derive(Debug, Clone, Default)]
pub struct MockJaccardReranker;
impl Reranker for MockJaccardReranker {
fn model(&self) -> &str {
"mock:jaccard"
}
fn rerank(&self, query: &str, candidates: &[&str]) -> Result<Vec<f32>, RerankError> {
let q = token_set(query);
Ok(candidates
.iter()
.map(|c| {
let c_tokens = token_set(c);
let inter = q.intersection(&c_tokens).count() as f32;
let union_ = q.union(&c_tokens).count() as f32;
if union_ == 0.0 { 0.0 } else { inter / union_ }
})
.collect())
}
}
#[derive(Debug, Clone, Default)]
pub struct AlwaysFailReranker;
impl Reranker for AlwaysFailReranker {
fn model(&self) -> &str {
"mock:always-fail"
}
fn rerank(&self, _query: &str, _candidates: &[&str]) -> Result<Vec<f32>, RerankError> {
Err(RerankError::Network(
"intentional failure for test".to_string(),
))
}
}
fn token_set(text: &str) -> std::collections::HashSet<String> {
let mut out: std::collections::HashSet<String> = std::collections::HashSet::new();
let mut buf = String::new();
for ch in text.chars() {
if ch.is_alphanumeric() {
for lc in ch.to_lowercase() {
buf.push(lc);
}
} else if !buf.is_empty() {
out.insert(std::mem::take(&mut buf));
}
}
if !buf.is_empty() {
out.insert(buf);
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mock_jaccard_identical_query_and_candidate_scores_one() {
let r = MockJaccardReranker;
let s = r.rerank("alice bob", &["alice bob"]).unwrap();
assert!((s[0] - 1.0).abs() < 1e-6);
}
#[test]
fn mock_jaccard_disjoint_scores_zero() {
let r = MockJaccardReranker;
let s = r.rerank("alice", &["zed"]).unwrap();
assert_eq!(s[0], 0.0);
}
#[test]
fn mock_jaccard_partial_overlap() {
let r = MockJaccardReranker;
let s = r.rerank("alice bob", &["alice carol"]).unwrap();
assert!((s[0] - (1.0 / 3.0)).abs() < 1e-6);
}
#[test]
fn mock_jaccard_length_matches_candidates_len() {
let r = MockJaccardReranker;
let s = r
.rerank("alpha", &["alpha beta", "alpha gamma", "delta"])
.unwrap();
assert_eq!(s.len(), 3);
}
#[test]
fn mock_jaccard_empty_candidates_empty_output() {
let r = MockJaccardReranker;
let s = r.rerank("alpha", &[]).unwrap();
assert!(s.is_empty());
}
#[test]
fn always_fail_reranker_returns_err() {
let r = AlwaysFailReranker;
assert!(r.rerank("q", &["c"]).is_err());
}
#[test]
fn model_id_contains_provider_prefix() {
assert_eq!(MockJaccardReranker.model(), "mock:jaccard");
assert_eq!(AlwaysFailReranker.model(), "mock:always-fail");
}
}