Skip to main content

graphrag_core/embeddings/
api_providers.rs

1//! API-based embedding providers (OpenAI, Voyage AI, Cohere, Jina AI, Mistral, etc.)
2//!
3//! This module provides embedding generation using external API services.
4//! All providers implement the `EmbeddingProvider` trait for consistency.
5
6use crate::core::error::{GraphRAGError, Result};
7use crate::embeddings::{EmbeddingConfig, EmbeddingProvider, EmbeddingProviderType};
8
9#[cfg(feature = "ureq")]
10use ureq;
11
12/// Generic HTTP-based embedding provider
13pub struct HttpEmbeddingProvider {
14    provider_type: EmbeddingProviderType,
15    api_key: String,
16    model: String,
17    endpoint: String,
18    dimensions: usize,
19
20    #[cfg(feature = "ureq")]
21    client: ureq::Agent,
22}
23
24impl HttpEmbeddingProvider {
25    /// Create OpenAI embeddings provider
26    ///
27    /// # Example
28    /// ```rust,ignore
29    /// let provider = HttpEmbeddingProvider::openai(
30    ///     "sk-...".to_string(),
31    ///     "text-embedding-3-small".to_string()
32    /// );
33    /// ```
34    pub fn openai(api_key: String, model: String) -> Self {
35        let dimensions = match model.as_str() {
36            "text-embedding-3-large" => 3072,
37            "text-embedding-3-small" => 1536,
38            "text-embedding-ada-002" => 1536,
39            _ => 1536,
40        };
41
42        Self {
43            provider_type: EmbeddingProviderType::OpenAI,
44            api_key,
45            model,
46            endpoint: "https://api.openai.com/v1/embeddings".to_string(),
47            dimensions,
48            #[cfg(feature = "ureq")]
49            client: ureq::Agent::new(),
50        }
51    }
52
53    /// Create Voyage AI embeddings provider
54    ///
55    /// # Example
56    /// ```rust,ignore
57    /// let provider = HttpEmbeddingProvider::voyage_ai(
58    ///     "pa-...".to_string(),
59    ///     "voyage-3-large".to_string()
60    /// );
61    /// ```
62    pub fn voyage_ai(api_key: String, model: String) -> Self {
63        let dimensions = match model.as_str() {
64            "voyage-3-large" => 1024,
65            "voyage-3.5" => 1024,
66            "voyage-3.5-lite" => 1024,
67            "voyage-code-3" => 1024,
68            "voyage-finance-2" => 1024,
69            "voyage-law-2" => 1024,
70            _ => 1024,
71        };
72
73        Self {
74            provider_type: EmbeddingProviderType::VoyageAI,
75            api_key,
76            model,
77            endpoint: "https://api.voyageai.com/v1/embeddings".to_string(),
78            dimensions,
79            #[cfg(feature = "ureq")]
80            client: ureq::Agent::new(),
81        }
82    }
83
84    /// Create Cohere embeddings provider
85    ///
86    /// # Example
87    /// ```rust,ignore
88    /// let provider = HttpEmbeddingProvider::cohere(
89    ///     "...".to_string(),
90    ///     "embed-english-v3.0".to_string()
91    /// );
92    /// ```
93    pub fn cohere(api_key: String, model: String) -> Self {
94        let dimensions = match model.as_str() {
95            "embed-v4" | "embed-english-v3.0" | "embed-multilingual-v3.0" => 1024,
96            "embed-english-light-v3.0" => 384,
97            _ => 1024,
98        };
99
100        Self {
101            provider_type: EmbeddingProviderType::Cohere,
102            api_key,
103            model,
104            endpoint: "https://api.cohere.ai/v1/embed".to_string(),
105            dimensions,
106            #[cfg(feature = "ureq")]
107            client: ureq::Agent::new(),
108        }
109    }
110
111    /// Create Jina AI embeddings provider
112    ///
113    /// # Example
114    /// ```rust,ignore
115    /// let provider = HttpEmbeddingProvider::jina_ai(
116    ///     "jina_...".to_string(),
117    ///     "jina-embeddings-v3".to_string()
118    /// );
119    /// ```
120    pub fn jina_ai(api_key: String, model: String) -> Self {
121        let dimensions = match model.as_str() {
122            "jina-embeddings-v4" => 1024,
123            "jina-clip-v2" => 768,
124            "jina-embeddings-v3" => 1024,
125            _ => 1024,
126        };
127
128        Self {
129            provider_type: EmbeddingProviderType::JinaAI,
130            api_key,
131            model,
132            endpoint: "https://api.jina.ai/v1/embeddings".to_string(),
133            dimensions,
134            #[cfg(feature = "ureq")]
135            client: ureq::Agent::new(),
136        }
137    }
138
139    /// Create Mistral AI embeddings provider
140    ///
141    /// # Example
142    /// ```rust,ignore
143    /// let provider = HttpEmbeddingProvider::mistral(
144    ///     "...".to_string(),
145    ///     "mistral-embed".to_string()
146    /// );
147    /// ```
148    pub fn mistral(api_key: String, model: String) -> Self {
149        let dimensions = match model.as_str() {
150            "mistral-embed" | "codestral-embed" => 1024,
151            _ => 1024,
152        };
153
154        Self {
155            provider_type: EmbeddingProviderType::Mistral,
156            api_key,
157            model,
158            endpoint: "https://api.mistral.ai/v1/embeddings".to_string(),
159            dimensions,
160            #[cfg(feature = "ureq")]
161            client: ureq::Agent::new(),
162        }
163    }
164
165    /// Create Together AI embeddings provider
166    ///
167    /// # Example
168    /// ```rust,ignore
169    /// let provider = HttpEmbeddingProvider::together_ai(
170    ///     "...".to_string(),
171    ///     "BAAI/bge-large-en-v1.5".to_string()
172    /// );
173    /// ```
174    pub fn together_ai(api_key: String, model: String) -> Self {
175        let dimensions = match model.as_str() {
176            "BAAI/bge-large-en-v1.5" | "WhereIsAI/UAE-Large-V1" => 1024,
177            "BAAI/bge-base-en-v1.5" => 768,
178            _ => 768,
179        };
180
181        Self {
182            provider_type: EmbeddingProviderType::TogetherAI,
183            api_key,
184            model,
185            endpoint: "https://api.together.xyz/v1/embeddings".to_string(),
186            dimensions,
187            #[cfg(feature = "ureq")]
188            client: ureq::Agent::new(),
189        }
190    }
191
192    /// Create provider from configuration
193    pub fn from_config(config: &EmbeddingConfig) -> Result<Self> {
194        let api_key = config.api_key.clone().ok_or_else(|| {
195            GraphRAGError::Embedding {
196                message: format!("API key required for {} provider", config.provider),
197            }
198        })?;
199
200        let provider = match config.provider {
201            EmbeddingProviderType::OpenAI => Self::openai(api_key, config.model.clone()),
202            EmbeddingProviderType::VoyageAI => Self::voyage_ai(api_key, config.model.clone()),
203            EmbeddingProviderType::Cohere => Self::cohere(api_key, config.model.clone()),
204            EmbeddingProviderType::JinaAI => Self::jina_ai(api_key, config.model.clone()),
205            EmbeddingProviderType::Mistral => Self::mistral(api_key, config.model.clone()),
206            EmbeddingProviderType::TogetherAI => Self::together_ai(api_key, config.model.clone()),
207            _ => {
208                return Err(GraphRAGError::Embedding {
209                    message: format!("Unsupported API provider: {}", config.provider),
210                })
211            }
212        };
213
214        Ok(provider)
215    }
216
217    #[cfg(feature = "ureq")]
218    fn make_request(&self, input: &str) -> Result<Vec<f32>> {
219        // Build request body based on provider
220        let request_body = match self.provider_type {
221            EmbeddingProviderType::OpenAI => {
222                serde_json::json!({
223                    "model": self.model.clone(),
224                    "input": input,
225                })
226            }
227            EmbeddingProviderType::VoyageAI => {
228                serde_json::json!({
229                    "model": self.model.clone(),
230                    "input": input,
231                    "input_type": "document",
232                })
233            }
234            EmbeddingProviderType::Cohere => {
235                serde_json::json!({
236                    "model": self.model.clone(),
237                    "texts": vec![input],
238                    "input_type": "search_document",
239                    "embedding_types": vec!["float"],
240                })
241            }
242            EmbeddingProviderType::JinaAI | EmbeddingProviderType::Mistral | EmbeddingProviderType::TogetherAI => {
243                serde_json::json!({
244                    "model": self.model.clone(),
245                    "input": input,
246                })
247            }
248            _ => {
249                return Err(GraphRAGError::Embedding {
250                    message: "Unsupported provider type".to_string(),
251                })
252            }
253        };
254
255        // Make HTTP request
256        let response = self
257            .client
258            .post(&self.endpoint)
259            .set("Authorization", &format!("Bearer {}", self.api_key))
260            .set("Content-Type", "application/json")
261            .send_json(request_body)
262            .map_err(|e| GraphRAGError::Embedding {
263                message: format!("HTTP request failed: {}", e),
264            })?;
265
266        // Parse response
267        let json_response: serde_json::Value =
268            response.into_json().map_err(|e| GraphRAGError::Embedding {
269                message: format!("Failed to parse JSON response: {}", e),
270            })?;
271
272        // Extract embedding based on provider response format
273        let embedding = match self.provider_type {
274            EmbeddingProviderType::OpenAI
275            | EmbeddingProviderType::VoyageAI
276            | EmbeddingProviderType::JinaAI
277            | EmbeddingProviderType::Mistral
278            | EmbeddingProviderType::TogetherAI => {
279                // OpenAI-compatible format: { "data": [{ "embedding": [...] }] }
280                json_response["data"][0]["embedding"]
281                    .as_array()
282                    .ok_or_else(|| GraphRAGError::Embedding {
283                        message: "Invalid response format: expected array".to_string(),
284                    })?
285                    .iter()
286                    .filter_map(|v| v.as_f64().map(|f| f as f32))
287                    .collect()
288            }
289            EmbeddingProviderType::Cohere => {
290                // Cohere format: { "embeddings": [[...]] }
291                json_response["embeddings"][0]
292                    .as_array()
293                    .ok_or_else(|| GraphRAGError::Embedding {
294                        message: "Invalid response format: expected array".to_string(),
295                    })?
296                    .iter()
297                    .filter_map(|v| v.as_f64().map(|f| f as f32))
298                    .collect()
299            }
300            _ => vec![],
301        };
302
303        if embedding.is_empty() {
304            return Err(GraphRAGError::Embedding {
305                message: "No embedding returned from API".to_string(),
306            });
307        }
308
309        Ok(embedding)
310    }
311
312    #[cfg(not(feature = "ureq"))]
313    fn make_request(&self, _input: &str) -> Result<Vec<f32>> {
314        Err(GraphRAGError::Embedding {
315            message: "ureq feature required for HTTP-based embeddings".to_string(),
316        })
317    }
318}
319
320#[async_trait::async_trait]
321impl EmbeddingProvider for HttpEmbeddingProvider {
322    async fn initialize(&mut self) -> Result<()> {
323        // API providers don't need initialization
324        Ok(())
325    }
326
327    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
328        self.make_request(text)
329    }
330
331    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
332        // TODO: Implement batch API calls for providers that support it
333        let mut embeddings = Vec::with_capacity(texts.len());
334        for text in texts {
335            embeddings.push(self.embed(text).await?);
336        }
337        Ok(embeddings)
338    }
339
340    fn dimensions(&self) -> usize {
341        self.dimensions
342    }
343
344    fn is_available(&self) -> bool {
345        #[cfg(feature = "ureq")]
346        {
347            !self.api_key.is_empty()
348        }
349
350        #[cfg(not(feature = "ureq"))]
351        {
352            false
353        }
354    }
355
356    fn provider_name(&self) -> &str {
357        match self.provider_type {
358            EmbeddingProviderType::OpenAI => "OpenAI",
359            EmbeddingProviderType::VoyageAI => "Voyage AI",
360            EmbeddingProviderType::Cohere => "Cohere",
361            EmbeddingProviderType::JinaAI => "Jina AI",
362            EmbeddingProviderType::Mistral => "Mistral AI",
363            EmbeddingProviderType::TogetherAI => "Together AI",
364            _ => "Unknown",
365        }
366    }
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    #[test]
374    fn test_openai_provider_creation() {
375        let provider = HttpEmbeddingProvider::openai(
376            "sk-test".to_string(),
377            "text-embedding-3-small".to_string(),
378        );
379
380        assert_eq!(provider.provider_name(), "OpenAI");
381        assert_eq!(provider.dimensions(), 1536);
382        assert_eq!(provider.endpoint, "https://api.openai.com/v1/embeddings");
383    }
384
385    #[test]
386    fn test_voyage_provider_creation() {
387        let provider = HttpEmbeddingProvider::voyage_ai(
388            "pa-test".to_string(),
389            "voyage-3-large".to_string(),
390        );
391
392        assert_eq!(provider.provider_name(), "Voyage AI");
393        assert_eq!(provider.dimensions(), 1024);
394    }
395
396    #[test]
397    fn test_provider_from_config() {
398        let config = EmbeddingConfig {
399            provider: EmbeddingProviderType::OpenAI,
400            model: "text-embedding-3-small".to_string(),
401            api_key: Some("sk-test".to_string()),
402            cache_dir: None,
403            batch_size: 32,
404        };
405
406        let provider = HttpEmbeddingProvider::from_config(&config);
407        assert!(provider.is_ok());
408
409        let provider = provider.unwrap();
410        assert_eq!(provider.provider_name(), "OpenAI");
411    }
412
413    #[test]
414    fn test_config_without_api_key_fails() {
415        let config = EmbeddingConfig {
416            provider: EmbeddingProviderType::OpenAI,
417            model: "text-embedding-3-small".to_string(),
418            api_key: None,
419            cache_dir: None,
420            batch_size: 32,
421        };
422
423        let result = HttpEmbeddingProvider::from_config(&config);
424        assert!(result.is_err());
425    }
426}