pub mod batch;
pub mod calibrate;
pub mod cascade;
pub mod engine;
pub mod fusion;
pub mod models;
pub mod multi_gpu;
pub mod tokenize;
pub mod types;
pub use types::{
CacheMetadata, Device, ModelConfig, ModelFile, ModelManifest, Precision, RerankConfig,
RerankRequest, RerankResult, ScorerType,
};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum Error {
#[error("model error: {0}")]
Model(String),
#[error("tokenizer error: {0}")]
Tokenizer(String),
#[error("inference error: {0}")]
Inference(String),
#[error("download error: {0}")]
Download(String),
#[error("cache error: {0}")]
Cache(String),
#[error("config error: {0}")]
Config(String),
#[error("calibration error: {0}")]
Calibration(String),
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("serialization error: {0}")]
Serde(#[from] serde_json::Error),
}
pub type Result<T> = std::result::Result<T, Error>;
pub fn rerank(
model_dir: &std::path::Path,
config: &ModelConfig,
request: &RerankRequest,
) -> Result<Vec<RerankResult>> {
use engine::Scorer;
use engine::ort_backend::OrtScorer;
let scorer = OrtScorer::new(config.clone(), model_dir)?;
let mut results = scorer.score(&request.query, &request.documents)?;
if let Some(top_k) = request.top_k {
results.truncate(top_k);
}
if request.return_documents {
for result in &mut results {
if result.index < request.documents.len() {
result.document = Some(request.documents[result.index].clone());
}
}
}
Ok(results)
}