mem0_rust/rerankers/
cohere.rs1use async_trait::async_trait;
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use crate::errors::MemoryError;
5use crate::models::ScoredMemory;
6use crate::config::CohereRerankerConfig;
7use super::Reranker;
8
9pub struct CohereReranker {
10 client: Client,
11 api_key: String,
12 model: String,
13}
14
15impl CohereReranker {
16 pub fn new(config: CohereRerankerConfig) -> Result<Self, MemoryError> {
17 let api_key = config.api_key
18 .or_else(|| std::env::var("COHERE_API_KEY").ok())
19 .ok_or_else(|| MemoryError::Config("COHERE_API_KEY not set".to_string()))?;
20
21 Ok(Self {
22 client: Client::new(),
23 api_key,
24 model: config.model,
25 })
26 }
27}
28
29#[derive(Serialize)]
30struct RerankRequest<'a> {
31 model: &'a str,
32 query: &'a str,
33 documents: Vec<&'a str>,
34 top_n: usize,
35}
36
37#[derive(Deserialize)]
38struct RerankResponse {
39 results: Vec<RerankResult>,
40}
41
42#[derive(Deserialize)]
43struct RerankResult {
44 index: usize,
45 relevance_score: f32,
46}
47
48#[async_trait]
49impl Reranker for CohereReranker {
50 async fn rerank(&self, query: &str, results: Vec<ScoredMemory>) -> Result<Vec<ScoredMemory>, MemoryError> {
51 if results.is_empty() {
52 return Ok(results);
53 }
54
55 let documents: Vec<&str> = results.iter().map(|m| m.record.content.as_str()).collect();
56
57 let request = RerankRequest {
58 model: &self.model,
59 query,
60 documents,
61 top_n: results.len(),
62 };
63
64 let response = self.client.post("https://api.cohere.com/v1/rerank")
65 .header("Authorization", format!("Bearer {}", self.api_key))
66 .header("Content-Type", "application/json")
67 .json(&request)
68 .send()
69 .await
70 .map_err(|e| MemoryError::Reranker(e.to_string()))?;
71
72 if !response.status().is_success() {
73 let error_text = response.text().await.unwrap_or_default();
74 return Err(MemoryError::Reranker(format!("Cohere API error: {}", error_text)));
75 }
76
77 let rerank_response: RerankResponse = response.json().await
78 .map_err(|e| MemoryError::Reranker(format!("Failed to parse response: {}", e)))?;
79
80 let mut reranked = Vec::new();
81 for result in rerank_response.results {
82 if let Some(mut memory) = results.get(result.index).cloned() {
83 memory.score = result.relevance_score;
84 reranked.push(memory);
85 }
86 }
87
88 Ok(reranked)
89 }
90
91 fn model_name(&self) -> &str {
92 &self.model
93 }
94}