use super::cache::RerankCache;
use super::types::{RerankRequest, RerankResponse};
use crate::utils::error::gateway_error::{GatewayError, Result};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::{debug, info};
#[async_trait]
pub trait RerankProvider: Send + Sync {
async fn rerank(&self, request: RerankRequest) -> Result<RerankResponse>;
fn provider_name(&self) -> &'static str;
fn supports_model(&self, model: &str) -> bool;
fn supported_models(&self) -> Vec<&'static str>;
}
pub struct RerankService {
providers: HashMap<String, Arc<dyn RerankProvider>>,
default_provider: Option<String>,
timeout: Duration,
enable_cache: bool,
cache: Option<Arc<RerankCache>>,
}
impl RerankService {
pub fn new() -> Self {
Self {
providers: HashMap::new(),
default_provider: None,
timeout: Duration::from_secs(30),
enable_cache: false,
cache: None,
}
}
pub fn register_provider(
&mut self,
name: impl Into<String>,
provider: Arc<dyn RerankProvider>,
) -> &mut Self {
let name = name.into();
info!("Registering rerank provider: {}", name);
self.providers.insert(name, provider);
self
}
pub fn set_default_provider(&mut self, name: impl Into<String>) -> &mut Self {
self.default_provider = Some(name.into());
self
}
pub fn set_timeout(&mut self, timeout: Duration) -> &mut Self {
self.timeout = timeout;
self
}
pub fn enable_cache(&mut self, cache: Arc<RerankCache>) -> &mut Self {
self.enable_cache = true;
self.cache = Some(cache);
self
}
pub async fn rerank(&self, request: RerankRequest) -> Result<RerankResponse> {
let start = Instant::now();
self.validate_request(&request)?;
if self.enable_cache
&& let Some(cache) = &self.cache
&& let Some(cached) = cache.get(&request).await
{
debug!("Rerank cache hit for query: {}", request.query);
return Ok(cached);
}
let provider_name = self.extract_provider_name(&request.model);
let provider = self.get_provider(&provider_name)?;
let response = tokio::time::timeout(self.timeout, provider.rerank(request.clone()))
.await
.map_err(|_| {
GatewayError::Timeout(format!("Rerank request timed out after {:?}", self.timeout))
})??;
if self.enable_cache
&& let Some(cache) = &self.cache
{
cache.set(&request, &response).await;
}
let elapsed = start.elapsed();
info!(
"Rerank completed in {:?}: {} documents -> {} results",
elapsed,
request.documents.len(),
response.results.len()
);
Ok(response)
}
pub(crate) fn validate_request(&self, request: &RerankRequest) -> Result<()> {
if request.query.is_empty() {
return Err(GatewayError::BadRequest(
"Query cannot be empty".to_string(),
));
}
if request.documents.is_empty() {
return Err(GatewayError::BadRequest(
"Documents list cannot be empty".to_string(),
));
}
if request.documents.len() > 10000 {
return Err(GatewayError::BadRequest(
"Too many documents (max 10000)".to_string(),
));
}
if let Some(top_n) = request.top_n
&& top_n == 0
{
return Err(GatewayError::BadRequest(
"top_n must be greater than 0".to_string(),
));
}
Ok(())
}
pub(crate) fn extract_provider_name(&self, model: &str) -> String {
if let Some(idx) = model.find('/') {
model[..idx].to_string()
} else {
self.default_provider
.clone()
.unwrap_or_else(|| "cohere".to_string())
}
}
fn get_provider(&self, name: &str) -> Result<&Arc<dyn RerankProvider>> {
self.providers
.get(name)
.ok_or_else(|| GatewayError::NotFound(format!("Rerank provider not found: {}", name)))
}
pub fn providers(&self) -> Vec<&str> {
self.providers.keys().map(|s| s.as_str()).collect()
}
pub fn supports_model(&self, model: &str) -> bool {
let provider_name = self.extract_provider_name(model);
if let Some(provider) = self.providers.get(&provider_name) {
provider.supports_model(model)
} else {
false
}
}
}
impl Default for RerankService {
fn default() -> Self {
Self::new()
}
}