1use std::sync::Arc;
40
41use tracing::warn;
42
43use crate::error::{LlmError, Result};
44use crate::model_config::{ProviderConfig, ProviderType as ConfigProviderType};
45use crate::providers::anthropic::AnthropicProvider;
46use crate::providers::azure_openai::AzureOpenAIProvider;
47use crate::providers::gemini::GeminiProvider;
48use crate::providers::huggingface::HuggingFaceProvider;
49use crate::providers::lmstudio::LMStudioProvider;
50use crate::providers::mistral::MistralProvider;
51use crate::providers::openai_compatible::OpenAICompatibleProvider;
52use crate::providers::openrouter::OpenRouterProvider;
53use crate::providers::xai::XAIProvider;
54use crate::traits::{EmbeddingProvider, LLMProvider};
55use crate::{MockProvider, OllamaProvider, OpenAIProvider, VsCodeCopilotProvider};
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum ProviderType {
60 OpenAI,
62 Anthropic,
64 Gemini,
66 OpenRouter,
68 XAI,
70 HuggingFace,
72 Ollama,
74 LMStudio,
76 VsCodeCopilot,
78 Mock,
80 Mistral,
82 AzureOpenAI,
84 #[cfg(feature = "bedrock")]
86 Bedrock,
87}
88
89impl ProviderType {
90 #[allow(clippy::should_implement_trait)]
103 pub fn from_str(s: &str) -> Option<Self> {
104 match s.to_lowercase().as_str() {
105 "openai" => Some(Self::OpenAI),
106 "anthropic" | "claude" => Some(Self::Anthropic),
107 "gemini" | "google" | "vertex" | "vertexai" => Some(Self::Gemini),
108 "openrouter" | "open-router" => Some(Self::OpenRouter),
109 "xai" | "grok" => Some(Self::XAI),
110 "huggingface" | "hf" | "hugging-face" | "hugging_face" => Some(Self::HuggingFace),
111 "ollama" => Some(Self::Ollama),
112 "lmstudio" | "lm-studio" | "lm_studio" => Some(Self::LMStudio),
113 "vscode" | "vscode-copilot" | "copilot" => Some(Self::VsCodeCopilot),
114 "mock" => Some(Self::Mock),
115 "mistral" | "mistral-ai" | "mistralai" => Some(Self::Mistral),
116 "azure" | "azure-openai" | "azure_openai" | "azureopenai" => Some(Self::AzureOpenAI),
117 #[cfg(feature = "bedrock")]
118 "bedrock" | "aws-bedrock" | "aws_bedrock" => Some(Self::Bedrock),
119 _ => None,
120 }
121 }
122}
123
124pub struct ProviderFactory;
128
129impl ProviderFactory {
130 pub fn from_env() -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
154 if let Ok(provider_str) = std::env::var("EDGEQUAKE_LLM_PROVIDER") {
156 if let Some(provider_type) = ProviderType::from_str(&provider_str) {
157 return Self::create(provider_type);
158 }
159 return Err(LlmError::ConfigError(format!(
160 "Unknown provider type: {}. Valid options: openai, anthropic, ollama, lmstudio, mock",
161 provider_str
162 )));
163 }
164
165 if std::env::var("OLLAMA_HOST").is_ok() || std::env::var("OLLAMA_MODEL").is_ok() {
168 return Self::create(ProviderType::Ollama);
169 }
170
171 if std::env::var("LMSTUDIO_HOST").is_ok() || std::env::var("LMSTUDIO_MODEL").is_ok() {
173 return Self::create(ProviderType::LMStudio);
174 }
175
176 if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
178 if !api_key.is_empty() {
179 return Self::create(ProviderType::Anthropic);
180 }
181 }
182
183 if let Ok(api_key) =
185 std::env::var("GEMINI_API_KEY").or_else(|_| std::env::var("GOOGLE_API_KEY"))
186 {
187 if !api_key.is_empty() {
188 return Self::create(ProviderType::Gemini);
189 }
190 }
191
192 if let Ok(api_key) = std::env::var("MISTRAL_API_KEY") {
194 if !api_key.is_empty() {
195 return Self::create(ProviderType::Mistral);
196 }
197 }
198
199 let azure_key = std::env::var("AZURE_OPENAI_CONTENTGEN_API_KEY")
201 .or_else(|_| std::env::var("AZURE_OPENAI_API_KEY"));
202 if let Ok(api_key) = azure_key {
203 if !api_key.is_empty() {
204 let azure_endpoint = std::env::var("AZURE_OPENAI_CONTENTGEN_API_ENDPOINT")
206 .or_else(|_| std::env::var("AZURE_OPENAI_ENDPOINT"));
207 if azure_endpoint.is_ok() {
208 return Self::create(ProviderType::AzureOpenAI);
209 }
210 }
211 }
212
213 if let Ok(api_key) = std::env::var("XAI_API_KEY") {
215 if !api_key.is_empty() {
216 return Self::create(ProviderType::XAI);
217 }
218 }
219
220 if let Ok(api_key) =
222 std::env::var("HF_TOKEN").or_else(|_| std::env::var("HUGGINGFACE_TOKEN"))
223 {
224 if !api_key.is_empty() {
225 return Self::create(ProviderType::HuggingFace);
226 }
227 }
228
229 if let Ok(api_key) = std::env::var("OPENROUTER_API_KEY") {
231 if !api_key.is_empty() {
232 return Self::create(ProviderType::OpenRouter);
233 }
234 }
235
236 if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
237 if !api_key.is_empty() && api_key != "test-key" {
238 return Self::create(ProviderType::OpenAI);
239 }
240 }
241
242 Ok(Self::create_mock())
244 }
245
246 pub fn create(
260 provider_type: ProviderType,
261 ) -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
262 match provider_type {
263 ProviderType::OpenAI => Self::create_openai(),
264 ProviderType::Anthropic => Self::create_anthropic(),
265 ProviderType::Gemini => Self::create_gemini(),
266 ProviderType::OpenRouter => Self::create_openrouter(),
267 ProviderType::XAI => Self::create_xai(),
268 ProviderType::HuggingFace => Self::create_huggingface(),
269 ProviderType::Ollama => Self::create_ollama(),
270 ProviderType::LMStudio => Self::create_lmstudio(),
271 ProviderType::VsCodeCopilot => Self::create_vscode_copilot(),
272 ProviderType::Mock => Ok(Self::create_mock()),
273 ProviderType::Mistral => Self::create_mistral(),
274 ProviderType::AzureOpenAI => Self::create_azure_openai(),
275 #[cfg(feature = "bedrock")]
276 ProviderType::Bedrock => Self::create_bedrock(),
277 }
278 }
279
280 pub fn create_with_model(
298 provider_type: ProviderType,
299 model: Option<&str>,
300 ) -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
301 match model {
302 Some(m) => match provider_type {
303 ProviderType::OpenRouter => Self::create_openrouter_with_model(m),
304 ProviderType::Anthropic => Self::create_anthropic_with_model(m),
305 ProviderType::Gemini => Self::create_gemini_with_model(m),
306 ProviderType::XAI => Self::create_xai_with_model(m),
307 ProviderType::OpenAI => Self::create_openai_with_model(m),
308 ProviderType::Ollama => Self::create_ollama_with_model(m),
309 ProviderType::LMStudio => Self::create_lmstudio_with_model(m),
310 ProviderType::HuggingFace => Self::create_huggingface(),
312 ProviderType::VsCodeCopilot => Self::create_vscode_copilot(),
313 ProviderType::Mock => Ok(Self::create_mock()),
314 ProviderType::Mistral => Self::create_mistral_with_model(m),
315 ProviderType::AzureOpenAI => Self::create_azure_openai_with_deployment(m),
316 #[cfg(feature = "bedrock")]
317 ProviderType::Bedrock => Self::create_bedrock_with_model(m),
318 },
319 None => Self::create(provider_type),
320 }
321 }
322
323 pub fn from_config(
352 config: &ProviderConfig,
353 ) -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
354 Self::from_config_with_model(config, None)
355 }
356
357 pub fn from_config_with_model(
371 config: &ProviderConfig,
372 model_name: Option<&str>,
373 ) -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
374 match config.provider_type {
375 ConfigProviderType::OpenAI => Self::create_openai(),
376 ConfigProviderType::Ollama => Self::create_ollama(),
377 ConfigProviderType::LMStudio => Self::create_lmstudio(),
378 ConfigProviderType::Mock => Ok(Self::create_mock()),
379 ConfigProviderType::OpenAICompatible => {
380 Self::create_openai_compatible_with_model(config, model_name)
381 }
382 ConfigProviderType::Azure => {
383 Self::create_azure_openai()
385 }
386 ConfigProviderType::Anthropic => Self::create_anthropic_from_config(config, model_name),
387 ConfigProviderType::OpenRouter => {
388 Self::create_openrouter_from_config(config, model_name)
389 }
390 ConfigProviderType::Mistral => Self::create_mistral_from_config(config, model_name),
391 }
392 }
393
394 #[allow(dead_code)]
403 fn create_openai_compatible(
404 config: &ProviderConfig,
405 ) -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
406 Self::create_openai_compatible_with_model(config, None)
407 }
408
409 fn create_openai_compatible_with_model(
411 config: &ProviderConfig,
412 model_name: Option<&str>,
413 ) -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
414 let mut provider_instance = OpenAICompatibleProvider::from_config(config.clone())?;
415
416 if let Some(model) = model_name {
418 provider_instance = provider_instance.with_model(model);
419 }
420
421 let provider = Arc::new(provider_instance);
422
423 let has_embedding = config.default_embedding_model.is_some();
425
426 if has_embedding {
427 Ok((provider.clone(), provider))
429 } else {
430 match Self::create_openai() {
432 Ok((_, embedding)) => Ok((provider, embedding)),
433 Err(_) => {
434 Ok((provider.clone(), provider))
437 }
438 }
439 }
440 }
441
442 fn create_openai() -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
446 let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
447 LlmError::ConfigError("OPENAI_API_KEY not set for OpenAI provider".to_string())
448 })?;
449
450 if api_key.is_empty() || api_key == "test-key" {
451 return Err(LlmError::ConfigError(
452 "OPENAI_API_KEY is empty or invalid".to_string(),
453 ));
454 }
455
456 let provider = Arc::new(OpenAIProvider::new(api_key));
457 Ok((provider.clone(), provider))
458 }
459
460 fn create_anthropic() -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
484 let provider = Arc::new(AnthropicProvider::from_env()?);
486
487 let embedding: Arc<dyn EmbeddingProvider> = match Self::create_openai() {
489 Ok((_, embedding)) => embedding,
490 Err(_) => Arc::new(MockProvider::new()),
491 };
492
493 Ok((provider, embedding))
494 }
495
496 fn create_anthropic_from_config(
500 config: &ProviderConfig,
501 model_name: Option<&str>,
502 ) -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
503 let api_key_var = config.api_key_env.as_deref().unwrap_or("ANTHROPIC_API_KEY");
505 let api_key = std::env::var(api_key_var).map_err(|_| {
506 LlmError::ConfigError(format!("{} not set for Anthropic provider", api_key_var))
507 })?;
508
509 if api_key.is_empty() {
510 return Err(LlmError::ConfigError(format!("{} is empty", api_key_var)));
511 }
512
513 let model = model_name
515 .map(|s| s.to_string())
516 .or_else(|| config.default_llm_model.clone())
517 .unwrap_or_else(|| "claude-sonnet-4-5-20250929".to_string());
518
519 let mut provider = AnthropicProvider::new(api_key).with_model(model);
521
522 if let Some(base_url) = &config.base_url {
523 provider = provider.with_base_url(base_url);
524 }
525
526 let llm_provider = Arc::new(provider);
527
528 let embedding: Arc<dyn EmbeddingProvider> = match Self::create_openai() {
530 Ok((_, embedding)) => embedding,
531 Err(_) => Arc::new(MockProvider::new()),
532 };
533
534 Ok((llm_provider, embedding))
535 }
536
537 fn create_gemini() -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
546 use crate::GeminiProvider;
547
548 let provider = GeminiProvider::from_env()?;
549 let llm_provider: Arc<dyn LLMProvider> = Arc::new(provider);
552
553 let embedding_provider = GeminiProvider::from_env()?;
554 let embedding: Arc<dyn EmbeddingProvider> = Arc::new(embedding_provider);
555
556 Ok((llm_provider, embedding))
557 }
558
559 fn create_openrouter() -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
567 let api_key = std::env::var("OPENROUTER_API_KEY").map_err(|_| {
568 LlmError::ConfigError("OPENROUTER_API_KEY not set for OpenRouter provider".to_string())
569 })?;
570
571 if api_key.is_empty() {
572 return Err(LlmError::ConfigError(
573 "OPENROUTER_API_KEY is empty".to_string(),
574 ));
575 }
576
577 let model = std::env::var("OPENROUTER_MODEL")
578 .unwrap_or_else(|_| "anthropic/claude-3.5-sonnet".to_string());
579
580 let mut provider = OpenRouterProvider::new(api_key).with_model(model);
581
582 if let Ok(url) = std::env::var("OPENROUTER_SITE_URL") {
584 provider = provider.with_site_url(url);
585 }
586 if let Ok(name) = std::env::var("OPENROUTER_SITE_NAME") {
587 provider = provider.with_site_name(name);
588 }
589
590 let llm_provider = Arc::new(provider);
591
592 let embedding: Arc<dyn EmbeddingProvider> = match Self::create_openai() {
594 Ok((_, embedding)) => embedding,
595 Err(_) => Arc::new(MockProvider::new()),
596 };
597
598 Ok((llm_provider, embedding))
599 }
600
601 fn create_openrouter_with_model(
609 model: &str,
610 ) -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
611 let api_key = std::env::var("OPENROUTER_API_KEY").map_err(|_| {
612 LlmError::ConfigError("OPENROUTER_API_KEY not set for OpenRouter provider".to_string())
613 })?;
614
615 if api_key.is_empty() {
616 return Err(LlmError::ConfigError(
617 "OPENROUTER_API_KEY is empty".to_string(),
618 ));
619 }
620
621 let mut provider = OpenRouterProvider::new(api_key).with_model(model);
622
623 if let Ok(url) = std::env::var("OPENROUTER_SITE_URL") {
625 provider = provider.with_site_url(url);
626 }
627 if let Ok(name) = std::env::var("OPENROUTER_SITE_NAME") {
628 provider = provider.with_site_name(name);
629 }
630
631 let llm_provider = Arc::new(provider);
632
633 let embedding: Arc<dyn EmbeddingProvider> = match Self::create_openai() {
635 Ok((_, embedding)) => embedding,
636 Err(_) => Arc::new(MockProvider::new()),
637 };
638
639 Ok((llm_provider, embedding))
640 }
641
642 fn create_openai_with_model(
644 model: &str,
645 ) -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
646 let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
647 LlmError::ConfigError("OPENAI_API_KEY not set for OpenAI provider".to_string())
648 })?;
649
650 if api_key.is_empty() || api_key == "test-key" {
651 return Err(LlmError::ConfigError(
652 "OPENAI_API_KEY is empty or invalid".to_string(),
653 ));
654 }
655
656 let provider = Arc::new(OpenAIProvider::new(api_key).with_model(model));
657 Ok((provider.clone(), provider))
658 }
659
660 fn create_anthropic_with_model(
662 model: &str,
663 ) -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
664 let api_key = std::env::var("ANTHROPIC_API_KEY").map_err(|_| {
665 LlmError::ConfigError("ANTHROPIC_API_KEY not set for Anthropic provider".to_string())
666 })?;
667
668 let mut provider = AnthropicProvider::new(&api_key);
669
670 if let Ok(base_url) = std::env::var("ANTHROPIC_BASE_URL") {
672 provider = provider.with_base_url(&base_url);
673 }
674 provider = provider.with_model(model);
675
676 let llm_provider = Arc::new(provider);
677
678 let embedding: Arc<dyn EmbeddingProvider> = match Self::create_openai() {
680 Ok((_, embedding)) => embedding,
681 Err(_) => Arc::new(MockProvider::new()),
682 };
683
684 Ok((llm_provider, embedding))
685 }
686
687 fn create_gemini_with_model(
690 model: &str,
691 ) -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
692 if model.starts_with("vertexai:") {
694 let actual_model = model.strip_prefix("vertexai:").unwrap_or(model);
696
697 let provider = Arc::new(GeminiProvider::from_env_vertex_ai()?.with_model(actual_model));
699
700 let embedding: Arc<dyn EmbeddingProvider> =
702 Arc::new(GeminiProvider::from_env_vertex_ai()?);
703
704 return Ok((provider, embedding));
705 }
706
707 let api_key = std::env::var("GEMINI_API_KEY")
709 .or_else(|_| std::env::var("GOOGLE_API_KEY"))
710 .map_err(|_| {
711 LlmError::ConfigError(
712 "GEMINI_API_KEY or GOOGLE_API_KEY not set for Gemini provider".to_string(),
713 )
714 })?;
715
716 let provider = Arc::new(GeminiProvider::new(&api_key).with_model(model));
717
718 let embedding: Arc<dyn EmbeddingProvider> = Arc::new(GeminiProvider::new(&api_key));
720
721 Ok((provider, embedding))
722 }
723
724 fn create_xai_with_model(
726 model: &str,
727 ) -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
728 std::env::var("XAI_API_KEY").map_err(|_| {
730 LlmError::ConfigError("XAI_API_KEY not set for xAI provider".to_string())
731 })?;
732
733 std::env::set_var("XAI_MODEL", model);
735 let provider = XAIProvider::from_env()?;
736
737 let llm_provider = Arc::new(provider);
738
739 let embedding: Arc<dyn EmbeddingProvider> = match Self::create_openai() {
741 Ok((_, embedding)) => embedding,
742 Err(_) => Arc::new(MockProvider::new()),
743 };
744
745 Ok((llm_provider, embedding))
746 }
747
748 fn create_ollama_with_model(
750 model: &str,
751 ) -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
752 std::env::set_var("OLLAMA_MODEL", model);
754 let provider = Arc::new(OllamaProvider::from_env()?);
755 Ok((provider.clone(), provider))
756 }
757
758 fn create_lmstudio_with_model(
760 model: &str,
761 ) -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
762 std::env::set_var("LMSTUDIO_MODEL", model);
764 let provider = Arc::new(LMStudioProvider::from_env()?);
765 Ok((provider.clone(), provider))
766 }
767
768 fn create_openrouter_from_config(
772 config: &ProviderConfig,
773 model_name: Option<&str>,
774 ) -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
775 let api_key_var = config
777 .api_key_env
778 .as_deref()
779 .unwrap_or("OPENROUTER_API_KEY");
780 let api_key = std::env::var(api_key_var).map_err(|_| {
781 LlmError::ConfigError(format!("{} not set for OpenRouter provider", api_key_var))
782 })?;
783
784 if api_key.is_empty() {
785 return Err(LlmError::ConfigError(format!("{} is empty", api_key_var)));
786 }
787
788 let model = model_name
790 .map(|s| s.to_string())
791 .or_else(|| config.default_llm_model.clone())
792 .unwrap_or_else(|| "anthropic/claude-3.5-sonnet".to_string());
793
794 let mut provider = OpenRouterProvider::new(api_key).with_model(model);
796
797 if let Some(base_url) = &config.base_url {
798 provider = provider.with_base_url(base_url);
799 }
800
801 let llm_provider = Arc::new(provider);
802
803 let embedding: Arc<dyn EmbeddingProvider> = match Self::create_openai() {
805 Ok((_, embedding)) => embedding,
806 Err(_) => Arc::new(MockProvider::new()),
807 };
808
809 Ok((llm_provider, embedding))
810 }
811
812 fn create_ollama() -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
819 let provider = Arc::new(OllamaProvider::from_env()?);
820 Ok((provider.clone(), provider))
821 }
822
823 fn create_lmstudio() -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
831 let provider = Arc::new(LMStudioProvider::from_env()?);
832 Ok((provider.clone(), provider))
833 }
834
835 fn create_xai() -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
843 let provider = Arc::new(XAIProvider::from_env()?);
844
845 let embedding: Arc<dyn EmbeddingProvider> = match Self::create_openai() {
847 Ok((_, embedding)) => embedding,
848 Err(_) => Arc::new(MockProvider::new()),
849 };
850
851 Ok((provider, embedding))
852 }
853
854 fn create_huggingface() -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
863 let provider = Arc::new(HuggingFaceProvider::from_env()?);
864
865 let embedding: Arc<dyn EmbeddingProvider> = match Self::create_openai() {
867 Ok((_, embedding)) => embedding,
868 Err(_) => Arc::new(MockProvider::new()),
869 };
870
871 Ok((provider, embedding))
872 }
873
874 fn create_vscode_copilot() -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
889 let model =
890 std::env::var("VSCODE_COPILOT_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string());
891
892 let builder = VsCodeCopilotProvider::new().model(model);
895
896 let provider = Arc::new(builder.build().map_err(|e| {
897 LlmError::ConfigError(format!("Failed to create VSCode Copilot provider: {}", e))
898 })?);
899
900 Ok((provider.clone(), provider))
903 }
904
905 fn create_mock() -> (Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>) {
909 let provider = Arc::new(MockProvider::new());
910 (provider.clone(), provider)
911 }
912
913 fn create_mistral() -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
922 let provider = Arc::new(MistralProvider::from_env()?);
923 Ok((provider.clone(), provider))
924 }
925
926 fn create_mistral_with_model(
928 model: &str,
929 ) -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
930 let provider = Arc::new(MistralProvider::from_env()?.with_model(model));
931 Ok((provider.clone(), provider))
932 }
933
934 fn create_mistral_from_config(
936 config: &ProviderConfig,
937 model_name: Option<&str>,
938 ) -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
939 let mut provider = MistralProvider::from_config(config)?;
940 if let Some(model) = model_name {
941 provider = provider.with_model(model);
942 }
943 let provider = Arc::new(provider);
944 Ok((provider.clone(), provider))
945 }
946
947 pub fn create_azure_openai() -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
969 let provider = Arc::new(AzureOpenAIProvider::from_env_auto()?);
970 Ok((provider.clone(), provider))
971 }
972
973 fn create_azure_openai_with_deployment(
975 deployment: &str,
976 ) -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
977 let provider = Arc::new(AzureOpenAIProvider::from_env_auto()?.with_deployment(deployment));
979 Ok((provider.clone(), provider))
980 }
981
982 #[cfg(feature = "bedrock")]
993 fn create_bedrock() -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
994 use crate::providers::bedrock::BedrockProvider;
995
996 let rt = tokio::runtime::Handle::try_current().map_err(|_| {
997 LlmError::ConfigError(
998 "Bedrock provider requires a Tokio runtime (use #[tokio::main] or Runtime::new())"
999 .to_string(),
1000 )
1001 })?;
1002 let provider = Arc::new(tokio::task::block_in_place(|| {
1003 rt.block_on(BedrockProvider::from_env())
1004 })?);
1005
1006 let embedding: Arc<dyn EmbeddingProvider> = provider.clone();
1008
1009 Ok((provider, embedding))
1010 }
1011
1012 #[cfg(feature = "bedrock")]
1014 fn create_bedrock_with_model(
1015 model: &str,
1016 ) -> Result<(Arc<dyn LLMProvider>, Arc<dyn EmbeddingProvider>)> {
1017 use crate::providers::bedrock::BedrockProvider;
1018
1019 let rt = tokio::runtime::Handle::try_current().map_err(|_| {
1020 LlmError::ConfigError("Bedrock provider requires a Tokio runtime".to_string())
1021 })?;
1022 let provider = Arc::new(
1023 tokio::task::block_in_place(|| rt.block_on(BedrockProvider::from_env()))?
1024 .with_model(model),
1025 );
1026
1027 let embedding: Arc<dyn EmbeddingProvider> = provider.clone();
1029
1030 Ok((provider, embedding))
1031 }
1032
1033 pub fn embedding_dimension() -> Result<usize> {
1045 let (_, embedding_provider) = Self::from_env()?;
1046 Ok(embedding_provider.dimension())
1047 }
1048
1049 pub fn create_embedding_provider(
1081 provider_name: &str,
1082 model: &str,
1083 _dimension: usize,
1084 ) -> Result<Arc<dyn EmbeddingProvider>> {
1085 let provider_type = ProviderType::from_str(provider_name).ok_or_else(|| {
1086 LlmError::ConfigError(format!(
1087 "Unknown embedding provider: {}. Valid: openai, ollama, lmstudio, vscode-copilot, mock",
1088 provider_name
1089 ))
1090 })?;
1091
1092 match provider_type {
1093 ProviderType::OpenAI => {
1094 let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
1095 LlmError::ConfigError(
1096 "OPENAI_API_KEY required for OpenAI embedding provider".to_string(),
1097 )
1098 })?;
1099 let provider = OpenAIProvider::new(api_key).with_embedding_model(model);
1102 Ok(Arc::new(provider))
1103 }
1104 ProviderType::Anthropic => {
1105 warn!("Anthropic doesn't support embeddings, using mock provider");
1107 Ok(Arc::new(MockProvider::new()))
1108 }
1109 ProviderType::OpenRouter => {
1110 warn!("OpenRouter doesn't support embeddings, using mock provider");
1112 Ok(Arc::new(MockProvider::new()))
1113 }
1114 ProviderType::XAI => {
1115 warn!("xAI doesn't support embeddings, using mock provider");
1117 Ok(Arc::new(MockProvider::new()))
1118 }
1119 ProviderType::HuggingFace => {
1120 warn!("HuggingFace LLM provider doesn't support embeddings, using mock provider");
1122 Ok(Arc::new(MockProvider::new()))
1123 }
1124 ProviderType::Gemini => {
1125 match GeminiProvider::from_env() {
1129 Ok(provider) => Ok(Arc::new(provider.with_embedding_model(model))),
1130 Err(e) => {
1131 warn!(
1132 "Gemini credentials unavailable ({}), falling back to mock embedding provider",
1133 e
1134 );
1135 Ok(Arc::new(MockProvider::new()))
1136 }
1137 }
1138 }
1139 ProviderType::Ollama => {
1140 let host = std::env::var("OLLAMA_HOST")
1142 .unwrap_or_else(|_| "http://localhost:11434".to_string());
1143 let provider = OllamaProvider::builder()
1144 .host(&host)
1145 .embedding_model(model)
1146 .build()?;
1147 Ok(Arc::new(provider))
1148 }
1149 ProviderType::LMStudio => {
1150 let host = std::env::var("LMSTUDIO_HOST")
1152 .unwrap_or_else(|_| "http://localhost:1234".to_string());
1153 let provider = LMStudioProvider::builder()
1154 .host(&host)
1155 .embedding_model(model)
1156 .build()?;
1157 Ok(Arc::new(provider))
1158 }
1159 ProviderType::Mock => {
1160 Ok(Arc::new(MockProvider::new()))
1162 }
1163 ProviderType::VsCodeCopilot => {
1164 let provider = VsCodeCopilotProvider::new()
1167 .embedding_model(model)
1168 .build()
1169 .map_err(|e| LlmError::ApiError(e.to_string()))?;
1170 Ok(Arc::new(provider))
1171 }
1172 ProviderType::Mistral => {
1173 let provider = MistralProvider::from_env()?.with_embedding_model(model);
1175 Ok(Arc::new(provider))
1176 }
1177 ProviderType::AzureOpenAI => {
1178 let provider =
1180 AzureOpenAIProvider::from_env_auto()?.with_embedding_deployment(model);
1181 Ok(Arc::new(provider))
1182 }
1183 #[cfg(feature = "bedrock")]
1184 ProviderType::Bedrock => {
1185 warn!("Bedrock doesn't support embeddings via Converse API, using mock provider");
1187 Ok(Arc::new(MockProvider::new()))
1188 }
1189 }
1190 }
1191
1192 pub fn create_llm_provider(provider_name: &str, model: &str) -> Result<Arc<dyn LLMProvider>> {
1222 let provider_type = ProviderType::from_str(provider_name).ok_or_else(|| {
1223 LlmError::ConfigError(format!(
1224 "Unknown LLM provider: {}. Valid: openai, ollama, lmstudio, mock",
1225 provider_name
1226 ))
1227 })?;
1228
1229 match provider_type {
1230 ProviderType::OpenAI => {
1231 let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
1232 LlmError::ConfigError(
1233 "OPENAI_API_KEY required for OpenAI LLM provider".to_string(),
1234 )
1235 })?;
1236 let provider = OpenAIProvider::new(api_key).with_model(model);
1238 Ok(Arc::new(provider))
1239 }
1240 ProviderType::Anthropic => {
1241 let api_key = std::env::var("ANTHROPIC_API_KEY").map_err(|_| {
1242 LlmError::ConfigError(
1243 "ANTHROPIC_API_KEY required for Anthropic LLM provider".to_string(),
1244 )
1245 })?;
1246 let provider = AnthropicProvider::new(api_key).with_model(model);
1248 Ok(Arc::new(provider))
1249 }
1250 ProviderType::OpenRouter => {
1251 let api_key = std::env::var("OPENROUTER_API_KEY").map_err(|_| {
1252 LlmError::ConfigError(
1253 "OPENROUTER_API_KEY required for OpenRouter LLM provider".to_string(),
1254 )
1255 })?;
1256 let provider = OpenRouterProvider::new(api_key).with_model(model);
1258 Ok(Arc::new(provider))
1259 }
1260 ProviderType::XAI => {
1261 let provider = XAIProvider::from_env()?.with_model(model);
1263 Ok(Arc::new(provider))
1264 }
1265 ProviderType::HuggingFace => {
1266 let provider = HuggingFaceProvider::from_env()?.with_model(model);
1268 Ok(Arc::new(provider))
1269 }
1270 ProviderType::Gemini => {
1271 if model.starts_with("vertexai:") {
1274 let actual_model = model.strip_prefix("vertexai:").unwrap_or(model);
1275 let provider = GeminiProvider::from_env_vertex_ai()?.with_model(actual_model);
1276 Ok(Arc::new(provider))
1277 } else {
1278 let provider = GeminiProvider::from_env()?.with_model(model);
1279 Ok(Arc::new(provider))
1280 }
1281 }
1282 ProviderType::Ollama => {
1283 let host = std::env::var("OLLAMA_HOST")
1285 .unwrap_or_else(|_| "http://localhost:11434".to_string());
1286 let provider = OllamaProvider::builder().host(&host).model(model).build()?;
1287 Ok(Arc::new(provider))
1288 }
1289 ProviderType::LMStudio => {
1290 let host = std::env::var("LMSTUDIO_HOST")
1292 .unwrap_or_else(|_| "http://localhost:1234".to_string());
1293 let provider = LMStudioProvider::builder()
1294 .host(&host)
1295 .model(model)
1296 .build()?;
1297 Ok(Arc::new(provider))
1298 }
1299 ProviderType::Mock => {
1300 Ok(Arc::new(MockProvider::new()))
1302 }
1303 ProviderType::VsCodeCopilot => {
1304 let proxy_url = std::env::var("VSCODE_COPILOT_PROXY_URL")
1306 .unwrap_or_else(|_| "http://localhost:4141".to_string());
1307 let provider = VsCodeCopilotProvider::new()
1308 .proxy_url(&proxy_url)
1309 .model(model)
1310 .build()?;
1311 Ok(Arc::new(provider))
1312 }
1313 ProviderType::Mistral => {
1314 let provider = MistralProvider::from_env()?.with_model(model);
1316 Ok(Arc::new(provider))
1317 }
1318 ProviderType::AzureOpenAI => {
1319 let provider = AzureOpenAIProvider::from_env_auto()?.with_deployment(model);
1321 Ok(Arc::new(provider))
1322 }
1323 #[cfg(feature = "bedrock")]
1324 ProviderType::Bedrock => {
1325 use crate::providers::bedrock::BedrockProvider;
1327 let handle = tokio::runtime::Handle::try_current().map_err(|_| {
1328 LlmError::ConfigError("Bedrock provider requires a Tokio runtime".to_string())
1329 })?;
1330 let provider =
1331 tokio::task::block_in_place(|| handle.block_on(BedrockProvider::from_env()))
1332 .map_err(|e| {
1333 LlmError::ConfigError(format!(
1334 "Failed to initialize Bedrock provider: {e}"
1335 ))
1336 })?;
1337 let provider = provider.with_model(model);
1338 Ok(Arc::new(provider))
1339 }
1340 }
1341 }
1342}
1343
1344#[cfg(test)]
1345mod tests {
1346 use super::*;
1347 use serial_test::serial;
1348
1349 #[test]
1350 fn test_provider_type_parsing() {
1351 assert_eq!(ProviderType::from_str("openai"), Some(ProviderType::OpenAI));
1352 assert_eq!(ProviderType::from_str("OLLAMA"), Some(ProviderType::Ollama));
1353 assert_eq!(
1354 ProviderType::from_str("lmstudio"),
1355 Some(ProviderType::LMStudio)
1356 );
1357 assert_eq!(
1358 ProviderType::from_str("lm-studio"),
1359 Some(ProviderType::LMStudio)
1360 );
1361 assert_eq!(
1362 ProviderType::from_str("lm_studio"),
1363 Some(ProviderType::LMStudio)
1364 );
1365 assert_eq!(ProviderType::from_str("mock"), Some(ProviderType::Mock));
1366
1367 assert_eq!(ProviderType::from_str("gemini"), Some(ProviderType::Gemini));
1369 assert_eq!(ProviderType::from_str("google"), Some(ProviderType::Gemini));
1370 assert_eq!(ProviderType::from_str("vertex"), Some(ProviderType::Gemini));
1371 assert_eq!(
1372 ProviderType::from_str("vertexai"),
1373 Some(ProviderType::Gemini)
1374 );
1375
1376 assert_eq!(
1378 ProviderType::from_str("openrouter"),
1379 Some(ProviderType::OpenRouter)
1380 );
1381
1382 assert_eq!(ProviderType::from_str("xai"), Some(ProviderType::XAI));
1384 assert_eq!(ProviderType::from_str("grok"), Some(ProviderType::XAI));
1385
1386 assert_eq!(
1388 ProviderType::from_str("huggingface"),
1389 Some(ProviderType::HuggingFace)
1390 );
1391 assert_eq!(
1392 ProviderType::from_str("hf"),
1393 Some(ProviderType::HuggingFace)
1394 );
1395 assert_eq!(
1396 ProviderType::from_str("hugging-face"),
1397 Some(ProviderType::HuggingFace)
1398 );
1399
1400 assert_eq!(
1402 ProviderType::from_str("azure"),
1403 Some(ProviderType::AzureOpenAI)
1404 );
1405 assert_eq!(
1406 ProviderType::from_str("azure-openai"),
1407 Some(ProviderType::AzureOpenAI)
1408 );
1409 assert_eq!(
1410 ProviderType::from_str("azure_openai"),
1411 Some(ProviderType::AzureOpenAI)
1412 );
1413 assert_eq!(
1414 ProviderType::from_str("AZURE"),
1415 Some(ProviderType::AzureOpenAI)
1416 );
1417
1418 assert_eq!(ProviderType::from_str("invalid"), None);
1419 assert_eq!(ProviderType::from_str(""), None);
1420 }
1421
1422 #[test]
1423 fn test_mock_creation() {
1424 let (llm, embedding) = ProviderFactory::create_mock();
1425 assert_eq!(llm.name(), "mock");
1426 assert_eq!(embedding.name(), "mock");
1427 assert_eq!(embedding.dimension(), 1536);
1428 }
1429
1430 #[test]
1431 fn test_explicit_mock_creation() {
1432 let (llm, embedding) = ProviderFactory::create(ProviderType::Mock).unwrap();
1433 assert_eq!(llm.name(), "mock");
1434 assert_eq!(embedding.dimension(), 1536);
1435 }
1436
1437 #[test]
1438 #[serial]
1439 fn test_from_env_fallback_to_mock() {
1440 std::env::remove_var("EDGEQUAKE_LLM_PROVIDER");
1444 std::env::remove_var("OPENAI_API_KEY");
1445 std::env::remove_var("XAI_API_KEY");
1446 std::env::remove_var("GOOGLE_API_KEY");
1447 std::env::remove_var("GEMINI_API_KEY");
1448 std::env::remove_var("OPENROUTER_API_KEY");
1449 std::env::remove_var("ANTHROPIC_API_KEY");
1450 std::env::remove_var("AZURE_OPENAI_API_KEY");
1452 std::env::remove_var("AZURE_OPENAI_ENDPOINT");
1453 std::env::remove_var("AZURE_OPENAI_DEPLOYMENT_NAME");
1454 std::env::remove_var("AZURE_OPENAI_CONTENTGEN_API_KEY");
1456 std::env::remove_var("AZURE_OPENAI_CONTENTGEN_API_ENDPOINT");
1457 std::env::remove_var("AZURE_OPENAI_CONTENTGEN_MODEL_DEPLOYMENT");
1458 std::env::remove_var("HUGGINGFACE_API_KEY"); std::env::remove_var("HF_TOKEN"); std::env::remove_var("HUGGINGFACE_TOKEN"); std::env::remove_var("OLLAMA_HOST");
1462 std::env::remove_var("OLLAMA_MODEL");
1463 std::env::remove_var("LMSTUDIO_HOST");
1464 std::env::remove_var("LMSTUDIO_MODEL");
1465 std::env::remove_var("MISTRAL_API_KEY");
1466
1467 let (llm, _) = ProviderFactory::from_env().unwrap();
1468 assert_eq!(llm.name(), "mock");
1469 }
1470
1471 #[test]
1472 #[serial]
1473 fn test_explicit_provider_env() {
1474 std::env::remove_var("EDGEQUAKE_LLM_PROVIDER");
1476 std::env::remove_var("OLLAMA_HOST");
1477 std::env::remove_var("OPENAI_API_KEY");
1478 std::env::remove_var("XAI_API_KEY");
1479 std::env::remove_var("GOOGLE_API_KEY");
1480 std::env::remove_var("GEMINI_API_KEY");
1481 std::env::remove_var("OPENROUTER_API_KEY");
1482 std::env::remove_var("ANTHROPIC_API_KEY");
1483 std::env::remove_var("AZURE_OPENAI_API_KEY");
1484 std::env::remove_var("LMSTUDIO_HOST");
1485
1486 std::env::set_var("EDGEQUAKE_LLM_PROVIDER", "mock");
1487 let (llm, _) = ProviderFactory::from_env().unwrap();
1488 assert_eq!(llm.name(), "mock");
1489
1490 std::env::remove_var("EDGEQUAKE_LLM_PROVIDER");
1492 }
1493
1494 #[test]
1495 #[serial]
1496 fn test_lmstudio_auto_detection() {
1497 std::env::remove_var("EDGEQUAKE_LLM_PROVIDER");
1499 std::env::remove_var("OLLAMA_HOST");
1500 std::env::remove_var("OLLAMA_MODEL");
1501 std::env::remove_var("OPENAI_API_KEY");
1502 std::env::remove_var("XAI_API_KEY");
1503 std::env::remove_var("GOOGLE_API_KEY");
1504 std::env::remove_var("GEMINI_API_KEY");
1505 std::env::remove_var("OPENROUTER_API_KEY");
1506 std::env::remove_var("ANTHROPIC_API_KEY");
1507 std::env::remove_var("AZURE_OPENAI_API_KEY");
1508
1509 std::env::set_var("LMSTUDIO_HOST", "http://localhost:1234");
1511 let (llm, embedding) = ProviderFactory::from_env().unwrap();
1512 assert_eq!(llm.name(), "lmstudio");
1513 assert_eq!(embedding.name(), "lmstudio");
1514
1515 std::env::remove_var("LMSTUDIO_HOST");
1517 }
1518
1519 #[test]
1520 #[serial]
1521 fn test_lmstudio_model_detection() {
1522 std::env::remove_var("EDGEQUAKE_LLM_PROVIDER");
1524 std::env::remove_var("OLLAMA_HOST");
1525 std::env::remove_var("OLLAMA_MODEL");
1526 std::env::remove_var("OPENAI_API_KEY");
1527 std::env::remove_var("XAI_API_KEY");
1528 std::env::remove_var("GOOGLE_API_KEY");
1529 std::env::remove_var("GEMINI_API_KEY");
1530 std::env::remove_var("OPENROUTER_API_KEY");
1531 std::env::remove_var("ANTHROPIC_API_KEY");
1532 std::env::remove_var("AZURE_OPENAI_API_KEY");
1533 std::env::remove_var("LMSTUDIO_HOST");
1534
1535 std::env::set_var("LMSTUDIO_MODEL", "mistral-7b");
1537 let (llm, _) = ProviderFactory::from_env().unwrap();
1538 assert_eq!(llm.name(), "lmstudio");
1539
1540 std::env::remove_var("LMSTUDIO_MODEL");
1542 }
1543
1544 #[test]
1545 fn test_explicit_lmstudio_creation() {
1546 let (llm, embedding) = ProviderFactory::create(ProviderType::LMStudio).unwrap();
1547 assert_eq!(llm.name(), "lmstudio");
1548 assert_eq!(embedding.name(), "lmstudio");
1549 assert_eq!(embedding.dimension(), 768);
1551 }
1552
1553 #[test]
1554 #[serial]
1555 fn test_invalid_provider_env() {
1556 std::env::remove_var("EDGEQUAKE_LLM_PROVIDER");
1558 std::env::remove_var("OLLAMA_HOST");
1559 std::env::remove_var("OPENAI_API_KEY");
1560 std::env::remove_var("XAI_API_KEY");
1561 std::env::remove_var("GOOGLE_API_KEY");
1562 std::env::remove_var("GEMINI_API_KEY");
1563 std::env::remove_var("OPENROUTER_API_KEY");
1564 std::env::remove_var("ANTHROPIC_API_KEY");
1565 std::env::remove_var("AZURE_OPENAI_API_KEY");
1566 std::env::remove_var("LMSTUDIO_HOST");
1567
1568 std::env::set_var("EDGEQUAKE_LLM_PROVIDER", "invalid_provider");
1569 let result = ProviderFactory::from_env();
1570 assert!(result.is_err());
1571 if let Err(e) = result {
1572 assert!(e.to_string().contains("Unknown provider type"));
1573 }
1574
1575 std::env::remove_var("EDGEQUAKE_LLM_PROVIDER");
1577 }
1578
1579 #[test]
1580 #[serial]
1581 fn test_openai_creation_requires_api_key() {
1582 std::env::remove_var("OPENAI_API_KEY");
1584 std::env::remove_var("EDGEQUAKE_LLM_PROVIDER");
1585 std::env::remove_var("OLLAMA_HOST");
1586 std::env::remove_var("LMSTUDIO_HOST");
1587
1588 let result = ProviderFactory::create(ProviderType::OpenAI);
1589 assert!(result.is_err());
1590 if let Err(e) = result {
1591 assert!(e.to_string().contains("OPENAI_API_KEY not set"));
1592 }
1593 }
1594
1595 #[test]
1596 #[serial]
1597 fn test_embedding_dimension_detection() {
1598 std::env::remove_var("EDGEQUAKE_LLM_PROVIDER");
1600 std::env::remove_var("OLLAMA_HOST");
1601 std::env::remove_var("OPENAI_API_KEY");
1602 std::env::remove_var("LMSTUDIO_HOST");
1603
1604 std::env::set_var("EDGEQUAKE_LLM_PROVIDER", "mock");
1605 let dim = ProviderFactory::embedding_dimension().unwrap();
1606 assert_eq!(dim, 1536);
1607
1608 std::env::remove_var("EDGEQUAKE_LLM_PROVIDER");
1610 }
1611
1612 #[test]
1613 #[serial]
1614 fn test_provider_priority_ollama_over_lmstudio() {
1615 std::env::remove_var("EDGEQUAKE_LLM_PROVIDER");
1617 std::env::remove_var("OPENAI_API_KEY");
1618 std::env::remove_var("LMSTUDIO_HOST");
1619 std::env::remove_var("LMSTUDIO_MODEL");
1620
1621 std::env::set_var("OLLAMA_HOST", "http://localhost:11434");
1623 std::env::set_var("LMSTUDIO_HOST", "http://localhost:1234");
1624
1625 let (llm, _) = ProviderFactory::from_env().unwrap();
1626 assert_eq!(llm.name(), "ollama");
1627
1628 std::env::remove_var("OLLAMA_HOST");
1630 std::env::remove_var("LMSTUDIO_HOST");
1631 }
1632
1633 #[test]
1634 fn test_create_with_model_mock_none() {
1635 let (llm, emb) = ProviderFactory::create_with_model(ProviderType::Mock, None).unwrap();
1637 assert_eq!(llm.name(), "mock");
1638 assert_eq!(emb.name(), "mock");
1639 }
1640
1641 #[test]
1642 fn test_create_with_model_mock_some() {
1643 let (llm, _) =
1645 ProviderFactory::create_with_model(ProviderType::Mock, Some("any-model")).unwrap();
1646 assert_eq!(llm.name(), "mock");
1647 }
1648
1649 #[test]
1650 fn test_create_embedding_provider_mock() {
1651 let provider =
1652 ProviderFactory::create_embedding_provider("mock", "mock-model", 1536).unwrap();
1653 assert_eq!(provider.name(), "mock");
1654 assert_eq!(provider.dimension(), 1536);
1655 }
1656
1657 #[test]
1658 fn test_create_embedding_provider_unknown() {
1659 let result = ProviderFactory::create_embedding_provider("unknown", "model", 1536);
1660 match result {
1661 Err(e) => assert!(e.to_string().contains("Unknown embedding provider")),
1662 Ok(_) => panic!("Expected error for unknown provider"),
1663 }
1664 }
1665
1666 #[test]
1667 fn test_create_llm_provider_mock() {
1668 let provider = ProviderFactory::create_llm_provider("mock", "mock-model").unwrap();
1669 assert_eq!(provider.name(), "mock");
1670 }
1671
1672 #[test]
1673 fn test_create_llm_provider_unknown() {
1674 let result = ProviderFactory::create_llm_provider("unknown", "model");
1675 match result {
1676 Err(e) => assert!(e.to_string().contains("Unknown LLM provider")),
1677 Ok(_) => panic!("Expected error for unknown provider"),
1678 }
1679 }
1680
1681 #[test]
1682 fn test_provider_type_debug() {
1683 let pt = ProviderType::Mock;
1685 let debug = format!("{:?}", pt);
1686 assert_eq!(debug, "Mock");
1687 }
1688
1689 #[test]
1690 fn test_provider_type_clone_eq() {
1691 let pt1 = ProviderType::OpenAI;
1692 let pt2 = pt1;
1693 assert_eq!(pt1, pt2);
1694 assert_ne!(pt1, ProviderType::Ollama);
1695 }
1696
1697 #[test]
1698 #[serial]
1699 fn test_from_config_azure_no_creds() {
1700 use crate::model_config::{ProviderConfig, ProviderType as ConfigProviderType};
1701 std::env::set_var("AZURE_OPENAI_CONTENTGEN_API_KEY", "");
1704 std::env::set_var("AZURE_OPENAI_CONTENTGEN_API_ENDPOINT", "");
1705 std::env::set_var("AZURE_OPENAI_CONTENTGEN_MODEL_DEPLOYMENT", "");
1706 std::env::set_var("AZURE_OPENAI_API_KEY", "");
1707 std::env::set_var("AZURE_OPENAI_ENDPOINT", "");
1708 std::env::set_var("AZURE_OPENAI_DEPLOYMENT_NAME", "");
1709 let config = ProviderConfig {
1710 provider_type: ConfigProviderType::Azure,
1711 ..ProviderConfig::default()
1712 };
1713 let result = ProviderFactory::from_config(&config);
1714 std::env::remove_var("AZURE_OPENAI_CONTENTGEN_API_KEY");
1716 std::env::remove_var("AZURE_OPENAI_CONTENTGEN_API_ENDPOINT");
1717 std::env::remove_var("AZURE_OPENAI_CONTENTGEN_MODEL_DEPLOYMENT");
1718 std::env::remove_var("AZURE_OPENAI_API_KEY");
1719 std::env::remove_var("AZURE_OPENAI_ENDPOINT");
1720 std::env::remove_var("AZURE_OPENAI_DEPLOYMENT_NAME");
1721 assert!(
1723 result.is_err(),
1724 "Expected error when Azure credentials are not set"
1725 );
1726 }
1727
1728 #[test]
1729 fn test_from_config_mock() {
1730 use crate::model_config::{ProviderConfig, ProviderType as ConfigProviderType};
1731 let config = ProviderConfig {
1732 provider_type: ConfigProviderType::Mock,
1733 ..ProviderConfig::default()
1734 };
1735 let (llm, emb) = ProviderFactory::from_config(&config).unwrap();
1736 assert_eq!(llm.name(), "mock");
1737 assert_eq!(emb.name(), "mock");
1738 }
1739
1740 #[test]
1741 fn test_vscode_copilot_parsing() {
1742 assert_eq!(
1743 ProviderType::from_str("vscode"),
1744 Some(ProviderType::VsCodeCopilot)
1745 );
1746 assert_eq!(
1747 ProviderType::from_str("copilot"),
1748 Some(ProviderType::VsCodeCopilot)
1749 );
1750 assert_eq!(
1751 ProviderType::from_str("vscode-copilot"),
1752 Some(ProviderType::VsCodeCopilot)
1753 );
1754 }
1755
1756 #[test]
1758 fn test_create_embedding_provider_anthropic_fallback() {
1759 let provider =
1761 ProviderFactory::create_embedding_provider("anthropic", "any-model", 1536).unwrap();
1762 assert_eq!(provider.name(), "mock");
1763 }
1764
1765 #[test]
1766 fn test_create_embedding_provider_openrouter_fallback() {
1767 let provider =
1769 ProviderFactory::create_embedding_provider("openrouter", "any-model", 1536).unwrap();
1770 assert_eq!(provider.name(), "mock");
1771 }
1772
1773 #[test]
1774 fn test_create_embedding_provider_xai_fallback() {
1775 let provider =
1777 ProviderFactory::create_embedding_provider("xai", "any-model", 1536).unwrap();
1778 assert_eq!(provider.name(), "mock");
1779 }
1780
1781 #[test]
1782 fn test_create_embedding_provider_huggingface_fallback() {
1783 let provider =
1785 ProviderFactory::create_embedding_provider("huggingface", "any-model", 768).unwrap();
1786 assert_eq!(provider.name(), "mock");
1787 }
1788
1789 #[test]
1790 fn test_create_embedding_provider_gemini_fallback() {
1791 let provider =
1794 ProviderFactory::create_embedding_provider("gemini", "any-model", 768).unwrap();
1795 let name = provider.name();
1796 assert!(
1797 name == "gemini" || name == "mock",
1798 "Expected 'gemini' (with API key) or 'mock' (without), got '{}'",
1799 name
1800 );
1801 }
1802
1803 #[test]
1804 fn test_create_embedding_provider_ollama() {
1805 let provider =
1807 ProviderFactory::create_embedding_provider("ollama", "nomic-embed-text", 768).unwrap();
1808 assert_eq!(provider.name(), "ollama");
1809 }
1810
1811 #[test]
1812 fn test_create_embedding_provider_lmstudio() {
1813 let provider =
1815 ProviderFactory::create_embedding_provider("lmstudio", "nomic-embed-text-v1.5", 768)
1816 .unwrap();
1817 assert_eq!(provider.name(), "lmstudio");
1818 }
1819
1820 #[test]
1821 fn test_create_embedding_provider_vscode_copilot() {
1822 let provider = ProviderFactory::create_embedding_provider(
1824 "vscode-copilot",
1825 "text-embedding-3-small",
1826 1536,
1827 )
1828 .unwrap();
1829 assert_eq!(provider.name(), "vscode-copilot");
1830 }
1831
1832 #[test]
1834 fn test_from_config_ollama() {
1835 use crate::model_config::{ProviderConfig, ProviderType as ConfigProviderType};
1836 let config = ProviderConfig {
1837 provider_type: ConfigProviderType::Ollama,
1838 ..ProviderConfig::default()
1839 };
1840 let (llm, emb) = ProviderFactory::from_config(&config).unwrap();
1841 assert_eq!(llm.name(), "ollama");
1842 assert_eq!(emb.name(), "ollama");
1843 }
1844
1845 #[test]
1846 fn test_from_config_lmstudio() {
1847 use crate::model_config::{ProviderConfig, ProviderType as ConfigProviderType};
1848 let config = ProviderConfig {
1849 provider_type: ConfigProviderType::LMStudio,
1850 ..ProviderConfig::default()
1851 };
1852 let (llm, emb) = ProviderFactory::from_config(&config).unwrap();
1853 assert_eq!(llm.name(), "lmstudio");
1854 assert_eq!(emb.name(), "lmstudio");
1855 }
1856
1857 #[test]
1858 #[serial]
1859 fn test_from_config_openai_requires_api_key() {
1860 use crate::model_config::{ProviderConfig, ProviderType as ConfigProviderType};
1861 std::env::remove_var("OPENAI_API_KEY");
1862 let config = ProviderConfig {
1863 provider_type: ConfigProviderType::OpenAI,
1864 ..ProviderConfig::default()
1865 };
1866 let result = ProviderFactory::from_config(&config);
1867 assert!(result.is_err());
1868 if let Err(e) = result {
1869 assert!(e.to_string().contains("OPENAI_API_KEY"));
1870 }
1871 }
1872
1873 #[test]
1875 fn test_create_with_model_ollama() {
1876 let (llm, _) =
1877 ProviderFactory::create_with_model(ProviderType::Ollama, Some("llama3:8b")).unwrap();
1878 assert_eq!(llm.name(), "ollama");
1879 assert_eq!(llm.model(), "llama3:8b");
1880 }
1881
1882 #[test]
1883 fn test_create_with_model_lmstudio() {
1884 let (llm, _) =
1885 ProviderFactory::create_with_model(ProviderType::LMStudio, Some("mistral-7b")).unwrap();
1886 assert_eq!(llm.name(), "lmstudio");
1887 assert_eq!(llm.model(), "mistral-7b");
1888 }
1889
1890 #[test]
1893 #[serial]
1894 fn test_provider_type_parsing_azure() {
1895 assert_eq!(
1896 ProviderType::from_str("azure"),
1897 Some(ProviderType::AzureOpenAI)
1898 );
1899 assert_eq!(
1900 ProviderType::from_str("azure-openai"),
1901 Some(ProviderType::AzureOpenAI)
1902 );
1903 assert_eq!(
1904 ProviderType::from_str("AZURE"),
1905 Some(ProviderType::AzureOpenAI)
1906 );
1907 }
1908
1909 #[test]
1910 #[serial]
1911 fn test_create_azure_openai_fails_without_env() {
1912 std::env::set_var("AZURE_OPENAI_CONTENTGEN_API_KEY", "");
1915 std::env::set_var("AZURE_OPENAI_CONTENTGEN_API_ENDPOINT", "");
1916 std::env::set_var("AZURE_OPENAI_CONTENTGEN_MODEL_DEPLOYMENT", "");
1917 std::env::set_var("AZURE_OPENAI_API_KEY", "");
1918 std::env::set_var("AZURE_OPENAI_ENDPOINT", "");
1919 std::env::set_var("AZURE_OPENAI_DEPLOYMENT_NAME", "");
1920
1921 let result = ProviderFactory::create(ProviderType::AzureOpenAI);
1922 std::env::remove_var("AZURE_OPENAI_CONTENTGEN_API_KEY");
1924 std::env::remove_var("AZURE_OPENAI_CONTENTGEN_API_ENDPOINT");
1925 std::env::remove_var("AZURE_OPENAI_CONTENTGEN_MODEL_DEPLOYMENT");
1926 std::env::remove_var("AZURE_OPENAI_API_KEY");
1927 std::env::remove_var("AZURE_OPENAI_ENDPOINT");
1928 std::env::remove_var("AZURE_OPENAI_DEPLOYMENT_NAME");
1929 assert!(
1930 result.is_err(),
1931 "Azure provider should fail when env vars are empty"
1932 );
1933 }
1934
1935 #[test]
1936 #[serial]
1937 fn test_from_env_auto_detects_azure_with_contentgen_vars() {
1938 std::env::remove_var("EDGEQUAKE_LLM_PROVIDER");
1940 std::env::remove_var("OLLAMA_HOST");
1941 std::env::remove_var("OLLAMA_MODEL");
1942 std::env::remove_var("LMSTUDIO_HOST");
1943 std::env::remove_var("LMSTUDIO_MODEL");
1944 std::env::remove_var("ANTHROPIC_API_KEY");
1945 std::env::remove_var("GEMINI_API_KEY");
1946 std::env::remove_var("GOOGLE_API_KEY");
1947 std::env::remove_var("MISTRAL_API_KEY");
1948
1949 std::env::set_var("AZURE_OPENAI_CONTENTGEN_API_KEY", "test-azure-key");
1951 std::env::set_var(
1952 "AZURE_OPENAI_CONTENTGEN_API_ENDPOINT",
1953 "https://test.openai.azure.com",
1954 );
1955 std::env::set_var("AZURE_OPENAI_CONTENTGEN_MODEL_DEPLOYMENT", "gpt-4o");
1956
1957 let result = ProviderFactory::from_env();
1958
1959 std::env::remove_var("AZURE_OPENAI_CONTENTGEN_API_KEY");
1961 std::env::remove_var("AZURE_OPENAI_CONTENTGEN_API_ENDPOINT");
1962 std::env::remove_var("AZURE_OPENAI_CONTENTGEN_MODEL_DEPLOYMENT");
1963
1964 let (llm, _) = result.expect("Should detect Azure from CONTENTGEN vars");
1965 assert_eq!(llm.name(), "azure-openai");
1966 }
1967
1968 #[test]
1969 #[serial]
1970 fn test_from_env_auto_detects_azure_with_standard_vars() {
1971 std::env::remove_var("EDGEQUAKE_LLM_PROVIDER");
1973 std::env::remove_var("OLLAMA_HOST");
1974 std::env::remove_var("OLLAMA_MODEL");
1975 std::env::remove_var("LMSTUDIO_HOST");
1976 std::env::remove_var("LMSTUDIO_MODEL");
1977 std::env::remove_var("ANTHROPIC_API_KEY");
1978 std::env::remove_var("GEMINI_API_KEY");
1979 std::env::remove_var("GOOGLE_API_KEY");
1980 std::env::remove_var("MISTRAL_API_KEY");
1981 std::env::remove_var("AZURE_OPENAI_CONTENTGEN_API_KEY");
1982
1983 std::env::set_var("AZURE_OPENAI_API_KEY", "test-azure-key");
1985 std::env::set_var("AZURE_OPENAI_ENDPOINT", "https://test.openai.azure.com");
1986 std::env::set_var("AZURE_OPENAI_DEPLOYMENT_NAME", "gpt-4o");
1987
1988 let result = ProviderFactory::from_env();
1989
1990 std::env::remove_var("AZURE_OPENAI_API_KEY");
1992 std::env::remove_var("AZURE_OPENAI_ENDPOINT");
1993 std::env::remove_var("AZURE_OPENAI_DEPLOYMENT_NAME");
1994
1995 let (llm, _) = result.expect("Should detect Azure from standard vars");
1996 assert_eq!(llm.name(), "azure-openai");
1997 }
1998
1999 #[test]
2000 #[serial]
2001 fn test_explicit_azure_provider_selection() {
2002 std::env::remove_var("EDGEQUAKE_LLM_PROVIDER");
2004 std::env::remove_var("OLLAMA_HOST");
2005 std::env::remove_var("LMSTUDIO_HOST");
2006 std::env::remove_var("OPENAI_API_KEY");
2007
2008 std::env::set_var("AZURE_OPENAI_CONTENTGEN_API_KEY", "test-key");
2010 std::env::set_var(
2011 "AZURE_OPENAI_CONTENTGEN_API_ENDPOINT",
2012 "https://test.openai.azure.com",
2013 );
2014 std::env::set_var("AZURE_OPENAI_CONTENTGEN_MODEL_DEPLOYMENT", "gpt-4.1-mini");
2015
2016 std::env::set_var("EDGEQUAKE_LLM_PROVIDER", "azure");
2017 let result = ProviderFactory::from_env();
2018
2019 std::env::remove_var("EDGEQUAKE_LLM_PROVIDER");
2021 std::env::remove_var("AZURE_OPENAI_CONTENTGEN_API_KEY");
2022 std::env::remove_var("AZURE_OPENAI_CONTENTGEN_API_ENDPOINT");
2023 std::env::remove_var("AZURE_OPENAI_CONTENTGEN_MODEL_DEPLOYMENT");
2024
2025 let (llm, _) = result.expect("Explicit azure provider selection should succeed");
2026 assert_eq!(llm.name(), "azure-openai");
2027 assert_eq!(llm.model(), "gpt-4.1-mini");
2028 }
2029
2030 #[test]
2031 #[serial]
2032 fn test_create_with_model_azure() {
2033 std::env::set_var("AZURE_OPENAI_CONTENTGEN_API_KEY", "test-key");
2034 std::env::set_var(
2035 "AZURE_OPENAI_CONTENTGEN_API_ENDPOINT",
2036 "https://test.openai.azure.com",
2037 );
2038 std::env::set_var("AZURE_OPENAI_CONTENTGEN_MODEL_DEPLOYMENT", "gpt-4o");
2039
2040 let result = ProviderFactory::create_with_model(
2041 ProviderType::AzureOpenAI,
2042 Some("my-custom-deployment"),
2043 );
2044
2045 std::env::remove_var("AZURE_OPENAI_CONTENTGEN_API_KEY");
2047 std::env::remove_var("AZURE_OPENAI_CONTENTGEN_API_ENDPOINT");
2048 std::env::remove_var("AZURE_OPENAI_CONTENTGEN_MODEL_DEPLOYMENT");
2049
2050 let (llm, _) = result.expect("Azure create_with_model should succeed");
2051 assert_eq!(llm.name(), "azure-openai");
2052 assert_eq!(llm.model(), "my-custom-deployment");
2053 }
2054}