Skip to main content

graphrag_core/embeddings/
api_providers.rs

1//! API-based embedding providers (OpenAI, Voyage AI, Cohere, Jina AI, Mistral, etc.)
2//!
3//! This module provides embedding generation using external API services.
4//! All providers implement the `EmbeddingProvider` trait for consistency.
5
6use crate::core::error::{GraphRAGError, Result};
7use crate::embeddings::{EmbeddingConfig, EmbeddingProvider, EmbeddingProviderType};
8
9#[cfg(feature = "ureq")]
10use ureq;
11
12/// Generic HTTP-based embedding provider
13pub struct HttpEmbeddingProvider {
14    provider_type: EmbeddingProviderType,
15    api_key: String,
16    model: String,
17    endpoint: String,
18    dimensions: usize,
19
20    #[cfg(feature = "ureq")]
21    client: ureq::Agent,
22}
23
24impl HttpEmbeddingProvider {
25    /// Create OpenAI embeddings provider
26    ///
27    /// # Example
28    /// ```rust,ignore
29    /// let provider = HttpEmbeddingProvider::openai(
30    ///     "sk-...".to_string(),
31    ///     "text-embedding-3-small".to_string()
32    /// );
33    /// ```
34    pub fn openai(api_key: String, model: String) -> Self {
35        let dimensions = match model.as_str() {
36            "text-embedding-3-large" => 3072,
37            "text-embedding-3-small" => 1536,
38            "text-embedding-ada-002" => 1536,
39            _ => 1536,
40        };
41
42        Self {
43            provider_type: EmbeddingProviderType::OpenAI,
44            api_key,
45            model,
46            endpoint: "https://api.openai.com/v1/embeddings".to_string(),
47            dimensions,
48            #[cfg(feature = "ureq")]
49            client: ureq::Agent::new(),
50        }
51    }
52
53    /// Create Voyage AI embeddings provider
54    ///
55    /// # Example
56    /// ```rust,ignore
57    /// let provider = HttpEmbeddingProvider::voyage_ai(
58    ///     "pa-...".to_string(),
59    ///     "voyage-3-large".to_string()
60    /// );
61    /// ```
62    pub fn voyage_ai(api_key: String, model: String) -> Self {
63        let dimensions = match model.as_str() {
64            "voyage-3-large" => 1024,
65            "voyage-3.5" => 1024,
66            "voyage-3.5-lite" => 1024,
67            "voyage-code-3" => 1024,
68            "voyage-finance-2" => 1024,
69            "voyage-law-2" => 1024,
70            _ => 1024,
71        };
72
73        Self {
74            provider_type: EmbeddingProviderType::VoyageAI,
75            api_key,
76            model,
77            endpoint: "https://api.voyageai.com/v1/embeddings".to_string(),
78            dimensions,
79            #[cfg(feature = "ureq")]
80            client: ureq::Agent::new(),
81        }
82    }
83
84    /// Create Cohere embeddings provider
85    ///
86    /// # Example
87    /// ```rust,ignore
88    /// let provider = HttpEmbeddingProvider::cohere(
89    ///     "...".to_string(),
90    ///     "embed-english-v3.0".to_string()
91    /// );
92    /// ```
93    pub fn cohere(api_key: String, model: String) -> Self {
94        let dimensions = match model.as_str() {
95            "embed-v4" | "embed-english-v3.0" | "embed-multilingual-v3.0" => 1024,
96            "embed-english-light-v3.0" => 384,
97            _ => 1024,
98        };
99
100        Self {
101            provider_type: EmbeddingProviderType::Cohere,
102            api_key,
103            model,
104            endpoint: "https://api.cohere.ai/v1/embed".to_string(),
105            dimensions,
106            #[cfg(feature = "ureq")]
107            client: ureq::Agent::new(),
108        }
109    }
110
111    /// Create Jina AI embeddings provider
112    ///
113    /// # Example
114    /// ```rust,ignore
115    /// let provider = HttpEmbeddingProvider::jina_ai(
116    ///     "jina_...".to_string(),
117    ///     "jina-embeddings-v3".to_string()
118    /// );
119    /// ```
120    pub fn jina_ai(api_key: String, model: String) -> Self {
121        let dimensions = match model.as_str() {
122            "jina-embeddings-v4" => 1024,
123            "jina-clip-v2" => 768,
124            "jina-embeddings-v3" => 1024,
125            _ => 1024,
126        };
127
128        Self {
129            provider_type: EmbeddingProviderType::JinaAI,
130            api_key,
131            model,
132            endpoint: "https://api.jina.ai/v1/embeddings".to_string(),
133            dimensions,
134            #[cfg(feature = "ureq")]
135            client: ureq::Agent::new(),
136        }
137    }
138
139    /// Create Mistral AI embeddings provider
140    ///
141    /// # Example
142    /// ```rust,ignore
143    /// let provider = HttpEmbeddingProvider::mistral(
144    ///     "...".to_string(),
145    ///     "mistral-embed".to_string()
146    /// );
147    /// ```
148    pub fn mistral(api_key: String, model: String) -> Self {
149        let dimensions = match model.as_str() {
150            "mistral-embed" | "codestral-embed" => 1024,
151            _ => 1024,
152        };
153
154        Self {
155            provider_type: EmbeddingProviderType::Mistral,
156            api_key,
157            model,
158            endpoint: "https://api.mistral.ai/v1/embeddings".to_string(),
159            dimensions,
160            #[cfg(feature = "ureq")]
161            client: ureq::Agent::new(),
162        }
163    }
164
165    /// Create Together AI embeddings provider
166    ///
167    /// # Example
168    /// ```rust,ignore
169    /// let provider = HttpEmbeddingProvider::together_ai(
170    ///     "...".to_string(),
171    ///     "BAAI/bge-large-en-v1.5".to_string()
172    /// );
173    /// ```
174    pub fn together_ai(api_key: String, model: String) -> Self {
175        let dimensions = match model.as_str() {
176            "BAAI/bge-large-en-v1.5" | "WhereIsAI/UAE-Large-V1" => 1024,
177            "BAAI/bge-base-en-v1.5" => 768,
178            _ => 768,
179        };
180
181        Self {
182            provider_type: EmbeddingProviderType::TogetherAI,
183            api_key,
184            model,
185            endpoint: "https://api.together.xyz/v1/embeddings".to_string(),
186            dimensions,
187            #[cfg(feature = "ureq")]
188            client: ureq::Agent::new(),
189        }
190    }
191
192    /// Create provider from configuration
193    pub fn from_config(config: &EmbeddingConfig) -> Result<Self> {
194        let api_key = config
195            .api_key
196            .clone()
197            .ok_or_else(|| GraphRAGError::Embedding {
198                message: format!("API key required for {} provider", config.provider),
199            })?;
200
201        let provider = match config.provider {
202            EmbeddingProviderType::OpenAI => Self::openai(api_key, config.model.clone()),
203            EmbeddingProviderType::VoyageAI => Self::voyage_ai(api_key, config.model.clone()),
204            EmbeddingProviderType::Cohere => Self::cohere(api_key, config.model.clone()),
205            EmbeddingProviderType::JinaAI => Self::jina_ai(api_key, config.model.clone()),
206            EmbeddingProviderType::Mistral => Self::mistral(api_key, config.model.clone()),
207            EmbeddingProviderType::TogetherAI => Self::together_ai(api_key, config.model.clone()),
208            _ => {
209                return Err(GraphRAGError::Embedding {
210                    message: format!("Unsupported API provider: {}", config.provider),
211                })
212            },
213        };
214
215        Ok(provider)
216    }
217
218    #[cfg(feature = "ureq")]
219    fn make_request(&self, input: &str) -> Result<Vec<f32>> {
220        // Build request body based on provider
221        let request_body = match self.provider_type {
222            EmbeddingProviderType::OpenAI => {
223                serde_json::json!({
224                    "model": self.model.clone(),
225                    "input": input,
226                })
227            },
228            EmbeddingProviderType::VoyageAI => {
229                serde_json::json!({
230                    "model": self.model.clone(),
231                    "input": input,
232                    "input_type": "document",
233                })
234            },
235            EmbeddingProviderType::Cohere => {
236                serde_json::json!({
237                    "model": self.model.clone(),
238                    "texts": vec![input],
239                    "input_type": "search_document",
240                    "embedding_types": vec!["float"],
241                })
242            },
243            EmbeddingProviderType::JinaAI
244            | EmbeddingProviderType::Mistral
245            | EmbeddingProviderType::TogetherAI => {
246                serde_json::json!({
247                    "model": self.model.clone(),
248                    "input": input,
249                })
250            },
251            _ => {
252                return Err(GraphRAGError::Embedding {
253                    message: "Unsupported provider type".to_string(),
254                })
255            },
256        };
257
258        // Make HTTP request
259        let response = self
260            .client
261            .post(&self.endpoint)
262            .set("Authorization", &format!("Bearer {}", self.api_key))
263            .set("Content-Type", "application/json")
264            .send_json(request_body)
265            .map_err(|e| GraphRAGError::Embedding {
266                message: format!("HTTP request failed: {}", e),
267            })?;
268
269        // Parse response
270        let json_response: serde_json::Value =
271            response.into_json().map_err(|e| GraphRAGError::Embedding {
272                message: format!("Failed to parse JSON response: {}", e),
273            })?;
274
275        // Extract embedding based on provider response format
276        let embedding = match self.provider_type {
277            EmbeddingProviderType::OpenAI
278            | EmbeddingProviderType::VoyageAI
279            | EmbeddingProviderType::JinaAI
280            | EmbeddingProviderType::Mistral
281            | EmbeddingProviderType::TogetherAI => {
282                // OpenAI-compatible format: { "data": [{ "embedding": [...] }] }
283                json_response["data"][0]["embedding"]
284                    .as_array()
285                    .ok_or_else(|| GraphRAGError::Embedding {
286                        message: "Invalid response format: expected array".to_string(),
287                    })?
288                    .iter()
289                    .filter_map(|v| v.as_f64().map(|f| f as f32))
290                    .collect()
291            },
292            EmbeddingProviderType::Cohere => {
293                // Cohere format: { "embeddings": [[...]] }
294                json_response["embeddings"][0]
295                    .as_array()
296                    .ok_or_else(|| GraphRAGError::Embedding {
297                        message: "Invalid response format: expected array".to_string(),
298                    })?
299                    .iter()
300                    .filter_map(|v| v.as_f64().map(|f| f as f32))
301                    .collect()
302            },
303            _ => vec![],
304        };
305
306        if embedding.is_empty() {
307            return Err(GraphRAGError::Embedding {
308                message: "No embedding returned from API".to_string(),
309            });
310        }
311
312        Ok(embedding)
313    }
314
315    #[cfg(not(feature = "ureq"))]
316    fn make_request(&self, _input: &str) -> Result<Vec<f32>> {
317        Err(GraphRAGError::Embedding {
318            message: "ureq feature required for HTTP-based embeddings".to_string(),
319        })
320    }
321
322    /// Make batch embedding request for multiple texts
323    #[cfg(feature = "ureq")]
324    fn make_batch_request(&self, inputs: &[&str]) -> Result<Vec<Vec<f32>>> {
325        // Build request body based on provider
326        let request_body = match self.provider_type {
327            EmbeddingProviderType::OpenAI => {
328                serde_json::json!({
329                    "model": self.model.clone(),
330                    "input": inputs,
331                })
332            },
333            EmbeddingProviderType::VoyageAI => {
334                serde_json::json!({
335                    "model": self.model.clone(),
336                    "input": inputs,
337                    "input_type": "document",
338                })
339            },
340            EmbeddingProviderType::Cohere => {
341                serde_json::json!({
342                    "model": self.model.clone(),
343                    "texts": inputs,
344                    "input_type": "search_document",
345                    "embedding_types": vec!["float"],
346                })
347            },
348            EmbeddingProviderType::JinaAI
349            | EmbeddingProviderType::Mistral
350            | EmbeddingProviderType::TogetherAI => {
351                serde_json::json!({
352                    "model": self.model.clone(),
353                    "input": inputs,
354                })
355            },
356            _ => {
357                return Err(GraphRAGError::Embedding {
358                    message: "Unsupported provider type for batch".to_string(),
359                })
360            },
361        };
362
363        // Make HTTP request
364        let response = self
365            .client
366            .post(&self.endpoint)
367            .set("Authorization", &format!("Bearer {}", self.api_key))
368            .set("Content-Type", "application/json")
369            .send_json(request_body)
370            .map_err(|e| GraphRAGError::Embedding {
371                message: format!("Batch HTTP request failed: {}", e),
372            })?;
373
374        // Parse response
375        let json_response: serde_json::Value =
376            response.into_json().map_err(|e| GraphRAGError::Embedding {
377                message: format!("Failed to parse batch JSON response: {}", e),
378            })?;
379
380        // Extract embeddings based on provider response format
381        let embeddings = match self.provider_type {
382            EmbeddingProviderType::OpenAI
383            | EmbeddingProviderType::VoyageAI
384            | EmbeddingProviderType::JinaAI
385            | EmbeddingProviderType::Mistral
386            | EmbeddingProviderType::TogetherAI => {
387                // OpenAI-compatible format: { "data": [{ "embedding": [...] }, ...] }
388                let data_array =
389                    json_response["data"]
390                        .as_array()
391                        .ok_or_else(|| GraphRAGError::Embedding {
392                            message: "Invalid batch response format: expected data array"
393                                .to_string(),
394                        })?;
395
396                data_array
397                    .iter()
398                    .map(|item| {
399                        item["embedding"]
400                            .as_array()
401                            .ok_or_else(|| GraphRAGError::Embedding {
402                                message: "Invalid embedding format in batch".to_string(),
403                            })
404                            .map(|arr| {
405                                arr.iter()
406                                    .filter_map(|v| v.as_f64().map(|f| f as f32))
407                                    .collect()
408                            })
409                    })
410                    .collect::<Result<Vec<Vec<f32>>>>()?
411            },
412            EmbeddingProviderType::Cohere => {
413                // Cohere format: { "embeddings": [[...], [...], ...] }
414                let embeddings_array = json_response["embeddings"].as_array().ok_or_else(|| {
415                    GraphRAGError::Embedding {
416                        message: "Invalid Cohere batch response format".to_string(),
417                    }
418                })?;
419
420                embeddings_array
421                    .iter()
422                    .map(|emb| {
423                        emb.as_array()
424                            .ok_or_else(|| GraphRAGError::Embedding {
425                                message: "Invalid embedding array in Cohere batch".to_string(),
426                            })
427                            .map(|arr| {
428                                arr.iter()
429                                    .filter_map(|v| v.as_f64().map(|f| f as f32))
430                                    .collect()
431                            })
432                    })
433                    .collect::<Result<Vec<Vec<f32>>>>()?
434            },
435            _ => vec![],
436        };
437
438        if embeddings.is_empty() || embeddings.len() != inputs.len() {
439            return Err(GraphRAGError::Embedding {
440                message: format!(
441                    "Batch embedding count mismatch: expected {}, got {}",
442                    inputs.len(),
443                    embeddings.len()
444                ),
445            });
446        }
447
448        Ok(embeddings)
449    }
450
451    #[cfg(not(feature = "ureq"))]
452    fn make_batch_request(&self, _inputs: &[&str]) -> Result<Vec<Vec<f32>>> {
453        Err(GraphRAGError::Embedding {
454            message: "ureq feature required for batch embeddings".to_string(),
455        })
456    }
457}
458
459#[async_trait::async_trait]
460impl EmbeddingProvider for HttpEmbeddingProvider {
461    async fn initialize(&mut self) -> Result<()> {
462        // API providers don't need initialization
463        Ok(())
464    }
465
466    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
467        self.make_request(text)
468    }
469
470    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
471        // Use batch API for providers that support it
472        if texts.is_empty() {
473            return Ok(Vec::new());
474        }
475
476        // For single text, use regular embed
477        if texts.len() == 1 {
478            return Ok(vec![self.embed(texts[0]).await?]);
479        }
480
481        #[cfg(feature = "ureq")]
482        {
483            // Try batch request for supported providers
484            match self.make_batch_request(texts) {
485                Ok(embeddings) => return Ok(embeddings),
486                Err(_) => {
487                    // Fallback to sequential requests if batch fails
488                },
489            }
490        }
491
492        // Fallback: sequential requests
493        let mut embeddings = Vec::with_capacity(texts.len());
494        for text in texts {
495            embeddings.push(self.embed(text).await?);
496        }
497        Ok(embeddings)
498    }
499
500    fn dimensions(&self) -> usize {
501        self.dimensions
502    }
503
504    fn is_available(&self) -> bool {
505        #[cfg(feature = "ureq")]
506        {
507            !self.api_key.is_empty()
508        }
509
510        #[cfg(not(feature = "ureq"))]
511        {
512            false
513        }
514    }
515
516    fn provider_name(&self) -> &str {
517        match self.provider_type {
518            EmbeddingProviderType::OpenAI => "OpenAI",
519            EmbeddingProviderType::VoyageAI => "Voyage AI",
520            EmbeddingProviderType::Cohere => "Cohere",
521            EmbeddingProviderType::JinaAI => "Jina AI",
522            EmbeddingProviderType::Mistral => "Mistral AI",
523            EmbeddingProviderType::TogetherAI => "Together AI",
524            _ => "Unknown",
525        }
526    }
527}
528
529#[cfg(test)]
530mod tests {
531    use super::*;
532
533    #[test]
534    fn test_openai_provider_creation() {
535        let provider = HttpEmbeddingProvider::openai(
536            "sk-test".to_string(),
537            "text-embedding-3-small".to_string(),
538        );
539
540        assert_eq!(provider.provider_name(), "OpenAI");
541        assert_eq!(provider.dimensions(), 1536);
542        assert_eq!(provider.endpoint, "https://api.openai.com/v1/embeddings");
543    }
544
545    #[test]
546    fn test_voyage_provider_creation() {
547        let provider =
548            HttpEmbeddingProvider::voyage_ai("pa-test".to_string(), "voyage-3-large".to_string());
549
550        assert_eq!(provider.provider_name(), "Voyage AI");
551        assert_eq!(provider.dimensions(), 1024);
552    }
553
554    #[test]
555    fn test_provider_from_config() {
556        let config = EmbeddingConfig {
557            provider: EmbeddingProviderType::OpenAI,
558            model: "text-embedding-3-small".to_string(),
559            api_key: Some("sk-test".to_string()),
560            cache_dir: None,
561            batch_size: 32,
562        };
563
564        let provider = HttpEmbeddingProvider::from_config(&config);
565        assert!(provider.is_ok());
566
567        let provider = provider.unwrap();
568        assert_eq!(provider.provider_name(), "OpenAI");
569    }
570
571    #[test]
572    fn test_config_without_api_key_fails() {
573        let config = EmbeddingConfig {
574            provider: EmbeddingProviderType::OpenAI,
575            model: "text-embedding-3-small".to_string(),
576            api_key: None,
577            cache_dir: None,
578            batch_size: 32,
579        };
580
581        let result = HttpEmbeddingProvider::from_config(&config);
582        assert!(result.is_err());
583    }
584}