Skip to main content

redis_vl/vectorizers/
cohere.rs

1//! Cohere embedding adapter.
2//!
3//! Enabled by the `cohere` feature flag. Cohere uses a distinct `/embed` API
4//! that differs from the OpenAI-compatible format: it takes `texts`,
5//! `input_type`, and `embedding_types` fields.
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9
10use super::{AsyncVectorizer, Vectorizer};
11use crate::error::Result;
12
13/// Configuration for the Cohere embedding provider.
14#[derive(Debug, Clone)]
15pub struct CohereConfig {
16    /// API key for Cohere.
17    pub api_key: String,
18    /// Embedding model name (default: `embed-english-v3.0`).
19    pub model: String,
20    /// The Cohere `input_type` to use (e.g. `"search_document"`, `"search_query"`).
21    pub input_type: String,
22}
23
24impl CohereConfig {
25    /// Creates a new Cohere config.
26    pub fn new(
27        api_key: impl Into<String>,
28        model: impl Into<String>,
29        input_type: impl Into<String>,
30    ) -> Self {
31        Self {
32            api_key: api_key.into(),
33            model: model.into(),
34            input_type: input_type.into(),
35        }
36    }
37
38    /// Constructs from `COHERE_API_KEY` environment variable.
39    pub fn from_env(model: impl Into<String>, input_type: impl Into<String>) -> Result<Self> {
40        let api_key = std::env::var("COHERE_API_KEY")
41            .map_err(|_| crate::error::Error::InvalidInput("COHERE_API_KEY not set".into()))?;
42        Ok(Self::new(api_key, model, input_type))
43    }
44}
45
46const COHERE_EMBED_URL: &str = "https://api.cohere.com/v1/embed";
47
48#[derive(Serialize)]
49struct CohereEmbedRequest<'a> {
50    model: &'a str,
51    texts: Vec<&'a str>,
52    input_type: &'a str,
53    embedding_types: Vec<&'a str>,
54}
55
56#[derive(Deserialize)]
57struct CohereEmbedResponse {
58    embeddings: CohereEmbeddings,
59}
60
61#[derive(Deserialize)]
62struct CohereEmbeddings {
63    float: Option<Vec<Vec<f32>>>,
64}
65
66/// Cohere embedding adapter.
67///
68/// Uses the Cohere `/embed` API which differs from the OpenAI-compatible format.
69#[derive(Debug, Clone)]
70pub struct CohereTextVectorizer {
71    config: CohereConfig,
72    client: reqwest::Client,
73    blocking_client: reqwest::blocking::Client,
74}
75
76impl CohereTextVectorizer {
77    /// Creates a new Cohere adapter.
78    pub fn new(config: CohereConfig) -> Self {
79        Self {
80            config,
81            client: reqwest::Client::new(),
82            blocking_client: reqwest::blocking::Client::new(),
83        }
84    }
85
86    fn build_request_body<'a>(&'a self, texts: &[&'a str]) -> CohereEmbedRequest<'a> {
87        CohereEmbedRequest {
88            model: &self.config.model,
89            texts: texts.to_vec(),
90            input_type: &self.config.input_type,
91            embedding_types: vec!["float"],
92        }
93    }
94
95    fn parse_response(response: CohereEmbedResponse) -> Result<Vec<Vec<f32>>> {
96        response.embeddings.float.ok_or_else(|| {
97            crate::error::Error::InvalidInput("no float embeddings in response".into())
98        })
99    }
100
101    async fn embed_many_inner(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
102        let resp: CohereEmbedResponse = self
103            .client
104            .post(COHERE_EMBED_URL)
105            .bearer_auth(&self.config.api_key)
106            .json(&self.build_request_body(texts))
107            .send()
108            .await?
109            .error_for_status()?
110            .json()
111            .await?;
112        Self::parse_response(resp)
113    }
114}
115
116impl Vectorizer for CohereTextVectorizer {
117    fn embed(&self, text: &str) -> Result<Vec<f32>> {
118        let resp: CohereEmbedResponse = self
119            .blocking_client
120            .post(COHERE_EMBED_URL)
121            .bearer_auth(&self.config.api_key)
122            .json(&self.build_request_body(&[text]))
123            .send()?
124            .error_for_status()?
125            .json()?;
126        let mut embeddings = Self::parse_response(resp)?;
127        Ok(embeddings.pop().unwrap_or_default())
128    }
129}
130
131#[async_trait]
132impl AsyncVectorizer for CohereTextVectorizer {
133    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
134        let mut v = self.embed_many_inner(&[text]).await?;
135        Ok(v.pop().unwrap_or_default())
136    }
137
138    async fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
139        self.embed_many_inner(texts).await
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146
147    #[test]
148    fn cohere_config_stores_fields() {
149        let cfg = CohereConfig::new("key", "embed-english-v3.0", "search_document");
150        assert_eq!(cfg.api_key, "key");
151        assert_eq!(cfg.model, "embed-english-v3.0");
152        assert_eq!(cfg.input_type, "search_document");
153    }
154
155    #[test]
156    fn cohere_request_serializes_correctly() {
157        let cfg = CohereConfig::new("k", "model", "search_query");
158        let v = CohereTextVectorizer::new(cfg);
159        let body = v.build_request_body(&["hello", "world"]);
160        let json = serde_json::to_value(&body).unwrap();
161        assert_eq!(json["model"], "model");
162        assert_eq!(json["input_type"], "search_query");
163        assert_eq!(json["embedding_types"], serde_json::json!(["float"]));
164        assert_eq!(json["texts"], serde_json::json!(["hello", "world"]));
165    }
166
167    #[test]
168    fn cohere_parse_response_extracts_floats() {
169        let resp = CohereEmbedResponse {
170            embeddings: CohereEmbeddings {
171                float: Some(vec![vec![1.0, 2.0], vec![3.0, 4.0]]),
172            },
173        };
174        let result = CohereTextVectorizer::parse_response(resp).unwrap();
175        assert_eq!(result.len(), 2);
176        assert_eq!(result[0], vec![1.0, 2.0]);
177    }
178
179    #[test]
180    fn cohere_parse_response_errors_on_missing_float() {
181        let resp = CohereEmbedResponse {
182            embeddings: CohereEmbeddings { float: None },
183        };
184        assert!(CohereTextVectorizer::parse_response(resp).is_err());
185    }
186
187    #[test]
188    fn cohere_vectorizer_is_send_sync() {
189        fn assert_send_sync<T: Send + Sync>() {}
190        assert_send_sync::<CohereTextVectorizer>();
191    }
192}