use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderConfig {
pub api_key: String,
pub base_url: String,
pub model: String,
#[serde(default = "default_api_format")]
pub api_format: String,
#[serde(default = "default_max_tokens")]
pub default_max_tokens: u32,
#[serde(default = "default_temperature")]
pub default_temperature: f32,
}
fn default_api_format() -> String {
"openai".to_string()
}
fn default_max_tokens() -> u32 {
4096
}
fn default_temperature() -> f32 {
0.7
}
impl Default for ProviderConfig {
fn default() -> Self {
Self {
api_key: String::new(),
base_url: String::new(),
model: "claude-sonnet-4-6".to_string(),
api_format: default_api_format(),
default_max_tokens: 4096,
default_temperature: 0.7,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GlobalSettings {
#[serde(default = "default_true")]
pub session_auto_save: bool,
#[serde(default = "default_session_max_history")]
pub session_max_history: usize,
#[serde(default = "default_true")]
pub checkpoint_enabled: bool,
#[serde(default = "default_checkpoint_interval")]
pub checkpoint_interval_sec: u32,
#[serde(default = "default_true")]
pub audit_enabled: bool,
#[serde(default)]
pub mcp_enabled: bool,
}
impl Default for GlobalSettings {
fn default() -> Self {
Self {
session_auto_save: true,
session_max_history: 100,
checkpoint_enabled: true,
checkpoint_interval_sec: 60,
audit_enabled: true,
mcp_enabled: false,
}
}
}
fn default_true() -> bool {
true
}
fn default_session_max_history() -> usize {
100
}
fn default_checkpoint_interval() -> u32 {
60
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConfigManager {
#[serde(default = "default_provider")]
pub active_provider: String,
#[serde(default)]
pub providers: HashMap<String, ProviderConfig>,
#[serde(default)]
pub settings: GlobalSettings,
#[serde(default)]
pub extra: HashMap<String, String>,
}
fn default_provider() -> String {
"anthropic".to_string()
}
impl Default for ConfigManager {
fn default() -> Self {
Self {
active_provider: "anthropic".to_string(),
providers: HashMap::new(),
settings: GlobalSettings::default(),
extra: HashMap::new(),
}
}
}
impl ConfigManager {
pub fn new() -> Self {
Self::default()
}
pub fn from_env() -> Self {
let mut config = Self::default();
if let Ok(provider) = std::env::var("CONTINUUM_PROVIDER") {
config.active_provider = provider;
}
if let Ok(api_key) = std::env::var("CONTINUUM_API_KEY") {
let provider_name = config.active_provider.clone();
let provider_config = config.providers.entry(provider_name).or_default();
provider_config.api_key = api_key;
}
if let Ok(base_url) = std::env::var("CONTINUUM_BASE_URL") {
let provider_name = config.active_provider.clone();
let provider_config = config.providers.entry(provider_name).or_default();
provider_config.base_url = base_url;
}
if let Ok(model) = std::env::var("CONTINUUM_MODEL") {
let provider_name = config.active_provider.clone();
let provider_config = config.providers.entry(provider_name).or_default();
provider_config.model = model;
}
if let Ok(val) = std::env::var("CONTINUUM_CHECKPOINT_ENABLED") {
if let Ok(enabled) = val.parse::<bool>() {
config.settings.checkpoint_enabled = enabled;
}
}
if let Ok(val) = std::env::var("CONTINUUM_AUDIT_ENABLED") {
if let Ok(enabled) = val.parse::<bool>() {
config.settings.audit_enabled = enabled;
}
}
config
}
pub async fn load_from_file(&mut self, path: &Path) -> Result<()> {
if !path.exists() {
return Ok(());
}
let content = tokio::fs::read_to_string(path).await?;
let loaded: ConfigManager = toml::from_str(&content)?;
self.merge(loaded);
Ok(())
}
pub fn load_from_file_sync(&mut self, path: &Path) -> Result<()> {
if !path.exists() {
return Ok(());
}
let content = std::fs::read_to_string(path)?;
let loaded: ConfigManager = toml::from_str(&content)?;
self.merge(loaded);
Ok(())
}
pub fn merge(&mut self, other: ConfigManager) {
for (name, provider) in other.providers {
if !provider.api_key.is_empty() {
self.providers.insert(name, provider);
}
}
if other.settings.session_max_history > 0 {
self.settings.session_max_history = other.settings.session_max_history;
}
if other.settings.checkpoint_interval_sec > 0 {
self.settings.checkpoint_interval_sec = other.settings.checkpoint_interval_sec;
}
self.extra.extend(other.extra);
if !other.active_provider.is_empty() && self.providers.contains_key(&other.active_provider)
{
self.active_provider = other.active_provider;
}
}
pub fn default_config_path() -> PathBuf {
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
home.join(".continuum").join("config.toml")
}
pub fn project_config_path() -> PathBuf {
PathBuf::from(".continuum").join("config.toml")
}
pub async fn load_full() -> Result<Self> {
let mut config = Self::new();
let user_path = Self::default_config_path();
config.load_from_file(&user_path).await?;
let project_path = Self::project_config_path();
config.load_from_file(&project_path).await?;
let env_config = Self::from_env();
config.merge_env(env_config);
Ok(config)
}
fn merge_env(&mut self, env: ConfigManager) {
if !env.active_provider.is_empty() {
self.active_provider = env.active_provider;
}
for (name, provider) in env.providers {
self.providers.insert(name, provider);
}
self.settings.audit_enabled = env.settings.audit_enabled;
self.settings.checkpoint_enabled = env.settings.checkpoint_enabled;
}
pub fn use_provider(&mut self, name: &str) -> Result<()> {
if !self.providers.contains_key(name) {
return Err(anyhow!(
"Provider '{}' not found. Use 'config add-provider' first.",
name
));
}
self.active_provider = name.to_string();
Ok(())
}
pub fn current(&self) -> Result<&ProviderConfig> {
self.providers
.get(&self.active_provider)
.ok_or_else(|| anyhow!("No provider '{}' configured", self.active_provider))
}
pub fn add_provider(&mut self, name: &str, config: ProviderConfig) {
self.providers.insert(name.to_string(), config);
}
pub fn list_providers(&self) -> Vec<&String> {
self.providers.keys().collect()
}
pub fn get(&self, key: &str) -> Option<&String> {
self.extra.get(key)
}
pub fn set(&mut self, key: String, value: String) {
self.extra.insert(key, value);
}
pub async fn save(&self, path: &Path) -> Result<()> {
if let Some(parent) = path.parent() {
tokio::fs::create_dir_all(parent).await?;
}
let content = toml::to_string_pretty(&self)?;
tokio::fs::write(path, content).await?;
Ok(())
}
pub fn save_sync(&self, path: &Path) -> Result<()> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let content = toml::to_string_pretty(&self)?;
std::fs::write(path, content)?;
Ok(())
}
pub fn resolve_env_refs(&mut self) {
for provider in self.providers.values_mut() {
provider.api_key = Self::resolve_env_string(&provider.api_key);
provider.base_url = Self::resolve_env_string(&provider.base_url);
provider.model = Self::resolve_env_string(&provider.model);
}
for value in self.extra.values_mut() {
*value = Self::resolve_env_string(value);
}
}
fn resolve_env_string(s: &str) -> String {
let mut result = s.to_string();
while let Some(start) = result.find("${") {
if let Some(end) = result[start..].find('}') {
let var_name = &result[start + 2..start + end];
if let Ok(val) = std::env::var(var_name) {
result.replace_range(start..start + end + 1, &val);
} else {
result.replace_range(start..start + end + 1, "");
}
} else {
break;
}
}
result
}
pub fn init_default_config(&self) -> Result<PathBuf> {
let path = Self::default_config_path();
if path.exists() {
return Err(anyhow!("Config file already exists at {:?}", path));
}
let default_config = Self {
active_provider: "anthropic".to_string(),
providers: {
let mut map = HashMap::new();
map.insert(
"anthropic".to_string(),
ProviderConfig {
api_key: "${ANTHROPIC_API_KEY}".to_string(),
base_url: "https://api.anthropic.com/v1".to_string(),
model: "claude-sonnet-4-6".to_string(),
api_format: "anthropic".to_string(),
default_max_tokens: 4096,
default_temperature: 0.7,
},
);
map.insert(
"openai".to_string(),
ProviderConfig {
api_key: "${OPENAI_API_KEY}".to_string(),
base_url: "https://api.openai.com/v1".to_string(),
model: "gpt-4".to_string(),
api_format: "openai".to_string(),
default_max_tokens: 4096,
default_temperature: 0.7,
},
);
map.insert(
"gemini".to_string(),
ProviderConfig {
api_key: "${GEMINI_API_KEY}".to_string(),
base_url: "https://generativelanguage.googleapis.com/v1".to_string(),
model: "gemini-pro".to_string(),
api_format: "google".to_string(),
default_max_tokens: 4096,
default_temperature: 0.7,
},
);
map
},
settings: GlobalSettings::default(),
extra: HashMap::new(),
};
default_config.save_sync(&path)?;
Ok(path)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_manager_creation() {
let config = ConfigManager::new();
assert_eq!(config.active_provider, "anthropic");
}
#[test]
fn test_provider_config_default() {
let provider = ProviderConfig::default();
assert_eq!(provider.default_max_tokens, 4096);
assert_eq!(provider.default_temperature, 0.7);
}
#[test]
fn test_global_settings_default() {
let settings = GlobalSettings::default();
assert!(settings.session_auto_save);
assert!(settings.checkpoint_enabled);
}
#[test]
fn test_add_provider() {
let mut config = ConfigManager::new();
let provider = ProviderConfig {
api_key: "test_key".to_string(),
base_url: "https://test.api.com".to_string(),
model: "test-model".to_string(),
api_format: "openai".to_string(),
default_max_tokens: 8192,
default_temperature: 0.5,
};
config.add_provider("test", provider);
assert!(config.providers.contains_key("test"));
}
#[test]
fn test_use_provider() {
let mut config = ConfigManager::new();
let provider = ProviderConfig {
api_key: "test_key".to_string(),
base_url: "https://test.api.com".to_string(),
model: "test-model".to_string(),
api_format: "openai".to_string(),
default_max_tokens: 4096,
default_temperature: 0.7,
};
config.add_provider("test", provider);
config.use_provider("test").unwrap();
assert_eq!(config.active_provider, "test");
}
#[test]
fn test_use_provider_not_found() {
let mut config = ConfigManager::new();
let result = config.use_provider("nonexistent");
assert!(result.is_err());
}
#[test]
fn test_resolve_env_string() {
std::env::set_var("TEST_VAR", "test_value");
let resolved = ConfigManager::resolve_env_string("${TEST_VAR}");
assert_eq!(resolved, "test_value");
std::env::remove_var("TEST_VAR");
}
#[test]
fn test_set_get_config() {
let mut config = ConfigManager::new();
config.set("test_key".to_string(), "test_value".to_string());
assert_eq!(config.get("test_key"), Some(&"test_value".to_string()));
}
#[test]
fn test_list_providers() {
let mut config = ConfigManager::new();
let provider = ProviderConfig {
api_key: "key1".to_string(),
base_url: "url1".to_string(),
model: "model1".to_string(),
api_format: "openai".to_string(),
default_max_tokens: 4096,
default_temperature: 0.7,
};
config.add_provider("provider1", provider);
let list = config.list_providers();
assert!(list.contains(&&"provider1".to_string()));
}
#[test]
fn test_config_serialization() {
let config = ConfigManager::new();
let toml_str = toml::to_string(&config).unwrap();
assert!(toml_str.contains("active_provider"));
}
}