use anyhow::{Context, Result};
pub const DEFAULT_RERANK_TOP_N: usize = 20;
const RERANK_TIMEOUT_SECS: u64 = 30;
pub struct KbReranker {
client: rsclaw_embed::FleetHttp,
url: String,
model: Option<String>,
api_key: Option<String>,
pub top_n: usize,
}
fn rerank_api_key(model: Option<&str>) -> Option<String> {
if model.map(rsclaw_embed::is_rsclaw_model).unwrap_or(false) {
rsclaw_config::load()
.ok()
.as_ref()
.and_then(crate::ocr::rsclaw_provider_key)
} else {
None
}
}
impl KbReranker {
pub fn from_config() -> Option<std::sync::Arc<Self>> {
let cfg = rsclaw_config::load().ok()?;
let rr = cfg.raw.kb.as_ref()?.rerank.clone()?;
if !rr.enabled.unwrap_or(true) {
return None;
}
let model_is_rsclaw = rr
.model
.as_deref()
.map(rsclaw_embed::is_rsclaw_model)
.unwrap_or(false);
let base_raw = rr.base_url.trim();
let base = if base_raw.is_empty() {
if model_is_rsclaw {
crate::ocr::rsclaw_provider_base_url(&cfg)
.unwrap_or_else(|| rsclaw_embed::RSCLAW_API_BASE_URL.to_owned())
} else {
return None;
}
} else {
base_raw.trim_end_matches('/').to_owned()
};
let client = rsclaw_embed::FleetHttp::new(None);
let api_key = if model_is_rsclaw {
crate::ocr::rsclaw_provider_key(&cfg)
} else {
None
};
Some(std::sync::Arc::new(Self {
client,
url: format!("{base}/rerank"),
model: rr.model,
api_key,
top_n: rr.top_n.unwrap_or(DEFAULT_RERANK_TOP_N).clamp(2, 100),
}))
}
pub fn remote(base_url: &str, model: impl Into<String>, top_n: usize) -> std::sync::Arc<Self> {
let base = base_url.trim().trim_end_matches('/');
let model = model.into();
let api_key = rerank_api_key(Some(&model));
std::sync::Arc::new(Self {
client: rsclaw_embed::FleetHttp::new(None),
url: format!("{base}/rerank"),
model: Some(model),
api_key,
top_n: top_n.clamp(2, 100),
})
}
pub fn rsclaw_default() -> std::sync::Arc<Self> {
Self::remote(
rsclaw_embed::RSCLAW_API_BASE_URL,
"rsclaw-reranker-v1",
DEFAULT_RERANK_TOP_N,
)
}
pub fn rerank(&self, query: &str, docs: &[&str]) -> Result<Vec<f32>> {
let mut body = serde_json::json!({
"query": query,
"documents": docs,
"top_n": self.top_n.min(docs.len()).max(1),
"return_documents": false,
});
if let Some(m) = &self.model {
body["model"] = serde_json::json!(m);
}
let send = || async {
let resp = self
.client
.post_following_redirects(
self.url.as_str(),
&body,
self.api_key.as_deref(),
false,
None,
Some(std::time::Duration::from_secs(RERANK_TIMEOUT_SECS)),
)
.await?
.error_for_status()?;
anyhow::Ok(resp.json::<serde_json::Value>().await?)
};
let resp: serde_json::Value = match tokio::runtime::Handle::try_current() {
Ok(handle) => tokio::task::block_in_place(|| handle.block_on(send()))
.context("rerank request failed")?,
Err(_) => {
let tmp_rt = tokio::runtime::Runtime::new()
.context("failed to create temp runtime for rerank")?;
tmp_rt.block_on(send()).context("rerank request failed")?
}
};
let results = resp
.get("results")
.and_then(|v| v.as_array())
.with_context(|| {
let detail = resp
.pointer("/error/message")
.and_then(|v| v.as_str())
.unwrap_or("no results array in response");
format!("rerank backend returned no results: {detail}")
})?;
let mut scores = vec![f32::NEG_INFINITY; docs.len()];
for r in results {
let idx = r.get("index").and_then(|v| v.as_u64()).unwrap_or(u64::MAX) as usize;
let score = r
.get("relevance_score")
.or_else(|| r.get("score"))
.and_then(|v| v.as_f64())
.unwrap_or(f64::NEG_INFINITY) as f32;
if idx < scores.len() {
scores[idx] = score;
}
}
if scores.iter().all(|s| !s.is_finite()) {
anyhow::bail!("rerank response carried no usable scores");
}
Ok(scores)
}
}
impl std::fmt::Debug for KbReranker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("KbReranker")
.field("url", &self.url)
.field("model", &self.model)
.field("top_n", &self.top_n)
.finish()
}
}