use std::collections::BTreeMap;
use std::fmt;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use costroid_core::{AlertThresholds, BudgetTargets};
use rust_decimal::Decimal;
use serde::de::{self, Deserializer, Visitor};
use serde::Deserialize;
#[derive(Debug, Clone, Default, Deserialize)]
#[serde(default)]
pub(crate) struct Config {
pub(crate) budget: BudgetConfig,
pub(crate) alerts: AlertsConfig,
}
#[derive(Debug, Clone, Default, Deserialize)]
#[serde(default)]
pub(crate) struct BudgetConfig {
total_monthly_usd: Option<Money>,
per_tool: BTreeMap<String, Money>,
}
#[derive(Debug, Clone, Default, Deserialize)]
#[serde(default)]
pub(crate) struct AlertsConfig {
enabled: bool,
quota_warn: Option<f64>,
quota_critical: Option<f64>,
}
impl Config {
pub(crate) fn budget_targets(&self) -> BudgetTargets {
BudgetTargets {
total_monthly_usd: self.budget.total_monthly_usd.map(|money| money.0),
per_tool: self
.budget
.per_tool
.iter()
.map(|(tool, money)| (tool.clone(), money.0))
.collect(),
}
}
pub(crate) fn alerts_enabled(&self) -> bool {
self.alerts.enabled
}
pub(crate) fn alert_thresholds(&self) -> AlertThresholds {
let mut thresholds = AlertThresholds::default();
if let Some(warn) = sane_fraction(self.alerts.quota_warn) {
thresholds.quota_warn_fraction = warn;
}
if let Some(critical) = sane_fraction(self.alerts.quota_critical) {
thresholds.quota_critical_fraction = critical;
}
if thresholds.quota_warn_fraction > thresholds.quota_critical_fraction {
thresholds = AlertThresholds::default();
}
thresholds
}
}
fn sane_fraction(value: Option<f64>) -> Option<f64> {
value.filter(|fraction| fraction.is_finite() && *fraction > 0.0)
}
#[derive(Debug, Clone, Copy)]
struct Money(Decimal);
impl<'de> Deserialize<'de> for Money {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
struct MoneyVisitor;
impl Visitor<'_> for MoneyVisitor {
type Value = Money;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("a dollar amount (a number or a quoted decimal string)")
}
fn visit_i64<E: de::Error>(self, value: i64) -> Result<Money, E> {
Ok(Money(Decimal::from(value)))
}
fn visit_u64<E: de::Error>(self, value: u64) -> Result<Money, E> {
Ok(Money(Decimal::from(value)))
}
fn visit_f64<E: de::Error>(self, value: f64) -> Result<Money, E> {
Decimal::from_f64_retain(value)
.map(Money)
.ok_or_else(|| de::Error::custom("dollar amount must be a finite number"))
}
fn visit_str<E: de::Error>(self, value: &str) -> Result<Money, E> {
Decimal::from_str(value.trim()).map(Money).map_err(|err| {
de::Error::custom(format!("invalid dollar amount '{value}': {err}"))
})
}
}
deserializer.deserialize_any(MoneyVisitor)
}
}
pub(crate) fn config_path() -> Option<PathBuf> {
let base = std::env::var_os("XDG_CONFIG_HOME")
.map(PathBuf::from)
.filter(|path| !path.as_os_str().is_empty())
.or_else(|| std::env::var_os("HOME").map(|home| PathBuf::from(home).join(".config")))?;
Some(base.join("costroid").join("config.toml"))
}
#[derive(Debug)]
pub(crate) enum ConfigError {
Read {
path: PathBuf,
source: std::io::Error,
},
Parse {
path: PathBuf,
source: toml::de::Error,
},
}
impl fmt::Display for ConfigError {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ConfigError::Read { path, source } => {
write!(
formatter,
"could not read config {}: {source}",
path.display()
)
}
ConfigError::Parse { path, source } => {
let detail = source.to_string();
let first = detail.lines().next().unwrap_or(&detail);
write!(formatter, "invalid config {}: {first}", path.display())
}
}
}
}
impl std::error::Error for ConfigError {}
pub(crate) fn load() -> Result<Config, ConfigError> {
match config_path() {
Some(path) => load_from(&path),
None => Ok(Config::default()),
}
}
pub(crate) fn load_from(path: &Path) -> Result<Config, ConfigError> {
let text = match std::fs::read_to_string(path) {
Ok(text) => text,
Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(Config::default()),
Err(source) => {
return Err(ConfigError::Read {
path: path.to_path_buf(),
source,
})
}
};
toml::from_str::<Config>(&text).map_err(|source| ConfigError::Parse {
path: path.to_path_buf(),
source,
})
}
#[cfg(test)]
mod tests {
use super::*;
struct TempDir {
path: PathBuf,
}
impl TempDir {
fn new() -> Self {
use std::sync::atomic::{AtomicU32, Ordering};
static COUNTER: AtomicU32 = AtomicU32::new(0);
let n = COUNTER.fetch_add(1, Ordering::Relaxed);
let pid = std::process::id();
let path = std::env::temp_dir().join(format!("costroid-config-test-{pid}-{n}"));
if let Err(err) = std::fs::create_dir_all(&path) {
panic!("temp dir should create: {err}");
}
Self { path }
}
fn write(&self, contents: &str) -> PathBuf {
let file = self.path.join("config.toml");
if let Err(err) = std::fs::write(&file, contents) {
panic!("fixture config should write: {err}");
}
file
}
}
impl Drop for TempDir {
fn drop(&mut self) {
let _ = std::fs::remove_dir_all(&self.path);
}
}
fn cents(value: i64, scale: u32) -> Decimal {
Decimal::new(value, scale)
}
#[test]
fn absent_file_loads_the_zero_config_default() {
let dir = TempDir::new();
let missing = dir.path.join("does-not-exist.toml");
let config = match load_from(&missing) {
Ok(config) => config,
Err(err) => panic!("absent file should default, not error: {err}"),
};
assert!(config.budget_targets().is_empty());
}
#[test]
fn present_file_parses_total_and_per_tool_targets() {
let dir = TempDir::new();
let path = dir.write(
"[budget]\n\
total_monthly_usd = 100.00\n\n\
[budget.per_tool]\n\
claude-code = 60.00\n\
codex = 40\n",
);
let config = match load_from(&path) {
Ok(config) => config,
Err(err) => panic!("valid config should parse: {err}"),
};
let targets = config.budget_targets();
assert!(!targets.is_empty());
assert_eq!(targets.total_monthly_usd, Some(cents(10_000, 2)));
assert_eq!(targets.per_tool.get("claude-code"), Some(¢s(6_000, 2)));
assert_eq!(targets.per_tool.get("codex"), Some(¢s(40, 0)));
}
#[test]
fn quoted_string_money_is_exact() {
let dir = TempDir::new();
let path = dir.write("[budget]\ntotal_monthly_usd = \"99.99\"\n");
let config = match load_from(&path) {
Ok(config) => config,
Err(err) => panic!("quoted money should parse: {err}"),
};
assert_eq!(
config.budget_targets().total_monthly_usd,
Some(cents(9_999, 2))
);
}
#[test]
fn malformed_file_is_a_typed_error_not_a_panic() {
let dir = TempDir::new();
let path = dir.write("[budget\ntotal_monthly_usd = 100\n");
match load_from(&path) {
Ok(config) => panic!("malformed config should error, got {config:?}"),
Err(err @ ConfigError::Parse { .. }) => {
let message = err.to_string();
assert!(message.contains("invalid config"), "message: {message}");
assert!(
!message.contains('\n'),
"status must be one line: {message}"
);
}
Err(other) => panic!("expected a Parse error, got {other}"),
}
}
#[test]
fn invalid_money_value_is_a_typed_error() {
let dir = TempDir::new();
let path = dir.write("[budget]\ntotal_monthly_usd = \"not-a-number\"\n");
match load_from(&path) {
Ok(config) => panic!("invalid money should error, got {config:?}"),
Err(ConfigError::Parse { .. }) => {}
Err(other) => panic!("expected a Parse error, got {other}"),
}
}
#[test]
fn unknown_keys_are_ignored_for_forward_compatibility() {
let dir = TempDir::new();
let path = dir.write(
"schema_version = 99\n\
[budget]\n\
total_monthly_usd = 50\n\
future_field = true\n\n\
[alerts]\n\
enabled = true\n",
);
let config = match load_from(&path) {
Ok(config) => config,
Err(err) => panic!("unknown keys should be ignored, not error: {err}"),
};
assert_eq!(
config.budget_targets().total_monthly_usd,
Some(cents(50, 0))
);
}
#[test]
fn empty_file_is_the_zero_config_default() {
let dir = TempDir::new();
let path = dir.write("");
let config = match load_from(&path) {
Ok(config) => config,
Err(err) => panic!("empty file should default: {err}"),
};
assert!(config.budget_targets().is_empty());
}
#[test]
fn alerts_default_off_with_canonical_thresholds() {
let dir = TempDir::new();
for contents in ["", "[budget]\ntotal_monthly_usd = 100\n"] {
let path = dir.write(contents);
let config = match load_from(&path) {
Ok(config) => config,
Err(err) => panic!("should default: {err}"),
};
assert!(
!config.alerts_enabled(),
"alerts must default OFF: {contents:?}"
);
let thresholds = config.alert_thresholds();
assert_eq!(
thresholds.quota_warn_fraction,
costroid_core::ALERT_WARN_FRACTION
);
assert_eq!(
thresholds.quota_critical_fraction,
costroid_core::ALERT_CRITICAL_FRACTION
);
}
}
#[test]
fn alerts_enabled_parses_and_keeps_default_thresholds() {
let dir = TempDir::new();
let path = dir.write("[alerts]\nenabled = true\n");
let config = match load_from(&path) {
Ok(config) => config,
Err(err) => panic!("alerts config should parse: {err}"),
};
assert!(config.alerts_enabled());
let thresholds = config.alert_thresholds();
assert_eq!(
thresholds.quota_warn_fraction,
costroid_core::ALERT_WARN_FRACTION
);
assert_eq!(
thresholds.quota_critical_fraction,
costroid_core::ALERT_CRITICAL_FRACTION
);
}
#[test]
fn alerts_threshold_overrides_apply() {
let dir = TempDir::new();
let path = dir.write("[alerts]\nenabled = true\nquota_warn = 0.5\nquota_critical = 0.9\n");
let config = match load_from(&path) {
Ok(config) => config,
Err(err) => panic!("alert overrides should parse: {err}"),
};
let thresholds = config.alert_thresholds();
assert_eq!(thresholds.quota_warn_fraction, 0.5);
assert_eq!(thresholds.quota_critical_fraction, 0.9);
}
#[test]
fn alerts_hostile_threshold_overrides_fall_back_to_defaults() {
let dir = TempDir::new();
let path = dir.write("[alerts]\nquota_warn = 0.0\nquota_critical = -0.5\n");
let config = match load_from(&path) {
Ok(config) => config,
Err(err) => panic!("odd-but-valid values should parse: {err}"),
};
let thresholds = config.alert_thresholds();
assert_eq!(
thresholds.quota_warn_fraction,
costroid_core::ALERT_WARN_FRACTION
);
assert_eq!(
thresholds.quota_critical_fraction,
costroid_core::ALERT_CRITICAL_FRACTION
);
}
#[test]
fn alerts_inverted_threshold_pair_falls_back_to_defaults() {
let dir = TempDir::new();
let path = dir.write("[alerts]\nenabled = true\nquota_warn = 0.9\nquota_critical = 0.5\n");
let config = match load_from(&path) {
Ok(config) => config,
Err(err) => panic!("inverted-but-valid values should parse: {err}"),
};
let thresholds = config.alert_thresholds();
assert_eq!(
thresholds.quota_warn_fraction,
costroid_core::ALERT_WARN_FRACTION
);
assert_eq!(
thresholds.quota_critical_fraction,
costroid_core::ALERT_CRITICAL_FRACTION
);
let path2 = dir.write("[alerts]\nenabled = true\nquota_critical = 0.5\n");
let config2 = match load_from(&path2) {
Ok(config) => config,
Err(err) => panic!("should parse: {err}"),
};
let t2 = config2.alert_thresholds();
assert_eq!(t2.quota_warn_fraction, costroid_core::ALERT_WARN_FRACTION);
assert_eq!(
t2.quota_critical_fraction,
costroid_core::ALERT_CRITICAL_FRACTION
);
}
#[test]
fn malformed_alerts_value_is_a_typed_error_not_a_panic() {
let dir = TempDir::new();
let path = dir.write("[alerts]\nenabled = \"yes\"\n");
match load_from(&path) {
Ok(config) => panic!("malformed alerts should error, got {config:?}"),
Err(ConfigError::Parse { .. }) => {}
Err(other) => panic!("expected a Parse error, got {other}"),
}
}
}