redis_vl/vectorizers/
cohere.rs1use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9
10use super::{AsyncVectorizer, Vectorizer};
11use crate::error::Result;
12
13#[derive(Debug, Clone)]
15pub struct CohereConfig {
16 pub api_key: String,
18 pub model: String,
20 pub input_type: String,
22}
23
24impl CohereConfig {
25 pub fn new(
27 api_key: impl Into<String>,
28 model: impl Into<String>,
29 input_type: impl Into<String>,
30 ) -> Self {
31 Self {
32 api_key: api_key.into(),
33 model: model.into(),
34 input_type: input_type.into(),
35 }
36 }
37
38 pub fn from_env(model: impl Into<String>, input_type: impl Into<String>) -> Result<Self> {
40 let api_key = std::env::var("COHERE_API_KEY")
41 .map_err(|_| crate::error::Error::InvalidInput("COHERE_API_KEY not set".into()))?;
42 Ok(Self::new(api_key, model, input_type))
43 }
44}
45
46const COHERE_EMBED_URL: &str = "https://api.cohere.com/v1/embed";
47
48#[derive(Serialize)]
49struct CohereEmbedRequest<'a> {
50 model: &'a str,
51 texts: Vec<&'a str>,
52 input_type: &'a str,
53 embedding_types: Vec<&'a str>,
54}
55
56#[derive(Deserialize)]
57struct CohereEmbedResponse {
58 embeddings: CohereEmbeddings,
59}
60
61#[derive(Deserialize)]
62struct CohereEmbeddings {
63 float: Option<Vec<Vec<f32>>>,
64}
65
66#[derive(Debug, Clone)]
70pub struct CohereTextVectorizer {
71 config: CohereConfig,
72 client: reqwest::Client,
73 blocking_client: reqwest::blocking::Client,
74}
75
76impl CohereTextVectorizer {
77 pub fn new(config: CohereConfig) -> Self {
79 Self {
80 config,
81 client: reqwest::Client::new(),
82 blocking_client: reqwest::blocking::Client::new(),
83 }
84 }
85
86 fn build_request_body<'a>(&'a self, texts: &[&'a str]) -> CohereEmbedRequest<'a> {
87 CohereEmbedRequest {
88 model: &self.config.model,
89 texts: texts.to_vec(),
90 input_type: &self.config.input_type,
91 embedding_types: vec!["float"],
92 }
93 }
94
95 fn parse_response(response: CohereEmbedResponse) -> Result<Vec<Vec<f32>>> {
96 response.embeddings.float.ok_or_else(|| {
97 crate::error::Error::InvalidInput("no float embeddings in response".into())
98 })
99 }
100
101 async fn embed_many_inner(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
102 let resp: CohereEmbedResponse = self
103 .client
104 .post(COHERE_EMBED_URL)
105 .bearer_auth(&self.config.api_key)
106 .json(&self.build_request_body(texts))
107 .send()
108 .await?
109 .error_for_status()?
110 .json()
111 .await?;
112 Self::parse_response(resp)
113 }
114}
115
116impl Vectorizer for CohereTextVectorizer {
117 fn embed(&self, text: &str) -> Result<Vec<f32>> {
118 let resp: CohereEmbedResponse = self
119 .blocking_client
120 .post(COHERE_EMBED_URL)
121 .bearer_auth(&self.config.api_key)
122 .json(&self.build_request_body(&[text]))
123 .send()?
124 .error_for_status()?
125 .json()?;
126 let mut embeddings = Self::parse_response(resp)?;
127 Ok(embeddings.pop().unwrap_or_default())
128 }
129}
130
131#[async_trait]
132impl AsyncVectorizer for CohereTextVectorizer {
133 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
134 let mut v = self.embed_many_inner(&[text]).await?;
135 Ok(v.pop().unwrap_or_default())
136 }
137
138 async fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
139 self.embed_many_inner(texts).await
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146
147 #[test]
148 fn cohere_config_stores_fields() {
149 let cfg = CohereConfig::new("key", "embed-english-v3.0", "search_document");
150 assert_eq!(cfg.api_key, "key");
151 assert_eq!(cfg.model, "embed-english-v3.0");
152 assert_eq!(cfg.input_type, "search_document");
153 }
154
155 #[test]
156 fn cohere_request_serializes_correctly() {
157 let cfg = CohereConfig::new("k", "model", "search_query");
158 let v = CohereTextVectorizer::new(cfg);
159 let body = v.build_request_body(&["hello", "world"]);
160 let json = serde_json::to_value(&body).unwrap();
161 assert_eq!(json["model"], "model");
162 assert_eq!(json["input_type"], "search_query");
163 assert_eq!(json["embedding_types"], serde_json::json!(["float"]));
164 assert_eq!(json["texts"], serde_json::json!(["hello", "world"]));
165 }
166
167 #[test]
168 fn cohere_parse_response_extracts_floats() {
169 let resp = CohereEmbedResponse {
170 embeddings: CohereEmbeddings {
171 float: Some(vec![vec![1.0, 2.0], vec![3.0, 4.0]]),
172 },
173 };
174 let result = CohereTextVectorizer::parse_response(resp).unwrap();
175 assert_eq!(result.len(), 2);
176 assert_eq!(result[0], vec![1.0, 2.0]);
177 }
178
179 #[test]
180 fn cohere_parse_response_errors_on_missing_float() {
181 let resp = CohereEmbedResponse {
182 embeddings: CohereEmbeddings { float: None },
183 };
184 assert!(CohereTextVectorizer::parse_response(resp).is_err());
185 }
186
187 #[test]
188 fn cohere_vectorizer_is_send_sync() {
189 fn assert_send_sync<T: Send + Sync>() {}
190 assert_send_sync::<CohereTextVectorizer>();
191 }
192}