use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::path::Path;
const MASKED_API_KEY: &str = "***";
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(default)]
pub struct Config {
#[serde(
alias = "provider_type",
alias = "default_provider",
skip_serializing_if = "Option::is_none"
)]
pub provider: Option<String>,
#[serde(alias = "default_model", skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub api_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub base_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub architect_model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub actuator_model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub verifier_model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub speculator_model: Option<String>,
}
impl Config {
pub fn from_toml_str(content: &str) -> Result<Self> {
toml::from_str(content).context("Failed to parse TOML configuration")
}
pub fn load_from_path(path: &Path) -> Result<Self> {
if !path.exists() {
return Ok(Self::default());
}
let content = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read config file: {}", path.display()))?;
Self::from_toml_str(&content)
}
pub fn to_toml_string(&self) -> Result<String> {
toml::to_string_pretty(self).context("Failed to serialize configuration to TOML")
}
pub fn masked(&self) -> Self {
let mut clone = self.clone();
if clone.api_key.is_some() {
clone.api_key = Some(MASKED_API_KEY.to_string());
}
clone
}
pub fn set_value(&mut self, key: &str, value: &str) -> Result<()> {
let value = value.to_string();
match key {
"provider" | "provider_type" | "default_provider" => self.provider = Some(value),
"model" | "default_model" => self.model = Some(value),
"api_key" => self.api_key = Some(value),
"base_url" => self.base_url = Some(value),
"architect_model" => self.architect_model = Some(value),
"actuator_model" => self.actuator_model = Some(value),
"verifier_model" => self.verifier_model = Some(value),
"speculator_model" => self.speculator_model = Some(value),
other => anyhow::bail!(
"Unknown configuration key: {other}. Valid keys: provider, model, api_key, \
base_url, architect_model, actuator_model, verifier_model, speculator_model"
),
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_string_parses_to_defaults() {
let cfg = Config::from_toml_str("").unwrap();
assert!(cfg.provider.is_none());
assert!(cfg.model.is_none());
assert!(cfg.api_key.is_none());
}
#[test]
fn aliases_are_accepted() {
let cfg = Config::from_toml_str(
r#"
provider_type = "openai"
default_model = "phi-4-npu-ov"
"#,
)
.unwrap();
assert_eq!(cfg.provider.as_deref(), Some("openai"));
assert_eq!(cfg.model.as_deref(), Some("phi-4-npu-ov"));
}
#[test]
fn missing_file_returns_default() {
let path = Path::new("/nonexistent/perspt/config.toml");
let cfg = Config::load_from_path(path).unwrap();
assert!(cfg.provider.is_none());
}
#[test]
fn masked_hides_api_key() {
let cfg = Config {
api_key: Some("super-secret".to_string()),
..Default::default()
};
assert_eq!(cfg.masked().api_key.as_deref(), Some("***"));
}
#[test]
fn masked_leaves_absent_key_absent() {
let cfg = Config::default();
assert!(cfg.masked().api_key.is_none());
}
#[test]
fn set_value_updates_known_keys() {
let mut cfg = Config::default();
cfg.set_value("default_model", "phi-4-npu-ov").unwrap();
assert_eq!(cfg.model.as_deref(), Some("phi-4-npu-ov"));
cfg.set_value("provider", "openai").unwrap();
assert_eq!(cfg.provider.as_deref(), Some("openai"));
}
#[test]
fn set_value_rejects_unknown_key() {
let mut cfg = Config::default();
assert!(cfg.set_value("nope", "x").is_err());
}
#[test]
fn round_trip_set_does_not_duplicate() {
let mut cfg = Config::default();
cfg.set_value("default_model", "a").unwrap();
cfg.set_value("default_model", "b").unwrap();
let serialized = cfg.to_toml_string().unwrap();
assert_eq!(serialized.matches("model").count(), 1);
let reparsed = Config::from_toml_str(&serialized).unwrap();
assert_eq!(reparsed.model.as_deref(), Some("b"));
}
}