1use std::{
2 collections::HashMap,
3 path::{Path, PathBuf},
4};
5
6use anyhow::Context;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
11#[serde(rename_all = "lowercase")]
12pub enum IssueSource {
13 #[default]
14 Github,
15 Local,
16}
17
18#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
19#[serde(default)]
20pub struct Config {
21 pub project: ProjectConfig,
22 pub pipeline: PipelineConfig,
23 pub labels: LabelConfig,
24 pub multi_repo: MultiRepoConfig,
25 pub models: ModelConfig,
26 #[serde(default)]
27 pub repos: HashMap<String, PathBuf>,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
31#[serde(default)]
32pub struct MultiRepoConfig {
33 pub enabled: bool,
34 pub target_field: String,
35}
36
37impl Default for MultiRepoConfig {
38 fn default() -> Self {
39 Self { enabled: false, target_field: "target_repo".to_string() }
40 }
41}
42
43#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
48#[serde(default)]
49pub struct ModelConfig {
50 pub default: Option<String>,
51 pub planner: Option<String>,
52 pub implementer: Option<String>,
53 pub reviewer: Option<String>,
54 pub fixer: Option<String>,
55}
56
57impl ModelConfig {
58 pub fn model_for(&self, role: &str) -> Option<&str> {
61 let agent_override = match role {
62 "planner" => self.planner.as_deref(),
63 "implementer" => self.implementer.as_deref(),
64 "reviewer" => self.reviewer.as_deref(),
65 "fixer" => self.fixer.as_deref(),
66 _ => None,
67 };
68 agent_override.or(self.default.as_deref())
69 }
70}
71
72#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
73#[serde(default)]
74pub struct ProjectConfig {
75 pub name: Option<String>,
76 pub test: Option<String>,
77 pub lint: Option<String>,
78 pub issue_source: IssueSource,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
82#[serde(default)]
83pub struct PipelineConfig {
84 pub max_parallel: u32,
85 pub cost_budget: f64,
86 pub poll_interval: u64,
87 pub turn_limit: u32,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
91#[serde(default)]
92pub struct LabelConfig {
93 pub ready: String,
94 pub cooking: String,
95 pub complete: String,
96 pub failed: String,
97}
98
99impl Default for PipelineConfig {
100 fn default() -> Self {
101 Self { max_parallel: 2, cost_budget: 15.0, poll_interval: 60, turn_limit: 50 }
102 }
103}
104
105impl Default for LabelConfig {
106 fn default() -> Self {
107 Self {
108 ready: "o-ready".to_string(),
109 cooking: "o-cooking".to_string(),
110 complete: "o-complete".to_string(),
111 failed: "o-failed".to_string(),
112 }
113 }
114}
115
116#[derive(Debug, Deserialize, Default)]
119#[serde(default)]
120struct RawConfig {
121 project: Option<RawProjectConfig>,
122 pipeline: Option<RawPipelineConfig>,
123 labels: Option<RawLabelConfig>,
124 multi_repo: Option<RawMultiRepoConfig>,
125 models: Option<RawModelConfig>,
126 repos: Option<HashMap<String, PathBuf>>,
127}
128
129#[derive(Debug, Default, Deserialize)]
130#[serde(default)]
131struct RawProjectConfig {
132 name: Option<String>,
133 test: Option<String>,
134 lint: Option<String>,
135 issue_source: Option<IssueSource>,
136}
137
138#[derive(Debug, Default, Deserialize)]
139#[serde(default)]
140struct RawPipelineConfig {
141 max_parallel: Option<u32>,
142 cost_budget: Option<f64>,
143 poll_interval: Option<u64>,
144 turn_limit: Option<u32>,
145}
146
147#[derive(Debug, Default, Deserialize)]
148#[serde(default)]
149struct RawLabelConfig {
150 ready: Option<String>,
151 cooking: Option<String>,
152 complete: Option<String>,
153 failed: Option<String>,
154}
155
156#[derive(Debug, Default, Deserialize)]
157#[serde(default)]
158struct RawMultiRepoConfig {
159 enabled: Option<bool>,
160 target_field: Option<String>,
161}
162
163#[derive(Debug, Default, Deserialize)]
164#[serde(default)]
165struct RawModelConfig {
166 default: Option<String>,
167 planner: Option<String>,
168 implementer: Option<String>,
169 reviewer: Option<String>,
170 fixer: Option<String>,
171}
172
173impl Config {
174 pub fn load(project_dir: &Path) -> anyhow::Result<Self> {
181 let mut config = Self::default();
182
183 if let Some(config_dir) = dirs::config_dir() {
185 let user_path = config_dir.join("oven").join("recipe.toml");
186 if user_path.exists() {
187 let content = std::fs::read_to_string(&user_path)
188 .with_context(|| format!("reading user config: {}", user_path.display()))?;
189 let raw: RawConfig = toml::from_str(&content)
190 .with_context(|| format!("parsing user config: {}", user_path.display()))?;
191 apply_raw(&mut config, &raw, true);
192 }
193 }
194
195 let project_path = project_dir.join("recipe.toml");
197 if project_path.exists() {
198 let content = std::fs::read_to_string(&project_path)
199 .with_context(|| format!("reading project config: {}", project_path.display()))?;
200 let raw: RawConfig = toml::from_str(&content)
201 .with_context(|| format!("parsing project config: {}", project_path.display()))?;
202 apply_raw(&mut config, &raw, false);
203 }
204
205 config.validate()?;
206 Ok(config)
207 }
208
209 pub fn resolve_repo(&self, name: &str) -> anyhow::Result<PathBuf> {
213 let path = self
214 .repos
215 .get(name)
216 .with_context(|| format!("repo '{name}' not found in user config [repos] section"))?;
217
218 let expanded = if path.starts_with("~") {
219 dirs::home_dir().map_or_else(
220 || path.clone(),
221 |home| home.join(path.strip_prefix("~").unwrap_or(path)),
222 )
223 } else {
224 path.clone()
225 };
226
227 if !expanded.exists() {
228 anyhow::bail!("repo '{name}' path does not exist: {}", expanded.display());
229 }
230
231 Ok(expanded)
232 }
233
234 fn validate(&self) -> anyhow::Result<()> {
236 if self.pipeline.max_parallel == 0 {
237 anyhow::bail!("pipeline.max_parallel must be >= 1 (got 0, which would deadlock)");
238 }
239 if self.pipeline.poll_interval < 10 {
240 anyhow::bail!(
241 "pipeline.poll_interval must be >= 10 (got {}, which would hammer the API)",
242 self.pipeline.poll_interval
243 );
244 }
245 if !self.pipeline.cost_budget.is_finite() || self.pipeline.cost_budget <= 0.0 {
246 anyhow::bail!(
247 "pipeline.cost_budget must be a finite number > 0 (got {})",
248 self.pipeline.cost_budget
249 );
250 }
251 if self.pipeline.turn_limit == 0 {
252 anyhow::bail!("pipeline.turn_limit must be >= 1 (got 0)");
253 }
254 Ok(())
255 }
256
257 pub fn default_user_toml() -> String {
259 r#"# Global oven defaults (all projects inherit these)
260
261[pipeline]
262# max_parallel = 2
263# cost_budget = 15.0
264# poll_interval = 60
265# turn_limit = 50
266
267# [labels]
268# ready = "o-ready"
269# cooking = "o-cooking"
270# complete = "o-complete"
271# failed = "o-failed"
272
273# Multi-repo path mappings (only honored from user config)
274# [repos]
275# api = "~/dev/api"
276# web = "~/dev/web"
277"#
278 .to_string()
279 }
280
281 pub fn default_project_toml() -> String {
283 r#"[project]
284# name = "my-project" # auto-detected from git remote
285# test = "cargo test" # test command
286# lint = "cargo clippy" # lint command
287# issue_source = "github" # "github" (default) or "local"
288
289[pipeline]
290max_parallel = 2
291cost_budget = 15.0
292poll_interval = 60
293
294# [labels]
295# ready = "o-ready"
296# cooking = "o-cooking"
297# complete = "o-complete"
298# failed = "o-failed"
299
300# [models]
301# default = "sonnet"
302# implementer = "opus"
303# fixer = "opus"
304"#
305 .to_string()
306 }
307}
308
309fn apply_raw(config: &mut Config, raw: &RawConfig, allow_repos: bool) {
312 if let Some(ref project) = raw.project {
313 if project.name.is_some() {
314 config.project.name.clone_from(&project.name);
315 }
316 if project.test.is_some() {
317 config.project.test.clone_from(&project.test);
318 }
319 if project.lint.is_some() {
320 config.project.lint.clone_from(&project.lint);
321 }
322 if let Some(ref source) = project.issue_source {
323 config.project.issue_source = source.clone();
324 }
325 }
326
327 if let Some(ref pipeline) = raw.pipeline {
328 if let Some(v) = pipeline.max_parallel {
329 config.pipeline.max_parallel = v;
330 }
331 if let Some(v) = pipeline.cost_budget {
332 config.pipeline.cost_budget = v;
333 }
334 if let Some(v) = pipeline.poll_interval {
335 config.pipeline.poll_interval = v;
336 }
337 if let Some(v) = pipeline.turn_limit {
338 config.pipeline.turn_limit = v;
339 }
340 }
341
342 if let Some(ref labels) = raw.labels {
343 if let Some(ref v) = labels.ready {
344 config.labels.ready.clone_from(v);
345 }
346 if let Some(ref v) = labels.cooking {
347 config.labels.cooking.clone_from(v);
348 }
349 if let Some(ref v) = labels.complete {
350 config.labels.complete.clone_from(v);
351 }
352 if let Some(ref v) = labels.failed {
353 config.labels.failed.clone_from(v);
354 }
355 }
356
357 if let Some(ref multi_repo) = raw.multi_repo {
359 if let Some(v) = multi_repo.enabled {
360 config.multi_repo.enabled = v;
361 }
362 if let Some(ref v) = multi_repo.target_field {
363 config.multi_repo.target_field.clone_from(v);
364 }
365 }
366
367 if let Some(ref models) = raw.models {
368 if models.default.is_some() {
369 config.models.default.clone_from(&models.default);
370 }
371 if models.planner.is_some() {
372 config.models.planner.clone_from(&models.planner);
373 }
374 if models.implementer.is_some() {
375 config.models.implementer.clone_from(&models.implementer);
376 }
377 if models.reviewer.is_some() {
378 config.models.reviewer.clone_from(&models.reviewer);
379 }
380 if models.fixer.is_some() {
381 config.models.fixer.clone_from(&models.fixer);
382 }
383 }
384
385 if allow_repos {
388 if let Some(ref repos) = raw.repos {
389 config.repos.clone_from(repos);
390 }
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use proptest::prelude::*;
397
398 use super::*;
399
400 proptest! {
401 #[test]
402 fn config_toml_roundtrip(
403 max_parallel in 1..100u32,
404 cost_budget in 0.0..1000.0f64,
405 poll_interval in 1..3600u64,
406 turn_limit in 1..200u32,
407 ready in "[a-z][a-z0-9-]{1,20}",
408 cooking in "[a-z][a-z0-9-]{1,20}",
409 complete in "[a-z][a-z0-9-]{1,20}",
410 failed in "[a-z][a-z0-9-]{1,20}",
411 ) {
412 let config = Config {
413 project: ProjectConfig::default(),
414 pipeline: PipelineConfig { max_parallel, cost_budget, poll_interval, turn_limit },
415 labels: LabelConfig { ready, cooking, complete, failed },
416 multi_repo: MultiRepoConfig::default(),
417 models: ModelConfig::default(),
418 repos: HashMap::new(),
419 };
420 let serialized = toml::to_string(&config).unwrap();
421 let deserialized: Config = toml::from_str(&serialized).unwrap();
422 assert_eq!(config.pipeline.max_parallel, deserialized.pipeline.max_parallel);
423 assert!((config.pipeline.cost_budget - deserialized.pipeline.cost_budget).abs() < 1e-6);
424 assert_eq!(config.pipeline.poll_interval, deserialized.pipeline.poll_interval);
425 assert_eq!(config.pipeline.turn_limit, deserialized.pipeline.turn_limit);
426 assert_eq!(config.labels, deserialized.labels);
427 }
428
429 #[test]
430 fn partial_toml_always_parses(
431 max_parallel in proptest::option::of(1..100u32),
432 cost_budget in proptest::option::of(0.0..1000.0f64),
433 ) {
434 let mut parts = vec!["[pipeline]".to_string()];
435 if let Some(mp) = max_parallel {
436 parts.push(format!("max_parallel = {mp}"));
437 }
438 if let Some(cb) = cost_budget {
439 parts.push(format!("cost_budget = {cb}"));
440 }
441 let toml_str = parts.join("\n");
442 let raw: RawConfig = toml::from_str(&toml_str).unwrap();
443 let mut config = Config::default();
444 apply_raw(&mut config, &raw, false);
445 if let Some(mp) = max_parallel {
446 assert_eq!(config.pipeline.max_parallel, mp);
447 }
448 }
449 }
450
451 #[test]
452 fn defaults_are_correct() {
453 let config = Config::default();
454 assert_eq!(config.pipeline.max_parallel, 2);
455 assert!(
456 (config.pipeline.cost_budget - 15.0).abs() < f64::EPSILON,
457 "cost_budget should be 15.0"
458 );
459 assert_eq!(config.pipeline.poll_interval, 60);
460 assert_eq!(config.pipeline.turn_limit, 50);
461 assert_eq!(config.labels.ready, "o-ready");
462 assert_eq!(config.labels.cooking, "o-cooking");
463 assert_eq!(config.labels.complete, "o-complete");
464 assert_eq!(config.labels.failed, "o-failed");
465 assert!(config.project.name.is_none());
466 assert!(config.repos.is_empty());
467 assert!(!config.multi_repo.enabled);
468 assert_eq!(config.multi_repo.target_field, "target_repo");
469 }
470
471 #[test]
472 fn load_from_valid_toml() {
473 let toml_str = r#"
474[project]
475name = "test-project"
476test = "cargo test"
477
478[pipeline]
479max_parallel = 4
480cost_budget = 20.0
481"#;
482 let raw: RawConfig = toml::from_str(toml_str).unwrap();
483 let mut config = Config::default();
484 apply_raw(&mut config, &raw, false);
485
486 assert_eq!(config.project.name.as_deref(), Some("test-project"));
487 assert_eq!(config.project.test.as_deref(), Some("cargo test"));
488 assert_eq!(config.pipeline.max_parallel, 4);
489 assert!((config.pipeline.cost_budget - 20.0).abs() < f64::EPSILON);
490 assert_eq!(config.pipeline.poll_interval, 60);
492 }
493
494 #[test]
495 fn project_overrides_user() {
496 let user_toml = r"
497[pipeline]
498max_parallel = 3
499cost_budget = 10.0
500poll_interval = 120
501";
502 let project_toml = r"
503[pipeline]
504max_parallel = 1
505cost_budget = 5.0
506";
507 let mut config = Config::default();
508
509 let user_raw: RawConfig = toml::from_str(user_toml).unwrap();
510 apply_raw(&mut config, &user_raw, true);
511 assert_eq!(config.pipeline.max_parallel, 3);
512 assert_eq!(config.pipeline.poll_interval, 120);
513
514 let project_raw: RawConfig = toml::from_str(project_toml).unwrap();
515 apply_raw(&mut config, &project_raw, false);
516 assert_eq!(config.pipeline.max_parallel, 1);
517 assert!((config.pipeline.cost_budget - 5.0).abs() < f64::EPSILON);
518 assert_eq!(config.pipeline.poll_interval, 120);
520 }
521
522 #[test]
523 fn repos_ignored_in_project_config() {
524 let project_toml = r#"
525[repos]
526evil = "/tmp/evil"
527"#;
528 let mut config = Config::default();
529 let raw: RawConfig = toml::from_str(project_toml).unwrap();
530 apply_raw(&mut config, &raw, false);
531 assert!(config.repos.is_empty());
532 }
533
534 #[test]
535 fn repos_honored_in_user_config() {
536 let user_toml = r#"
537[repos]
538api = "/home/user/dev/api"
539"#;
540 let mut config = Config::default();
541 let raw: RawConfig = toml::from_str(user_toml).unwrap();
542 apply_raw(&mut config, &raw, true);
543 assert_eq!(config.repos.get("api").unwrap(), Path::new("/home/user/dev/api"));
544 }
545
546 #[test]
547 fn missing_file_returns_defaults() {
548 let dir = tempfile::tempdir().unwrap();
549 let config = Config::load(dir.path()).unwrap();
550 assert_eq!(config, Config::default());
551 }
552
553 #[test]
554 fn invalid_toml_returns_error() {
555 let dir = tempfile::tempdir().unwrap();
556 std::fs::write(dir.path().join("recipe.toml"), "this is not [valid toml").unwrap();
557 let result = Config::load(dir.path());
558 assert!(result.is_err());
559 let err = result.unwrap_err().to_string();
560 assert!(err.contains("parsing project config"), "error was: {err}");
561 }
562
563 #[test]
564 fn default_user_toml_parses() {
565 let toml_str = Config::default_user_toml();
566 let raw: RawConfig = toml::from_str(&toml_str).unwrap();
567 let mut config = Config::default();
568 apply_raw(&mut config, &raw, true);
569 assert_eq!(config.pipeline.max_parallel, 2);
571 assert!(config.repos.is_empty());
572 }
573
574 #[test]
575 fn default_project_toml_parses() {
576 let toml_str = Config::default_project_toml();
577 let raw: RawConfig = toml::from_str(&toml_str).unwrap();
578 let mut config = Config::default();
579 apply_raw(&mut config, &raw, false);
580 assert_eq!(config.pipeline.max_parallel, 2);
582 }
583
584 #[test]
585 fn config_roundtrip_serialize_deserialize() {
586 let config = Config {
587 project: ProjectConfig {
588 name: Some("roundtrip".to_string()),
589 test: Some("make test".to_string()),
590 lint: None,
591 issue_source: IssueSource::Github,
592 },
593 pipeline: PipelineConfig { max_parallel: 5, cost_budget: 25.0, ..Default::default() },
594 labels: LabelConfig::default(),
595 multi_repo: MultiRepoConfig::default(),
596 models: ModelConfig::default(),
597 repos: HashMap::from([("svc".to_string(), PathBuf::from("/tmp/svc"))]),
598 };
599 let serialized = toml::to_string(&config).unwrap();
600 let deserialized: Config = toml::from_str(&serialized).unwrap();
601 assert_eq!(config, deserialized);
602 }
603
604 #[test]
605 fn multi_repo_config_from_project_toml() {
606 let toml_str = r#"
607[multi_repo]
608enabled = true
609target_field = "repo"
610"#;
611 let raw: RawConfig = toml::from_str(toml_str).unwrap();
612 let mut config = Config::default();
613 apply_raw(&mut config, &raw, false);
614 assert!(config.multi_repo.enabled);
615 assert_eq!(config.multi_repo.target_field, "repo");
616 }
617
618 #[test]
619 fn multi_repo_defaults_when_not_specified() {
620 let toml_str = r"
621[pipeline]
622max_parallel = 1
623";
624 let raw: RawConfig = toml::from_str(toml_str).unwrap();
625 let mut config = Config::default();
626 apply_raw(&mut config, &raw, false);
627 assert!(!config.multi_repo.enabled);
628 assert_eq!(config.multi_repo.target_field, "target_repo");
629 }
630
631 #[test]
632 fn resolve_repo_finds_existing_path() {
633 let dir = tempfile::tempdir().unwrap();
634 let mut config = Config::default();
635 config.repos.insert("test-repo".to_string(), dir.path().to_path_buf());
636
637 let resolved = config.resolve_repo("test-repo").unwrap();
638 assert_eq!(resolved, dir.path());
639 }
640
641 #[test]
642 fn resolve_repo_missing_name_errors() {
643 let config = Config::default();
644 let result = config.resolve_repo("nonexistent");
645 assert!(result.is_err());
646 assert!(result.unwrap_err().to_string().contains("not found in user config"));
647 }
648
649 #[test]
650 fn resolve_repo_missing_path_errors() {
651 let mut config = Config::default();
652 config.repos.insert("bad".to_string(), PathBuf::from("/nonexistent/path/xyz"));
653 let result = config.resolve_repo("bad");
654 assert!(result.is_err());
655 assert!(result.unwrap_err().to_string().contains("does not exist"));
656 }
657
658 #[test]
659 fn issue_source_defaults_to_github() {
660 let config = Config::default();
661 assert_eq!(config.project.issue_source, IssueSource::Github);
662 }
663
664 #[test]
665 fn issue_source_local_parses() {
666 let toml_str = r#"
667[project]
668issue_source = "local"
669"#;
670 let raw: RawConfig = toml::from_str(toml_str).unwrap();
671 let mut config = Config::default();
672 apply_raw(&mut config, &raw, false);
673 assert_eq!(config.project.issue_source, IssueSource::Local);
674 }
675
676 #[test]
677 fn issue_source_github_parses() {
678 let toml_str = r#"
679[project]
680issue_source = "github"
681"#;
682 let raw: RawConfig = toml::from_str(toml_str).unwrap();
683 let mut config = Config::default();
684 apply_raw(&mut config, &raw, false);
685 assert_eq!(config.project.issue_source, IssueSource::Github);
686 }
687
688 #[test]
689 fn validate_rejects_zero_max_parallel() {
690 let mut config = Config::default();
691 config.pipeline.max_parallel = 0;
692 let err = config.validate().unwrap_err().to_string();
693 assert!(err.contains("max_parallel"), "error was: {err}");
694 }
695
696 #[test]
697 fn validate_rejects_low_poll_interval() {
698 let mut config = Config::default();
699 config.pipeline.poll_interval = 5;
700 let err = config.validate().unwrap_err().to_string();
701 assert!(err.contains("poll_interval"), "error was: {err}");
702 }
703
704 #[test]
705 fn validate_rejects_zero_cost_budget() {
706 let mut config = Config::default();
707 config.pipeline.cost_budget = 0.0;
708 let err = config.validate().unwrap_err().to_string();
709 assert!(err.contains("cost_budget"), "error was: {err}");
710 }
711
712 #[test]
713 fn validate_rejects_nan_cost_budget() {
714 let mut config = Config::default();
715 config.pipeline.cost_budget = f64::NAN;
716 let err = config.validate().unwrap_err().to_string();
717 assert!(err.contains("cost_budget"), "error was: {err}");
718 }
719
720 #[test]
721 fn validate_rejects_infinity_cost_budget() {
722 let mut config = Config::default();
723 config.pipeline.cost_budget = f64::INFINITY;
724 let err = config.validate().unwrap_err().to_string();
725 assert!(err.contains("cost_budget"), "error was: {err}");
726 }
727
728 #[test]
729 fn validate_rejects_zero_turn_limit() {
730 let mut config = Config::default();
731 config.pipeline.turn_limit = 0;
732 let err = config.validate().unwrap_err().to_string();
733 assert!(err.contains("turn_limit"), "error was: {err}");
734 }
735
736 #[test]
737 fn validate_accepts_defaults() {
738 Config::default().validate().unwrap();
739 }
740
741 #[test]
742 fn issue_source_invalid_errors() {
743 let toml_str = r#"
744[project]
745issue_source = "jira"
746"#;
747 let result = toml::from_str::<RawConfig>(toml_str);
748 assert!(result.is_err());
749 }
750
751 #[test]
752 fn issue_source_roundtrip() {
753 let config = Config {
754 project: ProjectConfig { issue_source: IssueSource::Local, ..Default::default() },
755 ..Default::default()
756 };
757 let serialized = toml::to_string(&config).unwrap();
758 let deserialized: Config = toml::from_str(&serialized).unwrap();
759 assert_eq!(deserialized.project.issue_source, IssueSource::Local);
760 }
761
762 #[test]
763 fn model_for_returns_agent_override() {
764 let models = ModelConfig {
765 default: Some("sonnet".to_string()),
766 implementer: Some("opus".to_string()),
767 ..Default::default()
768 };
769 assert_eq!(models.model_for("implementer"), Some("opus"));
770 assert_eq!(models.model_for("reviewer"), Some("sonnet"));
771 }
772
773 #[test]
774 fn model_for_returns_none_when_unset() {
775 let models = ModelConfig::default();
776 assert_eq!(models.model_for("planner"), None);
777 }
778
779 #[test]
780 fn model_config_from_toml() {
781 let toml_str = r#"
782[models]
783default = "sonnet"
784implementer = "opus"
785fixer = "opus"
786"#;
787 let raw: RawConfig = toml::from_str(toml_str).unwrap();
788 let mut config = Config::default();
789 apply_raw(&mut config, &raw, false);
790 assert_eq!(config.models.default.as_deref(), Some("sonnet"));
791 assert_eq!(config.models.implementer.as_deref(), Some("opus"));
792 assert_eq!(config.models.fixer.as_deref(), Some("opus"));
793 assert!(config.models.planner.is_none());
794 assert!(config.models.reviewer.is_none());
795 }
796
797 #[test]
798 fn model_config_project_overrides_user() {
799 let user_toml = r#"
800[models]
801default = "sonnet"
802implementer = "sonnet"
803"#;
804 let project_toml = r#"
805[models]
806implementer = "opus"
807"#;
808 let mut config = Config::default();
809 let user_raw: RawConfig = toml::from_str(user_toml).unwrap();
810 apply_raw(&mut config, &user_raw, true);
811 assert_eq!(config.models.implementer.as_deref(), Some("sonnet"));
812
813 let project_raw: RawConfig = toml::from_str(project_toml).unwrap();
814 apply_raw(&mut config, &project_raw, false);
815 assert_eq!(config.models.implementer.as_deref(), Some("opus"));
816 assert_eq!(config.models.default.as_deref(), Some("sonnet"));
818 }
819
820 #[test]
821 fn model_config_defaults_when_not_specified() {
822 let toml_str = r"
823[pipeline]
824max_parallel = 1
825";
826 let raw: RawConfig = toml::from_str(toml_str).unwrap();
827 let mut config = Config::default();
828 apply_raw(&mut config, &raw, false);
829 assert_eq!(config.models, ModelConfig::default());
830 }
831}