use std::sync::{Arc, Mutex};
use crate::error::{MemeError, Result};
pub(crate) struct OnnxReranker {
model: Arc<Mutex<fastembed::TextRerank>>,
}
impl std::fmt::Debug for OnnxReranker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OnnxReranker").finish_non_exhaustive()
}
}
impl OnnxReranker {
pub(crate) fn new(model_name: &str) -> Result<Self> {
let reranker_model = resolve_model(model_name)?;
let model = fastembed::TextRerank::try_new(
fastembed::RerankInitOptions::new(reranker_model).with_show_download_progress(true),
)
.map_err(|e| MemeError::Internal(format!("reranker init failed: {e}")))?;
Ok(Self {
model: Arc::new(Mutex::new(model)),
})
}
pub(crate) async fn rerank(
&self,
query: &str,
documents: &[&str],
top_n: usize,
) -> Result<Vec<usize>> {
if documents.is_empty() || top_n == 0 {
return Ok(Vec::new());
}
let model = Arc::clone(&self.model);
let query = query.to_owned();
let docs: Vec<String> = documents.iter().map(|s| (*s).to_owned()).collect();
let n = top_n.min(docs.len());
tokio::task::spawn_blocking(move || {
let doc_refs: Vec<&str> = docs.iter().map(String::as_str).collect();
let results = {
let mut guard = model
.lock()
.map_err(|e| MemeError::Internal(format!("reranker lock poisoned: {e}")))?;
guard
.rerank(query.as_str(), doc_refs.as_slice(), false, None)
.map_err(|e| MemeError::Internal(format!("reranker inference failed: {e}")))?
};
Ok(results.into_iter().take(n).map(|r| r.index).collect())
})
.await
.map_err(|e| MemeError::Internal(format!("reranker spawn_blocking failed: {e}")))?
}
}
fn resolve_model(name: &str) -> Result<fastembed::RerankerModel> {
for info in fastembed::TextRerank::list_supported_models() {
if info.model_code == name {
return Ok(info.model);
}
}
Err(MemeError::Config(format!("unknown reranker model: {name}")))
}