use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use super::{AsyncReranker, RerankDoc, RerankResult, Reranker};
use crate::error::Result;
#[derive(Debug, Clone)]
pub struct CohereRerankerConfig {
pub api_key: String,
pub model: String,
pub rank_by: Vec<String>,
pub return_score: bool,
}
impl CohereRerankerConfig {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
model: "rerank-english-v3.0".to_string(),
rank_by: Vec::new(),
return_score: true,
}
}
pub fn from_env() -> Result<Self> {
let api_key = std::env::var("COHERE_API_KEY")
.map_err(|_| crate::error::Error::InvalidInput("COHERE_API_KEY not set".into()))?;
Ok(Self::new(api_key))
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn rank_by(mut self, fields: Vec<String>) -> Self {
self.rank_by = fields;
self
}
pub fn return_score(mut self, val: bool) -> Self {
self.return_score = val;
self
}
}
const COHERE_RERANK_URL: &str = "https://api.cohere.com/v1/rerank";
#[derive(Debug, Serialize)]
#[serde(untagged)]
enum CohereDocument<'a> {
Text(&'a str),
Fields(std::collections::HashMap<&'a str, &'a str>),
}
#[derive(Debug, Serialize)]
struct CohereRerankRequest<'a> {
model: &'a str,
query: &'a str,
top_n: usize,
documents: Vec<CohereDocument<'a>>,
#[serde(skip_serializing_if = "Option::is_none")]
rank_fields: Option<Vec<&'a str>>,
}
#[derive(Deserialize)]
struct CohereRerankResponse {
results: Vec<CohereRerankResult>,
}
#[derive(Deserialize)]
struct CohereRerankResult {
index: usize,
relevance_score: f64,
}
#[derive(Debug, Clone)]
pub struct CohereReranker {
config: CohereRerankerConfig,
client: reqwest::Client,
blocking_client: reqwest::blocking::Client,
}
impl CohereReranker {
pub fn new(config: CohereRerankerConfig) -> Self {
Self {
config,
client: reqwest::Client::new(),
blocking_client: reqwest::blocking::Client::new(),
}
}
fn prepare_request<'a>(
&'a self,
query: &'a str,
docs: &'a [RerankDoc],
limit: Option<usize>,
) -> Result<CohereRerankRequest<'a>> {
let all_fields = docs.iter().all(|d| matches!(d, RerankDoc::Fields(_)));
let (documents, rank_fields) = if all_fields {
if self.config.rank_by.is_empty() {
return Err(crate::error::Error::InvalidInput(
"If reranking dictionary-like docs, you must provide a list of rank_by fields"
.into(),
));
}
let structured: Vec<CohereDocument<'a>> = docs
.iter()
.map(|d| match d {
RerankDoc::Fields(map) => {
let obj: std::collections::HashMap<&str, &str> =
map.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
CohereDocument::Fields(obj)
}
RerankDoc::Text(s) => CohereDocument::Text(s.as_str()),
})
.collect();
let rf: Vec<&str> = self.config.rank_by.iter().map(|s| s.as_str()).collect();
(structured, Some(rf))
} else {
let plain: Vec<CohereDocument<'a>> = docs
.iter()
.filter_map(|d| match d {
RerankDoc::Text(s) => Some(CohereDocument::Text(s.as_str())),
RerankDoc::Fields(map) => {
map.get("content").map(|s| CohereDocument::Text(s.as_str()))
}
})
.collect();
(plain, None)
};
let top_n = limit.unwrap_or(documents.len());
Ok(CohereRerankRequest {
model: &self.config.model,
query,
top_n,
documents,
rank_fields,
})
}
fn build_result(&self, docs: &[RerankDoc], response: CohereRerankResponse) -> RerankResult {
let mut reranked = Vec::with_capacity(response.results.len());
let mut scores = Vec::with_capacity(response.results.len());
for item in &response.results {
if item.index < docs.len() {
reranked.push(docs[item.index].clone());
scores.push(item.relevance_score);
}
}
RerankResult {
docs: reranked,
scores: if self.config.return_score {
Some(scores)
} else {
None
},
}
}
}
impl Reranker for CohereReranker {
fn rank(&self, query: &str, docs: &[RerankDoc], limit: Option<usize>) -> Result<RerankResult> {
let request = self.prepare_request(query, docs, limit)?;
let resp: CohereRerankResponse = self
.blocking_client
.post(COHERE_RERANK_URL)
.bearer_auth(&self.config.api_key)
.json(&request)
.send()?
.error_for_status()?
.json()?;
Ok(self.build_result(docs, resp))
}
}
#[async_trait]
impl AsyncReranker for CohereReranker {
async fn rank(
&self,
query: &str,
docs: &[RerankDoc],
limit: Option<usize>,
) -> Result<RerankResult> {
let request = self.prepare_request(query, docs, limit)?;
let resp: CohereRerankResponse = self
.client
.post(COHERE_RERANK_URL)
.bearer_auth(&self.config.api_key)
.json(&request)
.send()
.await?
.error_for_status()?
.json()
.await?;
Ok(self.build_result(docs, resp))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cohere_reranker_config_defaults() {
let cfg = CohereRerankerConfig::new("test-key");
assert_eq!(cfg.api_key, "test-key");
assert_eq!(cfg.model, "rerank-english-v3.0");
assert!(cfg.rank_by.is_empty());
assert!(cfg.return_score);
}
#[test]
fn cohere_reranker_config_builder() {
let cfg = CohereRerankerConfig::new("key")
.model("rerank-multilingual-v3.0")
.rank_by(vec!["content".into()])
.return_score(false);
assert_eq!(cfg.model, "rerank-multilingual-v3.0");
assert_eq!(cfg.rank_by, vec!["content"]);
assert!(!cfg.return_score);
}
#[test]
fn cohere_reranker_build_result_with_scores() {
let cfg = CohereRerankerConfig::new("key");
let reranker = CohereReranker::new(cfg);
let docs = vec![
RerankDoc::Text("a".into()),
RerankDoc::Text("b".into()),
RerankDoc::Text("c".into()),
];
let response = CohereRerankResponse {
results: vec![
CohereRerankResult {
index: 2,
relevance_score: 0.9,
},
CohereRerankResult {
index: 0,
relevance_score: 0.5,
},
],
};
let result = reranker.build_result(&docs, response);
assert_eq!(result.docs.len(), 2);
let scores = result.scores.unwrap();
assert_eq!(scores.len(), 2);
assert!((scores[0] - 0.9).abs() < f64::EPSILON);
assert!((scores[1] - 0.5).abs() < f64::EPSILON);
}
#[test]
fn cohere_reranker_build_result_without_scores() {
let cfg = CohereRerankerConfig::new("key").return_score(false);
let reranker = CohereReranker::new(cfg);
let docs = vec![RerankDoc::Text("a".into())];
let response = CohereRerankResponse {
results: vec![CohereRerankResult {
index: 0,
relevance_score: 0.9,
}],
};
let result = reranker.build_result(&docs, response);
assert!(result.scores.is_none());
}
#[test]
fn prepare_request_string_docs_sends_plain_text_no_rank_fields() {
let cfg = CohereRerankerConfig::new("key");
let reranker = CohereReranker::new(cfg);
let docs = vec![
RerankDoc::Text("doc1".into()),
RerankDoc::Text("doc2".into()),
];
let req = reranker.prepare_request("query", &docs, Some(2)).unwrap();
assert!(req.rank_fields.is_none());
assert_eq!(req.documents.len(), 2);
assert_eq!(req.top_n, 2);
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["documents"], serde_json::json!(["doc1", "doc2"]));
assert!(json.get("rank_fields").is_none());
}
#[test]
fn prepare_request_dict_docs_with_rank_by_sends_structured_plus_rank_fields() {
let cfg = CohereRerankerConfig::new("key").rank_by(vec!["content".into()]);
let reranker = CohereReranker::new(cfg);
let mut map1 = std::collections::HashMap::new();
map1.insert("content".to_string(), "document 1".to_string());
let mut map2 = std::collections::HashMap::new();
map2.insert("content".to_string(), "document 2".to_string());
let docs = vec![RerankDoc::Fields(map1), RerankDoc::Fields(map2)];
let req = reranker.prepare_request("query", &docs, None).unwrap();
assert_eq!(req.rank_fields, Some(vec!["content"]));
assert_eq!(req.documents.len(), 2);
let json = serde_json::to_value(&req).unwrap();
assert_eq!(
json["documents"][0]["content"],
serde_json::json!("document 1")
);
assert_eq!(json["rank_fields"], serde_json::json!(["content"]));
}
#[test]
fn prepare_request_dict_docs_without_rank_by_errors_like_python() {
let cfg = CohereRerankerConfig::new("key"); let reranker = CohereReranker::new(cfg);
let mut map = std::collections::HashMap::new();
map.insert("content".to_string(), "doc".to_string());
let docs = vec![RerankDoc::Fields(map)];
let err = reranker.prepare_request("query", &docs, None);
assert!(err.is_err());
let msg = err.unwrap_err().to_string();
assert!(
msg.contains("rank_by"),
"Error should mention rank_by: {msg}"
);
}
#[test]
fn prepare_request_limit_defaults_to_doc_count() {
let cfg = CohereRerankerConfig::new("key");
let reranker = CohereReranker::new(cfg);
let docs = vec![
RerankDoc::Text("a".into()),
RerankDoc::Text("b".into()),
RerankDoc::Text("c".into()),
];
let req = reranker.prepare_request("q", &docs, None).unwrap();
assert_eq!(req.top_n, 3);
}
#[test]
fn cohere_reranker_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<CohereReranker>();
}
}