1use std::{
2 collections::HashMap,
3 path::{Path, PathBuf},
4};
5
6use anyhow::Context;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
11#[serde(rename_all = "lowercase")]
12pub enum IssueSource {
13 #[default]
14 Github,
15 Local,
16}
17
18#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
20#[serde(rename_all = "lowercase")]
21pub enum MergeStrategy {
22 Squash,
23 #[default]
24 Merge,
25 Rebase,
26}
27
28impl MergeStrategy {
29 pub const fn gh_flag(&self) -> &'static str {
31 match self {
32 Self::Squash => "--squash",
33 Self::Merge => "--merge",
34 Self::Rebase => "--rebase",
35 }
36 }
37}
38
39#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
40#[serde(default)]
41pub struct Config {
42 pub project: ProjectConfig,
43 pub pipeline: PipelineConfig,
44 pub labels: LabelConfig,
45 pub multi_repo: MultiRepoConfig,
46 pub models: ModelConfig,
47 #[serde(default)]
48 pub repos: HashMap<String, PathBuf>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
52#[serde(default)]
53pub struct MultiRepoConfig {
54 pub enabled: bool,
55 pub target_field: String,
56}
57
58impl Default for MultiRepoConfig {
59 fn default() -> Self {
60 Self { enabled: false, target_field: "target_repo".to_string() }
61 }
62}
63
64#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
69#[serde(default)]
70pub struct ModelConfig {
71 pub default: Option<String>,
72 pub planner: Option<String>,
73 pub implementer: Option<String>,
74 pub reviewer: Option<String>,
75 pub fixer: Option<String>,
76}
77
78impl ModelConfig {
79 pub fn model_for(&self, role: &str) -> Option<&str> {
82 let agent_override = match role {
83 "planner" => self.planner.as_deref(),
84 "implementer" => self.implementer.as_deref(),
85 "reviewer" => self.reviewer.as_deref(),
86 "fixer" => self.fixer.as_deref(),
87 _ => None,
88 };
89 agent_override.or(self.default.as_deref())
90 }
91}
92
93#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
94#[serde(default)]
95pub struct ProjectConfig {
96 pub name: Option<String>,
97 pub test: Option<String>,
98 pub lint: Option<String>,
99 pub issue_source: IssueSource,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
103#[serde(default)]
104pub struct PipelineConfig {
105 pub max_parallel: u32,
106 pub cost_budget: f64,
107 pub poll_interval: u64,
108 pub turn_limit: u32,
109 pub merge_strategy: MergeStrategy,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
113#[serde(default)]
114pub struct LabelConfig {
115 pub ready: String,
116 pub cooking: String,
117 pub complete: String,
118 pub failed: String,
119}
120
121impl Default for PipelineConfig {
122 fn default() -> Self {
123 Self {
124 max_parallel: 2,
125 cost_budget: 15.0,
126 poll_interval: 60,
127 turn_limit: 50,
128 merge_strategy: MergeStrategy::default(),
129 }
130 }
131}
132
133impl Default for LabelConfig {
134 fn default() -> Self {
135 Self {
136 ready: "o-ready".to_string(),
137 cooking: "o-cooking".to_string(),
138 complete: "o-complete".to_string(),
139 failed: "o-failed".to_string(),
140 }
141 }
142}
143
144#[derive(Debug, Deserialize, Default)]
147#[serde(default)]
148struct RawConfig {
149 project: Option<RawProjectConfig>,
150 pipeline: Option<RawPipelineConfig>,
151 labels: Option<RawLabelConfig>,
152 multi_repo: Option<RawMultiRepoConfig>,
153 models: Option<RawModelConfig>,
154 repos: Option<HashMap<String, PathBuf>>,
155}
156
157#[derive(Debug, Default, Deserialize)]
158#[serde(default)]
159struct RawProjectConfig {
160 name: Option<String>,
161 test: Option<String>,
162 lint: Option<String>,
163 issue_source: Option<IssueSource>,
164}
165
166#[derive(Debug, Default, Deserialize)]
167#[serde(default)]
168struct RawPipelineConfig {
169 max_parallel: Option<u32>,
170 cost_budget: Option<f64>,
171 poll_interval: Option<u64>,
172 turn_limit: Option<u32>,
173 merge_strategy: Option<MergeStrategy>,
174}
175
176#[derive(Debug, Default, Deserialize)]
177#[serde(default)]
178struct RawLabelConfig {
179 ready: Option<String>,
180 cooking: Option<String>,
181 complete: Option<String>,
182 failed: Option<String>,
183}
184
185#[derive(Debug, Default, Deserialize)]
186#[serde(default)]
187struct RawMultiRepoConfig {
188 enabled: Option<bool>,
189 target_field: Option<String>,
190}
191
192#[derive(Debug, Default, Deserialize)]
193#[serde(default)]
194struct RawModelConfig {
195 default: Option<String>,
196 planner: Option<String>,
197 implementer: Option<String>,
198 reviewer: Option<String>,
199 fixer: Option<String>,
200}
201
202impl Config {
203 pub fn load(project_dir: &Path) -> anyhow::Result<Self> {
210 let mut config = Self::default();
211
212 if let Some(config_dir) = dirs::config_dir() {
214 let user_path = config_dir.join("oven").join("recipe.toml");
215 if user_path.exists() {
216 let content = std::fs::read_to_string(&user_path)
217 .with_context(|| format!("reading user config: {}", user_path.display()))?;
218 let raw: RawConfig = toml::from_str(&content)
219 .with_context(|| format!("parsing user config: {}", user_path.display()))?;
220 apply_raw(&mut config, &raw, true);
221 }
222 }
223
224 let project_path = project_dir.join("recipe.toml");
226 if project_path.exists() {
227 let content = std::fs::read_to_string(&project_path)
228 .with_context(|| format!("reading project config: {}", project_path.display()))?;
229 let raw: RawConfig = toml::from_str(&content)
230 .with_context(|| format!("parsing project config: {}", project_path.display()))?;
231 apply_raw(&mut config, &raw, false);
232 }
233
234 config.validate()?;
235 Ok(config)
236 }
237
238 pub fn resolve_repo(&self, name: &str) -> anyhow::Result<PathBuf> {
242 let path = self
243 .repos
244 .get(name)
245 .with_context(|| format!("repo '{name}' not found in user config [repos] section"))?;
246
247 let expanded = if path.starts_with("~") {
248 dirs::home_dir().map_or_else(
249 || path.clone(),
250 |home| home.join(path.strip_prefix("~").unwrap_or(path)),
251 )
252 } else {
253 path.clone()
254 };
255
256 if !expanded.exists() {
257 anyhow::bail!("repo '{name}' path does not exist: {}", expanded.display());
258 }
259
260 Ok(expanded)
261 }
262
263 fn validate(&self) -> anyhow::Result<()> {
265 if self.pipeline.max_parallel == 0 {
266 anyhow::bail!("pipeline.max_parallel must be >= 1 (got 0, which would deadlock)");
267 }
268 if self.pipeline.poll_interval < 10 {
269 anyhow::bail!(
270 "pipeline.poll_interval must be >= 10 (got {}, which would hammer the API)",
271 self.pipeline.poll_interval
272 );
273 }
274 if !self.pipeline.cost_budget.is_finite() || self.pipeline.cost_budget <= 0.0 {
275 anyhow::bail!(
276 "pipeline.cost_budget must be a finite number > 0 (got {})",
277 self.pipeline.cost_budget
278 );
279 }
280 if self.pipeline.turn_limit == 0 {
281 anyhow::bail!("pipeline.turn_limit must be >= 1 (got 0)");
282 }
283 Ok(())
284 }
285
286 pub fn default_user_toml() -> String {
288 r#"# Global oven defaults (all projects inherit these)
289
290[pipeline]
291# max_parallel = 2
292# cost_budget = 15.0
293# poll_interval = 60
294# turn_limit = 50
295# merge_strategy = "merge" # "merge" (default), "squash", or "rebase"
296
297# [labels]
298# ready = "o-ready"
299# cooking = "o-cooking"
300# complete = "o-complete"
301# failed = "o-failed"
302
303# Multi-repo path mappings (only honored from user config)
304# [repos]
305# api = "~/dev/api"
306# web = "~/dev/web"
307"#
308 .to_string()
309 }
310
311 pub fn default_project_toml() -> String {
313 r#"[project]
314# name = "my-project" # auto-detected from git remote
315# test = "cargo test" # test command
316# lint = "cargo clippy" # lint command
317# issue_source = "github" # "github" (default) or "local"
318
319[pipeline]
320max_parallel = 2
321cost_budget = 15.0
322poll_interval = 60
323# merge_strategy = "merge" # "merge" (default), "squash", or "rebase"
324
325# [labels]
326# ready = "o-ready"
327# cooking = "o-cooking"
328# complete = "o-complete"
329# failed = "o-failed"
330
331# [models]
332# default = "sonnet"
333# implementer = "opus"
334# fixer = "opus"
335"#
336 .to_string()
337 }
338}
339
340fn apply_raw(config: &mut Config, raw: &RawConfig, allow_repos: bool) {
343 if let Some(ref project) = raw.project {
344 if project.name.is_some() {
345 config.project.name.clone_from(&project.name);
346 }
347 if project.test.is_some() {
348 config.project.test.clone_from(&project.test);
349 }
350 if project.lint.is_some() {
351 config.project.lint.clone_from(&project.lint);
352 }
353 if let Some(ref source) = project.issue_source {
354 config.project.issue_source = source.clone();
355 }
356 }
357
358 if let Some(ref pipeline) = raw.pipeline {
359 if let Some(v) = pipeline.max_parallel {
360 config.pipeline.max_parallel = v;
361 }
362 if let Some(v) = pipeline.cost_budget {
363 config.pipeline.cost_budget = v;
364 }
365 if let Some(v) = pipeline.poll_interval {
366 config.pipeline.poll_interval = v;
367 }
368 if let Some(v) = pipeline.turn_limit {
369 config.pipeline.turn_limit = v;
370 }
371 if let Some(ref v) = pipeline.merge_strategy {
372 config.pipeline.merge_strategy = v.clone();
373 }
374 }
375
376 if let Some(ref labels) = raw.labels {
377 if let Some(ref v) = labels.ready {
378 config.labels.ready.clone_from(v);
379 }
380 if let Some(ref v) = labels.cooking {
381 config.labels.cooking.clone_from(v);
382 }
383 if let Some(ref v) = labels.complete {
384 config.labels.complete.clone_from(v);
385 }
386 if let Some(ref v) = labels.failed {
387 config.labels.failed.clone_from(v);
388 }
389 }
390
391 if let Some(ref multi_repo) = raw.multi_repo {
393 if let Some(v) = multi_repo.enabled {
394 config.multi_repo.enabled = v;
395 }
396 if let Some(ref v) = multi_repo.target_field {
397 config.multi_repo.target_field.clone_from(v);
398 }
399 }
400
401 if let Some(ref models) = raw.models {
402 if models.default.is_some() {
403 config.models.default.clone_from(&models.default);
404 }
405 if models.planner.is_some() {
406 config.models.planner.clone_from(&models.planner);
407 }
408 if models.implementer.is_some() {
409 config.models.implementer.clone_from(&models.implementer);
410 }
411 if models.reviewer.is_some() {
412 config.models.reviewer.clone_from(&models.reviewer);
413 }
414 if models.fixer.is_some() {
415 config.models.fixer.clone_from(&models.fixer);
416 }
417 }
418
419 if allow_repos {
422 if let Some(ref repos) = raw.repos {
423 config.repos.clone_from(repos);
424 }
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use proptest::prelude::*;
431
432 use super::*;
433
434 proptest! {
435 #[test]
436 fn config_toml_roundtrip(
437 max_parallel in 1..100u32,
438 cost_budget in 0.0..1000.0f64,
439 poll_interval in 1..3600u64,
440 turn_limit in 1..200u32,
441 ready in "[a-z][a-z0-9-]{1,20}",
442 cooking in "[a-z][a-z0-9-]{1,20}",
443 complete in "[a-z][a-z0-9-]{1,20}",
444 failed in "[a-z][a-z0-9-]{1,20}",
445 ) {
446 let config = Config {
447 project: ProjectConfig::default(),
448 pipeline: PipelineConfig { max_parallel, cost_budget, poll_interval, turn_limit, ..Default::default() },
449 labels: LabelConfig { ready, cooking, complete, failed },
450 multi_repo: MultiRepoConfig::default(),
451 models: ModelConfig::default(),
452 repos: HashMap::new(),
453 };
454 let serialized = toml::to_string(&config).unwrap();
455 let deserialized: Config = toml::from_str(&serialized).unwrap();
456 assert_eq!(config.pipeline.max_parallel, deserialized.pipeline.max_parallel);
457 assert!((config.pipeline.cost_budget - deserialized.pipeline.cost_budget).abs() < 1e-6);
458 assert_eq!(config.pipeline.poll_interval, deserialized.pipeline.poll_interval);
459 assert_eq!(config.pipeline.turn_limit, deserialized.pipeline.turn_limit);
460 assert_eq!(config.labels, deserialized.labels);
461 }
462
463 #[test]
464 fn partial_toml_always_parses(
465 max_parallel in proptest::option::of(1..100u32),
466 cost_budget in proptest::option::of(0.0..1000.0f64),
467 ) {
468 let mut parts = vec!["[pipeline]".to_string()];
469 if let Some(mp) = max_parallel {
470 parts.push(format!("max_parallel = {mp}"));
471 }
472 if let Some(cb) = cost_budget {
473 parts.push(format!("cost_budget = {cb}"));
474 }
475 let toml_str = parts.join("\n");
476 let raw: RawConfig = toml::from_str(&toml_str).unwrap();
477 let mut config = Config::default();
478 apply_raw(&mut config, &raw, false);
479 if let Some(mp) = max_parallel {
480 assert_eq!(config.pipeline.max_parallel, mp);
481 }
482 }
483 }
484
485 #[test]
486 fn defaults_are_correct() {
487 let config = Config::default();
488 assert_eq!(config.pipeline.max_parallel, 2);
489 assert!(
490 (config.pipeline.cost_budget - 15.0).abs() < f64::EPSILON,
491 "cost_budget should be 15.0"
492 );
493 assert_eq!(config.pipeline.poll_interval, 60);
494 assert_eq!(config.pipeline.turn_limit, 50);
495 assert_eq!(config.labels.ready, "o-ready");
496 assert_eq!(config.labels.cooking, "o-cooking");
497 assert_eq!(config.labels.complete, "o-complete");
498 assert_eq!(config.labels.failed, "o-failed");
499 assert!(config.project.name.is_none());
500 assert!(config.repos.is_empty());
501 assert!(!config.multi_repo.enabled);
502 assert_eq!(config.multi_repo.target_field, "target_repo");
503 }
504
505 #[test]
506 fn load_from_valid_toml() {
507 let toml_str = r#"
508[project]
509name = "test-project"
510test = "cargo test"
511
512[pipeline]
513max_parallel = 4
514cost_budget = 20.0
515"#;
516 let raw: RawConfig = toml::from_str(toml_str).unwrap();
517 let mut config = Config::default();
518 apply_raw(&mut config, &raw, false);
519
520 assert_eq!(config.project.name.as_deref(), Some("test-project"));
521 assert_eq!(config.project.test.as_deref(), Some("cargo test"));
522 assert_eq!(config.pipeline.max_parallel, 4);
523 assert!((config.pipeline.cost_budget - 20.0).abs() < f64::EPSILON);
524 assert_eq!(config.pipeline.poll_interval, 60);
526 }
527
528 #[test]
529 fn project_overrides_user() {
530 let user_toml = r"
531[pipeline]
532max_parallel = 3
533cost_budget = 10.0
534poll_interval = 120
535";
536 let project_toml = r"
537[pipeline]
538max_parallel = 1
539cost_budget = 5.0
540";
541 let mut config = Config::default();
542
543 let user_raw: RawConfig = toml::from_str(user_toml).unwrap();
544 apply_raw(&mut config, &user_raw, true);
545 assert_eq!(config.pipeline.max_parallel, 3);
546 assert_eq!(config.pipeline.poll_interval, 120);
547
548 let project_raw: RawConfig = toml::from_str(project_toml).unwrap();
549 apply_raw(&mut config, &project_raw, false);
550 assert_eq!(config.pipeline.max_parallel, 1);
551 assert!((config.pipeline.cost_budget - 5.0).abs() < f64::EPSILON);
552 assert_eq!(config.pipeline.poll_interval, 120);
554 }
555
556 #[test]
557 fn repos_ignored_in_project_config() {
558 let project_toml = r#"
559[repos]
560evil = "/tmp/evil"
561"#;
562 let mut config = Config::default();
563 let raw: RawConfig = toml::from_str(project_toml).unwrap();
564 apply_raw(&mut config, &raw, false);
565 assert!(config.repos.is_empty());
566 }
567
568 #[test]
569 fn repos_honored_in_user_config() {
570 let user_toml = r#"
571[repos]
572api = "/home/user/dev/api"
573"#;
574 let mut config = Config::default();
575 let raw: RawConfig = toml::from_str(user_toml).unwrap();
576 apply_raw(&mut config, &raw, true);
577 assert_eq!(config.repos.get("api").unwrap(), Path::new("/home/user/dev/api"));
578 }
579
580 #[test]
581 fn missing_file_returns_defaults() {
582 let dir = tempfile::tempdir().unwrap();
583 let config = Config::load(dir.path()).unwrap();
584 assert_eq!(config, Config::default());
585 }
586
587 #[test]
588 fn invalid_toml_returns_error() {
589 let dir = tempfile::tempdir().unwrap();
590 std::fs::write(dir.path().join("recipe.toml"), "this is not [valid toml").unwrap();
591 let result = Config::load(dir.path());
592 assert!(result.is_err());
593 let err = result.unwrap_err().to_string();
594 assert!(err.contains("parsing project config"), "error was: {err}");
595 }
596
597 #[test]
598 fn default_user_toml_parses() {
599 let toml_str = Config::default_user_toml();
600 let raw: RawConfig = toml::from_str(&toml_str).unwrap();
601 let mut config = Config::default();
602 apply_raw(&mut config, &raw, true);
603 assert_eq!(config.pipeline.max_parallel, 2);
605 assert!(config.repos.is_empty());
606 }
607
608 #[test]
609 fn default_project_toml_parses() {
610 let toml_str = Config::default_project_toml();
611 let raw: RawConfig = toml::from_str(&toml_str).unwrap();
612 let mut config = Config::default();
613 apply_raw(&mut config, &raw, false);
614 assert_eq!(config.pipeline.max_parallel, 2);
616 }
617
618 #[test]
619 fn config_roundtrip_serialize_deserialize() {
620 let config = Config {
621 project: ProjectConfig {
622 name: Some("roundtrip".to_string()),
623 test: Some("make test".to_string()),
624 lint: None,
625 issue_source: IssueSource::Github,
626 },
627 pipeline: PipelineConfig { max_parallel: 5, cost_budget: 25.0, ..Default::default() },
628 labels: LabelConfig::default(),
629 multi_repo: MultiRepoConfig::default(),
630 models: ModelConfig::default(),
631 repos: HashMap::from([("svc".to_string(), PathBuf::from("/tmp/svc"))]),
632 };
633 let serialized = toml::to_string(&config).unwrap();
634 let deserialized: Config = toml::from_str(&serialized).unwrap();
635 assert_eq!(config, deserialized);
636 }
637
638 #[test]
639 fn multi_repo_config_from_project_toml() {
640 let toml_str = r#"
641[multi_repo]
642enabled = true
643target_field = "repo"
644"#;
645 let raw: RawConfig = toml::from_str(toml_str).unwrap();
646 let mut config = Config::default();
647 apply_raw(&mut config, &raw, false);
648 assert!(config.multi_repo.enabled);
649 assert_eq!(config.multi_repo.target_field, "repo");
650 }
651
652 #[test]
653 fn multi_repo_defaults_when_not_specified() {
654 let toml_str = r"
655[pipeline]
656max_parallel = 1
657";
658 let raw: RawConfig = toml::from_str(toml_str).unwrap();
659 let mut config = Config::default();
660 apply_raw(&mut config, &raw, false);
661 assert!(!config.multi_repo.enabled);
662 assert_eq!(config.multi_repo.target_field, "target_repo");
663 }
664
665 #[test]
666 fn resolve_repo_finds_existing_path() {
667 let dir = tempfile::tempdir().unwrap();
668 let mut config = Config::default();
669 config.repos.insert("test-repo".to_string(), dir.path().to_path_buf());
670
671 let resolved = config.resolve_repo("test-repo").unwrap();
672 assert_eq!(resolved, dir.path());
673 }
674
675 #[test]
676 fn resolve_repo_missing_name_errors() {
677 let config = Config::default();
678 let result = config.resolve_repo("nonexistent");
679 assert!(result.is_err());
680 assert!(result.unwrap_err().to_string().contains("not found in user config"));
681 }
682
683 #[test]
684 fn resolve_repo_missing_path_errors() {
685 let mut config = Config::default();
686 config.repos.insert("bad".to_string(), PathBuf::from("/nonexistent/path/xyz"));
687 let result = config.resolve_repo("bad");
688 assert!(result.is_err());
689 assert!(result.unwrap_err().to_string().contains("does not exist"));
690 }
691
692 #[test]
693 fn issue_source_defaults_to_github() {
694 let config = Config::default();
695 assert_eq!(config.project.issue_source, IssueSource::Github);
696 }
697
698 #[test]
699 fn issue_source_local_parses() {
700 let toml_str = r#"
701[project]
702issue_source = "local"
703"#;
704 let raw: RawConfig = toml::from_str(toml_str).unwrap();
705 let mut config = Config::default();
706 apply_raw(&mut config, &raw, false);
707 assert_eq!(config.project.issue_source, IssueSource::Local);
708 }
709
710 #[test]
711 fn issue_source_github_parses() {
712 let toml_str = r#"
713[project]
714issue_source = "github"
715"#;
716 let raw: RawConfig = toml::from_str(toml_str).unwrap();
717 let mut config = Config::default();
718 apply_raw(&mut config, &raw, false);
719 assert_eq!(config.project.issue_source, IssueSource::Github);
720 }
721
722 #[test]
723 fn validate_rejects_zero_max_parallel() {
724 let mut config = Config::default();
725 config.pipeline.max_parallel = 0;
726 let err = config.validate().unwrap_err().to_string();
727 assert!(err.contains("max_parallel"), "error was: {err}");
728 }
729
730 #[test]
731 fn validate_rejects_low_poll_interval() {
732 let mut config = Config::default();
733 config.pipeline.poll_interval = 5;
734 let err = config.validate().unwrap_err().to_string();
735 assert!(err.contains("poll_interval"), "error was: {err}");
736 }
737
738 #[test]
739 fn validate_rejects_zero_cost_budget() {
740 let mut config = Config::default();
741 config.pipeline.cost_budget = 0.0;
742 let err = config.validate().unwrap_err().to_string();
743 assert!(err.contains("cost_budget"), "error was: {err}");
744 }
745
746 #[test]
747 fn validate_rejects_nan_cost_budget() {
748 let mut config = Config::default();
749 config.pipeline.cost_budget = f64::NAN;
750 let err = config.validate().unwrap_err().to_string();
751 assert!(err.contains("cost_budget"), "error was: {err}");
752 }
753
754 #[test]
755 fn validate_rejects_infinity_cost_budget() {
756 let mut config = Config::default();
757 config.pipeline.cost_budget = f64::INFINITY;
758 let err = config.validate().unwrap_err().to_string();
759 assert!(err.contains("cost_budget"), "error was: {err}");
760 }
761
762 #[test]
763 fn validate_rejects_zero_turn_limit() {
764 let mut config = Config::default();
765 config.pipeline.turn_limit = 0;
766 let err = config.validate().unwrap_err().to_string();
767 assert!(err.contains("turn_limit"), "error was: {err}");
768 }
769
770 #[test]
771 fn validate_accepts_defaults() {
772 Config::default().validate().unwrap();
773 }
774
775 #[test]
776 fn issue_source_invalid_errors() {
777 let toml_str = r#"
778[project]
779issue_source = "jira"
780"#;
781 let result = toml::from_str::<RawConfig>(toml_str);
782 assert!(result.is_err());
783 }
784
785 #[test]
786 fn issue_source_roundtrip() {
787 let config = Config {
788 project: ProjectConfig { issue_source: IssueSource::Local, ..Default::default() },
789 ..Default::default()
790 };
791 let serialized = toml::to_string(&config).unwrap();
792 let deserialized: Config = toml::from_str(&serialized).unwrap();
793 assert_eq!(deserialized.project.issue_source, IssueSource::Local);
794 }
795
796 #[test]
797 fn model_for_returns_agent_override() {
798 let models = ModelConfig {
799 default: Some("sonnet".to_string()),
800 implementer: Some("opus".to_string()),
801 ..Default::default()
802 };
803 assert_eq!(models.model_for("implementer"), Some("opus"));
804 assert_eq!(models.model_for("reviewer"), Some("sonnet"));
805 }
806
807 #[test]
808 fn model_for_returns_none_when_unset() {
809 let models = ModelConfig::default();
810 assert_eq!(models.model_for("planner"), None);
811 }
812
813 #[test]
814 fn model_config_from_toml() {
815 let toml_str = r#"
816[models]
817default = "sonnet"
818implementer = "opus"
819fixer = "opus"
820"#;
821 let raw: RawConfig = toml::from_str(toml_str).unwrap();
822 let mut config = Config::default();
823 apply_raw(&mut config, &raw, false);
824 assert_eq!(config.models.default.as_deref(), Some("sonnet"));
825 assert_eq!(config.models.implementer.as_deref(), Some("opus"));
826 assert_eq!(config.models.fixer.as_deref(), Some("opus"));
827 assert!(config.models.planner.is_none());
828 assert!(config.models.reviewer.is_none());
829 }
830
831 #[test]
832 fn model_config_project_overrides_user() {
833 let user_toml = r#"
834[models]
835default = "sonnet"
836implementer = "sonnet"
837"#;
838 let project_toml = r#"
839[models]
840implementer = "opus"
841"#;
842 let mut config = Config::default();
843 let user_raw: RawConfig = toml::from_str(user_toml).unwrap();
844 apply_raw(&mut config, &user_raw, true);
845 assert_eq!(config.models.implementer.as_deref(), Some("sonnet"));
846
847 let project_raw: RawConfig = toml::from_str(project_toml).unwrap();
848 apply_raw(&mut config, &project_raw, false);
849 assert_eq!(config.models.implementer.as_deref(), Some("opus"));
850 assert_eq!(config.models.default.as_deref(), Some("sonnet"));
852 }
853
854 #[test]
855 fn model_config_defaults_when_not_specified() {
856 let toml_str = r"
857[pipeline]
858max_parallel = 1
859";
860 let raw: RawConfig = toml::from_str(toml_str).unwrap();
861 let mut config = Config::default();
862 apply_raw(&mut config, &raw, false);
863 assert_eq!(config.models, ModelConfig::default());
864 }
865
866 #[test]
867 fn merge_strategy_defaults_to_merge() {
868 let config = Config::default();
869 assert_eq!(config.pipeline.merge_strategy, MergeStrategy::Merge);
870 }
871
872 #[test]
873 fn merge_strategy_squash_parses() {
874 let toml_str = r#"
875[pipeline]
876merge_strategy = "squash"
877"#;
878 let raw: RawConfig = toml::from_str(toml_str).unwrap();
879 let mut config = Config::default();
880 apply_raw(&mut config, &raw, false);
881 assert_eq!(config.pipeline.merge_strategy, MergeStrategy::Squash);
882 }
883
884 #[test]
885 fn merge_strategy_rebase_parses() {
886 let toml_str = r#"
887[pipeline]
888merge_strategy = "rebase"
889"#;
890 let raw: RawConfig = toml::from_str(toml_str).unwrap();
891 let mut config = Config::default();
892 apply_raw(&mut config, &raw, false);
893 assert_eq!(config.pipeline.merge_strategy, MergeStrategy::Rebase);
894 }
895
896 #[test]
897 fn merge_strategy_invalid_errors() {
898 let toml_str = r#"
899[pipeline]
900merge_strategy = "fast-forward"
901"#;
902 let result = toml::from_str::<RawConfig>(toml_str);
903 assert!(result.is_err());
904 }
905
906 #[test]
907 fn merge_strategy_project_overrides_user() {
908 let user_toml = r#"
909[pipeline]
910merge_strategy = "squash"
911"#;
912 let project_toml = r#"
913[pipeline]
914merge_strategy = "rebase"
915"#;
916 let mut config = Config::default();
917 let user_raw: RawConfig = toml::from_str(user_toml).unwrap();
918 apply_raw(&mut config, &user_raw, true);
919 assert_eq!(config.pipeline.merge_strategy, MergeStrategy::Squash);
920
921 let project_raw: RawConfig = toml::from_str(project_toml).unwrap();
922 apply_raw(&mut config, &project_raw, false);
923 assert_eq!(config.pipeline.merge_strategy, MergeStrategy::Rebase);
924 }
925
926 #[test]
927 fn merge_strategy_gh_flags() {
928 assert_eq!(MergeStrategy::Squash.gh_flag(), "--squash");
929 assert_eq!(MergeStrategy::Merge.gh_flag(), "--merge");
930 assert_eq!(MergeStrategy::Rebase.gh_flag(), "--rebase");
931 }
932
933 #[test]
934 fn merge_strategy_roundtrip() {
935 let config = Config {
936 pipeline: PipelineConfig {
937 merge_strategy: MergeStrategy::Rebase,
938 ..Default::default()
939 },
940 ..Default::default()
941 };
942 let serialized = toml::to_string(&config).unwrap();
943 let deserialized: Config = toml::from_str(&serialized).unwrap();
944 assert_eq!(deserialized.pipeline.merge_strategy, MergeStrategy::Rebase);
945 }
946}