synaptic_cohere/
reranker.rs1#[cfg(feature = "retrieval")]
2use async_trait::async_trait;
3use serde_json::{json, Value};
4use synaptic_core::{Document, SynapticError};
5
6#[derive(Debug, Clone)]
12pub struct CohereRerankerConfig {
13 pub api_key: String,
15 pub model: String,
17 pub top_n: Option<usize>,
19 pub base_url: String,
21}
22
23impl CohereRerankerConfig {
24 pub fn new(api_key: impl Into<String>) -> Self {
26 Self {
27 api_key: api_key.into(),
28 model: "rerank-v3.5".to_string(),
29 top_n: None,
30 base_url: "https://api.cohere.ai/v2".to_string(),
31 }
32 }
33
34 pub fn with_model(mut self, model: impl Into<String>) -> Self {
36 self.model = model.into();
37 self
38 }
39
40 pub fn with_top_n(mut self, top_n: usize) -> Self {
42 self.top_n = Some(top_n);
43 self
44 }
45
46 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
48 self.base_url = base_url.into();
49 self
50 }
51}
52
53pub struct CohereReranker {
62 config: CohereRerankerConfig,
63 client: reqwest::Client,
64}
65
66impl CohereReranker {
67 pub fn new(config: CohereRerankerConfig) -> Self {
69 Self {
70 config,
71 client: reqwest::Client::new(),
72 }
73 }
74
75 pub fn with_client(config: CohereRerankerConfig, client: reqwest::Client) -> Self {
77 Self { config, client }
78 }
79
80 pub async fn rerank(
92 &self,
93 query: &str,
94 documents: Vec<Document>,
95 top_n: Option<usize>,
96 ) -> Result<Vec<Document>, SynapticError> {
97 if documents.is_empty() {
98 return Ok(Vec::new());
99 }
100
101 let top_n = top_n.or(self.config.top_n).unwrap_or(documents.len());
102
103 let doc_texts: Vec<&str> = documents.iter().map(|d| d.content.as_str()).collect();
104
105 let body = json!({
106 "model": self.config.model,
107 "query": query,
108 "documents": doc_texts,
109 "top_n": top_n,
110 });
111
112 let response = self
113 .client
114 .post(format!("{}/rerank", self.config.base_url))
115 .header("Authorization", format!("Bearer {}", self.config.api_key))
116 .header("Content-Type", "application/json")
117 .json(&body)
118 .send()
119 .await
120 .map_err(|e| SynapticError::Model(format!("Cohere rerank request failed: {e}")))?;
121
122 if !response.status().is_success() {
123 let status = response.status().as_u16();
124 let text = response.text().await.unwrap_or_default();
125 return Err(SynapticError::Model(format!(
126 "Cohere rerank API error ({status}): {text}"
127 )));
128 }
129
130 let resp_body: Value = response
131 .json()
132 .await
133 .map_err(|e| SynapticError::Model(format!("Cohere rerank parse error: {e}")))?;
134
135 let results = resp_body["results"]
136 .as_array()
137 .ok_or_else(|| SynapticError::Model("missing 'results' in response".to_string()))?;
138
139 let mut reranked = Vec::with_capacity(results.len());
140 for result in results {
141 let index = result["index"].as_u64().unwrap_or(0) as usize;
142 let score = result["relevance_score"].as_f64().unwrap_or(0.0);
143 if index < documents.len() {
144 let mut doc = documents[index].clone();
145 doc.metadata
146 .insert("relevance_score".to_string(), json!(score));
147 reranked.push(doc);
148 }
149 }
150
151 Ok(reranked)
152 }
153}
154
155#[cfg(feature = "retrieval")]
160#[async_trait]
161impl synaptic_retrieval::DocumentCompressor for CohereReranker {
162 async fn compress_documents(
163 &self,
164 documents: Vec<Document>,
165 query: &str,
166 ) -> Result<Vec<Document>, SynapticError> {
167 self.rerank(query, documents, None).await
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 #[test]
176 fn config_defaults() {
177 let config = CohereRerankerConfig::new("test-key");
178 assert_eq!(config.api_key, "test-key");
179 assert_eq!(config.model, "rerank-v3.5");
180 assert_eq!(config.base_url, "https://api.cohere.ai/v2");
181 assert!(config.top_n.is_none());
182 }
183
184 #[test]
185 fn config_builder() {
186 let config = CohereRerankerConfig::new("key")
187 .with_model("rerank-english-v3.0")
188 .with_top_n(5)
189 .with_base_url("https://custom.api.com");
190
191 assert_eq!(config.model, "rerank-english-v3.0");
192 assert_eq!(config.top_n, Some(5));
193 assert_eq!(config.base_url, "https://custom.api.com");
194 }
195
196 #[tokio::test]
197 async fn rerank_empty_documents() {
198 let config = CohereRerankerConfig::new("test-key");
199 let reranker = CohereReranker::new(config);
200
201 let result = reranker.rerank("query", Vec::new(), None).await.unwrap();
202 assert!(result.is_empty());
203 }
204}