1use anyhow::{anyhow, Result};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ProviderConfig {
13 pub api_key: String,
15 pub base_url: String,
17 pub model: String,
19 #[serde(default = "default_api_format")]
21 pub api_format: String,
22 #[serde(default = "default_max_tokens")]
24 pub default_max_tokens: u32,
25 #[serde(default = "default_temperature")]
27 pub default_temperature: f32,
28}
29
30fn default_api_format() -> String {
31 "openai".to_string()
32}
33
34fn default_max_tokens() -> u32 {
35 4096
36}
37
38fn default_temperature() -> f32 {
39 0.7
40}
41
42impl Default for ProviderConfig {
43 fn default() -> Self {
44 Self {
45 api_key: String::new(),
46 base_url: String::new(),
47 model: "claude-sonnet-4-6".to_string(),
48 api_format: default_api_format(),
49 default_max_tokens: 4096,
50 default_temperature: 0.7,
51 }
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct GlobalSettings {
58 #[serde(default = "default_true")]
60 pub session_auto_save: bool,
61 #[serde(default = "default_session_max_history")]
63 pub session_max_history: usize,
64 #[serde(default = "default_true")]
66 pub checkpoint_enabled: bool,
67 #[serde(default = "default_checkpoint_interval")]
69 pub checkpoint_interval_sec: u32,
70 #[serde(default = "default_true")]
72 pub audit_enabled: bool,
73 #[serde(default)]
75 pub mcp_enabled: bool,
76}
77
78impl Default for GlobalSettings {
79 fn default() -> Self {
80 Self {
81 session_auto_save: true,
82 session_max_history: 100,
83 checkpoint_enabled: true,
84 checkpoint_interval_sec: 60,
85 audit_enabled: true,
86 mcp_enabled: false,
87 }
88 }
89}
90
91fn default_true() -> bool {
92 true
93}
94
95fn default_session_max_history() -> usize {
96 100
97}
98
99fn default_checkpoint_interval() -> u32 {
100 60
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct ConfigManager {
106 #[serde(default = "default_provider")]
108 pub active_provider: String,
109 #[serde(default)]
111 pub providers: HashMap<String, ProviderConfig>,
112 #[serde(default)]
114 pub settings: GlobalSettings,
115 #[serde(default)]
117 pub extra: HashMap<String, String>,
118}
119
120fn default_provider() -> String {
121 "anthropic".to_string()
122}
123
124impl Default for ConfigManager {
125 fn default() -> Self {
126 Self {
127 active_provider: "anthropic".to_string(),
128 providers: HashMap::new(),
129 settings: GlobalSettings::default(),
130 extra: HashMap::new(),
131 }
132 }
133}
134
135impl ConfigManager {
136 pub fn new() -> Self {
138 Self::default()
139 }
140
141 pub fn from_env() -> Self {
143 let mut config = Self::default();
144
145 if let Ok(provider) = std::env::var("CONTINUUM_PROVIDER") {
147 config.active_provider = provider;
148 }
149
150 if let Ok(api_key) = std::env::var("CONTINUUM_API_KEY") {
152 let provider_name = config.active_provider.clone();
153 let provider_config = config.providers.entry(provider_name).or_default();
154 provider_config.api_key = api_key;
155 }
156
157 if let Ok(base_url) = std::env::var("CONTINUUM_BASE_URL") {
159 let provider_name = config.active_provider.clone();
160 let provider_config = config.providers.entry(provider_name).or_default();
161 provider_config.base_url = base_url;
162 }
163
164 if let Ok(model) = std::env::var("CONTINUUM_MODEL") {
166 let provider_name = config.active_provider.clone();
167 let provider_config = config.providers.entry(provider_name).or_default();
168 provider_config.model = model;
169 }
170
171 if let Ok(val) = std::env::var("CONTINUUM_CHECKPOINT_ENABLED") {
173 if let Ok(enabled) = val.parse::<bool>() {
174 config.settings.checkpoint_enabled = enabled;
175 }
176 }
177
178 if let Ok(val) = std::env::var("CONTINUUM_AUDIT_ENABLED") {
179 if let Ok(enabled) = val.parse::<bool>() {
180 config.settings.audit_enabled = enabled;
181 }
182 }
183
184 config
185 }
186
187 pub async fn load_from_file(&mut self, path: &Path) -> Result<()> {
189 if !path.exists() {
190 return Ok(());
191 }
192
193 let content = tokio::fs::read_to_string(path).await?;
194 let loaded: ConfigManager = toml::from_str(&content)?;
195
196 self.merge(loaded);
198 Ok(())
199 }
200
201 pub fn load_from_file_sync(&mut self, path: &Path) -> Result<()> {
203 if !path.exists() {
204 return Ok(());
205 }
206
207 let content = std::fs::read_to_string(path)?;
208 let loaded: ConfigManager = toml::from_str(&content)?;
209 self.merge(loaded);
210 Ok(())
211 }
212
213 pub fn merge(&mut self, other: ConfigManager) {
215 for (name, provider) in other.providers {
217 if !provider.api_key.is_empty() {
219 self.providers.insert(name, provider);
220 }
221 }
222
223 if other.settings.session_max_history > 0 {
225 self.settings.session_max_history = other.settings.session_max_history;
226 }
227 if other.settings.checkpoint_interval_sec > 0 {
228 self.settings.checkpoint_interval_sec = other.settings.checkpoint_interval_sec;
229 }
230
231 self.extra.extend(other.extra);
233
234 if !other.active_provider.is_empty() && self.providers.contains_key(&other.active_provider)
236 {
237 self.active_provider = other.active_provider;
238 }
239 }
240
241 pub fn default_config_path() -> PathBuf {
243 let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
245 home.join(".continuum").join("config.toml")
246 }
247
248 pub fn project_config_path() -> PathBuf {
250 PathBuf::from(".continuum").join("config.toml")
251 }
252
253 pub async fn load_full() -> Result<Self> {
255 let mut config = Self::new();
257
258 let user_path = Self::default_config_path();
260 config.load_from_file(&user_path).await?;
261
262 let project_path = Self::project_config_path();
264 config.load_from_file(&project_path).await?;
265
266 let env_config = Self::from_env();
268 config.merge_env(env_config);
269
270 Ok(config)
271 }
272
273 fn merge_env(&mut self, env: ConfigManager) {
275 if !env.active_provider.is_empty() {
277 self.active_provider = env.active_provider;
278 }
279
280 for (name, provider) in env.providers {
282 self.providers.insert(name, provider);
283 }
284
285 self.settings.audit_enabled = env.settings.audit_enabled;
287 self.settings.checkpoint_enabled = env.settings.checkpoint_enabled;
288 }
289
290 pub fn use_provider(&mut self, name: &str) -> Result<()> {
292 if !self.providers.contains_key(name) {
293 return Err(anyhow!(
294 "Provider '{}' not found. Use 'config add-provider' first.",
295 name
296 ));
297 }
298 self.active_provider = name.to_string();
299 Ok(())
300 }
301
302 pub fn current(&self) -> Result<&ProviderConfig> {
304 self.providers
305 .get(&self.active_provider)
306 .ok_or_else(|| anyhow!("No provider '{}' configured", self.active_provider))
307 }
308
309 pub fn add_provider(&mut self, name: &str, config: ProviderConfig) {
311 self.providers.insert(name.to_string(), config);
312 }
313
314 pub fn list_providers(&self) -> Vec<&String> {
316 self.providers.keys().collect()
317 }
318
319 pub fn get(&self, key: &str) -> Option<&String> {
321 self.extra.get(key)
322 }
323
324 pub fn set(&mut self, key: String, value: String) {
326 self.extra.insert(key, value);
327 }
328
329 pub async fn save(&self, path: &Path) -> Result<()> {
331 if let Some(parent) = path.parent() {
333 tokio::fs::create_dir_all(parent).await?;
334 }
335
336 let content = toml::to_string_pretty(&self)?;
337 tokio::fs::write(path, content).await?;
338 Ok(())
339 }
340
341 pub fn save_sync(&self, path: &Path) -> Result<()> {
343 if let Some(parent) = path.parent() {
344 std::fs::create_dir_all(parent)?;
345 }
346
347 let content = toml::to_string_pretty(&self)?;
348 std::fs::write(path, content)?;
349 Ok(())
350 }
351
352 pub fn resolve_env_refs(&mut self) {
354 for provider in self.providers.values_mut() {
356 provider.api_key = Self::resolve_env_string(&provider.api_key);
357 provider.base_url = Self::resolve_env_string(&provider.base_url);
358 provider.model = Self::resolve_env_string(&provider.model);
359 }
360
361 for value in self.extra.values_mut() {
363 *value = Self::resolve_env_string(value);
364 }
365 }
366
367 fn resolve_env_string(s: &str) -> String {
369 let mut result = s.to_string();
370 while let Some(start) = result.find("${") {
372 if let Some(end) = result[start..].find('}') {
373 let var_name = &result[start + 2..start + end];
374 if let Ok(val) = std::env::var(var_name) {
375 result.replace_range(start..start + end + 1, &val);
376 } else {
377 result.replace_range(start..start + end + 1, "");
379 }
380 } else {
381 break;
382 }
383 }
384 result
385 }
386
387 pub fn init_default_config(&self) -> Result<PathBuf> {
389 let path = Self::default_config_path();
390
391 if path.exists() {
392 return Err(anyhow!("Config file already exists at {:?}", path));
393 }
394
395 let default_config = Self {
397 active_provider: "anthropic".to_string(),
398 providers: {
399 let mut map = HashMap::new();
400 map.insert(
401 "anthropic".to_string(),
402 ProviderConfig {
403 api_key: "${ANTHROPIC_API_KEY}".to_string(),
404 base_url: "https://api.anthropic.com/v1".to_string(),
405 model: "claude-sonnet-4-6".to_string(),
406 api_format: "anthropic".to_string(),
407 default_max_tokens: 4096,
408 default_temperature: 0.7,
409 },
410 );
411 map.insert(
412 "openai".to_string(),
413 ProviderConfig {
414 api_key: "${OPENAI_API_KEY}".to_string(),
415 base_url: "https://api.openai.com/v1".to_string(),
416 model: "gpt-4".to_string(),
417 api_format: "openai".to_string(),
418 default_max_tokens: 4096,
419 default_temperature: 0.7,
420 },
421 );
422 map.insert(
423 "gemini".to_string(),
424 ProviderConfig {
425 api_key: "${GEMINI_API_KEY}".to_string(),
426 base_url: "https://generativelanguage.googleapis.com/v1".to_string(),
427 model: "gemini-pro".to_string(),
428 api_format: "google".to_string(),
429 default_max_tokens: 4096,
430 default_temperature: 0.7,
431 },
432 );
433 map
434 },
435 settings: GlobalSettings::default(),
436 extra: HashMap::new(),
437 };
438
439 default_config.save_sync(&path)?;
440 Ok(path)
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447
448 #[test]
449 fn test_config_manager_creation() {
450 let config = ConfigManager::new();
451 assert_eq!(config.active_provider, "anthropic");
452 }
453
454 #[test]
455 fn test_provider_config_default() {
456 let provider = ProviderConfig::default();
457 assert_eq!(provider.default_max_tokens, 4096);
458 assert_eq!(provider.default_temperature, 0.7);
459 }
460
461 #[test]
462 fn test_global_settings_default() {
463 let settings = GlobalSettings::default();
464 assert!(settings.session_auto_save);
465 assert!(settings.checkpoint_enabled);
466 }
467
468 #[test]
469 fn test_add_provider() {
470 let mut config = ConfigManager::new();
471 let provider = ProviderConfig {
472 api_key: "test_key".to_string(),
473 base_url: "https://test.api.com".to_string(),
474 model: "test-model".to_string(),
475 api_format: "openai".to_string(),
476 default_max_tokens: 8192,
477 default_temperature: 0.5,
478 };
479 config.add_provider("test", provider);
480 assert!(config.providers.contains_key("test"));
481 }
482
483 #[test]
484 fn test_use_provider() {
485 let mut config = ConfigManager::new();
486 let provider = ProviderConfig {
487 api_key: "test_key".to_string(),
488 base_url: "https://test.api.com".to_string(),
489 model: "test-model".to_string(),
490 api_format: "openai".to_string(),
491 default_max_tokens: 4096,
492 default_temperature: 0.7,
493 };
494 config.add_provider("test", provider);
495
496 config.use_provider("test").unwrap();
497 assert_eq!(config.active_provider, "test");
498 }
499
500 #[test]
501 fn test_use_provider_not_found() {
502 let mut config = ConfigManager::new();
503 let result = config.use_provider("nonexistent");
504 assert!(result.is_err());
505 }
506
507 #[test]
508 fn test_resolve_env_string() {
509 std::env::set_var("TEST_VAR", "test_value");
510 let resolved = ConfigManager::resolve_env_string("${TEST_VAR}");
511 assert_eq!(resolved, "test_value");
512 std::env::remove_var("TEST_VAR");
513 }
514
515 #[test]
516 fn test_set_get_config() {
517 let mut config = ConfigManager::new();
518 config.set("test_key".to_string(), "test_value".to_string());
519 assert_eq!(config.get("test_key"), Some(&"test_value".to_string()));
520 }
521
522 #[test]
523 fn test_list_providers() {
524 let mut config = ConfigManager::new();
525 let provider = ProviderConfig {
526 api_key: "key1".to_string(),
527 base_url: "url1".to_string(),
528 model: "model1".to_string(),
529 api_format: "openai".to_string(),
530 default_max_tokens: 4096,
531 default_temperature: 0.7,
532 };
533 config.add_provider("provider1", provider);
534
535 let list = config.list_providers();
536 assert!(list.contains(&&"provider1".to_string()));
537 }
538
539 #[test]
540 fn test_config_serialization() {
541 let config = ConfigManager::new();
542 let toml_str = toml::to_string(&config).unwrap();
543 assert!(toml_str.contains("active_provider"));
544 }
545}