use anyhow::{anyhow, Result};
use async_trait::async_trait;
use smooth_operator::rerank::Reranker;
use smooth_operator_core::KnowledgeResult;
pub const DEFAULT_RERANK_MODEL: &str = "rerank-english-v3.0";
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct RerankScore {
pub index: usize,
pub relevance_score: f32,
}
#[async_trait]
pub trait RerankBackend: Send + Sync {
async fn rerank(
&self,
query: &str,
documents: &[String],
top_n: usize,
) -> Result<Vec<RerankScore>>;
}
#[derive(Clone)]
pub struct HttpRerankBackend {
client: reqwest::Client,
base_url: String,
api_key: String,
model: String,
}
impl HttpRerankBackend {
#[must_use]
pub fn new(
base_url: impl Into<String>,
api_key: impl Into<String>,
model: impl Into<String>,
) -> Self {
Self {
client: reqwest::Client::new(),
base_url: base_url.into(),
api_key: api_key.into(),
model: model.into(),
}
}
}
#[async_trait]
impl RerankBackend for HttpRerankBackend {
async fn rerank(
&self,
query: &str,
documents: &[String],
top_n: usize,
) -> Result<Vec<RerankScore>> {
let url = format!("{}/v1/rerank", self.base_url.trim_end_matches('/'));
let body = serde_json::json!({
"model": self.model,
"query": query,
"documents": documents,
"top_n": top_n,
});
let resp = self
.client
.post(&url)
.bearer_auth(&self.api_key)
.json(&body)
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
return Err(anyhow!("rerank request failed ({status}): {text}"));
}
#[derive(serde::Deserialize)]
struct ResultItem {
index: usize,
relevance_score: f32,
}
#[derive(serde::Deserialize)]
struct RerankResponse {
results: Vec<ResultItem>,
}
let parsed: RerankResponse = resp.json().await?;
Ok(parsed
.results
.into_iter()
.map(|r| RerankScore {
index: r.index,
relevance_score: r.relevance_score,
})
.collect())
}
}
pub struct GatewayReranker {
backend: std::sync::Arc<dyn RerankBackend>,
}
impl GatewayReranker {
#[must_use]
pub fn with_backend(backend: std::sync::Arc<dyn RerankBackend>) -> Self {
Self { backend }
}
#[must_use]
pub fn new(
base_url: impl Into<String>,
api_key: impl Into<String>,
model: impl Into<String>,
) -> Self {
Self::with_backend(std::sync::Arc::new(HttpRerankBackend::new(
base_url, api_key, model,
)))
}
pub fn from_env() -> Result<Self> {
let base_url = std::env::var("SMOOAI_GATEWAY_URL")
.map_err(|_| anyhow!("SMOOAI_GATEWAY_URL is not set"))?;
let api_key = std::env::var("SMOOAI_GATEWAY_KEY")
.map_err(|_| anyhow!("SMOOAI_GATEWAY_KEY is not set"))?;
Ok(Self::new(base_url, api_key, DEFAULT_RERANK_MODEL))
}
fn reorder(
mut scores: Vec<RerankScore>,
candidates: Vec<KnowledgeResult>,
top_k: usize,
) -> Vec<KnowledgeResult> {
let n = candidates.len();
let mut slots: Vec<Option<KnowledgeResult>> = candidates.into_iter().map(Some).collect();
scores.sort_by(|a, b| {
b.relevance_score
.partial_cmp(&a.relevance_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut out: Vec<KnowledgeResult> = Vec::with_capacity(top_k.min(n));
let mut taken = vec![false; n];
for s in scores {
if out.len() >= top_k {
break;
}
if s.index < n && !taken[s.index] {
if let Some(c) = slots[s.index].take() {
taken[s.index] = true;
out.push(c);
}
}
}
if out.len() < top_k {
for (i, slot) in slots.iter_mut().enumerate() {
if out.len() >= top_k {
break;
}
if !taken[i] {
if let Some(c) = slot.take() {
out.push(c);
}
}
}
}
out
}
}
#[async_trait]
impl Reranker for GatewayReranker {
async fn rerank(
&self,
query: &str,
candidates: Vec<KnowledgeResult>,
top_k: usize,
) -> Vec<KnowledgeResult> {
if candidates.is_empty() || top_k == 0 {
return Vec::new();
}
let documents: Vec<String> = candidates.iter().map(|c| c.chunk.clone()).collect();
match self.backend.rerank(query, &documents, top_k).await {
Ok(scores) => Self::reorder(scores, candidates, top_k),
Err(e) => {
tracing::warn!(
error = %e,
"GatewayReranker call failed; falling back to upstream candidate order"
);
let mut fallback = candidates;
fallback.truncate(top_k);
fallback
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
fn result(id: &str, chunk: &str) -> KnowledgeResult {
KnowledgeResult {
document_id: id.to_string(),
chunk: chunk.to_string(),
score: 0.5,
source: format!("{id}.md"),
}
}
struct StubBackend {
scores: Vec<RerankScore>,
}
#[async_trait]
impl RerankBackend for StubBackend {
async fn rerank(
&self,
_query: &str,
_documents: &[String],
_top_n: usize,
) -> Result<Vec<RerankScore>> {
Ok(self.scores.clone())
}
}
struct ErrorBackend;
#[async_trait]
impl RerankBackend for ErrorBackend {
async fn rerank(
&self,
_query: &str,
_documents: &[String],
_top_n: usize,
) -> Result<Vec<RerankScore>> {
Err(anyhow!("simulated rerank API failure"))
}
}
#[tokio::test]
async fn gateway_reranker_reorders_by_relevance() {
let candidates = vec![
result("shipping", "shipping takes 5-7 days"),
result("warranty", "warranty is one year"),
result("returns", "30 day refund window"),
];
let scores = vec![
RerankScore {
index: 0,
relevance_score: 0.1,
},
RerankScore {
index: 1,
relevance_score: 0.4,
},
RerankScore {
index: 2,
relevance_score: 0.95,
},
];
let reranker = GatewayReranker::with_backend(Arc::new(StubBackend { scores }));
let out = reranker.rerank("refund returns", candidates, 3).await;
assert_eq!(
out.iter()
.map(|r| r.document_id.as_str())
.collect::<Vec<_>>(),
vec!["returns", "warranty", "shipping"],
"candidates must be reordered by descending relevance score"
);
}
#[tokio::test]
async fn gateway_reranker_truncates_to_top_k() {
let candidates = vec![
result("a", "alpha"),
result("b", "beta"),
result("c", "gamma"),
];
let scores = vec![
RerankScore {
index: 2,
relevance_score: 0.9,
},
RerankScore {
index: 0,
relevance_score: 0.5,
},
RerankScore {
index: 1,
relevance_score: 0.1,
},
];
let reranker = GatewayReranker::with_backend(Arc::new(StubBackend { scores }));
let out = reranker.rerank("q", candidates, 1).await;
assert_eq!(out.len(), 1);
assert_eq!(out[0].document_id, "c", "top_k=1 keeps only the best");
}
#[tokio::test]
async fn gateway_reranker_error_falls_back_to_input_order() {
let candidates = vec![
result("first", "one"),
result("second", "two"),
result("third", "three"),
];
let reranker = GatewayReranker::with_backend(Arc::new(ErrorBackend));
let out = reranker.rerank("anything", candidates, 2).await;
assert_eq!(out.len(), 2, "fallback truncates to top_k");
assert_eq!(
out.iter()
.map(|r| r.document_id.as_str())
.collect::<Vec<_>>(),
vec!["first", "second"],
"on error the upstream order is preserved"
);
}
#[tokio::test]
async fn gateway_reranker_partial_scores_appends_unscored_in_order() {
let candidates = vec![result("a", "aaa"), result("b", "bbb"), result("c", "ccc")];
let scores = vec![RerankScore {
index: 2,
relevance_score: 0.9,
}];
let reranker = GatewayReranker::with_backend(Arc::new(StubBackend { scores }));
let out = reranker.rerank("q", candidates, 3).await;
assert_eq!(
out.iter()
.map(|r| r.document_id.as_str())
.collect::<Vec<_>>(),
vec!["c", "a", "b"],
"scored candidate first, then unscored in upstream order"
);
}
#[tokio::test]
async fn gateway_reranker_ignores_out_of_range_index() {
let candidates = vec![result("a", "aaa"), result("b", "bbb")];
let scores = vec![
RerankScore {
index: 99, relevance_score: 0.99,
},
RerankScore {
index: 1,
relevance_score: 0.5,
},
];
let reranker = GatewayReranker::with_backend(Arc::new(StubBackend { scores }));
let out = reranker.rerank("q", candidates, 2).await;
assert_eq!(
out.iter()
.map(|r| r.document_id.as_str())
.collect::<Vec<_>>(),
vec!["b", "a"]
);
}
#[tokio::test]
async fn gateway_reranker_empty_candidates_yields_empty() {
let reranker = GatewayReranker::with_backend(Arc::new(StubBackend { scores: vec![] }));
let out = reranker.rerank("q", vec![], 3).await;
assert!(out.is_empty());
}
#[tokio::test]
#[ignore = "network + creds: gated on SMOOTH_AGENT_E2E=1 and a /v1/rerank route"]
async fn live_rerank() {
if std::env::var("SMOOTH_AGENT_E2E").as_deref() != Ok("1") {
eprintln!("skipping live rerank: set SMOOTH_AGENT_E2E=1 to run");
return;
}
let Ok(reranker) = GatewayReranker::from_env() else {
eprintln!("skipping live rerank: SMOOAI_GATEWAY_URL / SMOOAI_GATEWAY_KEY not set");
return;
};
let candidates = vec![
result("shipping", "Standard shipping takes 5 to 7 business days."),
result("warranty", "Warranty claims must be filed within one year."),
result(
"returns",
"Our return policy: refunds within the 30 day window.",
),
];
let out = reranker
.rerank("how do refunds and returns work", candidates, 3)
.await;
eprintln!(
"live rerank order: {:?}",
out.iter()
.map(|r| r.document_id.as_str())
.collect::<Vec<_>>()
);
assert_eq!(out.len(), 3, "live rerank should return all 3 reordered");
}
}