graphrag_core/embeddings/
api_providers.rs1use crate::core::error::{GraphRAGError, Result};
7use crate::embeddings::{EmbeddingConfig, EmbeddingProvider, EmbeddingProviderType};
8
9#[cfg(feature = "ureq")]
10use ureq;
11
12pub struct HttpEmbeddingProvider {
14 provider_type: EmbeddingProviderType,
15 api_key: String,
16 model: String,
17 endpoint: String,
18 dimensions: usize,
19
20 #[cfg(feature = "ureq")]
21 client: ureq::Agent,
22}
23
24impl HttpEmbeddingProvider {
25 pub fn openai(api_key: String, model: String) -> Self {
35 let dimensions = match model.as_str() {
36 "text-embedding-3-large" => 3072,
37 "text-embedding-3-small" => 1536,
38 "text-embedding-ada-002" => 1536,
39 _ => 1536,
40 };
41
42 Self {
43 provider_type: EmbeddingProviderType::OpenAI,
44 api_key,
45 model,
46 endpoint: "https://api.openai.com/v1/embeddings".to_string(),
47 dimensions,
48 #[cfg(feature = "ureq")]
49 client: ureq::Agent::new(),
50 }
51 }
52
53 pub fn voyage_ai(api_key: String, model: String) -> Self {
63 let dimensions = match model.as_str() {
64 "voyage-3-large" => 1024,
65 "voyage-3.5" => 1024,
66 "voyage-3.5-lite" => 1024,
67 "voyage-code-3" => 1024,
68 "voyage-finance-2" => 1024,
69 "voyage-law-2" => 1024,
70 _ => 1024,
71 };
72
73 Self {
74 provider_type: EmbeddingProviderType::VoyageAI,
75 api_key,
76 model,
77 endpoint: "https://api.voyageai.com/v1/embeddings".to_string(),
78 dimensions,
79 #[cfg(feature = "ureq")]
80 client: ureq::Agent::new(),
81 }
82 }
83
84 pub fn cohere(api_key: String, model: String) -> Self {
94 let dimensions = match model.as_str() {
95 "embed-v4" | "embed-english-v3.0" | "embed-multilingual-v3.0" => 1024,
96 "embed-english-light-v3.0" => 384,
97 _ => 1024,
98 };
99
100 Self {
101 provider_type: EmbeddingProviderType::Cohere,
102 api_key,
103 model,
104 endpoint: "https://api.cohere.ai/v1/embed".to_string(),
105 dimensions,
106 #[cfg(feature = "ureq")]
107 client: ureq::Agent::new(),
108 }
109 }
110
111 pub fn jina_ai(api_key: String, model: String) -> Self {
121 let dimensions = match model.as_str() {
122 "jina-embeddings-v4" => 1024,
123 "jina-clip-v2" => 768,
124 "jina-embeddings-v3" => 1024,
125 _ => 1024,
126 };
127
128 Self {
129 provider_type: EmbeddingProviderType::JinaAI,
130 api_key,
131 model,
132 endpoint: "https://api.jina.ai/v1/embeddings".to_string(),
133 dimensions,
134 #[cfg(feature = "ureq")]
135 client: ureq::Agent::new(),
136 }
137 }
138
139 pub fn mistral(api_key: String, model: String) -> Self {
149 let dimensions = match model.as_str() {
150 "mistral-embed" | "codestral-embed" => 1024,
151 _ => 1024,
152 };
153
154 Self {
155 provider_type: EmbeddingProviderType::Mistral,
156 api_key,
157 model,
158 endpoint: "https://api.mistral.ai/v1/embeddings".to_string(),
159 dimensions,
160 #[cfg(feature = "ureq")]
161 client: ureq::Agent::new(),
162 }
163 }
164
165 pub fn together_ai(api_key: String, model: String) -> Self {
175 let dimensions = match model.as_str() {
176 "BAAI/bge-large-en-v1.5" | "WhereIsAI/UAE-Large-V1" => 1024,
177 "BAAI/bge-base-en-v1.5" => 768,
178 _ => 768,
179 };
180
181 Self {
182 provider_type: EmbeddingProviderType::TogetherAI,
183 api_key,
184 model,
185 endpoint: "https://api.together.xyz/v1/embeddings".to_string(),
186 dimensions,
187 #[cfg(feature = "ureq")]
188 client: ureq::Agent::new(),
189 }
190 }
191
192 pub fn from_config(config: &EmbeddingConfig) -> Result<Self> {
194 let api_key = config.api_key.clone().ok_or_else(|| {
195 GraphRAGError::Embedding {
196 message: format!("API key required for {} provider", config.provider),
197 }
198 })?;
199
200 let provider = match config.provider {
201 EmbeddingProviderType::OpenAI => Self::openai(api_key, config.model.clone()),
202 EmbeddingProviderType::VoyageAI => Self::voyage_ai(api_key, config.model.clone()),
203 EmbeddingProviderType::Cohere => Self::cohere(api_key, config.model.clone()),
204 EmbeddingProviderType::JinaAI => Self::jina_ai(api_key, config.model.clone()),
205 EmbeddingProviderType::Mistral => Self::mistral(api_key, config.model.clone()),
206 EmbeddingProviderType::TogetherAI => Self::together_ai(api_key, config.model.clone()),
207 _ => {
208 return Err(GraphRAGError::Embedding {
209 message: format!("Unsupported API provider: {}", config.provider),
210 })
211 }
212 };
213
214 Ok(provider)
215 }
216
217 #[cfg(feature = "ureq")]
218 fn make_request(&self, input: &str) -> Result<Vec<f32>> {
219 let request_body = match self.provider_type {
221 EmbeddingProviderType::OpenAI => {
222 serde_json::json!({
223 "model": self.model.clone(),
224 "input": input,
225 })
226 }
227 EmbeddingProviderType::VoyageAI => {
228 serde_json::json!({
229 "model": self.model.clone(),
230 "input": input,
231 "input_type": "document",
232 })
233 }
234 EmbeddingProviderType::Cohere => {
235 serde_json::json!({
236 "model": self.model.clone(),
237 "texts": vec![input],
238 "input_type": "search_document",
239 "embedding_types": vec!["float"],
240 })
241 }
242 EmbeddingProviderType::JinaAI | EmbeddingProviderType::Mistral | EmbeddingProviderType::TogetherAI => {
243 serde_json::json!({
244 "model": self.model.clone(),
245 "input": input,
246 })
247 }
248 _ => {
249 return Err(GraphRAGError::Embedding {
250 message: "Unsupported provider type".to_string(),
251 })
252 }
253 };
254
255 let response = self
257 .client
258 .post(&self.endpoint)
259 .set("Authorization", &format!("Bearer {}", self.api_key))
260 .set("Content-Type", "application/json")
261 .send_json(request_body)
262 .map_err(|e| GraphRAGError::Embedding {
263 message: format!("HTTP request failed: {}", e),
264 })?;
265
266 let json_response: serde_json::Value =
268 response.into_json().map_err(|e| GraphRAGError::Embedding {
269 message: format!("Failed to parse JSON response: {}", e),
270 })?;
271
272 let embedding = match self.provider_type {
274 EmbeddingProviderType::OpenAI
275 | EmbeddingProviderType::VoyageAI
276 | EmbeddingProviderType::JinaAI
277 | EmbeddingProviderType::Mistral
278 | EmbeddingProviderType::TogetherAI => {
279 json_response["data"][0]["embedding"]
281 .as_array()
282 .ok_or_else(|| GraphRAGError::Embedding {
283 message: "Invalid response format: expected array".to_string(),
284 })?
285 .iter()
286 .filter_map(|v| v.as_f64().map(|f| f as f32))
287 .collect()
288 }
289 EmbeddingProviderType::Cohere => {
290 json_response["embeddings"][0]
292 .as_array()
293 .ok_or_else(|| GraphRAGError::Embedding {
294 message: "Invalid response format: expected array".to_string(),
295 })?
296 .iter()
297 .filter_map(|v| v.as_f64().map(|f| f as f32))
298 .collect()
299 }
300 _ => vec![],
301 };
302
303 if embedding.is_empty() {
304 return Err(GraphRAGError::Embedding {
305 message: "No embedding returned from API".to_string(),
306 });
307 }
308
309 Ok(embedding)
310 }
311
312 #[cfg(not(feature = "ureq"))]
313 fn make_request(&self, _input: &str) -> Result<Vec<f32>> {
314 Err(GraphRAGError::Embedding {
315 message: "ureq feature required for HTTP-based embeddings".to_string(),
316 })
317 }
318}
319
320#[async_trait::async_trait]
321impl EmbeddingProvider for HttpEmbeddingProvider {
322 async fn initialize(&mut self) -> Result<()> {
323 Ok(())
325 }
326
327 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
328 self.make_request(text)
329 }
330
331 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
332 let mut embeddings = Vec::with_capacity(texts.len());
334 for text in texts {
335 embeddings.push(self.embed(text).await?);
336 }
337 Ok(embeddings)
338 }
339
340 fn dimensions(&self) -> usize {
341 self.dimensions
342 }
343
344 fn is_available(&self) -> bool {
345 #[cfg(feature = "ureq")]
346 {
347 !self.api_key.is_empty()
348 }
349
350 #[cfg(not(feature = "ureq"))]
351 {
352 false
353 }
354 }
355
356 fn provider_name(&self) -> &str {
357 match self.provider_type {
358 EmbeddingProviderType::OpenAI => "OpenAI",
359 EmbeddingProviderType::VoyageAI => "Voyage AI",
360 EmbeddingProviderType::Cohere => "Cohere",
361 EmbeddingProviderType::JinaAI => "Jina AI",
362 EmbeddingProviderType::Mistral => "Mistral AI",
363 EmbeddingProviderType::TogetherAI => "Together AI",
364 _ => "Unknown",
365 }
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 #[test]
374 fn test_openai_provider_creation() {
375 let provider = HttpEmbeddingProvider::openai(
376 "sk-test".to_string(),
377 "text-embedding-3-small".to_string(),
378 );
379
380 assert_eq!(provider.provider_name(), "OpenAI");
381 assert_eq!(provider.dimensions(), 1536);
382 assert_eq!(provider.endpoint, "https://api.openai.com/v1/embeddings");
383 }
384
385 #[test]
386 fn test_voyage_provider_creation() {
387 let provider = HttpEmbeddingProvider::voyage_ai(
388 "pa-test".to_string(),
389 "voyage-3-large".to_string(),
390 );
391
392 assert_eq!(provider.provider_name(), "Voyage AI");
393 assert_eq!(provider.dimensions(), 1024);
394 }
395
396 #[test]
397 fn test_provider_from_config() {
398 let config = EmbeddingConfig {
399 provider: EmbeddingProviderType::OpenAI,
400 model: "text-embedding-3-small".to_string(),
401 api_key: Some("sk-test".to_string()),
402 cache_dir: None,
403 batch_size: 32,
404 };
405
406 let provider = HttpEmbeddingProvider::from_config(&config);
407 assert!(provider.is_ok());
408
409 let provider = provider.unwrap();
410 assert_eq!(provider.provider_name(), "OpenAI");
411 }
412
413 #[test]
414 fn test_config_without_api_key_fails() {
415 let config = EmbeddingConfig {
416 provider: EmbeddingProviderType::OpenAI,
417 model: "text-embedding-3-small".to_string(),
418 api_key: None,
419 cache_dir: None,
420 batch_size: 32,
421 };
422
423 let result = HttpEmbeddingProvider::from_config(&config);
424 assert!(result.is_err());
425 }
426}