1use std::collections::HashMap;
4
5use mentedb_core::MenteError;
6use mentedb_core::error::MenteResult;
7use serde::{Deserialize, Serialize};
8
9use crate::provider::{AsyncEmbeddingProvider, EmbeddingProvider};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct HttpEmbeddingConfig {
14 pub api_url: String,
16 pub api_key: String,
18 pub model_name: String,
20 pub dimensions: usize,
22 pub headers: HashMap<String, String>,
24}
25
26impl HttpEmbeddingConfig {
27 pub fn openai(api_key: impl Into<String>, model: impl Into<String>) -> Self {
31 let model = model.into();
32 let dimensions = match model.as_str() {
33 "text-embedding-3-small" => 1536,
34 "text-embedding-3-large" => 3072,
35 "text-embedding-ada-002" => 1536,
36 _ => 1536,
37 };
38
39 let mut headers = HashMap::new();
40 headers.insert("Content-Type".to_string(), "application/json".to_string());
41
42 Self {
43 api_url: "https://api.openai.com/v1/embeddings".to_string(),
44 api_key: api_key.into(),
45 model_name: model,
46 dimensions,
47 headers,
48 }
49 }
50
51 pub fn cohere(api_key: impl Into<String>, model: impl Into<String>) -> Self {
55 let model = model.into();
56 let dimensions = match model.as_str() {
57 "embed-english-v3.0" => 1024,
58 "embed-multilingual-v3.0" => 1024,
59 "embed-english-light-v3.0" => 384,
60 "embed-multilingual-light-v3.0" => 384,
61 _ => 1024,
62 };
63
64 let mut headers = HashMap::new();
65 headers.insert("Content-Type".to_string(), "application/json".to_string());
66
67 Self {
68 api_url: "https://api.cohere.ai/v1/embed".to_string(),
69 api_key: api_key.into(),
70 model_name: model,
71 dimensions,
72 headers,
73 }
74 }
75
76 pub fn voyage(api_key: impl Into<String>, model: impl Into<String>) -> Self {
80 let model = model.into();
81 let dimensions = match model.as_str() {
82 "voyage-2" => 1024,
83 "voyage-large-2" => 1536,
84 "voyage-code-2" => 1536,
85 "voyage-lite-02-instruct" => 1024,
86 _ => 1024,
87 };
88
89 let mut headers = HashMap::new();
90 headers.insert("Content-Type".to_string(), "application/json".to_string());
91
92 Self {
93 api_url: "https://api.voyageai.com/v1/embeddings".to_string(),
94 api_key: api_key.into(),
95 model_name: model,
96 dimensions,
97 headers,
98 }
99 }
100
101 pub fn with_dimensions(mut self, dimensions: usize) -> Self {
103 self.dimensions = dimensions;
104 self
105 }
106
107 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
109 self.headers.insert(key.into(), value.into());
110 self
111 }
112}
113
114pub struct HttpEmbeddingProvider {
119 config: HttpEmbeddingConfig,
120}
121
122impl HttpEmbeddingProvider {
123 pub fn new(config: HttpEmbeddingConfig) -> Self {
125 Self { config }
126 }
127
128 pub fn config(&self) -> &HttpEmbeddingConfig {
130 &self.config
131 }
132}
133
134impl AsyncEmbeddingProvider for HttpEmbeddingProvider {
135 async fn embed(&self, _text: &str) -> MenteResult<Vec<f32>> {
136 Err(MenteError::Storage(
137 "HTTP embedding requires the 'http' feature for async, use sync EmbeddingProvider instead".to_string(),
138 ))
139 }
140
141 async fn embed_batch(&self, _texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
142 Err(MenteError::Storage(
143 "HTTP embedding requires the 'http' feature for async, use sync EmbeddingProvider instead".to_string(),
144 ))
145 }
146
147 fn dimensions(&self) -> usize {
148 self.config.dimensions
149 }
150
151 fn model_name(&self) -> &str {
152 &self.config.model_name
153 }
154}
155
156#[cfg(feature = "http")]
157mod http_impl {
158 use super::*;
159 use serde_json::json;
160
161 #[derive(Deserialize)]
162 struct OpenAIEmbeddingResponse {
163 data: Vec<OpenAIEmbeddingData>,
164 }
165
166 #[derive(Deserialize)]
167 struct OpenAIEmbeddingData {
168 embedding: Vec<f32>,
169 }
170
171 impl EmbeddingProvider for HttpEmbeddingProvider {
172 fn embed(&self, text: &str) -> MenteResult<Vec<f32>> {
173 let body = json!({
174 "model": self.config.model_name,
175 "input": text,
176 });
177
178 let mut req = ureq::post(&self.config.api_url)
179 .header("Authorization", &format!("Bearer {}", self.config.api_key));
180
181 for (k, v) in &self.config.headers {
182 if k.to_lowercase() != "content-type" {
183 req = req.header(k, v);
184 }
185 }
186
187 let result = req.send_json(&body);
188 let mut resp = match result {
189 Ok(r) => r,
190 Err(e) => {
191 return Err(MenteError::Storage(format!(
192 "HTTP embedding request failed: {}",
193 e
194 )));
195 }
196 };
197
198 let parsed: OpenAIEmbeddingResponse = resp.body_mut().read_json().map_err(|e| {
199 MenteError::Storage(format!("Failed to parse embedding response: {}", e))
200 })?;
201
202 parsed
203 .data
204 .into_iter()
205 .next()
206 .map(|d| d.embedding)
207 .ok_or_else(|| MenteError::Storage("Empty embedding response".to_string()))
208 }
209
210 fn embed_batch(&self, texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
211 let body = json!({
212 "model": self.config.model_name,
213 "input": texts,
214 });
215
216 let mut req = ureq::post(&self.config.api_url)
217 .header("Authorization", &format!("Bearer {}", self.config.api_key));
218
219 for (k, v) in &self.config.headers {
220 if k.to_lowercase() != "content-type" {
221 req = req.header(k, v);
222 }
223 }
224
225 let result = req.send_json(&body);
226 let mut resp = match result {
227 Ok(r) => r,
228 Err(e) => {
229 return Err(MenteError::Storage(format!(
230 "HTTP embedding request failed: {}",
231 e
232 )));
233 }
234 };
235
236 let parsed: OpenAIEmbeddingResponse = resp.body_mut().read_json().map_err(|e| {
237 MenteError::Storage(format!("Failed to parse embedding response: {}", e))
238 })?;
239
240 Ok(parsed.data.into_iter().map(|d| d.embedding).collect())
241 }
242
243 fn dimensions(&self) -> usize {
244 self.config.dimensions
245 }
246
247 fn model_name(&self) -> &str {
248 &self.config.model_name
249 }
250 }
251}
252
253#[cfg(not(feature = "http"))]
254impl EmbeddingProvider for HttpEmbeddingProvider {
255 fn embed(&self, _text: &str) -> MenteResult<Vec<f32>> {
256 Err(MenteError::Storage(
257 "HTTP embedding requires the 'http' feature. Enable it in Cargo.toml.".to_string(),
258 ))
259 }
260
261 fn embed_batch(&self, _texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
262 Err(MenteError::Storage(
263 "HTTP embedding requires the 'http' feature. Enable it in Cargo.toml.".to_string(),
264 ))
265 }
266
267 fn dimensions(&self) -> usize {
268 self.config.dimensions
269 }
270
271 fn model_name(&self) -> &str {
272 &self.config.model_name
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn test_openai_config() {
282 let config = HttpEmbeddingConfig::openai("sk-test", "text-embedding-3-small");
283 assert_eq!(config.api_url, "https://api.openai.com/v1/embeddings");
284 assert_eq!(config.dimensions, 1536);
285 assert_eq!(config.model_name, "text-embedding-3-small");
286 }
287
288 #[test]
289 fn test_cohere_config() {
290 let config = HttpEmbeddingConfig::cohere("key", "embed-english-v3.0");
291 assert_eq!(config.api_url, "https://api.cohere.ai/v1/embed");
292 assert_eq!(config.dimensions, 1024);
293 }
294
295 #[test]
296 fn test_voyage_config() {
297 let config = HttpEmbeddingConfig::voyage("key", "voyage-2");
298 assert_eq!(config.api_url, "https://api.voyageai.com/v1/embeddings");
299 assert_eq!(config.dimensions, 1024);
300 }
301
302 #[test]
303 fn test_with_dimensions_override() {
304 let config =
305 HttpEmbeddingConfig::openai("key", "text-embedding-3-small").with_dimensions(256);
306 assert_eq!(config.dimensions, 256);
307 }
308}