Skip to main content

oxigaf_cli/
config.rs

1//! TOML configuration file loading and validation.
2//!
3//! Loads the `oxigaf.toml` project configuration and converts it to the
4//! internal [`TrainingConfig`] and [`RasterConfig`] types used by the trainer
5//! and rasterizer subsystems.
6
7use std::env;
8use std::path::{Path, PathBuf};
9
10use anyhow::{Context, Result};
11use serde::{Deserialize, Serialize};
12
13use oxigaf::render::RasterConfig;
14use oxigaf::trainer::{
15    DensityConfig, InitConfig, LossConfig, OptimizerConfig, TensorBoardConfig, TrainingConfig,
16};
17
18// ---------------------------------------------------------------------------
19// Top-level configuration
20// ---------------------------------------------------------------------------
21
22/// Top-level project configuration loaded from `oxigaf.toml`.
23#[derive(Debug, Clone, Serialize, Deserialize)]
24#[serde(default)]
25#[derive(Default)]
26pub struct ProjectConfig {
27    /// Model file paths.
28    pub model: ModelSection,
29    /// GPU / backend settings.
30    pub device: DeviceSection,
31    /// Training hyper-parameters.
32    pub training: TrainingSection,
33    /// Output settings (checkpointing, logging, export).
34    pub output: OutputSection,
35}
36
37// ---------------------------------------------------------------------------
38// [model]
39// ---------------------------------------------------------------------------
40
41/// `[model]` section — paths to pretrained model files.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43#[serde(default)]
44pub struct ModelSection {
45    /// Path to the converted FLAME model directory.
46    pub flame_model_path: PathBuf,
47    /// Path to the directory containing diffusion model weights.
48    pub diffusion_weights_dir: PathBuf,
49}
50
51impl Default for ModelSection {
52    fn default() -> Self {
53        Self {
54            flame_model_path: PathBuf::from("~/.cache/oxigaf/flame2023"),
55            diffusion_weights_dir: PathBuf::from("~/.cache/oxigaf/weights"),
56        }
57    }
58}
59
60// ---------------------------------------------------------------------------
61// [device]
62// ---------------------------------------------------------------------------
63
64/// `[device]` section — GPU backend configuration.
65#[derive(Debug, Clone, Serialize, Deserialize)]
66#[serde(default)]
67pub struct DeviceSection {
68    /// GPU backend: `vulkan`, `metal`, `dx12`, or `gl`.
69    pub backend: String,
70    /// GPU device index.
71    pub gpu_index: usize,
72}
73
74impl Default for DeviceSection {
75    fn default() -> Self {
76        Self {
77            backend: "vulkan".to_string(),
78            gpu_index: 0,
79        }
80    }
81}
82
83// ---------------------------------------------------------------------------
84// [training]
85// ---------------------------------------------------------------------------
86
87/// `[training]` section — training hyper-parameters with sub-sections.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89#[serde(default)]
90pub struct TrainingSection {
91    pub total_iterations: u32,
92    pub views_per_step: usize,
93    pub image_size: u32,
94    pub guidance_scale_start: f32,
95    pub guidance_scale_end: f32,
96    pub guidance_anneal_steps: u32,
97    pub num_inference_steps: usize,
98    pub opacity_reset_interval: u32,
99
100    /// `[training.init]`
101    pub init: InitSection,
102    /// `[training.optimizer]`
103    pub optimizer: OptimizerSection,
104    /// `[training.density_control]`
105    pub density_control: DensityControlSection,
106    /// `[training.loss]`
107    pub loss: LossSection,
108}
109
110impl Default for TrainingSection {
111    fn default() -> Self {
112        Self {
113            total_iterations: 15_000,
114            views_per_step: 4,
115            image_size: 512,
116            guidance_scale_start: 7.5,
117            guidance_scale_end: 3.0,
118            guidance_anneal_steps: 10_000,
119            num_inference_steps: 50,
120            opacity_reset_interval: 3_000,
121            init: InitSection::default(),
122            optimizer: OptimizerSection::default(),
123            density_control: DensityControlSection::default(),
124            loss: LossSection::default(),
125        }
126    }
127}
128
129// ---------------------------------------------------------------------------
130// [training.init]
131// ---------------------------------------------------------------------------
132
133/// `[training.init]` — Gaussian initialisation parameters.
134#[derive(Debug, Clone, Serialize, Deserialize)]
135#[serde(default)]
136pub struct InitSection {
137    pub num_rigid_gaussians: usize,
138    pub num_flexible_gaussians: usize,
139    pub initial_scale: f32,
140    pub initial_opacity: f32,
141    pub sh_degree: u32,
142}
143
144impl Default for InitSection {
145    fn default() -> Self {
146        let d = InitConfig::default();
147        Self {
148            num_rigid_gaussians: d.num_rigid,
149            num_flexible_gaussians: d.num_flexible,
150            initial_scale: d.initial_scale,
151            initial_opacity: d.initial_opacity,
152            sh_degree: d.sh_degree,
153        }
154    }
155}
156
157// ---------------------------------------------------------------------------
158// [training.optimizer]
159// ---------------------------------------------------------------------------
160
161/// `[training.optimizer]` — per-parameter-group learning rates.
162#[derive(Debug, Clone, Serialize, Deserialize)]
163#[serde(default)]
164pub struct OptimizerSection {
165    pub position_lr: f32,
166    pub position_lr_final: f32,
167    pub rotation_lr: f32,
168    pub scale_lr: f32,
169    pub opacity_lr: f32,
170    pub sh_lr: f32,
171    pub offset_lr: f32,
172    pub beta1: f32,
173    pub beta2: f32,
174    pub epsilon: f32,
175    pub position_lr_decay_steps: u32,
176}
177
178impl Default for OptimizerSection {
179    fn default() -> Self {
180        let d = OptimizerConfig::default();
181        Self {
182            position_lr: d.lr_position,
183            position_lr_final: d.lr_position_final,
184            rotation_lr: d.lr_rotation,
185            scale_lr: d.lr_scale,
186            opacity_lr: d.lr_opacity,
187            sh_lr: d.lr_sh,
188            offset_lr: d.lr_offset,
189            beta1: d.beta1,
190            beta2: d.beta2,
191            epsilon: d.epsilon,
192            position_lr_decay_steps: d.position_lr_decay_steps,
193        }
194    }
195}
196
197// ---------------------------------------------------------------------------
198// [training.density_control]
199// ---------------------------------------------------------------------------
200
201/// `[training.density_control]` — adaptive density control parameters.
202#[derive(Debug, Clone, Serialize, Deserialize)]
203#[serde(default)]
204pub struct DensityControlSection {
205    pub interval: u32,
206    pub start_iteration: u32,
207    pub end_iteration: u32,
208    pub grad_threshold: f32,
209    pub min_opacity: f32,
210    pub max_screen_size: f32,
211    pub split_scale_threshold: f32,
212    pub max_gaussians: usize,
213}
214
215impl Default for DensityControlSection {
216    fn default() -> Self {
217        let d = DensityConfig::default();
218        Self {
219            interval: 500,
220            start_iteration: 1_000,
221            end_iteration: 12_000,
222            grad_threshold: d.grad_threshold,
223            min_opacity: d.min_opacity,
224            max_screen_size: d.max_screen_size,
225            split_scale_threshold: d.split_scale_threshold,
226            max_gaussians: d.max_gaussians,
227        }
228    }
229}
230
231// ---------------------------------------------------------------------------
232// [training.loss]
233// ---------------------------------------------------------------------------
234
235/// `[training.loss]` — loss function weights.
236#[derive(Debug, Clone, Serialize, Deserialize)]
237#[serde(default)]
238pub struct LossSection {
239    pub lambda_l1: f32,
240    pub lambda_ssim: f32,
241    pub lambda_ms_ssim: f32,
242    pub lambda_lpips: f32,
243    pub lambda_position_reg: f32,
244    pub lambda_scale_reg: f32,
245    pub lambda_opacity_reg: f32,
246    pub lambda_normal: f32,
247    pub lambda_gradient_penalty: f32,
248    pub gradient_penalty_threshold: f32,
249}
250
251impl Default for LossSection {
252    fn default() -> Self {
253        let d = LossConfig::default();
254        Self {
255            lambda_l1: d.w_l1,
256            lambda_ssim: d.w_ssim,
257            lambda_ms_ssim: d.w_ms_ssim,
258            lambda_lpips: d.w_lpips,
259            lambda_position_reg: d.w_position_reg,
260            lambda_scale_reg: d.w_scale_reg,
261            lambda_opacity_reg: d.w_opacity_reg,
262            lambda_normal: d.w_normal,
263            lambda_gradient_penalty: d.w_gradient_penalty,
264            gradient_penalty_threshold: d.gradient_penalty_threshold,
265        }
266    }
267}
268
269// ---------------------------------------------------------------------------
270// [output]
271// ---------------------------------------------------------------------------
272
273/// `[output]` section — checkpoint, logging, and export settings.
274#[derive(Debug, Clone, Serialize, Deserialize)]
275#[serde(default)]
276pub struct OutputSection {
277    pub checkpoint_interval: u32,
278    pub log_interval: u32,
279    pub export_format: String,
280}
281
282impl Default for OutputSection {
283    fn default() -> Self {
284        Self {
285            checkpoint_interval: 1_000,
286            log_interval: 50,
287            export_format: "ply".to_string(),
288        }
289    }
290}
291
292// ---------------------------------------------------------------------------
293// Conversion helpers
294// ---------------------------------------------------------------------------
295
296impl ProjectConfig {
297    /// Convert the user-facing project configuration to the internal
298    /// [`TrainingConfig`] consumed by the trainer.
299    pub fn to_training_config(&self) -> TrainingConfig {
300        let t = &self.training;
301        TrainingConfig {
302            total_iterations: t.total_iterations,
303            views_per_step: t.views_per_step,
304            density_control_interval: t.density_control.interval,
305            density_control_start: t.density_control.start_iteration,
306            density_control_end: t.density_control.end_iteration,
307            opacity_reset_interval: t.opacity_reset_interval,
308            checkpoint_interval: self.output.checkpoint_interval,
309            log_interval: self.output.log_interval,
310            guidance_scale_start: t.guidance_scale_start,
311            guidance_scale_end: t.guidance_scale_end,
312            guidance_anneal_steps: t.guidance_anneal_steps,
313            optimizer: OptimizerConfig {
314                lr_position: t.optimizer.position_lr,
315                lr_position_final: t.optimizer.position_lr_final,
316                lr_rotation: t.optimizer.rotation_lr,
317                lr_scale: t.optimizer.scale_lr,
318                lr_opacity: t.optimizer.opacity_lr,
319                lr_sh: t.optimizer.sh_lr,
320                lr_offset: t.optimizer.offset_lr,
321                beta1: t.optimizer.beta1,
322                beta2: t.optimizer.beta2,
323                epsilon: t.optimizer.epsilon,
324                position_lr_decay_steps: t.optimizer.position_lr_decay_steps,
325            },
326            loss: LossConfig {
327                w_l1: t.loss.lambda_l1,
328                w_ssim: t.loss.lambda_ssim,
329                w_ms_ssim: t.loss.lambda_ms_ssim,
330                w_lpips: t.loss.lambda_lpips,
331                w_position_reg: t.loss.lambda_position_reg,
332                w_scale_reg: t.loss.lambda_scale_reg,
333                w_opacity_reg: t.loss.lambda_opacity_reg,
334                w_normal: t.loss.lambda_normal,
335                w_gradient_penalty: t.loss.lambda_gradient_penalty,
336                gradient_penalty_threshold: t.loss.gradient_penalty_threshold,
337            },
338            density: DensityConfig {
339                grad_threshold: t.density_control.grad_threshold,
340                min_opacity: t.density_control.min_opacity,
341                max_screen_size: t.density_control.max_screen_size,
342                split_scale_threshold: t.density_control.split_scale_threshold,
343                max_gaussians: t.density_control.max_gaussians,
344            },
345            init: InitConfig {
346                num_rigid: t.init.num_rigid_gaussians,
347                num_flexible: t.init.num_flexible_gaussians,
348                initial_scale: t.init.initial_scale,
349                initial_opacity: t.init.initial_opacity,
350                sh_degree: t.init.sh_degree,
351            },
352            tensorboard: TensorBoardConfig::default(),
353        }
354    }
355
356    /// Build a [`RasterConfig`] from the project configuration.
357    pub fn to_raster_config(&self) -> RasterConfig {
358        RasterConfig {
359            image_width: self.training.image_size,
360            image_height: self.training.image_size,
361            sh_degree: self.training.init.sh_degree,
362            ..RasterConfig::default()
363        }
364    }
365
366    /// Basic validation of configuration values.
367    pub fn validate(&self) -> Result<()> {
368        anyhow::ensure!(
369            self.training.total_iterations > 0,
370            "total_iterations must be > 0"
371        );
372        anyhow::ensure!(
373            self.training.views_per_step > 0,
374            "views_per_step must be > 0"
375        );
376        anyhow::ensure!(self.training.image_size > 0, "image_size must be > 0");
377        anyhow::ensure!(
378            self.training.guidance_scale_start > 0.0,
379            "guidance_scale_start must be > 0"
380        );
381        anyhow::ensure!(
382            self.training.init.num_rigid_gaussians + self.training.init.num_flexible_gaussians > 0,
383            "Total Gaussian count must be > 0"
384        );
385        anyhow::ensure!(self.training.init.sh_degree <= 3, "SH degree must be <= 3");
386        Ok(())
387    }
388}
389
390// ---------------------------------------------------------------------------
391// Hierarchical Configuration Loading
392// ---------------------------------------------------------------------------
393
394/// Load config with hierarchical priority:
395/// 1. CLI arguments (highest)
396/// 2. Environment variables
397/// 3. Project config file (./oxigaf.toml)
398/// 4. User config file (~/.config/oxigaf/config.toml)
399/// 5. Default values (lowest)
400pub fn load_hierarchical_config(
401    cli_config_path: Option<&Path>,
402    override_values: Option<&ProjectConfig>,
403) -> Result<ProjectConfig> {
404    // Start with defaults
405    let mut config = ProjectConfig::default();
406
407    // Layer 1: User config (~/.config/oxigaf/config.toml)
408    if let Some(user_config_path) = get_user_config_path() {
409        if user_config_path.exists() {
410            tracing::debug!("Loading user config from: {}", user_config_path.display());
411            let user_config = load_config_from_file(&user_config_path)?;
412            config = merge_configs(config, user_config);
413        }
414    }
415
416    // Layer 2: Project config (./oxigaf.toml)
417    let project_config_path = PathBuf::from("./oxigaf.toml");
418    if project_config_path.exists() {
419        tracing::debug!(
420            "Loading project config from: {}",
421            project_config_path.display()
422        );
423        let project_config = load_config_from_file(&project_config_path)?;
424        config = merge_configs(config, project_config);
425    }
426
427    // Layer 3: CLI-specified config file
428    if let Some(path) = cli_config_path {
429        tracing::debug!("Loading CLI-specified config from: {}", path.display());
430        let cli_file_config = load_config_from_file(path)?;
431        config = merge_configs(config, cli_file_config);
432    }
433
434    // Layer 4: Environment variables
435    config = apply_env_overrides(config)?;
436
437    // Layer 5: CLI arguments (always take priority, regardless of value)
438    // Note: For actual CLI usage, it's recommended to apply CLI overrides
439    // directly after calling this function with override_values=None
440    if let Some(overrides) = override_values {
441        // For override_values, we do a simple overlay: any field that's been
442        // explicitly set in the override takes priority. Since we can't detect
443        // which fields were explicitly set vs. defaulted, this parameter is
444        // primarily for testing. In production, apply CLI overrides after calling
445        // this function.
446        config = merge_configs(config, overrides.clone());
447    }
448
449    config.validate()?;
450    Ok(config)
451}
452
453/// Get user config path (~/.config/oxigaf/config.toml)
454fn get_user_config_path() -> Option<PathBuf> {
455    let mut path = dirs::config_dir()?;
456    path.push("oxigaf");
457    path.push("config.toml");
458    Some(path)
459}
460
461/// Load config from a file without checking if it's the default oxigaf.toml
462fn load_config_from_file(path: &Path) -> Result<ProjectConfig> {
463    let contents = std::fs::read_to_string(path)
464        .with_context(|| format!("Failed to read config file: {}", path.display()))?;
465    let config: ProjectConfig = toml::from_str(&contents)
466        .with_context(|| format!("Failed to parse config file: {}", path.display()))?;
467    Ok(config)
468}
469
470/// Merge two configs (second takes priority)
471fn merge_configs(base: ProjectConfig, override_cfg: ProjectConfig) -> ProjectConfig {
472    ProjectConfig {
473        model: merge_model_section(base.model, override_cfg.model),
474        device: merge_device_section(base.device, override_cfg.device),
475        training: merge_training_section(base.training, override_cfg.training),
476        output: merge_output_section(base.output, override_cfg.output),
477    }
478}
479
480/// Merge model sections (override takes priority for non-default values)
481fn merge_model_section(base: ModelSection, override_cfg: ModelSection) -> ModelSection {
482    let default = ModelSection::default();
483    ModelSection {
484        flame_model_path: if override_cfg.flame_model_path != default.flame_model_path {
485            override_cfg.flame_model_path
486        } else {
487            base.flame_model_path
488        },
489        diffusion_weights_dir: if override_cfg.diffusion_weights_dir
490            != default.diffusion_weights_dir
491        {
492            override_cfg.diffusion_weights_dir
493        } else {
494            base.diffusion_weights_dir
495        },
496    }
497}
498
499/// Merge device sections (override takes priority for non-default values)
500fn merge_device_section(base: DeviceSection, override_cfg: DeviceSection) -> DeviceSection {
501    let default = DeviceSection::default();
502    DeviceSection {
503        backend: if override_cfg.backend != default.backend {
504            override_cfg.backend
505        } else {
506            base.backend
507        },
508        gpu_index: if override_cfg.gpu_index != default.gpu_index {
509            override_cfg.gpu_index
510        } else {
511            base.gpu_index
512        },
513    }
514}
515
516/// Merge training sections (override takes priority for non-default values)
517fn merge_training_section(base: TrainingSection, override_cfg: TrainingSection) -> TrainingSection {
518    let default = TrainingSection::default();
519    TrainingSection {
520        total_iterations: if override_cfg.total_iterations != default.total_iterations {
521            override_cfg.total_iterations
522        } else {
523            base.total_iterations
524        },
525        views_per_step: if override_cfg.views_per_step != default.views_per_step {
526            override_cfg.views_per_step
527        } else {
528            base.views_per_step
529        },
530        image_size: if override_cfg.image_size != default.image_size {
531            override_cfg.image_size
532        } else {
533            base.image_size
534        },
535        guidance_scale_start: if (override_cfg.guidance_scale_start - default.guidance_scale_start)
536            .abs()
537            > f32::EPSILON
538        {
539            override_cfg.guidance_scale_start
540        } else {
541            base.guidance_scale_start
542        },
543        guidance_scale_end: if (override_cfg.guidance_scale_end - default.guidance_scale_end).abs()
544            > f32::EPSILON
545        {
546            override_cfg.guidance_scale_end
547        } else {
548            base.guidance_scale_end
549        },
550        guidance_anneal_steps: if override_cfg.guidance_anneal_steps
551            != default.guidance_anneal_steps
552        {
553            override_cfg.guidance_anneal_steps
554        } else {
555            base.guidance_anneal_steps
556        },
557        num_inference_steps: if override_cfg.num_inference_steps != default.num_inference_steps {
558            override_cfg.num_inference_steps
559        } else {
560            base.num_inference_steps
561        },
562        opacity_reset_interval: if override_cfg.opacity_reset_interval
563            != default.opacity_reset_interval
564        {
565            override_cfg.opacity_reset_interval
566        } else {
567            base.opacity_reset_interval
568        },
569        init: merge_init_section(base.init, override_cfg.init),
570        optimizer: merge_optimizer_section(base.optimizer, override_cfg.optimizer),
571        density_control: merge_density_control_section(
572            base.density_control,
573            override_cfg.density_control,
574        ),
575        loss: merge_loss_section(base.loss, override_cfg.loss),
576    }
577}
578
579/// Merge init sections (override takes priority for non-default values)
580fn merge_init_section(base: InitSection, override_cfg: InitSection) -> InitSection {
581    let default = InitSection::default();
582    InitSection {
583        num_rigid_gaussians: if override_cfg.num_rigid_gaussians != default.num_rigid_gaussians {
584            override_cfg.num_rigid_gaussians
585        } else {
586            base.num_rigid_gaussians
587        },
588        num_flexible_gaussians: if override_cfg.num_flexible_gaussians
589            != default.num_flexible_gaussians
590        {
591            override_cfg.num_flexible_gaussians
592        } else {
593            base.num_flexible_gaussians
594        },
595        initial_scale: if (override_cfg.initial_scale - default.initial_scale).abs() > f32::EPSILON
596        {
597            override_cfg.initial_scale
598        } else {
599            base.initial_scale
600        },
601        initial_opacity: if (override_cfg.initial_opacity - default.initial_opacity).abs()
602            > f32::EPSILON
603        {
604            override_cfg.initial_opacity
605        } else {
606            base.initial_opacity
607        },
608        sh_degree: if override_cfg.sh_degree != default.sh_degree {
609            override_cfg.sh_degree
610        } else {
611            base.sh_degree
612        },
613    }
614}
615
616/// Merge optimizer sections (override takes priority for non-default values)
617fn merge_optimizer_section(
618    base: OptimizerSection,
619    override_cfg: OptimizerSection,
620) -> OptimizerSection {
621    let default = OptimizerSection::default();
622    OptimizerSection {
623        position_lr: if (override_cfg.position_lr - default.position_lr).abs() > f32::EPSILON {
624            override_cfg.position_lr
625        } else {
626            base.position_lr
627        },
628        position_lr_final: if (override_cfg.position_lr_final - default.position_lr_final).abs()
629            > f32::EPSILON
630        {
631            override_cfg.position_lr_final
632        } else {
633            base.position_lr_final
634        },
635        rotation_lr: if (override_cfg.rotation_lr - default.rotation_lr).abs() > f32::EPSILON {
636            override_cfg.rotation_lr
637        } else {
638            base.rotation_lr
639        },
640        scale_lr: if (override_cfg.scale_lr - default.scale_lr).abs() > f32::EPSILON {
641            override_cfg.scale_lr
642        } else {
643            base.scale_lr
644        },
645        opacity_lr: if (override_cfg.opacity_lr - default.opacity_lr).abs() > f32::EPSILON {
646            override_cfg.opacity_lr
647        } else {
648            base.opacity_lr
649        },
650        sh_lr: if (override_cfg.sh_lr - default.sh_lr).abs() > f32::EPSILON {
651            override_cfg.sh_lr
652        } else {
653            base.sh_lr
654        },
655        offset_lr: if (override_cfg.offset_lr - default.offset_lr).abs() > f32::EPSILON {
656            override_cfg.offset_lr
657        } else {
658            base.offset_lr
659        },
660        beta1: if (override_cfg.beta1 - default.beta1).abs() > f32::EPSILON {
661            override_cfg.beta1
662        } else {
663            base.beta1
664        },
665        beta2: if (override_cfg.beta2 - default.beta2).abs() > f32::EPSILON {
666            override_cfg.beta2
667        } else {
668            base.beta2
669        },
670        epsilon: if (override_cfg.epsilon - default.epsilon).abs() > f32::EPSILON {
671            override_cfg.epsilon
672        } else {
673            base.epsilon
674        },
675        position_lr_decay_steps: if override_cfg.position_lr_decay_steps
676            != default.position_lr_decay_steps
677        {
678            override_cfg.position_lr_decay_steps
679        } else {
680            base.position_lr_decay_steps
681        },
682    }
683}
684
685/// Merge density control sections (override takes priority for non-default values)
686fn merge_density_control_section(
687    base: DensityControlSection,
688    override_cfg: DensityControlSection,
689) -> DensityControlSection {
690    let default = DensityControlSection::default();
691    DensityControlSection {
692        interval: if override_cfg.interval != default.interval {
693            override_cfg.interval
694        } else {
695            base.interval
696        },
697        start_iteration: if override_cfg.start_iteration != default.start_iteration {
698            override_cfg.start_iteration
699        } else {
700            base.start_iteration
701        },
702        end_iteration: if override_cfg.end_iteration != default.end_iteration {
703            override_cfg.end_iteration
704        } else {
705            base.end_iteration
706        },
707        grad_threshold: if (override_cfg.grad_threshold - default.grad_threshold).abs()
708            > f32::EPSILON
709        {
710            override_cfg.grad_threshold
711        } else {
712            base.grad_threshold
713        },
714        min_opacity: if (override_cfg.min_opacity - default.min_opacity).abs() > f32::EPSILON {
715            override_cfg.min_opacity
716        } else {
717            base.min_opacity
718        },
719        max_screen_size: if (override_cfg.max_screen_size - default.max_screen_size).abs()
720            > f32::EPSILON
721        {
722            override_cfg.max_screen_size
723        } else {
724            base.max_screen_size
725        },
726        split_scale_threshold: if (override_cfg.split_scale_threshold
727            - default.split_scale_threshold)
728            .abs()
729            > f32::EPSILON
730        {
731            override_cfg.split_scale_threshold
732        } else {
733            base.split_scale_threshold
734        },
735        max_gaussians: if override_cfg.max_gaussians != default.max_gaussians {
736            override_cfg.max_gaussians
737        } else {
738            base.max_gaussians
739        },
740    }
741}
742
743/// Merge loss sections (override takes priority for non-default values)
744fn merge_loss_section(base: LossSection, override_cfg: LossSection) -> LossSection {
745    let default = LossSection::default();
746    LossSection {
747        lambda_l1: if (override_cfg.lambda_l1 - default.lambda_l1).abs() > f32::EPSILON {
748            override_cfg.lambda_l1
749        } else {
750            base.lambda_l1
751        },
752        lambda_ssim: if (override_cfg.lambda_ssim - default.lambda_ssim).abs() > f32::EPSILON {
753            override_cfg.lambda_ssim
754        } else {
755            base.lambda_ssim
756        },
757        lambda_ms_ssim: if (override_cfg.lambda_ms_ssim - default.lambda_ms_ssim).abs()
758            > f32::EPSILON
759        {
760            override_cfg.lambda_ms_ssim
761        } else {
762            base.lambda_ms_ssim
763        },
764        lambda_lpips: if (override_cfg.lambda_lpips - default.lambda_lpips).abs() > f32::EPSILON {
765            override_cfg.lambda_lpips
766        } else {
767            base.lambda_lpips
768        },
769        lambda_position_reg: if (override_cfg.lambda_position_reg - default.lambda_position_reg)
770            .abs()
771            > f32::EPSILON
772        {
773            override_cfg.lambda_position_reg
774        } else {
775            base.lambda_position_reg
776        },
777        lambda_scale_reg: if (override_cfg.lambda_scale_reg - default.lambda_scale_reg).abs()
778            > f32::EPSILON
779        {
780            override_cfg.lambda_scale_reg
781        } else {
782            base.lambda_scale_reg
783        },
784        lambda_opacity_reg: if (override_cfg.lambda_opacity_reg - default.lambda_opacity_reg).abs()
785            > f32::EPSILON
786        {
787            override_cfg.lambda_opacity_reg
788        } else {
789            base.lambda_opacity_reg
790        },
791        lambda_normal: if (override_cfg.lambda_normal - default.lambda_normal).abs() > f32::EPSILON
792        {
793            override_cfg.lambda_normal
794        } else {
795            base.lambda_normal
796        },
797        lambda_gradient_penalty: if (override_cfg.lambda_gradient_penalty
798            - default.lambda_gradient_penalty)
799            .abs()
800            > f32::EPSILON
801        {
802            override_cfg.lambda_gradient_penalty
803        } else {
804            base.lambda_gradient_penalty
805        },
806        gradient_penalty_threshold: if (override_cfg.gradient_penalty_threshold
807            - default.gradient_penalty_threshold)
808            .abs()
809            > f32::EPSILON
810        {
811            override_cfg.gradient_penalty_threshold
812        } else {
813            base.gradient_penalty_threshold
814        },
815    }
816}
817
818/// Merge output sections (override takes priority for non-default values)
819fn merge_output_section(base: OutputSection, override_cfg: OutputSection) -> OutputSection {
820    let default = OutputSection::default();
821    OutputSection {
822        checkpoint_interval: if override_cfg.checkpoint_interval != default.checkpoint_interval {
823            override_cfg.checkpoint_interval
824        } else {
825            base.checkpoint_interval
826        },
827        log_interval: if override_cfg.log_interval != default.log_interval {
828            override_cfg.log_interval
829        } else {
830            base.log_interval
831        },
832        export_format: if override_cfg.export_format != default.export_format {
833            override_cfg.export_format
834        } else {
835            base.export_format
836        },
837    }
838}
839
840/// Apply environment variable overrides
841fn apply_env_overrides(mut config: ProjectConfig) -> Result<ProjectConfig> {
842    // Training parameters
843    if let Ok(val) = env::var("OXIGAF_TOTAL_ITERATIONS") {
844        config.training.total_iterations = val
845            .parse()
846            .with_context(|| format!("Invalid OXIGAF_TOTAL_ITERATIONS: {}", val))?;
847    }
848
849    if let Ok(val) = env::var("OXIGAF_IMAGE_SIZE") {
850        config.training.image_size = val
851            .parse()
852            .with_context(|| format!("Invalid OXIGAF_IMAGE_SIZE: {}", val))?;
853    }
854
855    if let Ok(val) = env::var("OXIGAF_VIEWS_PER_STEP") {
856        config.training.views_per_step = val
857            .parse()
858            .with_context(|| format!("Invalid OXIGAF_VIEWS_PER_STEP: {}", val))?;
859    }
860
861    if let Ok(val) = env::var("OXIGAF_GUIDANCE_SCALE_START") {
862        config.training.guidance_scale_start = val
863            .parse()
864            .with_context(|| format!("Invalid OXIGAF_GUIDANCE_SCALE_START: {}", val))?;
865    }
866
867    if let Ok(val) = env::var("OXIGAF_GUIDANCE_SCALE_END") {
868        config.training.guidance_scale_end = val
869            .parse()
870            .with_context(|| format!("Invalid OXIGAF_GUIDANCE_SCALE_END: {}", val))?;
871    }
872
873    // Optimizer parameters
874    if let Ok(val) = env::var("OXIGAF_POSITION_LR") {
875        config.training.optimizer.position_lr = val
876            .parse()
877            .with_context(|| format!("Invalid OXIGAF_POSITION_LR: {}", val))?;
878    }
879
880    if let Ok(val) = env::var("OXIGAF_SCALING_LR") {
881        config.training.optimizer.scale_lr = val
882            .parse()
883            .with_context(|| format!("Invalid OXIGAF_SCALING_LR: {}", val))?;
884    }
885
886    if let Ok(val) = env::var("OXIGAF_ROTATION_LR") {
887        config.training.optimizer.rotation_lr = val
888            .parse()
889            .with_context(|| format!("Invalid OXIGAF_ROTATION_LR: {}", val))?;
890    }
891
892    if let Ok(val) = env::var("OXIGAF_OPACITY_LR") {
893        config.training.optimizer.opacity_lr = val
894            .parse()
895            .with_context(|| format!("Invalid OXIGAF_OPACITY_LR: {}", val))?;
896    }
897
898    if let Ok(val) = env::var("OXIGAF_SH_LR") {
899        config.training.optimizer.sh_lr = val
900            .parse()
901            .with_context(|| format!("Invalid OXIGAF_SH_LR: {}", val))?;
902    }
903
904    // Device parameters
905    if let Ok(val) = env::var("OXIGAF_DEVICE_BACKEND") {
906        config.device.backend = val;
907    }
908
909    if let Ok(val) = env::var("OXIGAF_DEVICE_GPU_INDEX") {
910        config.device.gpu_index = val
911            .parse()
912            .with_context(|| format!("Invalid OXIGAF_DEVICE_GPU_INDEX: {}", val))?;
913    }
914
915    // Output parameters
916    if let Ok(val) = env::var("OXIGAF_OUTPUT_CHECKPOINT_INTERVAL") {
917        config.output.checkpoint_interval = val
918            .parse()
919            .with_context(|| format!("Invalid OXIGAF_OUTPUT_CHECKPOINT_INTERVAL: {}", val))?;
920    }
921
922    if let Ok(val) = env::var("OXIGAF_OUTPUT_LOG_INTERVAL") {
923        config.output.log_interval = val
924            .parse()
925            .with_context(|| format!("Invalid OXIGAF_OUTPUT_LOG_INTERVAL: {}", val))?;
926    }
927
928    if let Ok(val) = env::var("OXIGAF_OUTPUT_EXPORT_FORMAT") {
929        config.output.export_format = val;
930    }
931
932    // Init parameters
933    if let Ok(val) = env::var("OXIGAF_SH_DEGREE") {
934        config.training.init.sh_degree = val
935            .parse()
936            .with_context(|| format!("Invalid OXIGAF_SH_DEGREE: {}", val))?;
937    }
938
939    if let Ok(val) = env::var("OXIGAF_NUM_RIGID_GAUSSIANS") {
940        config.training.init.num_rigid_gaussians = val
941            .parse()
942            .with_context(|| format!("Invalid OXIGAF_NUM_RIGID_GAUSSIANS: {}", val))?;
943    }
944
945    if let Ok(val) = env::var("OXIGAF_NUM_FLEXIBLE_GAUSSIANS") {
946        config.training.init.num_flexible_gaussians = val
947            .parse()
948            .with_context(|| format!("Invalid OXIGAF_NUM_FLEXIBLE_GAUSSIANS: {}", val))?;
949    }
950
951    Ok(config)
952}
953
954// ---------------------------------------------------------------------------
955// Loading
956// ---------------------------------------------------------------------------
957
958/// Load and validate a [`ProjectConfig`] from a TOML file.
959///
960/// If `path` does not exist and is the default `oxigaf.toml`, returns the
961/// default configuration instead of erroring.
962#[allow(dead_code)]
963pub fn load_config(path: &Path) -> Result<ProjectConfig> {
964    if path.exists() {
965        let contents = std::fs::read_to_string(path)
966            .with_context(|| format!("Failed to read config file: {}", path.display()))?;
967        let config: ProjectConfig = toml::from_str(&contents)
968            .with_context(|| format!("Failed to parse config file: {}", path.display()))?;
969        config.validate()?;
970        Ok(config)
971    } else if path.ends_with("oxigaf.toml") {
972        tracing::info!(
973            "Config file not found at {}, using defaults",
974            path.display()
975        );
976        Ok(ProjectConfig::default())
977    } else {
978        anyhow::bail!("Config file not found: {}", path.display())
979    }
980}
981
982/// Generate a default TOML configuration string that can be written to a file.
983#[allow(dead_code)]
984pub fn generate_default_config() -> Result<String> {
985    let config = ProjectConfig::default();
986    toml::to_string_pretty(&config).context("Failed to serialize default config")
987}
988
989// ---------------------------------------------------------------------------
990// Utilities
991// ---------------------------------------------------------------------------
992
993/// Expand a leading `~` in a path to the user's home directory.
994pub fn expand_tilde(path: &Path) -> PathBuf {
995    let s = path.to_string_lossy();
996    if s.starts_with("~/") || s == "~" {
997        if let Ok(home) = std::env::var("HOME") {
998            return PathBuf::from(s.replacen('~', &home, 1));
999        }
1000    }
1001    path.to_path_buf()
1002}
1003
1004#[cfg(test)]
1005mod tests {
1006    use super::*;
1007
1008    #[test]
1009    fn default_config_round_trips() -> Result<()> {
1010        let config = ProjectConfig::default();
1011        let toml_str =
1012            toml::to_string_pretty(&config).context("Failed to serialize default config")?;
1013        let parsed: ProjectConfig =
1014            toml::from_str(&toml_str).context("Failed to parse serialized config")?;
1015        assert_eq!(parsed.training.total_iterations, 15_000);
1016        assert_eq!(parsed.training.init.num_rigid_gaussians, 50_000);
1017        Ok(())
1018    }
1019
1020    #[test]
1021    fn partial_config_uses_defaults() -> Result<()> {
1022        let toml_str = r#"
1023[training]
1024total_iterations = 5000
1025"#;
1026        let config: ProjectConfig =
1027            toml::from_str(toml_str).context("Failed to parse partial config")?;
1028        assert_eq!(config.training.total_iterations, 5000);
1029        // Unspecified fields use defaults.
1030        assert_eq!(config.training.views_per_step, 4);
1031        assert_eq!(config.training.init.num_rigid_gaussians, 50_000);
1032        Ok(())
1033    }
1034
1035    #[test]
1036    fn validation_catches_zero_iterations() {
1037        let mut config = ProjectConfig::default();
1038        config.training.total_iterations = 0;
1039        assert!(config.validate().is_err());
1040    }
1041
1042    #[test]
1043    fn expand_tilde_works() {
1044        // Use a unique env var to avoid conflicts in parallel tests
1045        std::env::set_var("HOME", "/home/test");
1046        let p = expand_tilde(Path::new("~/.cache/oxigaf"));
1047        assert_eq!(p, PathBuf::from("/home/test/.cache/oxigaf"));
1048    }
1049}