use crate::core::rerank::service::RerankProvider;
use crate::core::rerank::types::{RerankRequest, RerankResponse, RerankResult, RerankUsage};
use crate::utils::error::gateway_error::{GatewayError, Result};
use async_trait::async_trait;
use std::collections::HashMap;
pub struct CohereRerankProvider {
api_key: String,
base_url: String,
client: reqwest::Client,
}
impl CohereRerankProvider {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
base_url: "https://api.cohere.ai/v1".to_string(),
client: reqwest::Client::new(),
}
}
pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into();
self
}
}
#[async_trait]
impl RerankProvider for CohereRerankProvider {
async fn rerank(&self, request: RerankRequest) -> Result<RerankResponse> {
let model = if request.model.contains('/') {
request
.model
.split('/')
.next_back()
.unwrap_or(&request.model)
} else {
&request.model
};
let documents: Vec<String> = request
.documents
.iter()
.map(|d| d.get_text().to_string())
.collect();
let mut body = serde_json::json!({
"model": model,
"query": request.query,
"documents": documents,
});
if let Some(top_n) = request.top_n {
body["top_n"] = serde_json::json!(top_n);
}
if let Some(return_docs) = request.return_documents {
body["return_documents"] = serde_json::json!(return_docs);
}
let response = self
.client
.post(format!("{}/rerank", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| GatewayError::Network(format!("Cohere rerank request failed: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(GatewayError::Network(format!(
"Cohere rerank error ({}): {}",
status, error_text
)));
}
let cohere_response: serde_json::Value = response.json().await.map_err(|e| {
GatewayError::Validation(format!("Failed to parse Cohere response: {}", e))
})?;
let results = cohere_response["results"]
.as_array()
.ok_or_else(|| GatewayError::Validation("Missing results in response".to_string()))?
.iter()
.map(|r| {
let index = r["index"].as_u64().unwrap_or(0) as usize;
let relevance_score = r["relevance_score"].as_f64().unwrap_or(0.0);
let document = if request.return_documents.unwrap_or(true) {
request.documents.get(index).cloned()
} else {
None
};
RerankResult {
index,
relevance_score,
document,
}
})
.collect();
let usage = cohere_response.get("meta").and_then(|m| {
m.get("billed_units").map(|bu| RerankUsage {
query_tokens: None,
document_tokens: None,
total_tokens: None,
search_units: bu
.get("search_units")
.and_then(|s| s.as_u64())
.map(|s| s as u32),
})
});
Ok(RerankResponse {
id: cohere_response["id"]
.as_str()
.unwrap_or("unknown")
.to_string(),
results,
model: model.to_string(),
usage,
meta: HashMap::new(),
})
}
fn provider_name(&self) -> &'static str {
"cohere"
}
fn supports_model(&self, model: &str) -> bool {
let model_name = model.split('/').next_back().unwrap_or(model);
matches!(
model_name,
"rerank-english-v3.0"
| "rerank-multilingual-v3.0"
| "rerank-english-v2.0"
| "rerank-multilingual-v2.0"
)
}
fn supported_models(&self) -> Vec<&'static str> {
vec![
"rerank-english-v3.0",
"rerank-multilingual-v3.0",
"rerank-english-v2.0",
"rerank-multilingual-v2.0",
]
}
}