Skip to main content

mentedb_embedding/
http_provider.rs

1//! Generic HTTP-based embedding provider for OpenAI, Cohere, Voyage, and other APIs.
2
3use std::collections::HashMap;
4
5use mentedb_core::MenteError;
6use mentedb_core::error::MenteResult;
7use serde::{Deserialize, Serialize};
8
9use crate::provider::{AsyncEmbeddingProvider, EmbeddingProvider};
10
11/// Configuration for an HTTP-based embedding API.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct HttpEmbeddingConfig {
14    /// The API endpoint URL.
15    pub api_url: String,
16    /// The API key for authentication.
17    pub api_key: String,
18    /// The model name to request.
19    pub model_name: String,
20    /// The dimensionality of the returned embeddings.
21    pub dimensions: usize,
22    /// Additional headers to include in requests.
23    pub headers: HashMap<String, String>,
24}
25
26impl HttpEmbeddingConfig {
27    /// Create a configuration for OpenAI's embedding API.
28    ///
29    /// Default dimensions: 1536 for text-embedding-ada-002, 3072 for text-embedding-3-large.
30    pub fn openai(api_key: impl Into<String>, model: impl Into<String>) -> Self {
31        let model = model.into();
32        let dimensions = match model.as_str() {
33            "text-embedding-3-small" => 1536,
34            "text-embedding-3-large" => 3072,
35            "text-embedding-ada-002" => 1536,
36            _ => 1536,
37        };
38
39        let mut headers = HashMap::new();
40        headers.insert("Content-Type".to_string(), "application/json".to_string());
41
42        Self {
43            api_url: "https://api.openai.com/v1/embeddings".to_string(),
44            api_key: api_key.into(),
45            model_name: model,
46            dimensions,
47            headers,
48        }
49    }
50
51    /// Create a configuration for Cohere's embedding API.
52    ///
53    /// Default dimensions: 1024 for embed-english-v3.0.
54    pub fn cohere(api_key: impl Into<String>, model: impl Into<String>) -> Self {
55        let model = model.into();
56        let dimensions = match model.as_str() {
57            "embed-english-v3.0" => 1024,
58            "embed-multilingual-v3.0" => 1024,
59            "embed-english-light-v3.0" => 384,
60            "embed-multilingual-light-v3.0" => 384,
61            _ => 1024,
62        };
63
64        let mut headers = HashMap::new();
65        headers.insert("Content-Type".to_string(), "application/json".to_string());
66
67        Self {
68            api_url: "https://api.cohere.ai/v1/embed".to_string(),
69            api_key: api_key.into(),
70            model_name: model,
71            dimensions,
72            headers,
73        }
74    }
75
76    /// Create a configuration for Voyage AI's embedding API.
77    ///
78    /// Default dimensions: 1024 for voyage-2.
79    pub fn voyage(api_key: impl Into<String>, model: impl Into<String>) -> Self {
80        let model = model.into();
81        let dimensions = match model.as_str() {
82            "voyage-2" => 1024,
83            "voyage-large-2" => 1536,
84            "voyage-code-2" => 1536,
85            "voyage-lite-02-instruct" => 1024,
86            _ => 1024,
87        };
88
89        let mut headers = HashMap::new();
90        headers.insert("Content-Type".to_string(), "application/json".to_string());
91
92        Self {
93            api_url: "https://api.voyageai.com/v1/embeddings".to_string(),
94            api_key: api_key.into(),
95            model_name: model,
96            dimensions,
97            headers,
98        }
99    }
100
101    /// Override the embedding dimensions.
102    pub fn with_dimensions(mut self, dimensions: usize) -> Self {
103        self.dimensions = dimensions;
104        self
105    }
106
107    /// Add a custom header.
108    pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
109        self.headers.insert(key.into(), value.into());
110        self
111    }
112}
113
114/// HTTP-based embedding provider.
115///
116/// Currently requires an external HTTP client feature to function.
117/// The structure and configuration are fully usable for setup and validation.
118pub struct HttpEmbeddingProvider {
119    config: HttpEmbeddingConfig,
120}
121
122impl HttpEmbeddingProvider {
123    /// Create a new HTTP embedding provider with the given configuration.
124    pub fn new(config: HttpEmbeddingConfig) -> Self {
125        Self { config }
126    }
127
128    /// Get a reference to the provider's configuration.
129    pub fn config(&self) -> &HttpEmbeddingConfig {
130        &self.config
131    }
132}
133
134impl AsyncEmbeddingProvider for HttpEmbeddingProvider {
135    async fn embed(&self, _text: &str) -> MenteResult<Vec<f32>> {
136        Err(MenteError::Storage(
137            "HTTP embedding requires the 'http' feature for async, use sync EmbeddingProvider instead".to_string(),
138        ))
139    }
140
141    async fn embed_batch(&self, _texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
142        Err(MenteError::Storage(
143            "HTTP embedding requires the 'http' feature for async, use sync EmbeddingProvider instead".to_string(),
144        ))
145    }
146
147    fn dimensions(&self) -> usize {
148        self.config.dimensions
149    }
150
151    fn model_name(&self) -> &str {
152        &self.config.model_name
153    }
154}
155
156#[cfg(feature = "http")]
157mod http_impl {
158    use super::*;
159    use serde_json::json;
160
161    #[derive(Deserialize)]
162    struct OpenAIEmbeddingResponse {
163        data: Vec<OpenAIEmbeddingData>,
164    }
165
166    #[derive(Deserialize)]
167    struct OpenAIEmbeddingData {
168        embedding: Vec<f32>,
169    }
170
171    impl HttpEmbeddingProvider {
172        /// Retry-aware single embedding call with exponential backoff.
173        fn embed_with_retry(&self, text: &str, max_attempts: u32) -> MenteResult<Vec<f32>> {
174            let mut last_err = None;
175            for attempt in 0..max_attempts {
176                if attempt > 0 {
177                    std::thread::sleep(std::time::Duration::from_millis(500 * (1 << attempt)));
178                }
179
180                let body = json!({
181                    "model": self.config.model_name,
182                    "input": text,
183                });
184
185                let mut req = ureq::post(&self.config.api_url)
186                    .header("Authorization", &format!("Bearer {}", self.config.api_key));
187
188                for (k, v) in &self.config.headers {
189                    if k.to_lowercase() != "content-type" {
190                        req = req.header(k, v);
191                    }
192                }
193
194                match req.send_json(&body) {
195                    Ok(mut resp) => match resp.body_mut().read_json::<OpenAIEmbeddingResponse>() {
196                        Ok(parsed) => {
197                            return parsed
198                                .data
199                                .into_iter()
200                                .next()
201                                .map(|d| d.embedding)
202                                .ok_or_else(|| {
203                                    MenteError::Storage("Empty embedding response".to_string())
204                                });
205                        }
206                        Err(e) => {
207                            last_err = Some(format!("Failed to parse embedding response: {}", e));
208                        }
209                    },
210                    Err(e) => {
211                        last_err = Some(format!("HTTP embedding request failed: {}", e));
212                    }
213                }
214            }
215            Err(MenteError::Storage(last_err.unwrap_or_else(|| {
216                "embedding failed after retries".to_string()
217            })))
218        }
219
220        /// Retry-aware batch embedding call with exponential backoff.
221        fn embed_batch_with_retry(
222            &self,
223            texts: &[&str],
224            max_attempts: u32,
225        ) -> MenteResult<Vec<Vec<f32>>> {
226            let mut last_err = None;
227            for attempt in 0..max_attempts {
228                if attempt > 0 {
229                    std::thread::sleep(std::time::Duration::from_millis(500 * (1 << attempt)));
230                }
231
232                let body = json!({
233                    "model": self.config.model_name,
234                    "input": texts,
235                });
236
237                let mut req = ureq::post(&self.config.api_url)
238                    .header("Authorization", &format!("Bearer {}", self.config.api_key));
239
240                for (k, v) in &self.config.headers {
241                    if k.to_lowercase() != "content-type" {
242                        req = req.header(k, v);
243                    }
244                }
245
246                match req.send_json(&body) {
247                    Ok(mut resp) => match resp.body_mut().read_json::<OpenAIEmbeddingResponse>() {
248                        Ok(parsed) => {
249                            return Ok(parsed.data.into_iter().map(|d| d.embedding).collect());
250                        }
251                        Err(e) => {
252                            last_err = Some(format!("Failed to parse embedding response: {}", e));
253                        }
254                    },
255                    Err(e) => {
256                        last_err = Some(format!("HTTP embedding request failed: {}", e));
257                    }
258                }
259            }
260            Err(MenteError::Storage(last_err.unwrap_or_else(|| {
261                "batch embedding failed after retries".to_string()
262            })))
263        }
264    }
265
266    impl EmbeddingProvider for HttpEmbeddingProvider {
267        fn embed(&self, text: &str) -> MenteResult<Vec<f32>> {
268            self.embed_with_retry(text, 3)
269        }
270
271        fn embed_batch(&self, texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
272            self.embed_batch_with_retry(texts, 3)
273        }
274
275        fn dimensions(&self) -> usize {
276            self.config.dimensions
277        }
278
279        fn model_name(&self) -> &str {
280            &self.config.model_name
281        }
282    }
283}
284
285#[cfg(not(feature = "http"))]
286impl EmbeddingProvider for HttpEmbeddingProvider {
287    fn embed(&self, _text: &str) -> MenteResult<Vec<f32>> {
288        Err(MenteError::Storage(
289            "HTTP embedding requires the 'http' feature. Enable it in Cargo.toml.".to_string(),
290        ))
291    }
292
293    fn embed_batch(&self, _texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
294        Err(MenteError::Storage(
295            "HTTP embedding requires the 'http' feature. Enable it in Cargo.toml.".to_string(),
296        ))
297    }
298
299    fn dimensions(&self) -> usize {
300        self.config.dimensions
301    }
302
303    fn model_name(&self) -> &str {
304        &self.config.model_name
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn test_openai_config() {
314        let config = HttpEmbeddingConfig::openai("sk-test", "text-embedding-3-small");
315        assert_eq!(config.api_url, "https://api.openai.com/v1/embeddings");
316        assert_eq!(config.dimensions, 1536);
317        assert_eq!(config.model_name, "text-embedding-3-small");
318    }
319
320    #[test]
321    fn test_cohere_config() {
322        let config = HttpEmbeddingConfig::cohere("key", "embed-english-v3.0");
323        assert_eq!(config.api_url, "https://api.cohere.ai/v1/embed");
324        assert_eq!(config.dimensions, 1024);
325    }
326
327    #[test]
328    fn test_voyage_config() {
329        let config = HttpEmbeddingConfig::voyage("key", "voyage-2");
330        assert_eq!(config.api_url, "https://api.voyageai.com/v1/embeddings");
331        assert_eq!(config.dimensions, 1024);
332    }
333
334    #[test]
335    fn test_with_dimensions_override() {
336        let config =
337            HttpEmbeddingConfig::openai("key", "text-embedding-3-small").with_dimensions(256);
338        assert_eq!(config.dimensions, 256);
339    }
340}