Skip to main content

ralph_workflow/config/unified/
loading.rs

1//! Configuration loading and initialization.
2//!
3//! This module provides functions for loading and initializing Ralph's unified configuration.
4//!
5//! # Loading Strategy
6//!
7//! Configuration loading supports both production and testing scenarios:
8//!
9//! - **Production**: Uses `load_default()` which reads from `~/.config/ralph-workflow.toml`
10//! - **Testing**: Uses `load_with_env()` with a `ConfigEnvironment` trait for test isolation
11//!
12//! # Initialization
13//!
14//! Ralph can automatically create a default configuration file if none exists:
15//!
16//! ```rust
17//! use ralph_workflow::config::unified::UnifiedConfig;
18//!
19//! // Ensure config exists, creating it if needed
20//! let result = UnifiedConfig::ensure_config_exists()?;
21//!
22//! // Load the config
23//! let config = UnifiedConfig::load_default()
24//!     .expect("Config should exist after ensure_config_exists");
25//! # Ok::<(), std::io::Error>(())
26//! ```
27
28use super::types::UnifiedConfig;
29use std::io;
30
31/// Result of config initialization.
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum ConfigInitResult {
34    /// Config was created successfully.
35    Created,
36    /// Config already exists.
37    AlreadyExists,
38}
39
40/// Error type for unified config loading.
41#[derive(Debug, thiserror::Error)]
42pub enum ConfigLoadError {
43    #[error("Failed to read config file: {0}")]
44    Io(#[from] std::io::Error),
45    #[error("Failed to parse TOML: {0}")]
46    Toml(#[from] toml::de::Error),
47}
48
49/// Default unified config template embedded at compile time.
50pub const DEFAULT_UNIFIED_CONFIG: &str = include_str!("../../../examples/ralph-workflow.toml");
51
52impl UnifiedConfig {
53    /// Load unified configuration from the default path.
54    ///
55    /// Returns None if the file doesn't exist.
56    ///
57    /// # Examples
58    ///
59    /// ```rust
60    /// use ralph_workflow::config::unified::UnifiedConfig;
61    ///
62    /// if let Some(config) = UnifiedConfig::load_default() {
63    ///     println!("Verbosity level: {}", config.general.verbosity);
64    /// }
65    /// ```
66    pub fn load_default() -> Option<Self> {
67        Self::load_with_env(&super::super::path_resolver::RealConfigEnvironment)
68    }
69
70    /// Load unified configuration using a `ConfigEnvironment`.
71    ///
72    /// This is the testable version of `load_default`. It reads from the
73    /// unified config path as determined by the environment.
74    ///
75    /// Returns None if no config path is available or the file doesn't exist.
76    pub fn load_with_env(env: &dyn super::super::path_resolver::ConfigEnvironment) -> Option<Self> {
77        env.unified_config_path().and_then(|path| {
78            if env.file_exists(&path) {
79                Self::load_from_path_with_env(&path, env).ok()
80            } else {
81                None
82            }
83        })
84    }
85
86    /// Load unified configuration from a specific path.
87    ///
88    /// **Note:** This method uses `std::fs` directly. For testable code,
89    /// use `load_from_path_with_env` with a `ConfigEnvironment` instead.
90    ///
91    /// # Errors
92    ///
93    /// Returns an error if:
94    /// - The file cannot be read
95    /// - The TOML syntax is invalid
96    /// - Required fields are missing
97    pub fn load_from_path(path: &std::path::Path) -> Result<Self, ConfigLoadError> {
98        let contents = std::fs::read_to_string(path)?;
99        let config: Self = toml::from_str(&contents)?;
100        Ok(config)
101    }
102
103    /// Load unified configuration from a specific path using a `ConfigEnvironment`.
104    ///
105    /// This is the testable version of `load_from_path`.
106    pub fn load_from_path_with_env(
107        path: &std::path::Path,
108        env: &dyn super::super::path_resolver::ConfigEnvironment,
109    ) -> Result<Self, ConfigLoadError> {
110        let contents = env.read_file(path)?;
111        let config: Self = toml::from_str(&contents)?;
112        Ok(config)
113    }
114
115    /// Load unified configuration from pre-read content.
116    ///
117    /// This avoids re-reading the file when content is already available.
118    /// The path is used only for error messages.
119    ///
120    /// # Arguments
121    ///
122    /// * `content` - The raw TOML content string
123    ///
124    /// # Errors
125    ///
126    /// Returns an error if the TOML syntax is invalid or required fields are missing.
127    ///
128    /// # Examples
129    ///
130    /// ```rust
131    /// use ralph_workflow::config::unified::UnifiedConfig;
132    ///
133    /// let toml_content = r#"
134    ///     [general]
135    ///     verbosity = 3
136    /// "#;
137    ///
138    /// let config = UnifiedConfig::load_from_content(toml_content)?;
139    /// assert_eq!(config.general.verbosity, 3);
140    /// # Ok::<(), Box<dyn std::error::Error>>(())
141    /// ```
142    pub fn load_from_content(content: &str) -> Result<Self, ConfigLoadError> {
143        let config: Self = toml::from_str(content)?;
144        Ok(config)
145    }
146
147    /// Ensure unified config file exists, creating it from template if needed.
148    ///
149    /// This creates `~/.config/ralph-workflow.toml` with the default template
150    /// if it doesn't already exist.
151    ///
152    /// # Returns
153    ///
154    /// - `Created` if the config file was created
155    /// - `AlreadyExists` if the config file already existed
156    ///
157    /// # Errors
158    ///
159    /// Returns an error if:
160    /// - The home directory cannot be determined
161    /// - The config file cannot be written
162    ///
163    /// # Examples
164    ///
165    /// ```rust
166    /// use ralph_workflow::config::unified::{UnifiedConfig, ConfigInitResult};
167    ///
168    /// match UnifiedConfig::ensure_config_exists() {
169    ///     Ok(ConfigInitResult::Created) => println!("Created new config"),
170    ///     Ok(ConfigInitResult::AlreadyExists) => println!("Config already exists"),
171    ///     Err(e) => eprintln!("Failed to create config: {}", e),
172    /// }
173    /// # Ok::<(), std::io::Error>(())
174    /// ```
175    pub fn ensure_config_exists() -> io::Result<ConfigInitResult> {
176        Self::ensure_config_exists_with_env(&super::super::path_resolver::RealConfigEnvironment)
177    }
178
179    /// Ensure unified config file exists using a `ConfigEnvironment`.
180    ///
181    /// This is the testable version of `ensure_config_exists`.
182    pub fn ensure_config_exists_with_env(
183        env: &dyn super::super::path_resolver::ConfigEnvironment,
184    ) -> io::Result<ConfigInitResult> {
185        let Some(path) = env.unified_config_path() else {
186            return Err(io::Error::new(
187                io::ErrorKind::NotFound,
188                "Cannot determine config directory (no home directory)",
189            ));
190        };
191
192        Self::ensure_config_exists_at_with_env(&path, env)
193    }
194
195    /// Ensure a config file exists at the specified path.
196    ///
197    /// This is useful for custom config file locations or testing.
198    pub fn ensure_config_exists_at(path: &std::path::Path) -> io::Result<ConfigInitResult> {
199        Self::ensure_config_exists_at_with_env(
200            path,
201            &super::super::path_resolver::RealConfigEnvironment,
202        )
203    }
204
205    /// Ensure a config file exists at the specified path using a `ConfigEnvironment`.
206    ///
207    /// This is the testable version of `ensure_config_exists_at`.
208    pub fn ensure_config_exists_at_with_env(
209        path: &std::path::Path,
210        env: &dyn super::super::path_resolver::ConfigEnvironment,
211    ) -> io::Result<ConfigInitResult> {
212        if env.file_exists(path) {
213            return Ok(ConfigInitResult::AlreadyExists);
214        }
215
216        // Write the default template (write_file creates parent directories)
217        env.write_file(path, DEFAULT_UNIFIED_CONFIG)?;
218
219        Ok(ConfigInitResult::Created)
220    }
221
222    /// Merge local config into self (global), returning merged config.
223    ///
224    /// Local values override global values with these semantics:
225    /// - Scalar values: local replaces global when explicitly present in TOML
226    /// - Maps (agents, ccs_aliases): local entries merge with global (local wins on collision)
227    /// - Arrays (agent_chain): local replaces global entirely (not appended)
228    /// - Optional values: local Some(_) replaces global, local None preserves global
229    /// - CCS string values: empty string ("") means disabled, missing means use global
230    ///
231    /// This is a pure function - no I/O, cannot fail.
232    ///
233    /// IMPORTANT: This uses default-comparison heuristic and is primarily for tests.
234    /// For real TOML-based configs, use `merge_with_content` for proper presence tracking.
235    ///
236    /// # Arguments
237    ///
238    /// * `local` - The local configuration to merge into this global configuration
239    ///
240    /// # Returns
241    ///
242    /// A new `UnifiedConfig` with local values merged into global values.
243    ///
244    /// # Examples
245    ///
246    /// ```rust
247    /// use ralph_workflow::config::unified::UnifiedConfig;
248    ///
249    /// let global = UnifiedConfig::default();
250    /// let mut local = UnifiedConfig::default();
251    /// local.general.verbosity = 4;
252    ///
253    /// let merged = global.merge_with(&local);
254    /// assert_eq!(merged.general.verbosity, 4);
255    /// ```
256    pub fn merge_with(&self, local: &UnifiedConfig) -> UnifiedConfig {
257        use super::types::{
258            CcsConfig, GeneralBehaviorFlags, GeneralConfig, GeneralExecutionFlags,
259            GeneralWorkflowFlags,
260        };
261
262        // For programmatically-constructed configs, we use default comparison
263        // NOTE: This has known issues with booleans and default-valued fields (Issue #2)
264        // but is kept for backward compatibility with tests
265        let defaults = GeneralConfig::default();
266
267        // Merge general config - override if local differs from default
268        let general = GeneralConfig {
269            verbosity: if local.general.verbosity != defaults.verbosity {
270                local.general.verbosity
271            } else {
272                self.general.verbosity
273            },
274            behavior: GeneralBehaviorFlags {
275                interactive: if local.general.behavior.interactive != defaults.behavior.interactive
276                {
277                    local.general.behavior.interactive
278                } else {
279                    self.general.behavior.interactive
280                },
281                auto_detect_stack: if local.general.behavior.auto_detect_stack
282                    != defaults.behavior.auto_detect_stack
283                {
284                    local.general.behavior.auto_detect_stack
285                } else {
286                    self.general.behavior.auto_detect_stack
287                },
288                strict_validation: if local.general.behavior.strict_validation
289                    != defaults.behavior.strict_validation
290                {
291                    local.general.behavior.strict_validation
292                } else {
293                    self.general.behavior.strict_validation
294                },
295            },
296            workflow: GeneralWorkflowFlags {
297                checkpoint_enabled: if local.general.workflow.checkpoint_enabled
298                    != defaults.workflow.checkpoint_enabled
299                {
300                    local.general.workflow.checkpoint_enabled
301                } else {
302                    self.general.workflow.checkpoint_enabled
303                },
304            },
305            execution: GeneralExecutionFlags {
306                force_universal_prompt: if local.general.execution.force_universal_prompt
307                    != defaults.execution.force_universal_prompt
308                {
309                    local.general.execution.force_universal_prompt
310                } else {
311                    self.general.execution.force_universal_prompt
312                },
313                isolation_mode: if local.general.execution.isolation_mode
314                    != defaults.execution.isolation_mode
315                {
316                    local.general.execution.isolation_mode
317                } else {
318                    self.general.execution.isolation_mode
319                },
320            },
321            developer_iters: if local.general.developer_iters != defaults.developer_iters {
322                local.general.developer_iters
323            } else {
324                self.general.developer_iters
325            },
326            reviewer_reviews: if local.general.reviewer_reviews != defaults.reviewer_reviews {
327                local.general.reviewer_reviews
328            } else {
329                self.general.reviewer_reviews
330            },
331            developer_context: if local.general.developer_context != defaults.developer_context {
332                local.general.developer_context
333            } else {
334                self.general.developer_context
335            },
336            reviewer_context: if local.general.reviewer_context != defaults.reviewer_context {
337                local.general.reviewer_context
338            } else {
339                self.general.reviewer_context
340            },
341            review_depth: if local.general.review_depth != defaults.review_depth {
342                local.general.review_depth.clone()
343            } else {
344                self.general.review_depth.clone()
345            },
346            prompt_path: local
347                .general
348                .prompt_path
349                .clone()
350                .or_else(|| self.general.prompt_path.clone()),
351            templates_dir: local
352                .general
353                .templates_dir
354                .clone()
355                .or_else(|| self.general.templates_dir.clone()),
356            git_user_name: local
357                .general
358                .git_user_name
359                .clone()
360                .or_else(|| self.general.git_user_name.clone()),
361            git_user_email: local
362                .general
363                .git_user_email
364                .clone()
365                .or_else(|| self.general.git_user_email.clone()),
366            max_dev_continuations: if local.general.max_dev_continuations
367                != defaults.max_dev_continuations
368            {
369                local.general.max_dev_continuations
370            } else {
371                self.general.max_dev_continuations
372            },
373            max_xsd_retries: if local.general.max_xsd_retries != defaults.max_xsd_retries {
374                local.general.max_xsd_retries
375            } else {
376                self.general.max_xsd_retries
377            },
378            max_same_agent_retries: if local.general.max_same_agent_retries
379                != defaults.max_same_agent_retries
380            {
381                local.general.max_same_agent_retries
382            } else {
383                self.general.max_same_agent_retries
384            },
385        };
386
387        // Merge CCS config - empty string means use global
388        fn merge_ccs_string(local: &str, global: &str) -> String {
389            if local.is_empty() {
390                global.to_string()
391            } else {
392                local.to_string()
393            }
394        }
395
396        let ccs = CcsConfig {
397            output_flag: merge_ccs_string(&local.ccs.output_flag, &self.ccs.output_flag),
398            yolo_flag: merge_ccs_string(&local.ccs.yolo_flag, &self.ccs.yolo_flag),
399            verbose_flag: merge_ccs_string(&local.ccs.verbose_flag, &self.ccs.verbose_flag),
400            print_flag: merge_ccs_string(&local.ccs.print_flag, &self.ccs.print_flag),
401            streaming_flag: merge_ccs_string(&local.ccs.streaming_flag, &self.ccs.streaming_flag),
402            json_parser: merge_ccs_string(&local.ccs.json_parser, &self.ccs.json_parser),
403            session_flag: merge_ccs_string(&local.ccs.session_flag, &self.ccs.session_flag),
404            can_commit: if local.ccs.can_commit != CcsConfig::default().can_commit {
405                local.ccs.can_commit
406            } else {
407                self.ccs.can_commit
408            },
409        };
410
411        // Merge agents map (local entries override global entries)
412        let mut agents = self.agents.clone();
413        for (key, value) in &local.agents {
414            agents.insert(key.clone(), value.clone());
415        }
416
417        // Merge CCS aliases map (local entries override global entries)
418        let mut ccs_aliases = self.ccs_aliases.clone();
419        for (key, value) in &local.ccs_aliases {
420            ccs_aliases.insert(key.clone(), value.clone());
421        }
422
423        // Agent chain: local replaces global entirely (not merged)
424        let agent_chain = if local.agent_chain.is_some() {
425            local.agent_chain.clone()
426        } else {
427            self.agent_chain.clone()
428        };
429
430        UnifiedConfig {
431            general,
432            ccs,
433            agents,
434            ccs_aliases,
435            agent_chain,
436        }
437    }
438
439    /// Merge local config content (TOML string) into self (global).
440    ///
441    /// This version tracks which fields are actually present in the TOML source
442    /// to distinguish "not set" from "set to default value".
443    ///
444    /// # Arguments
445    ///
446    /// * `local_content` - The raw TOML content of the local config
447    /// * `local_parsed` - The parsed local config (already deserialized)
448    ///
449    /// # Returns
450    ///
451    /// A new `UnifiedConfig` with local values merged into global values, using
452    /// presence-based tracking to avoid false overrides of default values.
453    ///
454    /// # Examples
455    ///
456    /// ```rust
457    /// use ralph_workflow::config::unified::UnifiedConfig;
458    ///
459    /// let global = UnifiedConfig::default();
460    /// let local_toml = r#"
461    ///     [general]
462    ///     verbosity = 4
463    /// "#;
464    /// let local = UnifiedConfig::load_from_content(local_toml).unwrap();
465    ///
466    /// let merged = global.merge_with_content(local_toml, &local);
467    /// assert_eq!(merged.general.verbosity, 4);
468    /// ```
469    pub fn merge_with_content(
470        &self,
471        local_content: &str,
472        local_parsed: &UnifiedConfig,
473    ) -> UnifiedConfig {
474        use super::types::{
475            CcsConfig, GeneralBehaviorFlags, GeneralConfig, GeneralExecutionFlags,
476            GeneralWorkflowFlags,
477        };
478
479        // Parse raw TOML to check field presence
480        let local_toml: toml::Value =
481            toml::from_str(local_content).unwrap_or(toml::Value::Table(Default::default()));
482
483        // Helper to check if a field is present in the TOML
484        let general_table = local_toml.get("general");
485        let behavior_table = general_table.and_then(|g| g.get("behavior"));
486
487        // NOTE: workflow and execution fields are flattened into [general], not separate tables.
488        // So we check for them at the [general] level, not [general.workflow] or [general.execution].
489        let has_field = |key: &str| -> bool { general_table.and_then(|g| g.get(key)).is_some() };
490        let has_behavior_field =
491            |key: &str| -> bool { behavior_table.and_then(|b| b.get(key)).is_some() };
492
493        // Merge general config with presence-based override detection
494        // Only override if field was explicitly present in local TOML
495        let general = GeneralConfig {
496            verbosity: if has_field("verbosity") {
497                local_parsed.general.verbosity
498            } else {
499                self.general.verbosity
500            },
501            behavior: GeneralBehaviorFlags {
502                interactive: if has_behavior_field("interactive") {
503                    local_parsed.general.behavior.interactive
504                } else {
505                    self.general.behavior.interactive
506                },
507                auto_detect_stack: if has_behavior_field("auto_detect_stack") {
508                    local_parsed.general.behavior.auto_detect_stack
509                } else {
510                    self.general.behavior.auto_detect_stack
511                },
512                strict_validation: if has_behavior_field("strict_validation") {
513                    local_parsed.general.behavior.strict_validation
514                } else {
515                    self.general.behavior.strict_validation
516                },
517            },
518            workflow: GeneralWorkflowFlags {
519                checkpoint_enabled: if has_field("checkpoint_enabled") {
520                    local_parsed.general.workflow.checkpoint_enabled
521                } else {
522                    self.general.workflow.checkpoint_enabled
523                },
524            },
525            execution: GeneralExecutionFlags {
526                force_universal_prompt: if has_field("force_universal_prompt") {
527                    local_parsed.general.execution.force_universal_prompt
528                } else {
529                    self.general.execution.force_universal_prompt
530                },
531                isolation_mode: if has_field("isolation_mode") {
532                    local_parsed.general.execution.isolation_mode
533                } else {
534                    self.general.execution.isolation_mode
535                },
536            },
537            developer_iters: if has_field("developer_iters") {
538                local_parsed.general.developer_iters
539            } else {
540                self.general.developer_iters
541            },
542            reviewer_reviews: if has_field("reviewer_reviews") {
543                local_parsed.general.reviewer_reviews
544            } else {
545                self.general.reviewer_reviews
546            },
547            developer_context: if has_field("developer_context") {
548                local_parsed.general.developer_context
549            } else {
550                self.general.developer_context
551            },
552            reviewer_context: if has_field("reviewer_context") {
553                local_parsed.general.reviewer_context
554            } else {
555                self.general.reviewer_context
556            },
557            review_depth: if has_field("review_depth") {
558                local_parsed.general.review_depth.clone()
559            } else {
560                self.general.review_depth.clone()
561            },
562            prompt_path: local_parsed
563                .general
564                .prompt_path
565                .clone()
566                .or_else(|| self.general.prompt_path.clone()),
567            templates_dir: local_parsed
568                .general
569                .templates_dir
570                .clone()
571                .or_else(|| self.general.templates_dir.clone()),
572            git_user_name: local_parsed
573                .general
574                .git_user_name
575                .clone()
576                .or_else(|| self.general.git_user_name.clone()),
577            git_user_email: local_parsed
578                .general
579                .git_user_email
580                .clone()
581                .or_else(|| self.general.git_user_email.clone()),
582            max_dev_continuations: if has_field("max_dev_continuations") {
583                local_parsed.general.max_dev_continuations
584            } else {
585                self.general.max_dev_continuations
586            },
587            max_xsd_retries: if has_field("max_xsd_retries") {
588                local_parsed.general.max_xsd_retries
589            } else {
590                self.general.max_xsd_retries
591            },
592            max_same_agent_retries: if has_field("max_same_agent_retries") {
593                local_parsed.general.max_same_agent_retries
594            } else {
595                self.general.max_same_agent_retries
596            },
597        };
598
599        // Merge CCS config with presence-based semantics
600        // Check if CCS fields are present in local TOML
601        let ccs_table = local_toml.get("ccs");
602        let has_ccs_field = |key: &str| -> bool { ccs_table.and_then(|c| c.get(key)).is_some() };
603
604        let ccs = CcsConfig {
605            output_flag: if has_ccs_field("output_flag") {
606                local_parsed.ccs.output_flag.clone()
607            } else {
608                self.ccs.output_flag.clone()
609            },
610            yolo_flag: if has_ccs_field("yolo_flag") {
611                local_parsed.ccs.yolo_flag.clone()
612            } else {
613                self.ccs.yolo_flag.clone()
614            },
615            verbose_flag: if has_ccs_field("verbose_flag") {
616                local_parsed.ccs.verbose_flag.clone()
617            } else {
618                self.ccs.verbose_flag.clone()
619            },
620            print_flag: if has_ccs_field("print_flag") {
621                local_parsed.ccs.print_flag.clone()
622            } else {
623                self.ccs.print_flag.clone()
624            },
625            streaming_flag: if has_ccs_field("streaming_flag") {
626                local_parsed.ccs.streaming_flag.clone()
627            } else {
628                self.ccs.streaming_flag.clone()
629            },
630            json_parser: if has_ccs_field("json_parser") {
631                local_parsed.ccs.json_parser.clone()
632            } else {
633                self.ccs.json_parser.clone()
634            },
635            session_flag: if has_ccs_field("session_flag") {
636                local_parsed.ccs.session_flag.clone()
637            } else {
638                self.ccs.session_flag.clone()
639            },
640            can_commit: if has_ccs_field("can_commit") {
641                local_parsed.ccs.can_commit
642            } else {
643                self.ccs.can_commit
644            },
645        };
646
647        // Merge agents map (local entries override global entries)
648        let mut agents = self.agents.clone();
649        for (key, value) in &local_parsed.agents {
650            agents.insert(key.clone(), value.clone());
651        }
652
653        // Merge CCS aliases map (local entries override global entries)
654        let mut ccs_aliases = self.ccs_aliases.clone();
655        for (key, value) in &local_parsed.ccs_aliases {
656            ccs_aliases.insert(key.clone(), value.clone());
657        }
658
659        // Agent chain: local replaces global entirely (not merged)
660        let agent_chain = if local_parsed.agent_chain.is_some() {
661            local_parsed.agent_chain.clone()
662        } else {
663            self.agent_chain.clone()
664        };
665
666        UnifiedConfig {
667            general,
668            ccs,
669            agents,
670            ccs_aliases,
671            agent_chain,
672        }
673    }
674}