1use crate::core::error::{GraphRAGError, Result};
6use crate::embeddings::{EmbeddingConfig, EmbeddingProviderType};
7use serde::{Deserialize, Serialize};
8use std::path::PathBuf;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct EmbeddingsTomlConfig {
13 #[serde(default)]
15 pub embeddings: EmbeddingProviderConfig,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct EmbeddingProviderConfig {
21 #[serde(default = "default_provider")]
23 pub provider: String,
24
25 #[serde(default = "default_model")]
34 pub model: String,
35
36 pub api_key: Option<String>,
45
46 pub cache_dir: Option<String>,
49
50 #[serde(default = "default_batch_size")]
52 pub batch_size: usize,
53
54 #[serde(skip_serializing_if = "Option::is_none")]
56 pub dimensions: Option<usize>,
57}
58
59impl Default for EmbeddingProviderConfig {
60 fn default() -> Self {
61 Self {
62 provider: default_provider(),
63 model: default_model(),
64 api_key: None,
65 cache_dir: None,
66 batch_size: default_batch_size(),
67 dimensions: None,
68 }
69 }
70}
71
72fn default_provider() -> String {
73 "huggingface".to_string()
74}
75
76fn default_model() -> String {
77 "sentence-transformers/all-MiniLM-L6-v2".to_string()
78}
79
80fn default_batch_size() -> usize {
81 32
82}
83
84impl EmbeddingProviderConfig {
85 pub fn to_embedding_config(&self) -> Result<EmbeddingConfig> {
87 let provider = match self.provider.to_lowercase().as_str() {
89 "huggingface" | "hf" => EmbeddingProviderType::HuggingFace,
90 "openai" => EmbeddingProviderType::OpenAI,
91 "voyage" | "voyageai" | "voyage-ai" => EmbeddingProviderType::VoyageAI,
92 "cohere" => EmbeddingProviderType::Cohere,
93 "jina" | "jinaai" | "jina-ai" => EmbeddingProviderType::JinaAI,
94 "mistral" | "mistralai" | "mistral-ai" => EmbeddingProviderType::Mistral,
95 "together" | "togetherai" | "together-ai" => EmbeddingProviderType::TogetherAI,
96 "onnx" => EmbeddingProviderType::Onnx,
97 "candle" => EmbeddingProviderType::Candle,
98 _ => {
99 return Err(GraphRAGError::Config {
100 message: format!("Unknown embedding provider: {}", self.provider),
101 })
102 },
103 };
104
105 let api_key = self.api_key.clone().or_else(|| match provider {
107 EmbeddingProviderType::OpenAI => std::env::var("OPENAI_API_KEY").ok(),
108 EmbeddingProviderType::VoyageAI => std::env::var("VOYAGE_API_KEY").ok(),
109 EmbeddingProviderType::Cohere => std::env::var("COHERE_API_KEY").ok(),
110 EmbeddingProviderType::JinaAI => std::env::var("JINA_API_KEY").ok(),
111 EmbeddingProviderType::Mistral => std::env::var("MISTRAL_API_KEY").ok(),
112 EmbeddingProviderType::TogetherAI => std::env::var("TOGETHER_API_KEY").ok(),
113 _ => None,
114 });
115
116 Ok(EmbeddingConfig {
117 provider,
118 model: self.model.clone(),
119 api_key,
120 cache_dir: self.cache_dir.clone(),
121 batch_size: self.batch_size,
122 })
123 }
124
125 pub fn from_toml_file(path: impl Into<PathBuf>) -> Result<Self> {
127 let path = path.into();
128 let content = std::fs::read_to_string(&path).map_err(|e| GraphRAGError::Config {
129 message: format!("Failed to read config file {:?}: {}", path, e),
130 })?;
131
132 let config: EmbeddingsTomlConfig =
133 toml::from_str(&content).map_err(|e| GraphRAGError::Config {
134 message: format!("Failed to parse TOML config: {}", e),
135 })?;
136
137 Ok(config.embeddings)
138 }
139
140 pub fn to_toml_file(&self, path: impl Into<PathBuf>) -> Result<()> {
142 let path = path.into();
143 let config = EmbeddingsTomlConfig {
144 embeddings: self.clone(),
145 };
146
147 let toml_string = toml::to_string_pretty(&config).map_err(|e| GraphRAGError::Config {
148 message: format!("Failed to serialize TOML: {}", e),
149 })?;
150
151 std::fs::write(&path, toml_string).map_err(|e| GraphRAGError::Config {
152 message: format!("Failed to write config file {:?}: {}", path, e),
153 })?;
154
155 Ok(())
156 }
157
158 pub fn examples() -> Vec<(String, Self)> {
160 vec![
161 (
162 "HuggingFace (Free, Offline)".to_string(),
163 Self {
164 provider: "huggingface".to_string(),
165 model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
166 api_key: None,
167 cache_dir: Some("~/.cache/huggingface".to_string()),
168 batch_size: 32,
169 dimensions: Some(384),
170 },
171 ),
172 (
173 "HuggingFace (High Quality)".to_string(),
174 Self {
175 provider: "huggingface".to_string(),
176 model: "BAAI/bge-large-en-v1.5".to_string(),
177 api_key: None,
178 cache_dir: None,
179 batch_size: 16,
180 dimensions: Some(1024),
181 },
182 ),
183 (
184 "OpenAI (Production)".to_string(),
185 Self {
186 provider: "openai".to_string(),
187 model: "text-embedding-3-small".to_string(),
188 api_key: Some("sk-...".to_string()),
189 cache_dir: None,
190 batch_size: 100,
191 dimensions: Some(1536),
192 },
193 ),
194 (
195 "Voyage AI (Recommended by Anthropic)".to_string(),
196 Self {
197 provider: "voyage".to_string(),
198 model: "voyage-3-large".to_string(),
199 api_key: Some("pa-...".to_string()),
200 cache_dir: None,
201 batch_size: 128,
202 dimensions: Some(1024),
203 },
204 ),
205 (
206 "Voyage AI (Code Search)".to_string(),
207 Self {
208 provider: "voyage".to_string(),
209 model: "voyage-code-3".to_string(),
210 api_key: Some("pa-...".to_string()),
211 cache_dir: None,
212 batch_size: 64,
213 dimensions: Some(1024),
214 },
215 ),
216 (
217 "Cohere (Multilingual)".to_string(),
218 Self {
219 provider: "cohere".to_string(),
220 model: "embed-multilingual-v3.0".to_string(),
221 api_key: Some("...".to_string()),
222 cache_dir: None,
223 batch_size: 96,
224 dimensions: Some(1024),
225 },
226 ),
227 (
228 "Jina AI (Cost Optimized)".to_string(),
229 Self {
230 provider: "jina".to_string(),
231 model: "jina-embeddings-v3".to_string(),
232 api_key: Some("jina_...".to_string()),
233 cache_dir: None,
234 batch_size: 200,
235 dimensions: Some(1024),
236 },
237 ),
238 (
239 "Mistral (RAG Optimized)".to_string(),
240 Self {
241 provider: "mistral".to_string(),
242 model: "mistral-embed".to_string(),
243 api_key: Some("...".to_string()),
244 cache_dir: None,
245 batch_size: 50,
246 dimensions: Some(1024),
247 },
248 ),
249 (
250 "Together AI (Cheapest)".to_string(),
251 Self {
252 provider: "together".to_string(),
253 model: "BAAI/bge-large-en-v1.5".to_string(),
254 api_key: Some("...".to_string()),
255 cache_dir: None,
256 batch_size: 128,
257 dimensions: Some(1024),
258 },
259 ),
260 ]
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[test]
269 fn test_default_config() {
270 let config = EmbeddingProviderConfig::default();
271 assert_eq!(config.provider, "huggingface");
272 assert_eq!(config.model, "sentence-transformers/all-MiniLM-L6-v2");
273 assert_eq!(config.batch_size, 32);
274 }
275
276 #[test]
277 fn test_to_embedding_config() {
278 let toml_config = EmbeddingProviderConfig {
279 provider: "openai".to_string(),
280 model: "text-embedding-3-small".to_string(),
281 api_key: Some("sk-test".to_string()),
282 cache_dir: None,
283 batch_size: 50,
284 dimensions: None,
285 };
286
287 let config = toml_config.to_embedding_config().unwrap();
288 assert_eq!(config.provider, EmbeddingProviderType::OpenAI);
289 assert_eq!(config.model, "text-embedding-3-small");
290 assert_eq!(config.batch_size, 50);
291 }
292
293 #[test]
294 fn test_provider_aliases() {
295 let configs = vec![
296 ("huggingface", EmbeddingProviderType::HuggingFace),
297 ("hf", EmbeddingProviderType::HuggingFace),
298 ("openai", EmbeddingProviderType::OpenAI),
299 ("voyage", EmbeddingProviderType::VoyageAI),
300 ("voyageai", EmbeddingProviderType::VoyageAI),
301 ("voyage-ai", EmbeddingProviderType::VoyageAI),
302 ("cohere", EmbeddingProviderType::Cohere),
303 ("jina", EmbeddingProviderType::JinaAI),
304 ("jinaai", EmbeddingProviderType::JinaAI),
305 ("mistral", EmbeddingProviderType::Mistral),
306 ("together", EmbeddingProviderType::TogetherAI),
307 ];
308
309 for (alias, expected) in configs {
310 let config = EmbeddingProviderConfig {
311 provider: alias.to_string(),
312 ..Default::default()
313 };
314 let result = config.to_embedding_config().unwrap();
315 assert_eq!(result.provider, expected, "Failed for alias: {}", alias);
316 }
317 }
318
319 #[test]
320 fn test_toml_serialization() {
321 let config = EmbeddingProviderConfig {
322 provider: "openai".to_string(),
323 model: "text-embedding-3-small".to_string(),
324 api_key: Some("sk-test".to_string()),
325 cache_dir: Some("/custom/cache".to_string()),
326 batch_size: 100,
327 dimensions: Some(1536),
328 };
329
330 let toml_string = toml::to_string_pretty(&EmbeddingsTomlConfig {
331 embeddings: config.clone(),
332 })
333 .unwrap();
334
335 assert!(toml_string.contains("provider = \"openai\""));
336 assert!(toml_string.contains("model = \"text-embedding-3-small\""));
337 assert!(toml_string.contains("batch_size = 100"));
338 }
339
340 #[test]
341 fn test_examples() {
342 let examples = EmbeddingProviderConfig::examples();
343 assert!(!examples.is_empty());
344
345 for (name, config) in examples {
346 println!("Testing example: {}", name);
347 assert!(!config.provider.is_empty());
348 assert!(!config.model.is_empty());
349 assert!(config.batch_size > 0);
350
351 let embedding_config = config.to_embedding_config();
353 assert!(embedding_config.is_ok(), "Failed for: {}", name);
354 }
355 }
356}