Skip to main content

oven_cli/config/
mod.rs

1use std::{
2    collections::HashMap,
3    path::{Path, PathBuf},
4};
5
6use anyhow::Context;
7use serde::{Deserialize, Serialize};
8
9/// Where oven reads issues from: GitHub API or local `.oven/issues/` files.
10#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
11#[serde(rename_all = "lowercase")]
12pub enum IssueSource {
13    #[default]
14    Github,
15    Local,
16}
17
18#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
19#[serde(default)]
20pub struct Config {
21    pub project: ProjectConfig,
22    pub pipeline: PipelineConfig,
23    pub labels: LabelConfig,
24    pub multi_repo: MultiRepoConfig,
25    #[serde(default)]
26    pub repos: HashMap<String, PathBuf>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
30#[serde(default)]
31pub struct MultiRepoConfig {
32    pub enabled: bool,
33    pub target_field: String,
34}
35
36impl Default for MultiRepoConfig {
37    fn default() -> Self {
38        Self { enabled: false, target_field: "target_repo".to_string() }
39    }
40}
41
42#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
43#[serde(default)]
44pub struct ProjectConfig {
45    pub name: Option<String>,
46    pub test: Option<String>,
47    pub lint: Option<String>,
48    pub issue_source: IssueSource,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
52#[serde(default)]
53pub struct PipelineConfig {
54    pub max_parallel: u32,
55    pub cost_budget: f64,
56    pub poll_interval: u64,
57    pub turn_limit: u32,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
61#[serde(default)]
62pub struct LabelConfig {
63    pub ready: String,
64    pub cooking: String,
65    pub complete: String,
66    pub failed: String,
67}
68
69impl Default for PipelineConfig {
70    fn default() -> Self {
71        Self { max_parallel: 2, cost_budget: 15.0, poll_interval: 60, turn_limit: 50 }
72    }
73}
74
75impl Default for LabelConfig {
76    fn default() -> Self {
77        Self {
78            ready: "o-ready".to_string(),
79            cooking: "o-cooking".to_string(),
80            complete: "o-complete".to_string(),
81            failed: "o-failed".to_string(),
82        }
83    }
84}
85
86/// Intermediate representation for partial config deserialization.
87/// All fields are optional so we can tell which ones were explicitly set.
88#[derive(Debug, Deserialize, Default)]
89#[serde(default)]
90struct RawConfig {
91    project: Option<RawProjectConfig>,
92    pipeline: Option<RawPipelineConfig>,
93    labels: Option<RawLabelConfig>,
94    multi_repo: Option<RawMultiRepoConfig>,
95    repos: Option<HashMap<String, PathBuf>>,
96}
97
98#[derive(Debug, Default, Deserialize)]
99#[serde(default)]
100struct RawProjectConfig {
101    name: Option<String>,
102    test: Option<String>,
103    lint: Option<String>,
104    issue_source: Option<IssueSource>,
105}
106
107#[derive(Debug, Default, Deserialize)]
108#[serde(default)]
109struct RawPipelineConfig {
110    max_parallel: Option<u32>,
111    cost_budget: Option<f64>,
112    poll_interval: Option<u64>,
113    turn_limit: Option<u32>,
114}
115
116#[derive(Debug, Default, Deserialize)]
117#[serde(default)]
118struct RawLabelConfig {
119    ready: Option<String>,
120    cooking: Option<String>,
121    complete: Option<String>,
122    failed: Option<String>,
123}
124
125#[derive(Debug, Default, Deserialize)]
126#[serde(default)]
127struct RawMultiRepoConfig {
128    enabled: Option<bool>,
129    target_field: Option<String>,
130}
131
132impl Config {
133    /// Load config by merging user defaults with project overrides.
134    ///
135    /// User config: `~/.config/oven/recipe.toml`
136    /// Project config: `recipe.toml` in `project_dir`
137    ///
138    /// Missing files are not errors - defaults are used instead.
139    pub fn load(project_dir: &Path) -> anyhow::Result<Self> {
140        let mut config = Self::default();
141
142        // Load user config
143        if let Some(config_dir) = dirs::config_dir() {
144            let user_path = config_dir.join("oven").join("recipe.toml");
145            if user_path.exists() {
146                let content = std::fs::read_to_string(&user_path)
147                    .with_context(|| format!("reading user config: {}", user_path.display()))?;
148                let raw: RawConfig = toml::from_str(&content)
149                    .with_context(|| format!("parsing user config: {}", user_path.display()))?;
150                apply_raw(&mut config, &raw, true);
151            }
152        }
153
154        // Load project config (overrides user config)
155        let project_path = project_dir.join("recipe.toml");
156        if project_path.exists() {
157            let content = std::fs::read_to_string(&project_path)
158                .with_context(|| format!("reading project config: {}", project_path.display()))?;
159            let raw: RawConfig = toml::from_str(&content)
160                .with_context(|| format!("parsing project config: {}", project_path.display()))?;
161            apply_raw(&mut config, &raw, false);
162        }
163
164        config.validate()?;
165        Ok(config)
166    }
167
168    /// Resolve a repo name to a local path.
169    ///
170    /// Returns an error if the repo name is not in the config or the path doesn't exist.
171    pub fn resolve_repo(&self, name: &str) -> anyhow::Result<PathBuf> {
172        let path = self
173            .repos
174            .get(name)
175            .with_context(|| format!("repo '{name}' not found in user config [repos] section"))?;
176
177        let expanded = if path.starts_with("~") {
178            dirs::home_dir().map_or_else(
179                || path.clone(),
180                |home| home.join(path.strip_prefix("~").unwrap_or(path)),
181            )
182        } else {
183            path.clone()
184        };
185
186        if !expanded.exists() {
187            anyhow::bail!("repo '{name}' path does not exist: {}", expanded.display());
188        }
189
190        Ok(expanded)
191    }
192
193    /// Validate config values that could cause hangs or resource exhaustion.
194    fn validate(&self) -> anyhow::Result<()> {
195        if self.pipeline.max_parallel == 0 {
196            anyhow::bail!("pipeline.max_parallel must be >= 1 (got 0, which would deadlock)");
197        }
198        if self.pipeline.poll_interval < 10 {
199            anyhow::bail!(
200                "pipeline.poll_interval must be >= 10 (got {}, which would hammer the API)",
201                self.pipeline.poll_interval
202            );
203        }
204        if !self.pipeline.cost_budget.is_finite() || self.pipeline.cost_budget <= 0.0 {
205            anyhow::bail!(
206                "pipeline.cost_budget must be a finite number > 0 (got {})",
207                self.pipeline.cost_budget
208            );
209        }
210        if self.pipeline.turn_limit == 0 {
211            anyhow::bail!("pipeline.turn_limit must be >= 1 (got 0)");
212        }
213        Ok(())
214    }
215
216    /// Generate a starter user TOML for `~/.config/oven/recipe.toml`.
217    pub fn default_user_toml() -> String {
218        r#"# Global oven defaults (all projects inherit these)
219
220[pipeline]
221# max_parallel = 2
222# cost_budget = 15.0
223# poll_interval = 60
224# turn_limit = 50
225
226# [labels]
227# ready = "o-ready"
228# cooking = "o-cooking"
229# complete = "o-complete"
230# failed = "o-failed"
231
232# Multi-repo path mappings (only honored from user config)
233# [repos]
234# api = "~/dev/api"
235# web = "~/dev/web"
236"#
237        .to_string()
238    }
239
240    /// Generate a starter project TOML for `oven prep`.
241    pub fn default_project_toml() -> String {
242        r#"[project]
243# name = "my-project"    # auto-detected from git remote
244# test = "cargo test"    # test command
245# lint = "cargo clippy"  # lint command
246# issue_source = "github"  # "github" (default) or "local"
247
248[pipeline]
249max_parallel = 2
250cost_budget = 15.0
251poll_interval = 60
252
253# [labels]
254# ready = "o-ready"
255# cooking = "o-cooking"
256# complete = "o-complete"
257# failed = "o-failed"
258"#
259        .to_string()
260    }
261}
262
263/// Apply a raw (partial) config onto the resolved config.
264/// `allow_repos` controls whether the `repos` key is honored (only from user config).
265fn apply_raw(config: &mut Config, raw: &RawConfig, allow_repos: bool) {
266    if let Some(ref project) = raw.project {
267        if project.name.is_some() {
268            config.project.name.clone_from(&project.name);
269        }
270        if project.test.is_some() {
271            config.project.test.clone_from(&project.test);
272        }
273        if project.lint.is_some() {
274            config.project.lint.clone_from(&project.lint);
275        }
276        if let Some(ref source) = project.issue_source {
277            config.project.issue_source = source.clone();
278        }
279    }
280
281    if let Some(ref pipeline) = raw.pipeline {
282        if let Some(v) = pipeline.max_parallel {
283            config.pipeline.max_parallel = v;
284        }
285        if let Some(v) = pipeline.cost_budget {
286            config.pipeline.cost_budget = v;
287        }
288        if let Some(v) = pipeline.poll_interval {
289            config.pipeline.poll_interval = v;
290        }
291        if let Some(v) = pipeline.turn_limit {
292            config.pipeline.turn_limit = v;
293        }
294    }
295
296    if let Some(ref labels) = raw.labels {
297        if let Some(ref v) = labels.ready {
298            config.labels.ready.clone_from(v);
299        }
300        if let Some(ref v) = labels.cooking {
301            config.labels.cooking.clone_from(v);
302        }
303        if let Some(ref v) = labels.complete {
304            config.labels.complete.clone_from(v);
305        }
306        if let Some(ref v) = labels.failed {
307            config.labels.failed.clone_from(v);
308        }
309    }
310
311    // multi_repo settings from project config (controls feature enablement)
312    if let Some(ref multi_repo) = raw.multi_repo {
313        if let Some(v) = multi_repo.enabled {
314            config.multi_repo.enabled = v;
315        }
316        if let Some(ref v) = multi_repo.target_field {
317            config.multi_repo.target_field.clone_from(v);
318        }
319    }
320
321    // repos only honored from user config (security: project config shouldn't
322    // be able to point the tool at arbitrary repos on the filesystem)
323    if allow_repos {
324        if let Some(ref repos) = raw.repos {
325            config.repos.clone_from(repos);
326        }
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use proptest::prelude::*;
333
334    use super::*;
335
336    proptest! {
337        #[test]
338        fn config_toml_roundtrip(
339            max_parallel in 1..100u32,
340            cost_budget in 0.0..1000.0f64,
341            poll_interval in 1..3600u64,
342            turn_limit in 1..200u32,
343            ready in "[a-z][a-z0-9-]{1,20}",
344            cooking in "[a-z][a-z0-9-]{1,20}",
345            complete in "[a-z][a-z0-9-]{1,20}",
346            failed in "[a-z][a-z0-9-]{1,20}",
347        ) {
348            let config = Config {
349                project: ProjectConfig::default(),
350                pipeline: PipelineConfig { max_parallel, cost_budget, poll_interval, turn_limit },
351                labels: LabelConfig { ready, cooking, complete, failed },
352                multi_repo: MultiRepoConfig::default(),
353                repos: HashMap::new(),
354            };
355            let serialized = toml::to_string(&config).unwrap();
356            let deserialized: Config = toml::from_str(&serialized).unwrap();
357            assert_eq!(config.pipeline.max_parallel, deserialized.pipeline.max_parallel);
358            assert!((config.pipeline.cost_budget - deserialized.pipeline.cost_budget).abs() < 1e-6);
359            assert_eq!(config.pipeline.poll_interval, deserialized.pipeline.poll_interval);
360            assert_eq!(config.pipeline.turn_limit, deserialized.pipeline.turn_limit);
361            assert_eq!(config.labels, deserialized.labels);
362        }
363
364        #[test]
365        fn partial_toml_always_parses(
366            max_parallel in proptest::option::of(1..100u32),
367            cost_budget in proptest::option::of(0.0..1000.0f64),
368        ) {
369            let mut parts = vec!["[pipeline]".to_string()];
370            if let Some(mp) = max_parallel {
371                parts.push(format!("max_parallel = {mp}"));
372            }
373            if let Some(cb) = cost_budget {
374                parts.push(format!("cost_budget = {cb}"));
375            }
376            let toml_str = parts.join("\n");
377            let raw: RawConfig = toml::from_str(&toml_str).unwrap();
378            let mut config = Config::default();
379            apply_raw(&mut config, &raw, false);
380            if let Some(mp) = max_parallel {
381                assert_eq!(config.pipeline.max_parallel, mp);
382            }
383        }
384    }
385
386    #[test]
387    fn defaults_are_correct() {
388        let config = Config::default();
389        assert_eq!(config.pipeline.max_parallel, 2);
390        assert!(
391            (config.pipeline.cost_budget - 15.0).abs() < f64::EPSILON,
392            "cost_budget should be 15.0"
393        );
394        assert_eq!(config.pipeline.poll_interval, 60);
395        assert_eq!(config.pipeline.turn_limit, 50);
396        assert_eq!(config.labels.ready, "o-ready");
397        assert_eq!(config.labels.cooking, "o-cooking");
398        assert_eq!(config.labels.complete, "o-complete");
399        assert_eq!(config.labels.failed, "o-failed");
400        assert!(config.project.name.is_none());
401        assert!(config.repos.is_empty());
402        assert!(!config.multi_repo.enabled);
403        assert_eq!(config.multi_repo.target_field, "target_repo");
404    }
405
406    #[test]
407    fn load_from_valid_toml() {
408        let toml_str = r#"
409[project]
410name = "test-project"
411test = "cargo test"
412
413[pipeline]
414max_parallel = 4
415cost_budget = 20.0
416"#;
417        let raw: RawConfig = toml::from_str(toml_str).unwrap();
418        let mut config = Config::default();
419        apply_raw(&mut config, &raw, false);
420
421        assert_eq!(config.project.name.as_deref(), Some("test-project"));
422        assert_eq!(config.project.test.as_deref(), Some("cargo test"));
423        assert_eq!(config.pipeline.max_parallel, 4);
424        assert!((config.pipeline.cost_budget - 20.0).abs() < f64::EPSILON);
425        // Unset fields keep defaults
426        assert_eq!(config.pipeline.poll_interval, 60);
427    }
428
429    #[test]
430    fn project_overrides_user() {
431        let user_toml = r"
432[pipeline]
433max_parallel = 3
434cost_budget = 10.0
435poll_interval = 120
436";
437        let project_toml = r"
438[pipeline]
439max_parallel = 1
440cost_budget = 5.0
441";
442        let mut config = Config::default();
443
444        let user_raw: RawConfig = toml::from_str(user_toml).unwrap();
445        apply_raw(&mut config, &user_raw, true);
446        assert_eq!(config.pipeline.max_parallel, 3);
447        assert_eq!(config.pipeline.poll_interval, 120);
448
449        let project_raw: RawConfig = toml::from_str(project_toml).unwrap();
450        apply_raw(&mut config, &project_raw, false);
451        assert_eq!(config.pipeline.max_parallel, 1);
452        assert!((config.pipeline.cost_budget - 5.0).abs() < f64::EPSILON);
453        // poll_interval not overridden by project, stays at user value
454        assert_eq!(config.pipeline.poll_interval, 120);
455    }
456
457    #[test]
458    fn repos_ignored_in_project_config() {
459        let project_toml = r#"
460[repos]
461evil = "/tmp/evil"
462"#;
463        let mut config = Config::default();
464        let raw: RawConfig = toml::from_str(project_toml).unwrap();
465        apply_raw(&mut config, &raw, false);
466        assert!(config.repos.is_empty());
467    }
468
469    #[test]
470    fn repos_honored_in_user_config() {
471        let user_toml = r#"
472[repos]
473api = "/home/user/dev/api"
474"#;
475        let mut config = Config::default();
476        let raw: RawConfig = toml::from_str(user_toml).unwrap();
477        apply_raw(&mut config, &raw, true);
478        assert_eq!(config.repos.get("api").unwrap(), Path::new("/home/user/dev/api"));
479    }
480
481    #[test]
482    fn missing_file_returns_defaults() {
483        let dir = tempfile::tempdir().unwrap();
484        let config = Config::load(dir.path()).unwrap();
485        assert_eq!(config, Config::default());
486    }
487
488    #[test]
489    fn invalid_toml_returns_error() {
490        let dir = tempfile::tempdir().unwrap();
491        std::fs::write(dir.path().join("recipe.toml"), "this is not [valid toml").unwrap();
492        let result = Config::load(dir.path());
493        assert!(result.is_err());
494        let err = result.unwrap_err().to_string();
495        assert!(err.contains("parsing project config"), "error was: {err}");
496    }
497
498    #[test]
499    fn default_user_toml_parses() {
500        let toml_str = Config::default_user_toml();
501        let raw: RawConfig = toml::from_str(&toml_str).unwrap();
502        let mut config = Config::default();
503        apply_raw(&mut config, &raw, true);
504        // All commented out, so defaults remain
505        assert_eq!(config.pipeline.max_parallel, 2);
506        assert!(config.repos.is_empty());
507    }
508
509    #[test]
510    fn default_project_toml_parses() {
511        let toml_str = Config::default_project_toml();
512        let raw: RawConfig = toml::from_str(&toml_str).unwrap();
513        let mut config = Config::default();
514        apply_raw(&mut config, &raw, false);
515        // Should still have defaults since commented lines are ignored
516        assert_eq!(config.pipeline.max_parallel, 2);
517    }
518
519    #[test]
520    fn config_roundtrip_serialize_deserialize() {
521        let config = Config {
522            project: ProjectConfig {
523                name: Some("roundtrip".to_string()),
524                test: Some("make test".to_string()),
525                lint: None,
526                issue_source: IssueSource::Github,
527            },
528            pipeline: PipelineConfig { max_parallel: 5, cost_budget: 25.0, ..Default::default() },
529            labels: LabelConfig::default(),
530            multi_repo: MultiRepoConfig::default(),
531            repos: HashMap::from([("svc".to_string(), PathBuf::from("/tmp/svc"))]),
532        };
533        let serialized = toml::to_string(&config).unwrap();
534        let deserialized: Config = toml::from_str(&serialized).unwrap();
535        assert_eq!(config, deserialized);
536    }
537
538    #[test]
539    fn multi_repo_config_from_project_toml() {
540        let toml_str = r#"
541[multi_repo]
542enabled = true
543target_field = "repo"
544"#;
545        let raw: RawConfig = toml::from_str(toml_str).unwrap();
546        let mut config = Config::default();
547        apply_raw(&mut config, &raw, false);
548        assert!(config.multi_repo.enabled);
549        assert_eq!(config.multi_repo.target_field, "repo");
550    }
551
552    #[test]
553    fn multi_repo_defaults_when_not_specified() {
554        let toml_str = r"
555[pipeline]
556max_parallel = 1
557";
558        let raw: RawConfig = toml::from_str(toml_str).unwrap();
559        let mut config = Config::default();
560        apply_raw(&mut config, &raw, false);
561        assert!(!config.multi_repo.enabled);
562        assert_eq!(config.multi_repo.target_field, "target_repo");
563    }
564
565    #[test]
566    fn resolve_repo_finds_existing_path() {
567        let dir = tempfile::tempdir().unwrap();
568        let mut config = Config::default();
569        config.repos.insert("test-repo".to_string(), dir.path().to_path_buf());
570
571        let resolved = config.resolve_repo("test-repo").unwrap();
572        assert_eq!(resolved, dir.path());
573    }
574
575    #[test]
576    fn resolve_repo_missing_name_errors() {
577        let config = Config::default();
578        let result = config.resolve_repo("nonexistent");
579        assert!(result.is_err());
580        assert!(result.unwrap_err().to_string().contains("not found in user config"));
581    }
582
583    #[test]
584    fn resolve_repo_missing_path_errors() {
585        let mut config = Config::default();
586        config.repos.insert("bad".to_string(), PathBuf::from("/nonexistent/path/xyz"));
587        let result = config.resolve_repo("bad");
588        assert!(result.is_err());
589        assert!(result.unwrap_err().to_string().contains("does not exist"));
590    }
591
592    #[test]
593    fn issue_source_defaults_to_github() {
594        let config = Config::default();
595        assert_eq!(config.project.issue_source, IssueSource::Github);
596    }
597
598    #[test]
599    fn issue_source_local_parses() {
600        let toml_str = r#"
601[project]
602issue_source = "local"
603"#;
604        let raw: RawConfig = toml::from_str(toml_str).unwrap();
605        let mut config = Config::default();
606        apply_raw(&mut config, &raw, false);
607        assert_eq!(config.project.issue_source, IssueSource::Local);
608    }
609
610    #[test]
611    fn issue_source_github_parses() {
612        let toml_str = r#"
613[project]
614issue_source = "github"
615"#;
616        let raw: RawConfig = toml::from_str(toml_str).unwrap();
617        let mut config = Config::default();
618        apply_raw(&mut config, &raw, false);
619        assert_eq!(config.project.issue_source, IssueSource::Github);
620    }
621
622    #[test]
623    fn validate_rejects_zero_max_parallel() {
624        let mut config = Config::default();
625        config.pipeline.max_parallel = 0;
626        let err = config.validate().unwrap_err().to_string();
627        assert!(err.contains("max_parallel"), "error was: {err}");
628    }
629
630    #[test]
631    fn validate_rejects_low_poll_interval() {
632        let mut config = Config::default();
633        config.pipeline.poll_interval = 5;
634        let err = config.validate().unwrap_err().to_string();
635        assert!(err.contains("poll_interval"), "error was: {err}");
636    }
637
638    #[test]
639    fn validate_rejects_zero_cost_budget() {
640        let mut config = Config::default();
641        config.pipeline.cost_budget = 0.0;
642        let err = config.validate().unwrap_err().to_string();
643        assert!(err.contains("cost_budget"), "error was: {err}");
644    }
645
646    #[test]
647    fn validate_rejects_nan_cost_budget() {
648        let mut config = Config::default();
649        config.pipeline.cost_budget = f64::NAN;
650        let err = config.validate().unwrap_err().to_string();
651        assert!(err.contains("cost_budget"), "error was: {err}");
652    }
653
654    #[test]
655    fn validate_rejects_infinity_cost_budget() {
656        let mut config = Config::default();
657        config.pipeline.cost_budget = f64::INFINITY;
658        let err = config.validate().unwrap_err().to_string();
659        assert!(err.contains("cost_budget"), "error was: {err}");
660    }
661
662    #[test]
663    fn validate_rejects_zero_turn_limit() {
664        let mut config = Config::default();
665        config.pipeline.turn_limit = 0;
666        let err = config.validate().unwrap_err().to_string();
667        assert!(err.contains("turn_limit"), "error was: {err}");
668    }
669
670    #[test]
671    fn validate_accepts_defaults() {
672        Config::default().validate().unwrap();
673    }
674
675    #[test]
676    fn issue_source_invalid_errors() {
677        let toml_str = r#"
678[project]
679issue_source = "jira"
680"#;
681        let result = toml::from_str::<RawConfig>(toml_str);
682        assert!(result.is_err());
683    }
684
685    #[test]
686    fn issue_source_roundtrip() {
687        let config = Config {
688            project: ProjectConfig { issue_source: IssueSource::Local, ..Default::default() },
689            ..Default::default()
690        };
691        let serialized = toml::to_string(&config).unwrap();
692        let deserialized: Config = toml::from_str(&serialized).unwrap();
693        assert_eq!(deserialized.project.issue_source, IssueSource::Local);
694    }
695}