1use crate::core::error::{GraphRAGError, Result};
7use crate::embeddings::{EmbeddingConfig, EmbeddingProvider, EmbeddingProviderType};
8
9#[cfg(feature = "ureq")]
10use ureq;
11
12pub struct HttpEmbeddingProvider {
14 provider_type: EmbeddingProviderType,
15 api_key: String,
16 model: String,
17 endpoint: String,
18 dimensions: usize,
19
20 #[cfg(feature = "ureq")]
21 client: ureq::Agent,
22}
23
24impl HttpEmbeddingProvider {
25 pub fn openai(api_key: String, model: String) -> Self {
35 let dimensions = match model.as_str() {
36 "text-embedding-3-large" => 3072,
37 "text-embedding-3-small" => 1536,
38 "text-embedding-ada-002" => 1536,
39 _ => 1536,
40 };
41
42 Self {
43 provider_type: EmbeddingProviderType::OpenAI,
44 api_key,
45 model,
46 endpoint: "https://api.openai.com/v1/embeddings".to_string(),
47 dimensions,
48 #[cfg(feature = "ureq")]
49 client: ureq::Agent::new(),
50 }
51 }
52
53 pub fn voyage_ai(api_key: String, model: String) -> Self {
63 let dimensions = match model.as_str() {
64 "voyage-3-large" => 1024,
65 "voyage-3.5" => 1024,
66 "voyage-3.5-lite" => 1024,
67 "voyage-code-3" => 1024,
68 "voyage-finance-2" => 1024,
69 "voyage-law-2" => 1024,
70 _ => 1024,
71 };
72
73 Self {
74 provider_type: EmbeddingProviderType::VoyageAI,
75 api_key,
76 model,
77 endpoint: "https://api.voyageai.com/v1/embeddings".to_string(),
78 dimensions,
79 #[cfg(feature = "ureq")]
80 client: ureq::Agent::new(),
81 }
82 }
83
84 pub fn cohere(api_key: String, model: String) -> Self {
94 let dimensions = match model.as_str() {
95 "embed-v4" | "embed-english-v3.0" | "embed-multilingual-v3.0" => 1024,
96 "embed-english-light-v3.0" => 384,
97 _ => 1024,
98 };
99
100 Self {
101 provider_type: EmbeddingProviderType::Cohere,
102 api_key,
103 model,
104 endpoint: "https://api.cohere.ai/v1/embed".to_string(),
105 dimensions,
106 #[cfg(feature = "ureq")]
107 client: ureq::Agent::new(),
108 }
109 }
110
111 pub fn jina_ai(api_key: String, model: String) -> Self {
121 let dimensions = match model.as_str() {
122 "jina-embeddings-v4" => 1024,
123 "jina-clip-v2" => 768,
124 "jina-embeddings-v3" => 1024,
125 _ => 1024,
126 };
127
128 Self {
129 provider_type: EmbeddingProviderType::JinaAI,
130 api_key,
131 model,
132 endpoint: "https://api.jina.ai/v1/embeddings".to_string(),
133 dimensions,
134 #[cfg(feature = "ureq")]
135 client: ureq::Agent::new(),
136 }
137 }
138
139 pub fn mistral(api_key: String, model: String) -> Self {
149 let dimensions = match model.as_str() {
150 "mistral-embed" | "codestral-embed" => 1024,
151 _ => 1024,
152 };
153
154 Self {
155 provider_type: EmbeddingProviderType::Mistral,
156 api_key,
157 model,
158 endpoint: "https://api.mistral.ai/v1/embeddings".to_string(),
159 dimensions,
160 #[cfg(feature = "ureq")]
161 client: ureq::Agent::new(),
162 }
163 }
164
165 pub fn together_ai(api_key: String, model: String) -> Self {
175 let dimensions = match model.as_str() {
176 "BAAI/bge-large-en-v1.5" | "WhereIsAI/UAE-Large-V1" => 1024,
177 "BAAI/bge-base-en-v1.5" => 768,
178 _ => 768,
179 };
180
181 Self {
182 provider_type: EmbeddingProviderType::TogetherAI,
183 api_key,
184 model,
185 endpoint: "https://api.together.xyz/v1/embeddings".to_string(),
186 dimensions,
187 #[cfg(feature = "ureq")]
188 client: ureq::Agent::new(),
189 }
190 }
191
192 pub fn from_config(config: &EmbeddingConfig) -> Result<Self> {
194 let api_key = config
195 .api_key
196 .clone()
197 .ok_or_else(|| GraphRAGError::Embedding {
198 message: format!("API key required for {} provider", config.provider),
199 })?;
200
201 let provider = match config.provider {
202 EmbeddingProviderType::OpenAI => Self::openai(api_key, config.model.clone()),
203 EmbeddingProviderType::VoyageAI => Self::voyage_ai(api_key, config.model.clone()),
204 EmbeddingProviderType::Cohere => Self::cohere(api_key, config.model.clone()),
205 EmbeddingProviderType::JinaAI => Self::jina_ai(api_key, config.model.clone()),
206 EmbeddingProviderType::Mistral => Self::mistral(api_key, config.model.clone()),
207 EmbeddingProviderType::TogetherAI => Self::together_ai(api_key, config.model.clone()),
208 _ => {
209 return Err(GraphRAGError::Embedding {
210 message: format!("Unsupported API provider: {}", config.provider),
211 })
212 },
213 };
214
215 Ok(provider)
216 }
217
218 #[cfg(feature = "ureq")]
219 fn make_request(&self, input: &str) -> Result<Vec<f32>> {
220 let request_body = match self.provider_type {
222 EmbeddingProviderType::OpenAI => {
223 serde_json::json!({
224 "model": self.model.clone(),
225 "input": input,
226 })
227 },
228 EmbeddingProviderType::VoyageAI => {
229 serde_json::json!({
230 "model": self.model.clone(),
231 "input": input,
232 "input_type": "document",
233 })
234 },
235 EmbeddingProviderType::Cohere => {
236 serde_json::json!({
237 "model": self.model.clone(),
238 "texts": vec![input],
239 "input_type": "search_document",
240 "embedding_types": vec!["float"],
241 })
242 },
243 EmbeddingProviderType::JinaAI
244 | EmbeddingProviderType::Mistral
245 | EmbeddingProviderType::TogetherAI => {
246 serde_json::json!({
247 "model": self.model.clone(),
248 "input": input,
249 })
250 },
251 _ => {
252 return Err(GraphRAGError::Embedding {
253 message: "Unsupported provider type".to_string(),
254 })
255 },
256 };
257
258 let response = self
260 .client
261 .post(&self.endpoint)
262 .set("Authorization", &format!("Bearer {}", self.api_key))
263 .set("Content-Type", "application/json")
264 .send_json(request_body)
265 .map_err(|e| GraphRAGError::Embedding {
266 message: format!("HTTP request failed: {}", e),
267 })?;
268
269 let json_response: serde_json::Value =
271 response.into_json().map_err(|e| GraphRAGError::Embedding {
272 message: format!("Failed to parse JSON response: {}", e),
273 })?;
274
275 let embedding = match self.provider_type {
277 EmbeddingProviderType::OpenAI
278 | EmbeddingProviderType::VoyageAI
279 | EmbeddingProviderType::JinaAI
280 | EmbeddingProviderType::Mistral
281 | EmbeddingProviderType::TogetherAI => {
282 json_response["data"][0]["embedding"]
284 .as_array()
285 .ok_or_else(|| GraphRAGError::Embedding {
286 message: "Invalid response format: expected array".to_string(),
287 })?
288 .iter()
289 .filter_map(|v| v.as_f64().map(|f| f as f32))
290 .collect()
291 },
292 EmbeddingProviderType::Cohere => {
293 json_response["embeddings"][0]
295 .as_array()
296 .ok_or_else(|| GraphRAGError::Embedding {
297 message: "Invalid response format: expected array".to_string(),
298 })?
299 .iter()
300 .filter_map(|v| v.as_f64().map(|f| f as f32))
301 .collect()
302 },
303 _ => vec![],
304 };
305
306 if embedding.is_empty() {
307 return Err(GraphRAGError::Embedding {
308 message: "No embedding returned from API".to_string(),
309 });
310 }
311
312 Ok(embedding)
313 }
314
315 #[cfg(not(feature = "ureq"))]
316 fn make_request(&self, _input: &str) -> Result<Vec<f32>> {
317 Err(GraphRAGError::Embedding {
318 message: "ureq feature required for HTTP-based embeddings".to_string(),
319 })
320 }
321
322 #[cfg(feature = "ureq")]
324 fn make_batch_request(&self, inputs: &[&str]) -> Result<Vec<Vec<f32>>> {
325 let request_body = match self.provider_type {
327 EmbeddingProviderType::OpenAI => {
328 serde_json::json!({
329 "model": self.model.clone(),
330 "input": inputs,
331 })
332 },
333 EmbeddingProviderType::VoyageAI => {
334 serde_json::json!({
335 "model": self.model.clone(),
336 "input": inputs,
337 "input_type": "document",
338 })
339 },
340 EmbeddingProviderType::Cohere => {
341 serde_json::json!({
342 "model": self.model.clone(),
343 "texts": inputs,
344 "input_type": "search_document",
345 "embedding_types": vec!["float"],
346 })
347 },
348 EmbeddingProviderType::JinaAI
349 | EmbeddingProviderType::Mistral
350 | EmbeddingProviderType::TogetherAI => {
351 serde_json::json!({
352 "model": self.model.clone(),
353 "input": inputs,
354 })
355 },
356 _ => {
357 return Err(GraphRAGError::Embedding {
358 message: "Unsupported provider type for batch".to_string(),
359 })
360 },
361 };
362
363 let response = self
365 .client
366 .post(&self.endpoint)
367 .set("Authorization", &format!("Bearer {}", self.api_key))
368 .set("Content-Type", "application/json")
369 .send_json(request_body)
370 .map_err(|e| GraphRAGError::Embedding {
371 message: format!("Batch HTTP request failed: {}", e),
372 })?;
373
374 let json_response: serde_json::Value =
376 response.into_json().map_err(|e| GraphRAGError::Embedding {
377 message: format!("Failed to parse batch JSON response: {}", e),
378 })?;
379
380 let embeddings = match self.provider_type {
382 EmbeddingProviderType::OpenAI
383 | EmbeddingProviderType::VoyageAI
384 | EmbeddingProviderType::JinaAI
385 | EmbeddingProviderType::Mistral
386 | EmbeddingProviderType::TogetherAI => {
387 let data_array =
389 json_response["data"]
390 .as_array()
391 .ok_or_else(|| GraphRAGError::Embedding {
392 message: "Invalid batch response format: expected data array"
393 .to_string(),
394 })?;
395
396 data_array
397 .iter()
398 .map(|item| {
399 item["embedding"]
400 .as_array()
401 .ok_or_else(|| GraphRAGError::Embedding {
402 message: "Invalid embedding format in batch".to_string(),
403 })
404 .map(|arr| {
405 arr.iter()
406 .filter_map(|v| v.as_f64().map(|f| f as f32))
407 .collect()
408 })
409 })
410 .collect::<Result<Vec<Vec<f32>>>>()?
411 },
412 EmbeddingProviderType::Cohere => {
413 let embeddings_array = json_response["embeddings"].as_array().ok_or_else(|| {
415 GraphRAGError::Embedding {
416 message: "Invalid Cohere batch response format".to_string(),
417 }
418 })?;
419
420 embeddings_array
421 .iter()
422 .map(|emb| {
423 emb.as_array()
424 .ok_or_else(|| GraphRAGError::Embedding {
425 message: "Invalid embedding array in Cohere batch".to_string(),
426 })
427 .map(|arr| {
428 arr.iter()
429 .filter_map(|v| v.as_f64().map(|f| f as f32))
430 .collect()
431 })
432 })
433 .collect::<Result<Vec<Vec<f32>>>>()?
434 },
435 _ => vec![],
436 };
437
438 if embeddings.is_empty() || embeddings.len() != inputs.len() {
439 return Err(GraphRAGError::Embedding {
440 message: format!(
441 "Batch embedding count mismatch: expected {}, got {}",
442 inputs.len(),
443 embeddings.len()
444 ),
445 });
446 }
447
448 Ok(embeddings)
449 }
450
451 #[cfg(not(feature = "ureq"))]
452 fn make_batch_request(&self, _inputs: &[&str]) -> Result<Vec<Vec<f32>>> {
453 Err(GraphRAGError::Embedding {
454 message: "ureq feature required for batch embeddings".to_string(),
455 })
456 }
457}
458
459#[async_trait::async_trait]
460impl EmbeddingProvider for HttpEmbeddingProvider {
461 async fn initialize(&mut self) -> Result<()> {
462 Ok(())
464 }
465
466 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
467 self.make_request(text)
468 }
469
470 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
471 if texts.is_empty() {
473 return Ok(Vec::new());
474 }
475
476 if texts.len() == 1 {
478 return Ok(vec![self.embed(texts[0]).await?]);
479 }
480
481 #[cfg(feature = "ureq")]
482 {
483 match self.make_batch_request(texts) {
485 Ok(embeddings) => return Ok(embeddings),
486 Err(_) => {
487 },
489 }
490 }
491
492 let mut embeddings = Vec::with_capacity(texts.len());
494 for text in texts {
495 embeddings.push(self.embed(text).await?);
496 }
497 Ok(embeddings)
498 }
499
500 fn dimensions(&self) -> usize {
501 self.dimensions
502 }
503
504 fn is_available(&self) -> bool {
505 #[cfg(feature = "ureq")]
506 {
507 !self.api_key.is_empty()
508 }
509
510 #[cfg(not(feature = "ureq"))]
511 {
512 false
513 }
514 }
515
516 fn provider_name(&self) -> &str {
517 match self.provider_type {
518 EmbeddingProviderType::OpenAI => "OpenAI",
519 EmbeddingProviderType::VoyageAI => "Voyage AI",
520 EmbeddingProviderType::Cohere => "Cohere",
521 EmbeddingProviderType::JinaAI => "Jina AI",
522 EmbeddingProviderType::Mistral => "Mistral AI",
523 EmbeddingProviderType::TogetherAI => "Together AI",
524 _ => "Unknown",
525 }
526 }
527}
528
529#[cfg(test)]
530mod tests {
531 use super::*;
532
533 #[test]
534 fn test_openai_provider_creation() {
535 let provider = HttpEmbeddingProvider::openai(
536 "sk-test".to_string(),
537 "text-embedding-3-small".to_string(),
538 );
539
540 assert_eq!(provider.provider_name(), "OpenAI");
541 assert_eq!(provider.dimensions(), 1536);
542 assert_eq!(provider.endpoint, "https://api.openai.com/v1/embeddings");
543 }
544
545 #[test]
546 fn test_voyage_provider_creation() {
547 let provider =
548 HttpEmbeddingProvider::voyage_ai("pa-test".to_string(), "voyage-3-large".to_string());
549
550 assert_eq!(provider.provider_name(), "Voyage AI");
551 assert_eq!(provider.dimensions(), 1024);
552 }
553
554 #[test]
555 fn test_provider_from_config() {
556 let config = EmbeddingConfig {
557 provider: EmbeddingProviderType::OpenAI,
558 model: "text-embedding-3-small".to_string(),
559 api_key: Some("sk-test".to_string()),
560 cache_dir: None,
561 batch_size: 32,
562 };
563
564 let provider = HttpEmbeddingProvider::from_config(&config);
565 assert!(provider.is_ok());
566
567 let provider = provider.unwrap();
568 assert_eq!(provider.provider_name(), "OpenAI");
569 }
570
571 #[test]
572 fn test_config_without_api_key_fails() {
573 let config = EmbeddingConfig {
574 provider: EmbeddingProviderType::OpenAI,
575 model: "text-embedding-3-small".to_string(),
576 api_key: None,
577 cache_dir: None,
578 batch_size: 32,
579 };
580
581 let result = HttpEmbeddingProvider::from_config(&config);
582 assert!(result.is_err());
583 }
584}