1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::OnceLock;
4
5use serde::{Deserialize, Serialize};
6
7use crate::types::ReasoningEffort;
8
9static DOTENV_VARS: OnceLock<HashMap<String, String>> = OnceLock::new();
10
11fn load_dotenv_once(path: &Path) -> &'static HashMap<String, String> {
14 DOTENV_VARS.get_or_init(|| {
15 let mut map = HashMap::new();
16 let Ok(content) = std::fs::read_to_string(path) else {
17 return map;
18 };
19 for line in content.lines() {
20 let line = line.trim();
21 if line.is_empty() || line.starts_with('#') {
22 continue;
23 }
24 if let Some((k, v)) = line.split_once('=') {
25 let k = k.trim().to_string();
26 let v = v.trim().trim_matches('"').trim_matches('\'').to_string();
27 map.insert(k, v);
28 }
29 }
30 map
31 })
32}
33
34fn env_or_dotenv(key: &str, dotenv: &HashMap<String, String>) -> Option<String> {
36 std::env::var(key)
37 .ok()
38 .filter(|v| !v.is_empty())
39 .or_else(|| dotenv.get(key).filter(|v| !v.is_empty()).cloned())
40}
41
42pub fn get_secret(key: &str) -> Option<String> {
45 std::env::var(key)
46 .ok()
47 .filter(|v| !v.is_empty())
48 .or_else(|| {
49 DOTENV_VARS
50 .get()?
51 .get(key)
52 .filter(|v| !v.is_empty())
53 .cloned()
54 })
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize, Default)]
62pub struct ToolOverrideConfig {
63 #[serde(default)]
66 pub model: String,
67 #[serde(default)]
70 pub fallback_model: String,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct AgentConfig {
75 #[serde(skip)]
76 pub home_dir: PathBuf,
77 #[serde(default = "default_model")]
78 pub model: String,
79 #[serde(default = "default_max_iterations")]
80 pub max_iterations: u32,
81 #[serde(default)]
85 pub sub_agent_max_iterations: Option<u32>,
86 #[serde(default)]
87 pub tool_delay_ms: u64,
88 #[serde(default = "default_provider")]
89 pub provider: String,
90 pub base_url: Option<String>,
91 #[serde(default)]
96 pub routing: std::collections::HashMap<String, String>,
97 #[serde(default)]
108 pub tools: std::collections::HashMap<String, ToolOverrideConfig>,
109 #[serde(default)]
112 pub skills: std::collections::HashMap<String, ToolOverrideConfig>,
113 #[serde(skip)]
114 pub api_key: Option<String>,
115 #[serde(skip)]
118 pub fallback_api_keys: Vec<String>,
119 #[serde(default)]
120 pub compression: CompressionConfig,
121 #[serde(default)]
122 pub network: NetworkConfig,
123 #[serde(default)]
124 pub mcp_servers: Vec<McpServerConfig>,
125 #[serde(default)]
126 pub max_concurrent_requests: Option<usize>,
127 #[serde(default)]
128 pub security: SecurityConfig,
129 #[serde(default)]
130 pub memory_expiry: MemoryExpiryConfig,
131 #[serde(default = "default_nudge_interval")]
134 pub nudge_interval: u32,
135 #[serde(default = "default_llm_max_retries")]
137 pub llm_max_retries: u32,
138 #[serde(default = "default_llm_retry_base_ms")]
140 pub llm_retry_base_ms: u64,
141 #[serde(default)]
143 pub platform: PlatformConfig,
144 #[serde(default = "default_auto_skill_threshold")]
148 pub auto_skill_threshold: u32,
149 #[serde(default = "default_llm_timeout_secs")]
151 pub llm_timeout_secs: u64,
152 #[serde(default = "default_tool_timeout_secs")]
154 pub tool_timeout_secs: u64,
155 #[serde(default = "default_shutdown_timeout_secs")]
158 pub shutdown_timeout_secs: u64,
159 #[serde(default)]
163 pub max_tokens_per_task: Option<u32>,
164 #[serde(default)]
167 pub max_output_tokens: Option<u32>,
168 #[serde(default)]
171 pub reasoning_effort: Option<ReasoningEffort>,
172 #[serde(default)]
177 pub context_window: Option<usize>,
178 #[serde(default)]
186 pub disabled_toolsets: Vec<String>,
187 #[serde(default)]
192 pub disabled_tools: Vec<String>,
193 #[serde(default)]
197 pub show_usage_footer: bool,
198 #[serde(default)]
202 pub platforms: WebhookPlatformsConfig,
203 #[serde(default)]
206 pub server: ServerConfig,
207 #[serde(default)]
211 pub cron: CronConfig,
212}
213
214pub const DEFAULT_MODEL: &str = "anthropic/claude-sonnet-4-6";
216pub const DEFAULT_PROVIDER: &str = "openrouter";
218
219fn default_model() -> String {
220 DEFAULT_MODEL.into()
221}
222fn default_provider() -> String {
223 DEFAULT_PROVIDER.into()
224}
225fn default_max_iterations() -> u32 {
226 90
227}
228fn default_nudge_interval() -> u32 {
229 5
230}
231fn default_auto_skill_threshold() -> u32 {
232 5
233}
234fn default_llm_max_retries() -> u32 {
235 3
236}
237fn default_llm_retry_base_ms() -> u64 {
238 1000
239}
240fn default_llm_timeout_secs() -> u64 {
241 120
242}
243fn default_tool_timeout_secs() -> u64 {
244 60
245}
246fn default_shutdown_timeout_secs() -> u64 {
247 30
248}
249
250#[derive(Debug, Clone, Serialize, Deserialize)]
254pub struct MemoryExpiryConfig {
255 #[serde(default = "default_fact_days")]
257 pub fact_days: Option<u32>,
258 #[serde(default = "default_project_days")]
260 pub project_days: Option<u32>,
261 #[serde(default = "default_other_days")]
263 pub other_days: Option<u32>,
264 #[serde(default)]
266 pub preference_days: Option<u32>,
267 #[serde(default)]
269 pub skill_days: Option<u32>,
270}
271
272#[allow(clippy::unnecessary_wraps)]
273fn default_fact_days() -> Option<u32> {
274 Some(90)
275}
276#[allow(clippy::unnecessary_wraps)]
277fn default_project_days() -> Option<u32> {
278 Some(30)
279}
280#[allow(clippy::unnecessary_wraps)]
281fn default_other_days() -> Option<u32> {
282 Some(60)
283}
284
285impl Default for MemoryExpiryConfig {
286 fn default() -> Self {
287 Self {
288 fact_days: default_fact_days(),
289 project_days: default_project_days(),
290 other_days: default_other_days(),
291 preference_days: None,
292 skill_days: None,
293 }
294 }
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
299#[serde(rename_all = "lowercase")]
300pub enum TerminalSandbox {
301 #[default]
303 None,
304 Docker,
306}
307
308#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct SecurityConfig {
311 #[serde(skip)]
313 pub gateway_api_key: Option<String>,
314
315 #[serde(default)]
317 pub allowed_read_paths: Vec<PathBuf>,
318
319 #[serde(default)]
321 pub allowed_write_paths: Vec<PathBuf>,
322
323 #[serde(default = "default_approval_mode")]
325 pub approval_mode: String,
326
327 #[serde(default)]
329 pub rate_limit_rpm: Option<u32>,
330
331 #[serde(default)]
333 pub terminal_sandbox: TerminalSandbox,
334
335 #[serde(default = "default_sandbox_image")]
337 pub terminal_sandbox_image: String,
338
339 #[serde(default)]
342 pub terminal_sandbox_opts: Vec<String>,
343}
344
345fn default_approval_mode() -> String {
346 "smart".to_string()
347}
348
349fn default_sandbox_image() -> String {
350 "ubuntu:24.04".to_string()
351}
352
353impl Default for SecurityConfig {
354 fn default() -> Self {
355 Self {
356 gateway_api_key: None,
357 allowed_read_paths: Vec::new(),
358 allowed_write_paths: Vec::new(),
359 approval_mode: default_approval_mode(),
360 rate_limit_rpm: None,
361 terminal_sandbox: TerminalSandbox::None,
362 terminal_sandbox_image: default_sandbox_image(),
363 terminal_sandbox_opts: Vec::new(),
364 }
365 }
366}
367
368#[derive(Debug, Clone, Serialize, Deserialize)]
370pub struct PlatformConfig {
371 #[serde(default)]
374 pub allowed_user_ids: Vec<String>,
375
376 #[serde(default)]
379 pub require_mention: bool,
380
381 #[serde(default)]
384 pub bot_username: String,
385
386 #[serde(default = "default_true")]
390 pub session_per_user: bool,
391}
392
393fn default_true() -> bool {
394 true
395}
396
397impl Default for PlatformConfig {
398 fn default() -> Self {
399 Self {
400 allowed_user_ids: Vec::new(),
401 require_mention: false,
402 bot_username: String::new(),
403 session_per_user: true,
404 }
405 }
406}
407
408#[derive(Debug, Clone, Serialize, Deserialize)]
409pub struct McpServerConfig {
410 pub name: String,
411 pub command: String,
412 #[serde(default)]
413 pub args: Vec<String>,
414}
415
416#[derive(Debug, Clone, Serialize, Deserialize)]
420pub struct WebhookPlatformConfig {
421 #[serde(default)]
423 pub enabled: bool,
424 pub port: u16,
426 pub webhook_path: String,
428}
429
430#[derive(Debug, Clone, Serialize, Deserialize, Default)]
432pub struct WebhookPlatformsConfig {
433 #[serde(default)]
434 pub line: Option<WebhookPlatformConfig>,
435 #[serde(default)]
436 pub whatsapp: Option<WebhookPlatformConfig>,
437 #[serde(default)]
438 pub webhook: Option<WebhookPlatformConfig>,
439}
440
441#[derive(Debug, Clone, Serialize, Deserialize)]
443pub struct ServerConfig {
444 #[serde(default = "default_server_port")]
446 pub port: u16,
447}
448
449fn default_server_port() -> u16 {
450 3000
451}
452
453fn parse_cron_jobs_str(s: &str) -> Vec<CronJob> {
456 s.split(',')
457 .filter_map(|entry| {
458 let (expr, task) = entry.trim().split_once('=')?;
459 Some(CronJob {
460 schedule: expr.trim().to_string(),
461 task: task.trim().to_string(),
462 })
463 })
464 .collect()
465}
466
467impl Default for ServerConfig {
468 fn default() -> Self {
469 Self {
470 port: default_server_port(),
471 }
472 }
473}
474
475#[derive(Debug, Clone, Serialize, Deserialize)]
477pub struct CronJob {
478 pub schedule: String,
480 pub task: String,
482}
483
484#[derive(Debug, Clone, Serialize, Deserialize, Default)]
486pub struct CronConfig {
487 #[serde(default)]
489 pub jobs: Vec<CronJob>,
490 #[serde(default)]
492 pub memory_consolidation: Option<String>,
493 #[serde(default)]
495 pub memory_expiry: Option<String>,
496}
497
498impl WebhookPlatformConfig {
499 pub fn default_webhook() -> Self {
502 Self {
503 enabled: true,
504 port: 3001,
505 webhook_path: "/webhook".to_string(),
506 }
507 }
508
509 pub fn default_line() -> Self {
513 Self {
514 enabled: true,
515 port: 3002,
516 webhook_path: "/line".to_string(),
517 }
518 }
519
520 pub fn default_whatsapp() -> Self {
522 Self {
523 enabled: true,
524 port: 3003,
525 webhook_path: "/whatsapp".to_string(),
526 }
527 }
528}
529
530impl Default for AgentConfig {
531 fn default() -> Self {
532 let cwd = std::env::current_dir().unwrap_or_default();
533 let home = dirs::home_dir().unwrap_or_default();
534 Self {
535 home_dir: Self::garudust_dir(),
536 model: DEFAULT_MODEL.into(),
537 max_iterations: 90,
538 sub_agent_max_iterations: None,
539 tool_delay_ms: 0,
540 provider: DEFAULT_PROVIDER.into(),
541 base_url: None,
542 routing: std::collections::HashMap::new(),
543 tools: std::collections::HashMap::new(),
544 skills: std::collections::HashMap::new(),
545 api_key: None,
546 fallback_api_keys: Vec::new(),
547 compression: CompressionConfig::default(),
548 network: NetworkConfig::default(),
549 mcp_servers: Vec::new(),
550 max_concurrent_requests: None,
551 security: SecurityConfig {
552 gateway_api_key: None,
553 allowed_read_paths: vec![cwd.clone(), home],
554 allowed_write_paths: vec![cwd],
555 approval_mode: default_approval_mode(),
556 rate_limit_rpm: None,
557 terminal_sandbox: TerminalSandbox::None,
558 terminal_sandbox_image: default_sandbox_image(),
559 terminal_sandbox_opts: Vec::new(),
560 },
561 memory_expiry: MemoryExpiryConfig::default(),
562 nudge_interval: default_nudge_interval(),
563 llm_max_retries: default_llm_max_retries(),
564 llm_retry_base_ms: default_llm_retry_base_ms(),
565 platform: PlatformConfig::default(),
566 auto_skill_threshold: default_auto_skill_threshold(),
567 llm_timeout_secs: default_llm_timeout_secs(),
568 tool_timeout_secs: default_tool_timeout_secs(),
569 shutdown_timeout_secs: default_shutdown_timeout_secs(),
570 max_tokens_per_task: None,
571 max_output_tokens: None,
572 reasoning_effort: None,
573 context_window: None,
574 disabled_toolsets: Vec::new(),
575 disabled_tools: Vec::new(),
576 show_usage_footer: false,
577 platforms: WebhookPlatformsConfig {
578 webhook: Some(WebhookPlatformConfig::default_webhook()),
579 line: None,
580 whatsapp: None,
581 },
582 server: ServerConfig::default(),
583 cron: CronConfig::default(),
584 }
585 }
586}
587
588pub(crate) fn resolve_key_for_provider(
591 provider: &str,
592 dotenv: &HashMap<String, String>,
593) -> Option<String> {
594 match provider {
595 "anthropic" => env_or_dotenv("ANTHROPIC_API_KEY", dotenv),
596 "openai" => env_or_dotenv("OPENAI_API_KEY", dotenv),
597 "gemini" => env_or_dotenv("GEMINI_API_KEY", dotenv),
598 "groq" => env_or_dotenv("GROQ_API_KEY", dotenv),
599 "mistral" => env_or_dotenv("MISTRAL_API_KEY", dotenv),
600 "deepseek" => env_or_dotenv("DEEPSEEK_API_KEY", dotenv),
601 "xai" => env_or_dotenv("XAI_API_KEY", dotenv),
602 "vllm" => env_or_dotenv("VLLM_API_KEY", dotenv),
603 "thaillm" => env_or_dotenv("THAILLM_API_KEY", dotenv),
604 "ollama" | "bedrock" | "codex" => None,
605 _ => env_or_dotenv("OPENROUTER_API_KEY", dotenv),
606 }
607}
608
609pub(crate) fn detect_provider_from_env(config: &mut AgentConfig, dotenv: &HashMap<String, String>) {
613 if let Some(k) = env_or_dotenv("ANTHROPIC_API_KEY", dotenv) {
614 config.api_key = Some(k);
615 config.provider = "anthropic".into();
616 } else if let Some(k) = env_or_dotenv("OPENAI_API_KEY", dotenv) {
617 config.api_key = Some(k);
618 config.provider = "openai".into();
619 } else if let Some(k) = env_or_dotenv("GEMINI_API_KEY", dotenv) {
620 config.api_key = Some(k);
621 config.provider = "gemini".into();
622 } else if let Some(k) = env_or_dotenv("GROQ_API_KEY", dotenv) {
623 config.api_key = Some(k);
624 config.provider = "groq".into();
625 } else if let Some(k) = env_or_dotenv("MISTRAL_API_KEY", dotenv) {
626 config.api_key = Some(k);
627 config.provider = "mistral".into();
628 } else if let Some(k) = env_or_dotenv("DEEPSEEK_API_KEY", dotenv) {
629 config.api_key = Some(k);
630 config.provider = "deepseek".into();
631 } else if let Some(k) = env_or_dotenv("XAI_API_KEY", dotenv) {
632 config.api_key = Some(k);
633 config.provider = "xai".into();
634 } else if let Some(url) = env_or_dotenv("OLLAMA_BASE_URL", dotenv) {
635 config.provider = "ollama".into();
636 config.base_url = Some(url);
637 } else if let Some(url) = env_or_dotenv("VLLM_BASE_URL", dotenv) {
638 config.provider = "vllm".into();
639 config.base_url = Some(url);
640 config.api_key = env_or_dotenv("VLLM_API_KEY", dotenv);
641 } else if let Some(k) = env_or_dotenv("THAILLM_API_KEY", dotenv) {
642 config.api_key = Some(k);
643 config.provider = "thaillm".into();
644 } else if let Some(k) = env_or_dotenv("OPENROUTER_API_KEY", dotenv) {
645 config.api_key = Some(k);
646 config.provider = "openrouter".into();
647 }
648}
649
650impl AgentConfig {
651 pub fn garudust_dir() -> PathBuf {
653 dirs::home_dir()
654 .unwrap_or_else(|| PathBuf::from("/tmp"))
655 .join(".garudust")
656 }
657
658 pub fn load() -> Self {
666 let home_dir = Self::garudust_dir();
667
668 let env_file = home_dir.join(".env");
670 let dotenv = load_dotenv_once(&env_file);
671
672 let yaml_path = home_dir.join("config.yaml");
674 let mut config: AgentConfig = if yaml_path.exists() {
675 let src = std::fs::read_to_string(&yaml_path).unwrap_or_default();
676 serde_yaml::from_str(&src).unwrap_or_default()
677 } else {
678 AgentConfig::default()
679 };
680
681 config.home_dir = home_dir;
682
683 if config.security.allowed_read_paths.is_empty() {
685 let cwd = std::env::current_dir().unwrap_or_default();
686 let home = dirs::home_dir().unwrap_or_default();
687 config.security.allowed_read_paths = vec![cwd.clone(), home];
688 config.security.allowed_write_paths = vec![cwd];
689 }
690
691 let yaml_authoritative = yaml_path.exists();
702
703 if yaml_authoritative {
704 if config.api_key.is_none() {
705 config.api_key = resolve_key_for_provider(&config.provider, dotenv);
706 }
707 } else {
708 detect_provider_from_env(&mut config, dotenv);
709 }
710 if let Some(m) = env_or_dotenv("GARUDUST_MODEL", dotenv) {
711 config.model = m;
712 }
713 if let Some(u) = env_or_dotenv("GARUDUST_BASE_URL", dotenv) {
714 config.base_url = Some(u);
715 }
716 if let Some(v) = env_or_dotenv("LLM_FALLBACK_API_KEYS", dotenv) {
717 config.fallback_api_keys = v
718 .split(',')
719 .map(str::trim)
720 .filter(|s| !s.is_empty())
721 .map(str::to_string)
722 .collect();
723 }
724 if let Some(k) = env_or_dotenv("GARUDUST_API_KEY", dotenv) {
725 config.security.gateway_api_key = Some(k);
726 }
727 if let Some(v) = env_or_dotenv("GARUDUST_RATE_LIMIT", dotenv) {
728 if let Ok(n) = v.parse::<u32>() {
729 config.security.rate_limit_rpm = Some(n);
730 }
731 }
732 if let Some(mode) = env_or_dotenv("GARUDUST_APPROVAL_MODE", dotenv) {
733 config.security.approval_mode = mode;
734 }
735 if let Some(sandbox) = env_or_dotenv("GARUDUST_TERMINAL_SANDBOX", dotenv) {
736 config.security.terminal_sandbox = match sandbox.to_lowercase().as_str() {
737 "docker" => TerminalSandbox::Docker,
738 _ => TerminalSandbox::None,
739 };
740 }
741 if let Some(image) = env_or_dotenv("GARUDUST_SANDBOX_IMAGE", dotenv) {
742 config.security.terminal_sandbox_image = image;
743 }
744
745 if let Some(v) = env_or_dotenv("GARUDUST_PORT", dotenv) {
750 if let Ok(n) = v.parse::<u16>() {
751 config.server.port = n;
752 }
753 }
754 if let Some(v) = env_or_dotenv("GARUDUST_MEMORY_CRON", dotenv) {
755 config.cron.memory_consolidation = Some(v);
756 }
757 if let Some(v) = env_or_dotenv("GARUDUST_MEMORY_EXPIRY_CRON", dotenv) {
758 config.cron.memory_expiry = Some(v);
759 }
760 if let Some(v) = env_or_dotenv("GARUDUST_CRON_JOBS", dotenv) {
761 config.cron.jobs = parse_cron_jobs_str(&v);
762 }
763
764 config
765 }
766
767 pub fn save_yaml(&self) -> std::io::Result<()> {
769 std::fs::create_dir_all(&self.home_dir)?;
770 let yaml = serde_yaml::to_string(self).map_err(std::io::Error::other)?;
771 std::fs::write(self.home_dir.join("config.yaml"), yaml)
772 }
773
774 pub fn set_env_var(home_dir: &Path, key: &str, value: &str) -> std::io::Result<()> {
776 std::fs::create_dir_all(home_dir)?;
777 let env_path = home_dir.join(".env");
778 let existing = if env_path.exists() {
779 std::fs::read_to_string(&env_path)?
780 } else {
781 String::new()
782 };
783
784 let prefix = format!("{key}=");
785 let mut lines: Vec<String> = existing
786 .lines()
787 .filter(|l| !l.starts_with(&prefix))
788 .map(String::from)
789 .collect();
790 lines.push(format!("{key}={value}"));
791
792 std::fs::write(&env_path, lines.join("\n") + "\n")
793 }
794}
795
796#[derive(Debug, Clone, Serialize, Deserialize)]
799pub struct CompressionConfig {
800 pub enabled: bool,
801 pub threshold_fraction: f32,
802 pub model: Option<String>,
803}
804
805impl Default for CompressionConfig {
806 fn default() -> Self {
807 Self {
808 enabled: true,
809 threshold_fraction: 0.8,
810 model: None,
811 }
812 }
813}
814
815#[derive(Debug, Clone, Serialize, Deserialize, Default)]
816pub struct NetworkConfig {
817 pub force_ipv4: bool,
818 pub proxy: Option<String>,
819}
820
821#[cfg(test)]
822mod tests {
823 use std::collections::HashMap;
824
825 use super::{detect_provider_from_env, resolve_key_for_provider, AgentConfig};
826
827 fn dotenv(pairs: &[(&str, &str)]) -> HashMap<String, String> {
828 pairs
829 .iter()
830 .map(|(k, v)| ((*k).to_string(), (*v).to_string()))
831 .collect()
832 }
833
834 #[test]
837 fn resolve_openai_key() {
838 let map = dotenv(&[("OPENAI_API_KEY", "sk-test-openai")]);
839 assert_eq!(
840 resolve_key_for_provider("openai", &map),
841 Some("sk-test-openai".into())
842 );
843 }
844
845 #[test]
846 fn resolve_gemini_key() {
847 let map = dotenv(&[("GEMINI_API_KEY", "AIza-test")]);
848 assert_eq!(
849 resolve_key_for_provider("gemini", &map),
850 Some("AIza-test".into())
851 );
852 }
853
854 #[test]
855 fn resolve_groq_key() {
856 let map = dotenv(&[("GROQ_API_KEY", "gsk-test")]);
857 assert_eq!(
858 resolve_key_for_provider("groq", &map),
859 Some("gsk-test".into())
860 );
861 }
862
863 #[test]
864 fn resolve_mistral_key() {
865 let map = dotenv(&[("MISTRAL_API_KEY", "ms-test")]);
866 assert_eq!(
867 resolve_key_for_provider("mistral", &map),
868 Some("ms-test".into())
869 );
870 }
871
872 #[test]
873 fn resolve_deepseek_key() {
874 let map = dotenv(&[("DEEPSEEK_API_KEY", "ds-test")]);
875 assert_eq!(
876 resolve_key_for_provider("deepseek", &map),
877 Some("ds-test".into())
878 );
879 }
880
881 #[test]
882 fn resolve_xai_key() {
883 let map = dotenv(&[("XAI_API_KEY", "xai-test")]);
884 assert_eq!(
885 resolve_key_for_provider("xai", &map),
886 Some("xai-test".into())
887 );
888 }
889
890 #[test]
891 fn resolve_ollama_returns_none() {
892 let map = dotenv(&[("OPENROUTER_API_KEY", "or-test")]);
893 assert_eq!(resolve_key_for_provider("ollama", &map), None);
894 }
895
896 #[test]
897 fn resolve_unknown_provider_falls_back_to_openrouter() {
898 let map = dotenv(&[("OPENROUTER_API_KEY", "or-test")]);
899 assert_eq!(
900 resolve_key_for_provider("custom-provider", &map),
901 Some("or-test".into())
902 );
903 }
904
905 fn detect(pairs: &[(&str, &str)]) -> AgentConfig {
908 let mut cfg = AgentConfig::default();
909 detect_provider_from_env(&mut cfg, &dotenv(pairs));
910 cfg
911 }
912
913 #[test]
914 fn detect_openai_only() {
915 let cfg = detect(&[("OPENAI_API_KEY", "sk-test-openai")]);
916 assert_eq!(cfg.provider, "openai");
917 assert_eq!(cfg.api_key.as_deref(), Some("sk-test-openai"));
918 }
919
920 #[test]
921 fn detect_gemini_only() {
922 let cfg = detect(&[("GEMINI_API_KEY", "AIza-test")]);
923 assert_eq!(cfg.provider, "gemini");
924 assert_eq!(cfg.api_key.as_deref(), Some("AIza-test"));
925 }
926
927 #[test]
928 fn detect_groq_only() {
929 let cfg = detect(&[("GROQ_API_KEY", "gsk-test")]);
930 assert_eq!(cfg.provider, "groq");
931 assert_eq!(cfg.api_key.as_deref(), Some("gsk-test"));
932 }
933
934 #[test]
935 fn detect_mistral_only() {
936 let cfg = detect(&[("MISTRAL_API_KEY", "ms-test")]);
937 assert_eq!(cfg.provider, "mistral");
938 assert_eq!(cfg.api_key.as_deref(), Some("ms-test"));
939 }
940
941 #[test]
942 fn detect_deepseek_only() {
943 let cfg = detect(&[("DEEPSEEK_API_KEY", "ds-test")]);
944 assert_eq!(cfg.provider, "deepseek");
945 assert_eq!(cfg.api_key.as_deref(), Some("ds-test"));
946 }
947
948 #[test]
949 fn detect_xai_only() {
950 let cfg = detect(&[("XAI_API_KEY", "xai-test")]);
951 assert_eq!(cfg.provider, "xai");
952 assert_eq!(cfg.api_key.as_deref(), Some("xai-test"));
953 }
954
955 #[test]
956 fn detect_openrouter_only() {
957 let cfg = detect(&[("OPENROUTER_API_KEY", "or-test")]);
958 assert_eq!(cfg.provider, "openrouter");
959 assert_eq!(cfg.api_key.as_deref(), Some("or-test"));
960 }
961
962 #[test]
963 fn detect_ollama_sets_base_url_not_key() {
964 let cfg = detect(&[("OLLAMA_BASE_URL", "http://localhost:11434")]);
965 assert_eq!(cfg.provider, "ollama");
966 assert_eq!(cfg.base_url.as_deref(), Some("http://localhost:11434"));
967 assert!(cfg.api_key.is_none());
968 }
969
970 #[test]
971 fn detect_vllm_sets_base_url_and_key() {
972 let cfg = detect(&[
973 ("VLLM_BASE_URL", "http://localhost:8000/v1"),
974 ("VLLM_API_KEY", "vllm-test"),
975 ]);
976 assert_eq!(cfg.provider, "vllm");
977 assert_eq!(cfg.base_url.as_deref(), Some("http://localhost:8000/v1"));
978 assert_eq!(cfg.api_key.as_deref(), Some("vllm-test"));
979 }
980
981 #[test]
982 fn detect_empty_env_leaves_defaults() {
983 let cfg = detect(&[]);
984 assert_eq!(cfg.provider, "openrouter");
985 assert!(cfg.api_key.is_none());
986 }
987
988 #[test]
992 fn detect_anthropic_wins_over_openai_in_dotenv() {
993 let cfg = detect(&[
994 ("ANTHROPIC_API_KEY", "sk-ant-test"),
995 ("OPENAI_API_KEY", "sk-oai-test"),
996 ]);
997 assert_eq!(cfg.provider, "anthropic");
999 assert_eq!(cfg.api_key.as_deref(), Some("sk-ant-test"));
1000 }
1001}