Skip to main content

mentedb_embedding/
http_provider.rs

1//! Generic HTTP-based embedding provider for OpenAI, Cohere, Voyage, and other APIs.
2
3use std::collections::HashMap;
4
5use mentedb_core::MenteError;
6use mentedb_core::error::MenteResult;
7use serde::{Deserialize, Serialize};
8
9use crate::provider::{AsyncEmbeddingProvider, EmbeddingProvider};
10
11/// Configuration for an HTTP-based embedding API.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct HttpEmbeddingConfig {
14    /// The API endpoint URL.
15    pub api_url: String,
16    /// The API key for authentication.
17    pub api_key: String,
18    /// The model name to request.
19    pub model_name: String,
20    /// The dimensionality of the returned embeddings.
21    pub dimensions: usize,
22    /// Additional headers to include in requests.
23    pub headers: HashMap<String, String>,
24}
25
26impl HttpEmbeddingConfig {
27    /// Create a configuration for OpenAI's embedding API.
28    ///
29    /// Default dimensions: 1536 for text-embedding-ada-002, 3072 for text-embedding-3-large.
30    pub fn openai(api_key: impl Into<String>, model: impl Into<String>) -> Self {
31        let model = model.into();
32        let dimensions = match model.as_str() {
33            "text-embedding-3-small" => 1536,
34            "text-embedding-3-large" => 3072,
35            "text-embedding-ada-002" => 1536,
36            _ => 1536,
37        };
38
39        let mut headers = HashMap::new();
40        headers.insert("Content-Type".to_string(), "application/json".to_string());
41
42        Self {
43            api_url: "https://api.openai.com/v1/embeddings".to_string(),
44            api_key: api_key.into(),
45            model_name: model,
46            dimensions,
47            headers,
48        }
49    }
50
51    /// Create a configuration for Cohere's embedding API.
52    ///
53    /// Default dimensions: 1024 for embed-english-v3.0.
54    pub fn cohere(api_key: impl Into<String>, model: impl Into<String>) -> Self {
55        let model = model.into();
56        let dimensions = match model.as_str() {
57            "embed-english-v3.0" => 1024,
58            "embed-multilingual-v3.0" => 1024,
59            "embed-english-light-v3.0" => 384,
60            "embed-multilingual-light-v3.0" => 384,
61            _ => 1024,
62        };
63
64        let mut headers = HashMap::new();
65        headers.insert("Content-Type".to_string(), "application/json".to_string());
66
67        Self {
68            api_url: "https://api.cohere.ai/v1/embed".to_string(),
69            api_key: api_key.into(),
70            model_name: model,
71            dimensions,
72            headers,
73        }
74    }
75
76    /// Create a configuration for Voyage AI's embedding API.
77    ///
78    /// Default dimensions: 1024 for voyage-2.
79    pub fn voyage(api_key: impl Into<String>, model: impl Into<String>) -> Self {
80        let model = model.into();
81        let dimensions = match model.as_str() {
82            "voyage-2" => 1024,
83            "voyage-large-2" => 1536,
84            "voyage-code-2" => 1536,
85            "voyage-lite-02-instruct" => 1024,
86            _ => 1024,
87        };
88
89        let mut headers = HashMap::new();
90        headers.insert("Content-Type".to_string(), "application/json".to_string());
91
92        Self {
93            api_url: "https://api.voyageai.com/v1/embeddings".to_string(),
94            api_key: api_key.into(),
95            model_name: model,
96            dimensions,
97            headers,
98        }
99    }
100
101    /// Override the embedding dimensions.
102    pub fn with_dimensions(mut self, dimensions: usize) -> Self {
103        self.dimensions = dimensions;
104        self
105    }
106
107    /// Add a custom header.
108    pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
109        self.headers.insert(key.into(), value.into());
110        self
111    }
112}
113
114/// HTTP-based embedding provider.
115///
116/// Currently requires an external HTTP client feature to function.
117/// The structure and configuration are fully usable for setup and validation.
118pub struct HttpEmbeddingProvider {
119    config: HttpEmbeddingConfig,
120}
121
122impl HttpEmbeddingProvider {
123    /// Create a new HTTP embedding provider with the given configuration.
124    pub fn new(config: HttpEmbeddingConfig) -> Self {
125        Self { config }
126    }
127
128    /// Get a reference to the provider's configuration.
129    pub fn config(&self) -> &HttpEmbeddingConfig {
130        &self.config
131    }
132}
133
134impl AsyncEmbeddingProvider for HttpEmbeddingProvider {
135    async fn embed(&self, _text: &str) -> MenteResult<Vec<f32>> {
136        Err(MenteError::Storage(
137            "HTTP embedding requires the 'http' feature for async, use sync EmbeddingProvider instead".to_string(),
138        ))
139    }
140
141    async fn embed_batch(&self, _texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
142        Err(MenteError::Storage(
143            "HTTP embedding requires the 'http' feature for async, use sync EmbeddingProvider instead".to_string(),
144        ))
145    }
146
147    fn dimensions(&self) -> usize {
148        self.config.dimensions
149    }
150
151    fn model_name(&self) -> &str {
152        &self.config.model_name
153    }
154}
155
156#[cfg(feature = "http")]
157mod http_impl {
158    use super::*;
159    use serde_json::json;
160
161    #[derive(Deserialize)]
162    struct OpenAIEmbeddingResponse {
163        data: Vec<OpenAIEmbeddingData>,
164    }
165
166    #[derive(Deserialize)]
167    struct OpenAIEmbeddingData {
168        embedding: Vec<f32>,
169    }
170
171    impl EmbeddingProvider for HttpEmbeddingProvider {
172        fn embed(&self, text: &str) -> MenteResult<Vec<f32>> {
173            let body = json!({
174                "model": self.config.model_name,
175                "input": text,
176            });
177
178            let mut req = ureq::post(&self.config.api_url)
179                .header("Authorization", &format!("Bearer {}", self.config.api_key));
180
181            for (k, v) in &self.config.headers {
182                if k.to_lowercase() != "content-type" {
183                    req = req.header(k, v);
184                }
185            }
186
187            let result = req.send_json(&body);
188            let mut resp = match result {
189                Ok(r) => r,
190                Err(e) => {
191                    return Err(MenteError::Storage(format!(
192                        "HTTP embedding request failed: {}",
193                        e
194                    )));
195                }
196            };
197
198            let parsed: OpenAIEmbeddingResponse = resp.body_mut().read_json().map_err(|e| {
199                MenteError::Storage(format!("Failed to parse embedding response: {}", e))
200            })?;
201
202            parsed
203                .data
204                .into_iter()
205                .next()
206                .map(|d| d.embedding)
207                .ok_or_else(|| MenteError::Storage("Empty embedding response".to_string()))
208        }
209
210        fn embed_batch(&self, texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
211            let body = json!({
212                "model": self.config.model_name,
213                "input": texts,
214            });
215
216            let mut req = ureq::post(&self.config.api_url)
217                .header("Authorization", &format!("Bearer {}", self.config.api_key));
218
219            for (k, v) in &self.config.headers {
220                if k.to_lowercase() != "content-type" {
221                    req = req.header(k, v);
222                }
223            }
224
225            let result = req.send_json(&body);
226            let mut resp = match result {
227                Ok(r) => r,
228                Err(e) => {
229                    return Err(MenteError::Storage(format!(
230                        "HTTP embedding request failed: {}",
231                        e
232                    )));
233                }
234            };
235
236            let parsed: OpenAIEmbeddingResponse = resp.body_mut().read_json().map_err(|e| {
237                MenteError::Storage(format!("Failed to parse embedding response: {}", e))
238            })?;
239
240            Ok(parsed.data.into_iter().map(|d| d.embedding).collect())
241        }
242
243        fn dimensions(&self) -> usize {
244            self.config.dimensions
245        }
246
247        fn model_name(&self) -> &str {
248            &self.config.model_name
249        }
250    }
251}
252
253#[cfg(not(feature = "http"))]
254impl EmbeddingProvider for HttpEmbeddingProvider {
255    fn embed(&self, _text: &str) -> MenteResult<Vec<f32>> {
256        Err(MenteError::Storage(
257            "HTTP embedding requires the 'http' feature. Enable it in Cargo.toml.".to_string(),
258        ))
259    }
260
261    fn embed_batch(&self, _texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
262        Err(MenteError::Storage(
263            "HTTP embedding requires the 'http' feature. Enable it in Cargo.toml.".to_string(),
264        ))
265    }
266
267    fn dimensions(&self) -> usize {
268        self.config.dimensions
269    }
270
271    fn model_name(&self) -> &str {
272        &self.config.model_name
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_openai_config() {
282        let config = HttpEmbeddingConfig::openai("sk-test", "text-embedding-3-small");
283        assert_eq!(config.api_url, "https://api.openai.com/v1/embeddings");
284        assert_eq!(config.dimensions, 1536);
285        assert_eq!(config.model_name, "text-embedding-3-small");
286    }
287
288    #[test]
289    fn test_cohere_config() {
290        let config = HttpEmbeddingConfig::cohere("key", "embed-english-v3.0");
291        assert_eq!(config.api_url, "https://api.cohere.ai/v1/embed");
292        assert_eq!(config.dimensions, 1024);
293    }
294
295    #[test]
296    fn test_voyage_config() {
297        let config = HttpEmbeddingConfig::voyage("key", "voyage-2");
298        assert_eq!(config.api_url, "https://api.voyageai.com/v1/embeddings");
299        assert_eq!(config.dimensions, 1024);
300    }
301
302    #[test]
303    fn test_with_dimensions_override() {
304        let config =
305            HttpEmbeddingConfig::openai("key", "text-embedding-3-small").with_dimensions(256);
306        assert_eq!(config.dimensions, 256);
307    }
308}