use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use tracing::warn;
pub fn config_path(workspace: impl AsRef<Path>) -> PathBuf {
workspace.as_ref().join("pitboss.toml")
}
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
#[serde(default)]
pub struct Config {
pub models: ModelRoles,
pub retries: RetryBudgets,
pub audit: AuditConfig,
pub git: GitConfig,
pub tests: TestsConfig,
pub budgets: Budgets,
pub agent: AgentConfig,
pub caveman: CavemanConfig,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(default)]
pub struct ModelRoles {
pub planner: String,
pub implementer: String,
pub auditor: String,
pub fixer: String,
}
impl Default for ModelRoles {
fn default() -> Self {
Self {
planner: "claude-opus-4-7".to_string(),
implementer: "claude-opus-4-7".to_string(),
auditor: "claude-opus-4-7".to_string(),
fixer: "claude-opus-4-7".to_string(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(default)]
pub struct RetryBudgets {
pub fixer_max_attempts: u32,
pub max_phase_attempts: u32,
}
impl Default for RetryBudgets {
fn default() -> Self {
Self {
fixer_max_attempts: 2,
max_phase_attempts: 3,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(default)]
pub struct AuditConfig {
pub enabled: bool,
pub small_fix_line_limit: u32,
}
impl Default for AuditConfig {
fn default() -> Self {
Self {
enabled: true,
small_fix_line_limit: 30,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(default)]
pub struct GitConfig {
pub branch_prefix: String,
pub create_pr: bool,
}
impl Default for GitConfig {
fn default() -> Self {
Self {
branch_prefix: "pitboss/run-".to_string(),
create_pr: false,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(default)]
pub struct Budgets {
pub max_total_tokens: Option<u64>,
pub max_total_usd: Option<f64>,
pub pricing: HashMap<String, ModelPricing>,
}
impl Default for Budgets {
fn default() -> Self {
let mut pricing = HashMap::new();
pricing.insert(
"claude-opus-4-7".to_string(),
ModelPricing {
input_per_million_usd: 15.0,
output_per_million_usd: 75.0,
},
);
pricing.insert(
"claude-sonnet-4-6".to_string(),
ModelPricing {
input_per_million_usd: 3.0,
output_per_million_usd: 15.0,
},
);
pricing.insert(
"claude-haiku-4-5".to_string(),
ModelPricing {
input_per_million_usd: 1.0,
output_per_million_usd: 5.0,
},
);
Self {
max_total_tokens: None,
max_total_usd: None,
pricing,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct ModelPricing {
pub input_per_million_usd: f64,
pub output_per_million_usd: f64,
}
impl ModelPricing {
pub fn cost_usd(&self, input: u64, output: u64) -> f64 {
let input = (input as f64) * self.input_per_million_usd / 1_000_000.0;
let output = (output as f64) * self.output_per_million_usd / 1_000_000.0;
input + output
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(default)]
pub struct TestsConfig {
pub command: Option<String>,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(default)]
pub struct AgentConfig {
pub backend: Option<String>,
pub claude_code: BackendOverrides,
pub codex: BackendOverrides,
pub aider: BackendOverrides,
pub gemini: BackendOverrides,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(default)]
pub struct CavemanConfig {
pub enabled: bool,
pub intensity: CavemanIntensity,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum CavemanIntensity {
Lite,
#[default]
Full,
Ultra,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(default)]
pub struct BackendOverrides {
pub binary: Option<PathBuf>,
pub extra_args: Vec<String>,
pub model: Option<String>,
}
pub fn load(workspace: impl AsRef<Path>) -> Result<Config> {
let path = config_path(workspace.as_ref());
let text = match fs::read_to_string(&path) {
Ok(s) => s,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(Config::default()),
Err(e) => {
return Err(anyhow::Error::new(e).context(format!("config::load: reading {:?}", path)));
}
};
parse(&text).with_context(|| format!("config::load: parsing {:?}", path))
}
pub fn parse(text: &str) -> Result<Config> {
if text.trim().is_empty() {
return Ok(Config::default());
}
let value: toml::Value = toml::from_str(text).context("pitboss.toml is not valid TOML")?;
for unknown in find_unknown_keys(&value) {
warn!(key = %unknown, "pitboss.toml: unknown key {:?} (ignored)", unknown);
}
let cfg: Config = value
.try_into()
.context("pitboss.toml does not match the expected schema")?;
Ok(cfg)
}
fn find_unknown_keys(value: &toml::Value) -> Vec<String> {
let mut out = Vec::new();
let toml::Value::Table(top) = value else {
return out;
};
for (section, sub) in top {
let known_subkeys: &[&str] = match section.as_str() {
"models" => &["planner", "implementer", "auditor", "fixer"],
"retries" => &["fixer_max_attempts", "max_phase_attempts"],
"audit" => &["enabled", "small_fix_line_limit"],
"git" => &["branch_prefix", "create_pr"],
"tests" => &["command"],
"budgets" => &["max_total_tokens", "max_total_usd", "pricing"],
"agent" => &["backend", "claude_code", "codex", "aider", "gemini"],
"caveman" => &["enabled", "intensity"],
_ => {
out.push(section.clone());
continue;
}
};
if let toml::Value::Table(sub_table) = sub {
for sub_key in sub_table.keys() {
if !known_subkeys.contains(&sub_key.as_str()) {
out.push(format!("{}.{}", section, sub_key));
}
}
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn defaults_are_self_consistent() {
let cfg = Config::default();
assert_eq!(cfg.models.planner, "claude-opus-4-7");
assert_eq!(cfg.models.implementer, "claude-opus-4-7");
assert_eq!(cfg.models.auditor, "claude-opus-4-7");
assert_eq!(cfg.models.fixer, "claude-opus-4-7");
assert_eq!(cfg.retries.fixer_max_attempts, 2);
assert_eq!(cfg.retries.max_phase_attempts, 3);
assert!(cfg.audit.enabled);
assert_eq!(cfg.audit.small_fix_line_limit, 30);
assert_eq!(cfg.git.branch_prefix, "pitboss/run-");
assert!(!cfg.git.create_pr);
assert!(cfg.tests.command.is_none());
assert_eq!(cfg.budgets.max_total_tokens, None);
assert_eq!(cfg.budgets.max_total_usd, None);
assert!(cfg.budgets.pricing.contains_key("claude-opus-4-7"));
assert_eq!(cfg.agent, AgentConfig::default());
assert_eq!(cfg.agent.backend, None);
assert!(!cfg.caveman.enabled);
assert_eq!(cfg.caveman.intensity, CavemanIntensity::Full);
}
#[test]
fn caveman_section_round_trips_full_form() {
let text = "
[caveman]
enabled = true
intensity = \"ultra\"
";
let cfg = parse(text).unwrap();
assert!(cfg.caveman.enabled);
assert_eq!(cfg.caveman.intensity, CavemanIntensity::Ultra);
let value: toml::Value = toml::from_str(text).unwrap();
assert!(find_unknown_keys(&value).is_empty());
}
#[test]
fn caveman_section_accepts_each_intensity_level() {
for (s, expected) in [
("lite", CavemanIntensity::Lite),
("full", CavemanIntensity::Full),
("ultra", CavemanIntensity::Ultra),
] {
let text = format!("[caveman]\nenabled = true\nintensity = \"{s}\"\n");
let cfg = parse(&text).unwrap();
assert_eq!(cfg.caveman.intensity, expected, "intensity {s}");
}
}
#[test]
fn caveman_section_rejects_unknown_intensity() {
let text = "
[caveman]
enabled = true
intensity = \"galaxybrain\"
";
let err = parse(text).unwrap_err();
let msg = format!("{:#}", err);
assert!(
msg.contains("expected schema"),
"expected schema error for unknown intensity, got: {msg}"
);
}
#[test]
fn caveman_unknown_subkeys_are_flagged() {
let text = "
[caveman]
enabled = true
mode = \"wenyan\"
";
let value: toml::Value = toml::from_str(text).unwrap();
let unknown = find_unknown_keys(&value);
assert!(unknown.contains(&"caveman.mode".to_string()));
}
#[test]
fn model_pricing_cost_usd_is_per_million_tokens() {
let p = ModelPricing {
input_per_million_usd: 10.0,
output_per_million_usd: 100.0,
};
let cost = p.cost_usd(1_000_000, 100_000);
assert!((cost - 20.0).abs() < 1e-9, "cost: {cost}");
}
#[test]
fn budgets_section_parses_full_form() {
let text = "
[budgets]
max_total_tokens = 1_000_000
max_total_usd = 5.0
[budgets.pricing.claude-opus-4-7]
input_per_million_usd = 12.5
output_per_million_usd = 60.0
[budgets.pricing.custom-model]
input_per_million_usd = 0.5
output_per_million_usd = 2.0
";
let cfg = parse(text).unwrap();
assert_eq!(cfg.budgets.max_total_tokens, Some(1_000_000));
assert_eq!(cfg.budgets.max_total_usd, Some(5.0));
let opus = cfg.budgets.pricing.get("claude-opus-4-7").unwrap();
assert_eq!(opus.input_per_million_usd, 12.5);
assert_eq!(opus.output_per_million_usd, 60.0);
let custom = cfg.budgets.pricing.get("custom-model").unwrap();
assert_eq!(custom.input_per_million_usd, 0.5);
}
#[test]
fn budgets_pricing_subkeys_are_not_flagged_as_unknown() {
let text = "
[budgets]
max_total_tokens = 100
[budgets.pricing.brand-new-model]
input_per_million_usd = 1.0
output_per_million_usd = 2.0
";
let value: toml::Value = toml::from_str(text).unwrap();
let unknown = find_unknown_keys(&value);
assert!(unknown.is_empty(), "unexpected unknown keys: {:?}", unknown);
}
#[test]
fn agent_section_round_trips_full_form() {
let text = "
[agent]
backend = \"codex\"
[agent.claude_code]
binary = \"/opt/anthropic/claude\"
extra_args = [\"--max-turns\", \"50\"]
model = \"claude-opus-4-7\"
[agent.codex]
binary = \"/usr/local/bin/codex\"
extra_args = [\"--quiet\"]
model = \"gpt-5\"
[agent.aider]
binary = \"/usr/local/bin/aider\"
extra_args = []
model = \"sonnet\"
[agent.gemini]
binary = \"/usr/local/bin/gemini\"
extra_args = [\"--no-stream\"]
model = \"gemini-2.5-pro\"
";
let cfg = parse(text).unwrap();
assert_eq!(cfg.agent.backend.as_deref(), Some("codex"));
assert_eq!(
cfg.agent.claude_code.binary,
Some(PathBuf::from("/opt/anthropic/claude"))
);
assert_eq!(
cfg.agent.claude_code.extra_args,
vec!["--max-turns".to_string(), "50".to_string()]
);
assert_eq!(
cfg.agent.claude_code.model.as_deref(),
Some("claude-opus-4-7")
);
assert_eq!(
cfg.agent.codex.binary,
Some(PathBuf::from("/usr/local/bin/codex"))
);
assert_eq!(cfg.agent.codex.extra_args, vec!["--quiet".to_string()]);
assert_eq!(cfg.agent.codex.model.as_deref(), Some("gpt-5"));
assert_eq!(
cfg.agent.aider.binary,
Some(PathBuf::from("/usr/local/bin/aider"))
);
assert!(cfg.agent.aider.extra_args.is_empty());
assert_eq!(cfg.agent.aider.model.as_deref(), Some("sonnet"));
assert_eq!(
cfg.agent.gemini.binary,
Some(PathBuf::from("/usr/local/bin/gemini"))
);
assert_eq!(cfg.agent.gemini.extra_args, vec!["--no-stream".to_string()]);
assert_eq!(cfg.agent.gemini.model.as_deref(), Some("gemini-2.5-pro"));
let value: toml::Value = toml::from_str(text).unwrap();
assert!(find_unknown_keys(&value).is_empty());
}
#[test]
fn agent_backend_alone_round_trips_with_defaults() {
let text = "
[agent]
backend = \"codex\"
";
let cfg = parse(text).unwrap();
assert_eq!(cfg.agent.backend.as_deref(), Some("codex"));
assert_eq!(cfg.agent.codex, BackendOverrides::default());
assert_eq!(cfg.agent.claude_code, BackendOverrides::default());
assert_eq!(cfg.agent.aider, BackendOverrides::default());
assert_eq!(cfg.agent.gemini, BackendOverrides::default());
}
#[test]
fn empty_input_yields_defaults() {
assert_eq!(parse("").unwrap(), Config::default());
assert_eq!(parse(" \n\t\n").unwrap(), Config::default());
}
#[test]
fn full_input_overrides_every_field() {
let text = "
[models]
planner = \"a\"
implementer = \"b\"
auditor = \"c\"
fixer = \"d\"
[retries]
fixer_max_attempts = 7
max_phase_attempts = 11
[audit]
enabled = false
small_fix_line_limit = 5
[git]
branch_prefix = \"work/\"
create_pr = true
[tests]
command = \"make check\"
";
let cfg = parse(text).unwrap();
assert_eq!(cfg.models.planner, "a");
assert_eq!(cfg.models.implementer, "b");
assert_eq!(cfg.models.auditor, "c");
assert_eq!(cfg.models.fixer, "d");
assert_eq!(cfg.retries.fixer_max_attempts, 7);
assert_eq!(cfg.retries.max_phase_attempts, 11);
assert!(!cfg.audit.enabled);
assert_eq!(cfg.audit.small_fix_line_limit, 5);
assert_eq!(cfg.git.branch_prefix, "work/");
assert!(cfg.git.create_pr);
assert_eq!(cfg.tests.command.as_deref(), Some("make check"));
}
#[test]
fn partial_input_fills_remaining_with_defaults() {
let text = "
[git]
create_pr = true
";
let cfg = parse(text).unwrap();
assert!(cfg.git.create_pr);
assert_eq!(cfg.git.branch_prefix, "pitboss/run-");
assert_eq!(cfg.models, ModelRoles::default());
assert_eq!(cfg.retries, RetryBudgets::default());
assert_eq!(cfg.audit, AuditConfig::default());
}
#[test]
fn partial_section_fills_missing_subkeys() {
let text = "
[models]
implementer = \"custom-impl\"
";
let cfg = parse(text).unwrap();
assert_eq!(cfg.models.implementer, "custom-impl");
assert_eq!(cfg.models.planner, ModelRoles::default().planner);
assert_eq!(cfg.models.auditor, ModelRoles::default().auditor);
assert_eq!(cfg.models.fixer, ModelRoles::default().fixer);
}
#[test]
fn malformed_toml_is_an_error() {
let err = parse("[models\nplanner = \"x\"").unwrap_err();
let msg = format!("{:#}", err);
assert!(msg.contains("not valid TOML"), "msg: {msg}");
}
#[test]
fn wrong_value_type_is_an_error() {
let text = "
[retries]
fixer_max_attempts = \"two\"
";
let err = parse(text).unwrap_err();
let msg = format!("{:#}", err);
assert!(
msg.contains("expected schema"),
"expected schema error, got: {msg}"
);
}
#[test]
fn unknown_keys_are_collected_not_errored() {
let text = "
something_extra = 1
[models]
planner = \"p\"
new_role = \"x\"
[telemetry]
sink = \"stdout\"
";
let cfg = parse(text).unwrap();
assert_eq!(cfg.models.planner, "p");
let toml_value: toml::Value = toml::from_str(text).unwrap();
let unknown = find_unknown_keys(&toml_value);
assert!(unknown.contains(&"something_extra".to_string()));
assert!(unknown.contains(&"models.new_role".to_string()));
assert!(unknown.contains(&"telemetry".to_string()));
}
#[test]
fn no_unknown_keys_for_canonical_input() {
let text = "
[models]
planner = \"p\"
implementer = \"i\"
auditor = \"a\"
fixer = \"f\"
[retries]
fixer_max_attempts = 1
max_phase_attempts = 2
[audit]
enabled = true
small_fix_line_limit = 10
[git]
branch_prefix = \"x/\"
create_pr = false
[tests]
command = \"cargo test\"
";
let value: toml::Value = toml::from_str(text).unwrap();
assert!(find_unknown_keys(&value).is_empty());
}
#[test]
fn load_returns_defaults_when_file_missing() {
let dir = tempdir().unwrap();
let cfg = load(dir.path()).unwrap();
assert_eq!(cfg, Config::default());
}
#[test]
fn load_reads_file_from_workspace() {
let dir = tempdir().unwrap();
std::fs::write(
dir.path().join("pitboss.toml"),
"[git]\nbranch_prefix = \"loaded/\"\n",
)
.unwrap();
let cfg = load(dir.path()).unwrap();
assert_eq!(cfg.git.branch_prefix, "loaded/");
}
#[test]
fn load_surfaces_parse_errors_with_path_context() {
let dir = tempdir().unwrap();
std::fs::write(dir.path().join("pitboss.toml"), "[broken").unwrap();
let err = load(dir.path()).unwrap_err();
let msg = format!("{:#}", err);
assert!(msg.contains("pitboss.toml"), "msg: {msg}");
}
#[test]
fn init_template_round_trips_through_loader() {
let dir = tempdir().unwrap();
crate::cli::init::run(dir.path()).unwrap();
let cfg = load(dir.path()).unwrap();
assert_eq!(cfg, Config::default());
}
}