use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
pub struct RerankInput {
pub doc_id: String,
pub text: String,
pub prior_score: f32,
}
#[derive(Debug, Clone, PartialEq)]
pub struct RerankOutput {
pub doc_id: String,
pub score: f32,
}
#[async_trait::async_trait]
pub trait Reranker: Send + Sync {
async fn rerank(
&self,
query: &str,
candidates: Vec<RerankInput>,
top_k: usize,
) -> Vec<RerankOutput>;
}
pub struct IdentityReranker;
#[async_trait::async_trait]
impl Reranker for IdentityReranker {
async fn rerank(
&self,
_query: &str,
candidates: Vec<RerankInput>,
top_k: usize,
) -> Vec<RerankOutput> {
candidates
.into_iter()
.take(top_k)
.map(|c| RerankOutput {
doc_id: c.doc_id,
score: c.prior_score,
})
.collect()
}
}
pub type DynReranker = Arc<dyn Reranker>;
#[cfg(test)]
mod tests {
use super::*;
fn input(id: &str, prior: f32) -> RerankInput {
RerankInput {
doc_id: id.into(),
text: format!("text for {}", id),
prior_score: prior,
}
}
#[tokio::test]
async fn identity_preserves_order_and_truncates() {
let r = IdentityReranker;
let candidates = vec![input("a", 0.9), input("b", 0.7), input("c", 0.5)];
let out = r.rerank("query", candidates, 2).await;
assert_eq!(out.len(), 2);
assert_eq!(out[0].doc_id, "a");
assert_eq!(out[1].doc_id, "b");
}
#[tokio::test]
async fn identity_promotes_prior_score_to_score() {
let r = IdentityReranker;
let candidates = vec![input("a", 0.42)];
let out = r.rerank("q", candidates, 10).await;
assert_eq!(out[0].score, 0.42);
}
#[tokio::test]
async fn identity_handles_empty_candidates() {
let r = IdentityReranker;
let out = r.rerank("q", Vec::new(), 10).await;
assert!(out.is_empty());
}
#[tokio::test]
async fn identity_top_k_zero_returns_empty() {
let r = IdentityReranker;
let out = r.rerank("q", vec![input("a", 0.9)], 0).await;
assert!(out.is_empty());
}
#[tokio::test]
async fn identity_top_k_larger_than_candidates_returns_all() {
let r = IdentityReranker;
let out = r
.rerank("q", vec![input("a", 0.9), input("b", 0.7)], 100)
.await;
assert_eq!(out.len(), 2);
}
#[tokio::test]
async fn dyn_dispatch_via_trait_object() {
let r: DynReranker = Arc::new(IdentityReranker);
let out = r.rerank("q", vec![input("a", 0.5)], 10).await;
assert_eq!(out.len(), 1);
}
struct InvertingReranker;
#[async_trait::async_trait]
impl Reranker for InvertingReranker {
async fn rerank(
&self,
_query: &str,
candidates: Vec<RerankInput>,
top_k: usize,
) -> Vec<RerankOutput> {
let mut out: Vec<RerankOutput> = candidates
.into_iter()
.map(|c| RerankOutput {
doc_id: c.doc_id,
score: 1.0 - c.prior_score,
})
.collect();
out.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.doc_id.cmp(&b.doc_id))
});
out.truncate(top_k);
out
}
}
#[tokio::test]
async fn custom_reranker_can_reorder() {
let r = InvertingReranker;
let candidates = vec![input("a", 0.9), input("b", 0.5), input("c", 0.1)];
let out = r.rerank("q", candidates, 10).await;
assert_eq!(out[0].doc_id, "c");
assert_eq!(out[1].doc_id, "b");
assert_eq!(out[2].doc_id, "a");
}
}