Skip to main content

graphrag_core/embeddings/
config.rs

1//! Configuration for embedding providers via TOML
2//!
3//! This module provides TOML-based configuration for all embedding providers.
4
5use crate::core::error::{GraphRAGError, Result};
6use crate::embeddings::{EmbeddingConfig, EmbeddingProviderType};
7use serde::{Deserialize, Serialize};
8use std::path::PathBuf;
9
10/// TOML configuration for embeddings
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct EmbeddingsTomlConfig {
13    /// Embedding provider configuration
14    #[serde(default)]
15    pub embeddings: EmbeddingProviderConfig,
16}
17
18/// Embedding provider configuration
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct EmbeddingProviderConfig {
21    /// Provider type: "huggingface", "openai", "voyage", "cohere", "jina", "mistral", "together"
22    #[serde(default = "default_provider")]
23    pub provider: String,
24
25    /// Model identifier
26    /// - HuggingFace: "sentence-transformers/all-MiniLM-L6-v2"
27    /// - OpenAI: "text-embedding-3-small" or "text-embedding-3-large"
28    /// - Voyage: "voyage-3-large", "voyage-code-3", etc.
29    /// - Cohere: "embed-english-v3.0"
30    /// - Jina: "jina-embeddings-v3"
31    /// - Mistral: "mistral-embed"
32    /// - Together: "BAAI/bge-large-en-v1.5"
33    #[serde(default = "default_model")]
34    pub model: String,
35
36    /// API key (for API providers)
37    /// Can also be set via environment variables:
38    /// - OPENAI_API_KEY
39    /// - VOYAGE_API_KEY
40    /// - COHERE_API_KEY
41    /// - JINA_API_KEY
42    /// - MISTRAL_API_KEY
43    /// - TOGETHER_API_KEY
44    pub api_key: Option<String>,
45
46    /// Cache directory for downloaded models (HuggingFace)
47    /// Default: ~/.cache/huggingface/hub
48    pub cache_dir: Option<String>,
49
50    /// Batch size for processing multiple texts
51    #[serde(default = "default_batch_size")]
52    pub batch_size: usize,
53
54    /// Embedding dimensions (read-only, determined by model)
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub dimensions: Option<usize>,
57}
58
59impl Default for EmbeddingProviderConfig {
60    fn default() -> Self {
61        Self {
62            provider: default_provider(),
63            model: default_model(),
64            api_key: None,
65            cache_dir: None,
66            batch_size: default_batch_size(),
67            dimensions: None,
68        }
69    }
70}
71
72fn default_provider() -> String {
73    "huggingface".to_string()
74}
75
76fn default_model() -> String {
77    "sentence-transformers/all-MiniLM-L6-v2".to_string()
78}
79
80fn default_batch_size() -> usize {
81    32
82}
83
84impl EmbeddingProviderConfig {
85    /// Convert TOML config to EmbeddingConfig
86    pub fn to_embedding_config(&self) -> Result<EmbeddingConfig> {
87        // Parse provider type
88        let provider = match self.provider.to_lowercase().as_str() {
89            "huggingface" | "hf" => EmbeddingProviderType::HuggingFace,
90            "openai" => EmbeddingProviderType::OpenAI,
91            "voyage" | "voyageai" | "voyage-ai" => EmbeddingProviderType::VoyageAI,
92            "cohere" => EmbeddingProviderType::Cohere,
93            "jina" | "jinaai" | "jina-ai" => EmbeddingProviderType::JinaAI,
94            "mistral" | "mistralai" | "mistral-ai" => EmbeddingProviderType::Mistral,
95            "together" | "togetherai" | "together-ai" => EmbeddingProviderType::TogetherAI,
96            "onnx" => EmbeddingProviderType::Onnx,
97            "candle" => EmbeddingProviderType::Candle,
98            _ => {
99                return Err(GraphRAGError::Config {
100                    message: format!("Unknown embedding provider: {}", self.provider),
101                })
102            },
103        };
104
105        // Get API key from config or environment
106        let api_key = self.api_key.clone().or_else(|| match provider {
107            EmbeddingProviderType::OpenAI => std::env::var("OPENAI_API_KEY").ok(),
108            EmbeddingProviderType::VoyageAI => std::env::var("VOYAGE_API_KEY").ok(),
109            EmbeddingProviderType::Cohere => std::env::var("COHERE_API_KEY").ok(),
110            EmbeddingProviderType::JinaAI => std::env::var("JINA_API_KEY").ok(),
111            EmbeddingProviderType::Mistral => std::env::var("MISTRAL_API_KEY").ok(),
112            EmbeddingProviderType::TogetherAI => std::env::var("TOGETHER_API_KEY").ok(),
113            _ => None,
114        });
115
116        Ok(EmbeddingConfig {
117            provider,
118            model: self.model.clone(),
119            api_key,
120            cache_dir: self.cache_dir.clone(),
121            batch_size: self.batch_size,
122        })
123    }
124
125    /// Load from TOML file
126    pub fn from_toml_file(path: impl Into<PathBuf>) -> Result<Self> {
127        let path = path.into();
128        let content = std::fs::read_to_string(&path).map_err(|e| GraphRAGError::Config {
129            message: format!("Failed to read config file {:?}: {}", path, e),
130        })?;
131
132        let config: EmbeddingsTomlConfig =
133            toml::from_str(&content).map_err(|e| GraphRAGError::Config {
134                message: format!("Failed to parse TOML config: {}", e),
135            })?;
136
137        Ok(config.embeddings)
138    }
139
140    /// Save to TOML file
141    pub fn to_toml_file(&self, path: impl Into<PathBuf>) -> Result<()> {
142        let path = path.into();
143        let config = EmbeddingsTomlConfig {
144            embeddings: self.clone(),
145        };
146
147        let toml_string = toml::to_string_pretty(&config).map_err(|e| GraphRAGError::Config {
148            message: format!("Failed to serialize TOML: {}", e),
149        })?;
150
151        std::fs::write(&path, toml_string).map_err(|e| GraphRAGError::Config {
152            message: format!("Failed to write config file {:?}: {}", path, e),
153        })?;
154
155        Ok(())
156    }
157
158    /// Create example configurations for different use cases
159    pub fn examples() -> Vec<(String, Self)> {
160        vec![
161            (
162                "HuggingFace (Free, Offline)".to_string(),
163                Self {
164                    provider: "huggingface".to_string(),
165                    model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
166                    api_key: None,
167                    cache_dir: Some("~/.cache/huggingface".to_string()),
168                    batch_size: 32,
169                    dimensions: Some(384),
170                },
171            ),
172            (
173                "HuggingFace (High Quality)".to_string(),
174                Self {
175                    provider: "huggingface".to_string(),
176                    model: "BAAI/bge-large-en-v1.5".to_string(),
177                    api_key: None,
178                    cache_dir: None,
179                    batch_size: 16,
180                    dimensions: Some(1024),
181                },
182            ),
183            (
184                "OpenAI (Production)".to_string(),
185                Self {
186                    provider: "openai".to_string(),
187                    model: "text-embedding-3-small".to_string(),
188                    api_key: Some("sk-...".to_string()),
189                    cache_dir: None,
190                    batch_size: 100,
191                    dimensions: Some(1536),
192                },
193            ),
194            (
195                "Voyage AI (Recommended by Anthropic)".to_string(),
196                Self {
197                    provider: "voyage".to_string(),
198                    model: "voyage-3-large".to_string(),
199                    api_key: Some("pa-...".to_string()),
200                    cache_dir: None,
201                    batch_size: 128,
202                    dimensions: Some(1024),
203                },
204            ),
205            (
206                "Voyage AI (Code Search)".to_string(),
207                Self {
208                    provider: "voyage".to_string(),
209                    model: "voyage-code-3".to_string(),
210                    api_key: Some("pa-...".to_string()),
211                    cache_dir: None,
212                    batch_size: 64,
213                    dimensions: Some(1024),
214                },
215            ),
216            (
217                "Cohere (Multilingual)".to_string(),
218                Self {
219                    provider: "cohere".to_string(),
220                    model: "embed-multilingual-v3.0".to_string(),
221                    api_key: Some("...".to_string()),
222                    cache_dir: None,
223                    batch_size: 96,
224                    dimensions: Some(1024),
225                },
226            ),
227            (
228                "Jina AI (Cost Optimized)".to_string(),
229                Self {
230                    provider: "jina".to_string(),
231                    model: "jina-embeddings-v3".to_string(),
232                    api_key: Some("jina_...".to_string()),
233                    cache_dir: None,
234                    batch_size: 200,
235                    dimensions: Some(1024),
236                },
237            ),
238            (
239                "Mistral (RAG Optimized)".to_string(),
240                Self {
241                    provider: "mistral".to_string(),
242                    model: "mistral-embed".to_string(),
243                    api_key: Some("...".to_string()),
244                    cache_dir: None,
245                    batch_size: 50,
246                    dimensions: Some(1024),
247                },
248            ),
249            (
250                "Together AI (Cheapest)".to_string(),
251                Self {
252                    provider: "together".to_string(),
253                    model: "BAAI/bge-large-en-v1.5".to_string(),
254                    api_key: Some("...".to_string()),
255                    cache_dir: None,
256                    batch_size: 128,
257                    dimensions: Some(1024),
258                },
259            ),
260        ]
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    #[test]
269    fn test_default_config() {
270        let config = EmbeddingProviderConfig::default();
271        assert_eq!(config.provider, "huggingface");
272        assert_eq!(config.model, "sentence-transformers/all-MiniLM-L6-v2");
273        assert_eq!(config.batch_size, 32);
274    }
275
276    #[test]
277    fn test_to_embedding_config() {
278        let toml_config = EmbeddingProviderConfig {
279            provider: "openai".to_string(),
280            model: "text-embedding-3-small".to_string(),
281            api_key: Some("sk-test".to_string()),
282            cache_dir: None,
283            batch_size: 50,
284            dimensions: None,
285        };
286
287        let config = toml_config.to_embedding_config().unwrap();
288        assert_eq!(config.provider, EmbeddingProviderType::OpenAI);
289        assert_eq!(config.model, "text-embedding-3-small");
290        assert_eq!(config.batch_size, 50);
291    }
292
293    #[test]
294    fn test_provider_aliases() {
295        let configs = vec![
296            ("huggingface", EmbeddingProviderType::HuggingFace),
297            ("hf", EmbeddingProviderType::HuggingFace),
298            ("openai", EmbeddingProviderType::OpenAI),
299            ("voyage", EmbeddingProviderType::VoyageAI),
300            ("voyageai", EmbeddingProviderType::VoyageAI),
301            ("voyage-ai", EmbeddingProviderType::VoyageAI),
302            ("cohere", EmbeddingProviderType::Cohere),
303            ("jina", EmbeddingProviderType::JinaAI),
304            ("jinaai", EmbeddingProviderType::JinaAI),
305            ("mistral", EmbeddingProviderType::Mistral),
306            ("together", EmbeddingProviderType::TogetherAI),
307        ];
308
309        for (alias, expected) in configs {
310            let config = EmbeddingProviderConfig {
311                provider: alias.to_string(),
312                ..Default::default()
313            };
314            let result = config.to_embedding_config().unwrap();
315            assert_eq!(result.provider, expected, "Failed for alias: {}", alias);
316        }
317    }
318
319    #[test]
320    fn test_toml_serialization() {
321        let config = EmbeddingProviderConfig {
322            provider: "openai".to_string(),
323            model: "text-embedding-3-small".to_string(),
324            api_key: Some("sk-test".to_string()),
325            cache_dir: Some("/custom/cache".to_string()),
326            batch_size: 100,
327            dimensions: Some(1536),
328        };
329
330        let toml_string = toml::to_string_pretty(&EmbeddingsTomlConfig {
331            embeddings: config.clone(),
332        })
333        .unwrap();
334
335        assert!(toml_string.contains("provider = \"openai\""));
336        assert!(toml_string.contains("model = \"text-embedding-3-small\""));
337        assert!(toml_string.contains("batch_size = 100"));
338    }
339
340    #[test]
341    fn test_examples() {
342        let examples = EmbeddingProviderConfig::examples();
343        assert!(!examples.is_empty());
344
345        for (name, config) in examples {
346            println!("Testing example: {}", name);
347            assert!(!config.provider.is_empty());
348            assert!(!config.model.is_empty());
349            assert!(config.batch_size > 0);
350
351            // Should convert successfully
352            let embedding_config = config.to_embedding_config();
353            assert!(embedding_config.is_ok(), "Failed for: {}", name);
354        }
355    }
356}