Skip to main content

synaptic_cohere/
reranker.rs

1#[cfg(feature = "retrieval")]
2use async_trait::async_trait;
3use serde_json::{json, Value};
4use synaptic_core::{Document, SynapticError};
5
6// ---------------------------------------------------------------------------
7// Configuration
8// ---------------------------------------------------------------------------
9
10/// Configuration for the Cohere Reranker.
11#[derive(Debug, Clone)]
12pub struct CohereRerankerConfig {
13    /// Cohere API key.
14    pub api_key: String,
15    /// Reranker model name (default: `"rerank-v3.5"`).
16    pub model: String,
17    /// Maximum number of documents to return. If `None`, all documents are returned.
18    pub top_n: Option<usize>,
19    /// Base URL for the Cohere API (default: `"https://api.cohere.ai/v2"`).
20    pub base_url: String,
21}
22
23impl CohereRerankerConfig {
24    /// Create a new configuration with the given API key and default settings.
25    pub fn new(api_key: impl Into<String>) -> Self {
26        Self {
27            api_key: api_key.into(),
28            model: "rerank-v3.5".to_string(),
29            top_n: None,
30            base_url: "https://api.cohere.ai/v2".to_string(),
31        }
32    }
33
34    /// Set the reranker model.
35    pub fn with_model(mut self, model: impl Into<String>) -> Self {
36        self.model = model.into();
37        self
38    }
39
40    /// Set the maximum number of results to return.
41    pub fn with_top_n(mut self, top_n: usize) -> Self {
42        self.top_n = Some(top_n);
43        self
44    }
45
46    /// Set a custom base URL for the API.
47    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
48        self.base_url = base_url.into();
49        self
50    }
51}
52
53// ---------------------------------------------------------------------------
54// CohereReranker
55// ---------------------------------------------------------------------------
56
57/// A reranker that uses the Cohere Rerank API to reorder documents by
58/// relevance to a query.
59///
60/// Each returned document has a `relevance_score` entry added to its metadata.
61pub struct CohereReranker {
62    config: CohereRerankerConfig,
63    client: reqwest::Client,
64}
65
66impl CohereReranker {
67    /// Create a new `CohereReranker` with the given configuration.
68    pub fn new(config: CohereRerankerConfig) -> Self {
69        Self {
70            config,
71            client: reqwest::Client::new(),
72        }
73    }
74
75    /// Create a new `CohereReranker` with a custom HTTP client.
76    pub fn with_client(config: CohereRerankerConfig, client: reqwest::Client) -> Self {
77        Self { config, client }
78    }
79
80    /// Rerank documents by relevance to a query.
81    ///
82    /// Returns documents sorted by descending relevance score. Each document's
83    /// metadata will contain a `"relevance_score"` entry.
84    ///
85    /// # Arguments
86    ///
87    /// * `query` - The query to rank against.
88    /// * `documents` - The documents to rerank.
89    /// * `top_n` - Override the configured `top_n`. If `None`, uses the
90    ///   configured value, or returns all documents.
91    pub async fn rerank(
92        &self,
93        query: &str,
94        documents: Vec<Document>,
95        top_n: Option<usize>,
96    ) -> Result<Vec<Document>, SynapticError> {
97        if documents.is_empty() {
98            return Ok(Vec::new());
99        }
100
101        let top_n = top_n.or(self.config.top_n).unwrap_or(documents.len());
102
103        let doc_texts: Vec<&str> = documents.iter().map(|d| d.content.as_str()).collect();
104
105        let body = json!({
106            "model": self.config.model,
107            "query": query,
108            "documents": doc_texts,
109            "top_n": top_n,
110        });
111
112        let response = self
113            .client
114            .post(format!("{}/rerank", self.config.base_url))
115            .header("Authorization", format!("Bearer {}", self.config.api_key))
116            .header("Content-Type", "application/json")
117            .json(&body)
118            .send()
119            .await
120            .map_err(|e| SynapticError::Model(format!("Cohere rerank request failed: {e}")))?;
121
122        if !response.status().is_success() {
123            let status = response.status().as_u16();
124            let text = response.text().await.unwrap_or_default();
125            return Err(SynapticError::Model(format!(
126                "Cohere rerank API error ({status}): {text}"
127            )));
128        }
129
130        let resp_body: Value = response
131            .json()
132            .await
133            .map_err(|e| SynapticError::Model(format!("Cohere rerank parse error: {e}")))?;
134
135        let results = resp_body["results"]
136            .as_array()
137            .ok_or_else(|| SynapticError::Model("missing 'results' in response".to_string()))?;
138
139        let mut reranked = Vec::with_capacity(results.len());
140        for result in results {
141            let index = result["index"].as_u64().unwrap_or(0) as usize;
142            let score = result["relevance_score"].as_f64().unwrap_or(0.0);
143            if index < documents.len() {
144                let mut doc = documents[index].clone();
145                doc.metadata
146                    .insert("relevance_score".to_string(), json!(score));
147                reranked.push(doc);
148            }
149        }
150
151        Ok(reranked)
152    }
153}
154
155// ---------------------------------------------------------------------------
156// DocumentCompressor implementation (behind `retrieval` feature)
157// ---------------------------------------------------------------------------
158
159#[cfg(feature = "retrieval")]
160#[async_trait]
161impl synaptic_retrieval::DocumentCompressor for CohereReranker {
162    async fn compress_documents(
163        &self,
164        documents: Vec<Document>,
165        query: &str,
166    ) -> Result<Vec<Document>, SynapticError> {
167        self.rerank(query, documents, None).await
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[test]
176    fn config_defaults() {
177        let config = CohereRerankerConfig::new("test-key");
178        assert_eq!(config.api_key, "test-key");
179        assert_eq!(config.model, "rerank-v3.5");
180        assert_eq!(config.base_url, "https://api.cohere.ai/v2");
181        assert!(config.top_n.is_none());
182    }
183
184    #[test]
185    fn config_builder() {
186        let config = CohereRerankerConfig::new("key")
187            .with_model("rerank-english-v3.0")
188            .with_top_n(5)
189            .with_base_url("https://custom.api.com");
190
191        assert_eq!(config.model, "rerank-english-v3.0");
192        assert_eq!(config.top_n, Some(5));
193        assert_eq!(config.base_url, "https://custom.api.com");
194    }
195
196    #[tokio::test]
197    async fn rerank_empty_documents() {
198        let config = CohereRerankerConfig::new("test-key");
199        let reranker = CohereReranker::new(config);
200
201        let result = reranker.rerank("query", Vec::new(), None).await.unwrap();
202        assert!(result.is_empty());
203    }
204}