1use std::collections::HashMap;
4
5use mentedb_core::MenteError;
6use mentedb_core::error::MenteResult;
7use serde::{Deserialize, Serialize};
8
9use crate::provider::{AsyncEmbeddingProvider, EmbeddingProvider};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct HttpEmbeddingConfig {
14 pub api_url: String,
16 pub api_key: String,
18 pub model_name: String,
20 pub dimensions: usize,
22 pub headers: HashMap<String, String>,
24}
25
26impl HttpEmbeddingConfig {
27 pub fn openai(api_key: impl Into<String>, model: impl Into<String>) -> Self {
31 let model = model.into();
32 let dimensions = match model.as_str() {
33 "text-embedding-3-small" => 1536,
34 "text-embedding-3-large" => 3072,
35 "text-embedding-ada-002" => 1536,
36 _ => 1536,
37 };
38
39 let mut headers = HashMap::new();
40 headers.insert("Content-Type".to_string(), "application/json".to_string());
41
42 Self {
43 api_url: "https://api.openai.com/v1/embeddings".to_string(),
44 api_key: api_key.into(),
45 model_name: model,
46 dimensions,
47 headers,
48 }
49 }
50
51 pub fn cohere(api_key: impl Into<String>, model: impl Into<String>) -> Self {
55 let model = model.into();
56 let dimensions = match model.as_str() {
57 "embed-english-v3.0" => 1024,
58 "embed-multilingual-v3.0" => 1024,
59 "embed-english-light-v3.0" => 384,
60 "embed-multilingual-light-v3.0" => 384,
61 _ => 1024,
62 };
63
64 let mut headers = HashMap::new();
65 headers.insert("Content-Type".to_string(), "application/json".to_string());
66
67 Self {
68 api_url: "https://api.cohere.ai/v1/embed".to_string(),
69 api_key: api_key.into(),
70 model_name: model,
71 dimensions,
72 headers,
73 }
74 }
75
76 pub fn voyage(api_key: impl Into<String>, model: impl Into<String>) -> Self {
80 let model = model.into();
81 let dimensions = match model.as_str() {
82 "voyage-2" => 1024,
83 "voyage-large-2" => 1536,
84 "voyage-code-2" => 1536,
85 "voyage-lite-02-instruct" => 1024,
86 _ => 1024,
87 };
88
89 let mut headers = HashMap::new();
90 headers.insert("Content-Type".to_string(), "application/json".to_string());
91
92 Self {
93 api_url: "https://api.voyageai.com/v1/embeddings".to_string(),
94 api_key: api_key.into(),
95 model_name: model,
96 dimensions,
97 headers,
98 }
99 }
100
101 pub fn with_dimensions(mut self, dimensions: usize) -> Self {
103 self.dimensions = dimensions;
104 self
105 }
106
107 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
109 self.headers.insert(key.into(), value.into());
110 self
111 }
112}
113
114pub struct HttpEmbeddingProvider {
119 config: HttpEmbeddingConfig,
120}
121
122impl HttpEmbeddingProvider {
123 pub fn new(config: HttpEmbeddingConfig) -> Self {
125 Self { config }
126 }
127
128 pub fn config(&self) -> &HttpEmbeddingConfig {
130 &self.config
131 }
132}
133
134impl AsyncEmbeddingProvider for HttpEmbeddingProvider {
135 async fn embed(&self, _text: &str) -> MenteResult<Vec<f32>> {
136 Err(MenteError::Storage(
137 "HTTP embedding requires the 'http' feature for async, use sync EmbeddingProvider instead".to_string(),
138 ))
139 }
140
141 async fn embed_batch(&self, _texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
142 Err(MenteError::Storage(
143 "HTTP embedding requires the 'http' feature for async, use sync EmbeddingProvider instead".to_string(),
144 ))
145 }
146
147 fn dimensions(&self) -> usize {
148 self.config.dimensions
149 }
150
151 fn model_name(&self) -> &str {
152 &self.config.model_name
153 }
154}
155
156#[cfg(feature = "http")]
157mod http_impl {
158 use super::*;
159 use serde_json::json;
160
161 #[derive(Deserialize)]
162 struct OpenAIEmbeddingResponse {
163 data: Vec<OpenAIEmbeddingData>,
164 }
165
166 #[derive(Deserialize)]
167 struct OpenAIEmbeddingData {
168 embedding: Vec<f32>,
169 }
170
171 impl HttpEmbeddingProvider {
172 fn embed_with_retry(&self, text: &str, max_attempts: u32) -> MenteResult<Vec<f32>> {
174 let mut last_err = None;
175 for attempt in 0..max_attempts {
176 if attempt > 0 {
177 std::thread::sleep(std::time::Duration::from_millis(500 * (1 << attempt)));
178 }
179
180 let body = json!({
181 "model": self.config.model_name,
182 "input": text,
183 });
184
185 let mut req = ureq::post(&self.config.api_url)
186 .header("Authorization", &format!("Bearer {}", self.config.api_key));
187
188 for (k, v) in &self.config.headers {
189 if k.to_lowercase() != "content-type" {
190 req = req.header(k, v);
191 }
192 }
193
194 match req.send_json(&body) {
195 Ok(mut resp) => match resp.body_mut().read_json::<OpenAIEmbeddingResponse>() {
196 Ok(parsed) => {
197 return parsed
198 .data
199 .into_iter()
200 .next()
201 .map(|d| d.embedding)
202 .ok_or_else(|| {
203 MenteError::Storage("Empty embedding response".to_string())
204 });
205 }
206 Err(e) => {
207 last_err = Some(format!("Failed to parse embedding response: {}", e));
208 }
209 },
210 Err(e) => {
211 last_err = Some(format!("HTTP embedding request failed: {}", e));
212 }
213 }
214 }
215 Err(MenteError::Storage(last_err.unwrap_or_else(|| {
216 "embedding failed after retries".to_string()
217 })))
218 }
219
220 fn embed_batch_with_retry(
222 &self,
223 texts: &[&str],
224 max_attempts: u32,
225 ) -> MenteResult<Vec<Vec<f32>>> {
226 let mut last_err = None;
227 for attempt in 0..max_attempts {
228 if attempt > 0 {
229 std::thread::sleep(std::time::Duration::from_millis(500 * (1 << attempt)));
230 }
231
232 let body = json!({
233 "model": self.config.model_name,
234 "input": texts,
235 });
236
237 let mut req = ureq::post(&self.config.api_url)
238 .header("Authorization", &format!("Bearer {}", self.config.api_key));
239
240 for (k, v) in &self.config.headers {
241 if k.to_lowercase() != "content-type" {
242 req = req.header(k, v);
243 }
244 }
245
246 match req.send_json(&body) {
247 Ok(mut resp) => match resp.body_mut().read_json::<OpenAIEmbeddingResponse>() {
248 Ok(parsed) => {
249 return Ok(parsed.data.into_iter().map(|d| d.embedding).collect());
250 }
251 Err(e) => {
252 last_err = Some(format!("Failed to parse embedding response: {}", e));
253 }
254 },
255 Err(e) => {
256 last_err = Some(format!("HTTP embedding request failed: {}", e));
257 }
258 }
259 }
260 Err(MenteError::Storage(last_err.unwrap_or_else(|| {
261 "batch embedding failed after retries".to_string()
262 })))
263 }
264 }
265
266 impl EmbeddingProvider for HttpEmbeddingProvider {
267 fn embed(&self, text: &str) -> MenteResult<Vec<f32>> {
268 self.embed_with_retry(text, 3)
269 }
270
271 fn embed_batch(&self, texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
272 self.embed_batch_with_retry(texts, 3)
273 }
274
275 fn dimensions(&self) -> usize {
276 self.config.dimensions
277 }
278
279 fn model_name(&self) -> &str {
280 &self.config.model_name
281 }
282 }
283}
284
285#[cfg(not(feature = "http"))]
286impl EmbeddingProvider for HttpEmbeddingProvider {
287 fn embed(&self, _text: &str) -> MenteResult<Vec<f32>> {
288 Err(MenteError::Storage(
289 "HTTP embedding requires the 'http' feature. Enable it in Cargo.toml.".to_string(),
290 ))
291 }
292
293 fn embed_batch(&self, _texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
294 Err(MenteError::Storage(
295 "HTTP embedding requires the 'http' feature. Enable it in Cargo.toml.".to_string(),
296 ))
297 }
298
299 fn dimensions(&self) -> usize {
300 self.config.dimensions
301 }
302
303 fn model_name(&self) -> &str {
304 &self.config.model_name
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
313 fn test_openai_config() {
314 let config = HttpEmbeddingConfig::openai("sk-test", "text-embedding-3-small");
315 assert_eq!(config.api_url, "https://api.openai.com/v1/embeddings");
316 assert_eq!(config.dimensions, 1536);
317 assert_eq!(config.model_name, "text-embedding-3-small");
318 }
319
320 #[test]
321 fn test_cohere_config() {
322 let config = HttpEmbeddingConfig::cohere("key", "embed-english-v3.0");
323 assert_eq!(config.api_url, "https://api.cohere.ai/v1/embed");
324 assert_eq!(config.dimensions, 1024);
325 }
326
327 #[test]
328 fn test_voyage_config() {
329 let config = HttpEmbeddingConfig::voyage("key", "voyage-2");
330 assert_eq!(config.api_url, "https://api.voyageai.com/v1/embeddings");
331 assert_eq!(config.dimensions, 1024);
332 }
333
334 #[test]
335 fn test_with_dimensions_override() {
336 let config =
337 HttpEmbeddingConfig::openai("key", "text-embedding-3-small").with_dimensions(256);
338 assert_eq!(config.dimensions, 256);
339 }
340}