mem0_rust/rerankers/
cohere.rs

1use async_trait::async_trait;
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use crate::errors::MemoryError;
5use crate::models::ScoredMemory;
6use crate::config::CohereRerankerConfig;
7use super::Reranker;
8
9pub struct CohereReranker {
10    client: Client,
11    api_key: String,
12    model: String,
13}
14
15impl CohereReranker {
16    pub fn new(config: CohereRerankerConfig) -> Result<Self, MemoryError> {
17        let api_key = config.api_key
18             .or_else(|| std::env::var("COHERE_API_KEY").ok())
19             .ok_or_else(|| MemoryError::Config("COHERE_API_KEY not set".to_string()))?;
20             
21        Ok(Self {
22            client: Client::new(),
23            api_key,
24            model: config.model,
25        })
26    }
27}
28
29#[derive(Serialize)]
30struct RerankRequest<'a> {
31    model: &'a str,
32    query: &'a str,
33    documents: Vec<&'a str>,
34    top_n: usize,
35}
36
37#[derive(Deserialize)]
38struct RerankResponse {
39    results: Vec<RerankResult>,
40}
41
42#[derive(Deserialize)]
43struct RerankResult {
44    index: usize,
45    relevance_score: f32,
46}
47
48#[async_trait]
49impl Reranker for CohereReranker {
50    async fn rerank(&self, query: &str, results: Vec<ScoredMemory>) -> Result<Vec<ScoredMemory>, MemoryError> {
51        if results.is_empty() {
52            return Ok(results);
53        }
54
55        let documents: Vec<&str> = results.iter().map(|m| m.record.content.as_str()).collect();
56
57        let request = RerankRequest {
58            model: &self.model,
59            query,
60            documents,
61            top_n: results.len(),
62        };
63
64        let response = self.client.post("https://api.cohere.com/v1/rerank")
65            .header("Authorization", format!("Bearer {}", self.api_key))
66            .header("Content-Type", "application/json")
67            .json(&request)
68            .send()
69            .await
70            .map_err(|e| MemoryError::Reranker(e.to_string()))?;
71
72        if !response.status().is_success() {
73             let error_text = response.text().await.unwrap_or_default();
74             return Err(MemoryError::Reranker(format!("Cohere API error: {}", error_text)));
75        }
76
77        let rerank_response: RerankResponse = response.json().await
78            .map_err(|e| MemoryError::Reranker(format!("Failed to parse response: {}", e)))?;
79
80        let mut reranked = Vec::new();
81        for result in rerank_response.results {
82            if let Some(mut memory) = results.get(result.index).cloned() {
83                memory.score = result.relevance_score;
84                reranked.push(memory);
85            }
86        }
87
88        Ok(reranked)
89    }
90
91    fn model_name(&self) -> &str {
92        &self.model
93    }
94}