use std::collections::HashSet;
use std::sync::Arc;
use async_trait::async_trait;
use smooth_operator_core::KnowledgeResult;
#[async_trait]
pub trait Reranker: Send + Sync {
async fn rerank(
&self,
query: &str,
candidates: Vec<KnowledgeResult>,
top_k: usize,
) -> Vec<KnowledgeResult>;
}
#[derive(Debug, Clone, Default)]
pub struct NoopReranker;
#[async_trait]
impl Reranker for NoopReranker {
async fn rerank(
&self,
_query: &str,
mut candidates: Vec<KnowledgeResult>,
top_k: usize,
) -> Vec<KnowledgeResult> {
candidates.truncate(top_k);
candidates
}
}
#[derive(Debug, Clone)]
pub struct LexicalReranker {
k1: f32,
}
impl LexicalReranker {
#[must_use]
pub fn new() -> Self {
Self { k1: 1.2 }
}
#[must_use]
pub fn with_k1(k1: f32) -> Self {
Self { k1 }
}
fn tokenize(text: &str) -> Vec<String> {
text.to_lowercase()
.split(|c: char| !c.is_alphanumeric())
.filter(|t| !t.is_empty())
.map(str::to_string)
.collect()
}
fn score(&self, query_terms: &HashSet<String>, chunk: &str) -> f32 {
let chunk_tokens = Self::tokenize(chunk);
if chunk_tokens.is_empty() {
return 0.0;
}
let length_penalty = 1.0 + (1.0 + chunk_tokens.len() as f32).ln();
let mut score = 0.0_f32;
for term in query_terms {
let count = chunk_tokens.iter().filter(|t| *t == term).count() as f32;
if count > 0.0 {
let tf_saturated = count / (count + self.k1);
score += tf_saturated / length_penalty;
}
}
score
}
}
impl Default for LexicalReranker {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Reranker for LexicalReranker {
async fn rerank(
&self,
query: &str,
candidates: Vec<KnowledgeResult>,
top_k: usize,
) -> Vec<KnowledgeResult> {
if candidates.is_empty() || top_k == 0 {
return Vec::new();
}
let query_terms: HashSet<String> = Self::tokenize(query).into_iter().collect();
let mut scored: Vec<(f32, KnowledgeResult)> = candidates
.into_iter()
.map(|c| (self.score(&query_terms, &c.chunk), c))
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
scored.into_iter().take(top_k).map(|(_, c)| c).collect()
}
}
pub async fn apply_optional_rerank(
reranker: Option<&Arc<dyn Reranker>>,
query: &str,
mut candidates: Vec<KnowledgeResult>,
top_k: usize,
) -> Vec<KnowledgeResult> {
match reranker {
Some(r) => r.rerank(query, candidates, top_k).await,
None => {
candidates.truncate(top_k);
candidates
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn result(id: &str, chunk: &str, score: f32) -> KnowledgeResult {
KnowledgeResult {
document_id: id.to_string(),
chunk: chunk.to_string(),
score,
source: format!("{id}.md"),
}
}
#[tokio::test]
async fn lexical_reranker_promotes_best_lexical_match() {
let query = "return policy refund window";
let candidates = vec![
result(
"shipping",
"Standard shipping takes 5 to 7 business days.",
0.9,
),
result(
"warranty",
"Warranty claims must be filed within one year.",
0.8,
),
result(
"returns",
"Our return policy: refunds are issued within the 30 day return window.",
0.7,
),
];
let reranker = LexicalReranker::new();
let reranked = reranker.rerank(query, candidates, 3).await;
assert_eq!(
reranked[0].document_id,
"returns",
"the lexically-best doc should be promoted to the top, got order: {:?}",
reranked
.iter()
.map(|r| r.document_id.as_str())
.collect::<Vec<_>>()
);
}
#[tokio::test]
async fn noop_reranker_is_identity() {
let query = "anything at all";
let candidates = vec![
result("a", "first chunk about returns and refunds", 0.9),
result("b", "second chunk about shipping", 0.8),
result("c", "third chunk about returns refund window", 0.7),
];
let original: Vec<String> = candidates.iter().map(|r| r.document_id.clone()).collect();
let reranked = NoopReranker.rerank(query, candidates, 3).await;
let after: Vec<String> = reranked.iter().map(|r| r.document_id.clone()).collect();
assert_eq!(after, original, "noop must preserve order");
}
#[tokio::test]
async fn noop_reranker_truncates_to_top_k() {
let query = "q";
let candidates = vec![
result("a", "alpha", 0.9),
result("b", "beta", 0.8),
result("c", "gamma", 0.7),
];
let reranked = NoopReranker.rerank(query, candidates, 2).await;
assert_eq!(reranked.len(), 2);
assert_eq!(reranked[0].document_id, "a");
assert_eq!(reranked[1].document_id, "b");
}
#[tokio::test]
async fn lexical_reranker_truncates_after_reorder() {
let query = "refund returns";
let candidates = vec![
result("shipping", "shipping times and delivery", 0.9),
result("returns", "refund and returns policy details", 0.8),
result("misc", "unrelated content here", 0.7),
];
let reranked = LexicalReranker::new().rerank(query, candidates, 1).await;
assert_eq!(reranked.len(), 1);
assert_eq!(reranked[0].document_id, "returns");
}
#[tokio::test]
async fn lexical_reranker_no_overlap_preserves_order() {
let query = "quantum entanglement physics";
let candidates = vec![
result("a", "shipping and delivery", 0.9),
result("b", "returns and refunds", 0.8),
];
let reranked = LexicalReranker::new().rerank(query, candidates, 2).await;
assert_eq!(reranked[0].document_id, "a");
assert_eq!(reranked[1].document_id, "b");
}
#[tokio::test]
async fn apply_optional_rerank_none_truncates_only() {
let query = "refund";
let candidates = vec![
result("a", "shipping", 0.9),
result("returns", "refund refund refund window", 0.8),
];
let out = apply_optional_rerank(None, query, candidates, 2).await;
assert_eq!(out[0].document_id, "a");
}
#[tokio::test]
async fn apply_optional_rerank_some_reorders() {
let query = "refund window";
let candidates = vec![
result("a", "shipping and delivery times", 0.9),
result("returns", "refund window details and policy", 0.8),
];
let reranker: Arc<dyn Reranker> = Arc::new(LexicalReranker::new());
let out = apply_optional_rerank(Some(&reranker), query, candidates, 2).await;
assert_eq!(out[0].document_id, "returns");
}
}