1use std::env;
8use std::path::{Path, PathBuf};
9
10use anyhow::{Context, Result};
11use serde::{Deserialize, Serialize};
12
13use oxigaf::render::RasterConfig;
14use oxigaf::trainer::{
15 DensityConfig, InitConfig, LossConfig, OptimizerConfig, TensorBoardConfig, TrainingConfig,
16};
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
24#[serde(default)]
25#[derive(Default)]
26pub struct ProjectConfig {
27 pub model: ModelSection,
29 pub device: DeviceSection,
31 pub training: TrainingSection,
33 pub output: OutputSection,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
43#[serde(default)]
44pub struct ModelSection {
45 pub flame_model_path: PathBuf,
47 pub diffusion_weights_dir: PathBuf,
49}
50
51impl Default for ModelSection {
52 fn default() -> Self {
53 Self {
54 flame_model_path: PathBuf::from("~/.cache/oxigaf/flame2023"),
55 diffusion_weights_dir: PathBuf::from("~/.cache/oxigaf/weights"),
56 }
57 }
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
66#[serde(default)]
67pub struct DeviceSection {
68 pub backend: String,
70 pub gpu_index: usize,
72}
73
74impl Default for DeviceSection {
75 fn default() -> Self {
76 Self {
77 backend: "vulkan".to_string(),
78 gpu_index: 0,
79 }
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
89#[serde(default)]
90pub struct TrainingSection {
91 pub total_iterations: u32,
92 pub views_per_step: usize,
93 pub image_size: u32,
94 pub guidance_scale_start: f32,
95 pub guidance_scale_end: f32,
96 pub guidance_anneal_steps: u32,
97 pub num_inference_steps: usize,
98 pub opacity_reset_interval: u32,
99
100 pub init: InitSection,
102 pub optimizer: OptimizerSection,
104 pub density_control: DensityControlSection,
106 pub loss: LossSection,
108}
109
110impl Default for TrainingSection {
111 fn default() -> Self {
112 Self {
113 total_iterations: 15_000,
114 views_per_step: 4,
115 image_size: 512,
116 guidance_scale_start: 7.5,
117 guidance_scale_end: 3.0,
118 guidance_anneal_steps: 10_000,
119 num_inference_steps: 50,
120 opacity_reset_interval: 3_000,
121 init: InitSection::default(),
122 optimizer: OptimizerSection::default(),
123 density_control: DensityControlSection::default(),
124 loss: LossSection::default(),
125 }
126 }
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
135#[serde(default)]
136pub struct InitSection {
137 pub num_rigid_gaussians: usize,
138 pub num_flexible_gaussians: usize,
139 pub initial_scale: f32,
140 pub initial_opacity: f32,
141 pub sh_degree: u32,
142}
143
144impl Default for InitSection {
145 fn default() -> Self {
146 let d = InitConfig::default();
147 Self {
148 num_rigid_gaussians: d.num_rigid,
149 num_flexible_gaussians: d.num_flexible,
150 initial_scale: d.initial_scale,
151 initial_opacity: d.initial_opacity,
152 sh_degree: d.sh_degree,
153 }
154 }
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
163#[serde(default)]
164pub struct OptimizerSection {
165 pub position_lr: f32,
166 pub position_lr_final: f32,
167 pub rotation_lr: f32,
168 pub scale_lr: f32,
169 pub opacity_lr: f32,
170 pub sh_lr: f32,
171 pub offset_lr: f32,
172 pub beta1: f32,
173 pub beta2: f32,
174 pub epsilon: f32,
175 pub position_lr_decay_steps: u32,
176}
177
178impl Default for OptimizerSection {
179 fn default() -> Self {
180 let d = OptimizerConfig::default();
181 Self {
182 position_lr: d.lr_position,
183 position_lr_final: d.lr_position_final,
184 rotation_lr: d.lr_rotation,
185 scale_lr: d.lr_scale,
186 opacity_lr: d.lr_opacity,
187 sh_lr: d.lr_sh,
188 offset_lr: d.lr_offset,
189 beta1: d.beta1,
190 beta2: d.beta2,
191 epsilon: d.epsilon,
192 position_lr_decay_steps: d.position_lr_decay_steps,
193 }
194 }
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
203#[serde(default)]
204pub struct DensityControlSection {
205 pub interval: u32,
206 pub start_iteration: u32,
207 pub end_iteration: u32,
208 pub grad_threshold: f32,
209 pub min_opacity: f32,
210 pub max_screen_size: f32,
211 pub split_scale_threshold: f32,
212 pub max_gaussians: usize,
213}
214
215impl Default for DensityControlSection {
216 fn default() -> Self {
217 let d = DensityConfig::default();
218 Self {
219 interval: 500,
220 start_iteration: 1_000,
221 end_iteration: 12_000,
222 grad_threshold: d.grad_threshold,
223 min_opacity: d.min_opacity,
224 max_screen_size: d.max_screen_size,
225 split_scale_threshold: d.split_scale_threshold,
226 max_gaussians: d.max_gaussians,
227 }
228 }
229}
230
231#[derive(Debug, Clone, Serialize, Deserialize)]
237#[serde(default)]
238pub struct LossSection {
239 pub lambda_l1: f32,
240 pub lambda_ssim: f32,
241 pub lambda_ms_ssim: f32,
242 pub lambda_lpips: f32,
243 pub lambda_position_reg: f32,
244 pub lambda_scale_reg: f32,
245 pub lambda_opacity_reg: f32,
246 pub lambda_normal: f32,
247 pub lambda_gradient_penalty: f32,
248 pub gradient_penalty_threshold: f32,
249}
250
251impl Default for LossSection {
252 fn default() -> Self {
253 let d = LossConfig::default();
254 Self {
255 lambda_l1: d.w_l1,
256 lambda_ssim: d.w_ssim,
257 lambda_ms_ssim: d.w_ms_ssim,
258 lambda_lpips: d.w_lpips,
259 lambda_position_reg: d.w_position_reg,
260 lambda_scale_reg: d.w_scale_reg,
261 lambda_opacity_reg: d.w_opacity_reg,
262 lambda_normal: d.w_normal,
263 lambda_gradient_penalty: d.w_gradient_penalty,
264 gradient_penalty_threshold: d.gradient_penalty_threshold,
265 }
266 }
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
275#[serde(default)]
276pub struct OutputSection {
277 pub checkpoint_interval: u32,
278 pub log_interval: u32,
279 pub export_format: String,
280}
281
282impl Default for OutputSection {
283 fn default() -> Self {
284 Self {
285 checkpoint_interval: 1_000,
286 log_interval: 50,
287 export_format: "ply".to_string(),
288 }
289 }
290}
291
292impl ProjectConfig {
297 pub fn to_training_config(&self) -> TrainingConfig {
300 let t = &self.training;
301 TrainingConfig {
302 total_iterations: t.total_iterations,
303 views_per_step: t.views_per_step,
304 density_control_interval: t.density_control.interval,
305 density_control_start: t.density_control.start_iteration,
306 density_control_end: t.density_control.end_iteration,
307 opacity_reset_interval: t.opacity_reset_interval,
308 checkpoint_interval: self.output.checkpoint_interval,
309 log_interval: self.output.log_interval,
310 guidance_scale_start: t.guidance_scale_start,
311 guidance_scale_end: t.guidance_scale_end,
312 guidance_anneal_steps: t.guidance_anneal_steps,
313 optimizer: OptimizerConfig {
314 lr_position: t.optimizer.position_lr,
315 lr_position_final: t.optimizer.position_lr_final,
316 lr_rotation: t.optimizer.rotation_lr,
317 lr_scale: t.optimizer.scale_lr,
318 lr_opacity: t.optimizer.opacity_lr,
319 lr_sh: t.optimizer.sh_lr,
320 lr_offset: t.optimizer.offset_lr,
321 beta1: t.optimizer.beta1,
322 beta2: t.optimizer.beta2,
323 epsilon: t.optimizer.epsilon,
324 position_lr_decay_steps: t.optimizer.position_lr_decay_steps,
325 },
326 loss: LossConfig {
327 w_l1: t.loss.lambda_l1,
328 w_ssim: t.loss.lambda_ssim,
329 w_ms_ssim: t.loss.lambda_ms_ssim,
330 w_lpips: t.loss.lambda_lpips,
331 w_position_reg: t.loss.lambda_position_reg,
332 w_scale_reg: t.loss.lambda_scale_reg,
333 w_opacity_reg: t.loss.lambda_opacity_reg,
334 w_normal: t.loss.lambda_normal,
335 w_gradient_penalty: t.loss.lambda_gradient_penalty,
336 gradient_penalty_threshold: t.loss.gradient_penalty_threshold,
337 },
338 density: DensityConfig {
339 grad_threshold: t.density_control.grad_threshold,
340 min_opacity: t.density_control.min_opacity,
341 max_screen_size: t.density_control.max_screen_size,
342 split_scale_threshold: t.density_control.split_scale_threshold,
343 max_gaussians: t.density_control.max_gaussians,
344 },
345 init: InitConfig {
346 num_rigid: t.init.num_rigid_gaussians,
347 num_flexible: t.init.num_flexible_gaussians,
348 initial_scale: t.init.initial_scale,
349 initial_opacity: t.init.initial_opacity,
350 sh_degree: t.init.sh_degree,
351 },
352 tensorboard: TensorBoardConfig::default(),
353 }
354 }
355
356 pub fn to_raster_config(&self) -> RasterConfig {
358 RasterConfig {
359 image_width: self.training.image_size,
360 image_height: self.training.image_size,
361 sh_degree: self.training.init.sh_degree,
362 ..RasterConfig::default()
363 }
364 }
365
366 pub fn validate(&self) -> Result<()> {
368 anyhow::ensure!(
369 self.training.total_iterations > 0,
370 "total_iterations must be > 0"
371 );
372 anyhow::ensure!(
373 self.training.views_per_step > 0,
374 "views_per_step must be > 0"
375 );
376 anyhow::ensure!(self.training.image_size > 0, "image_size must be > 0");
377 anyhow::ensure!(
378 self.training.guidance_scale_start > 0.0,
379 "guidance_scale_start must be > 0"
380 );
381 anyhow::ensure!(
382 self.training.init.num_rigid_gaussians + self.training.init.num_flexible_gaussians > 0,
383 "Total Gaussian count must be > 0"
384 );
385 anyhow::ensure!(self.training.init.sh_degree <= 3, "SH degree must be <= 3");
386 Ok(())
387 }
388}
389
390pub fn load_hierarchical_config(
401 cli_config_path: Option<&Path>,
402 override_values: Option<&ProjectConfig>,
403) -> Result<ProjectConfig> {
404 let mut config = ProjectConfig::default();
406
407 if let Some(user_config_path) = get_user_config_path() {
409 if user_config_path.exists() {
410 tracing::debug!("Loading user config from: {}", user_config_path.display());
411 let user_config = load_config_from_file(&user_config_path)?;
412 config = merge_configs(config, user_config);
413 }
414 }
415
416 let project_config_path = PathBuf::from("./oxigaf.toml");
418 if project_config_path.exists() {
419 tracing::debug!(
420 "Loading project config from: {}",
421 project_config_path.display()
422 );
423 let project_config = load_config_from_file(&project_config_path)?;
424 config = merge_configs(config, project_config);
425 }
426
427 if let Some(path) = cli_config_path {
429 tracing::debug!("Loading CLI-specified config from: {}", path.display());
430 let cli_file_config = load_config_from_file(path)?;
431 config = merge_configs(config, cli_file_config);
432 }
433
434 config = apply_env_overrides(config)?;
436
437 if let Some(overrides) = override_values {
441 config = merge_configs(config, overrides.clone());
447 }
448
449 config.validate()?;
450 Ok(config)
451}
452
453fn get_user_config_path() -> Option<PathBuf> {
455 let mut path = dirs::config_dir()?;
456 path.push("oxigaf");
457 path.push("config.toml");
458 Some(path)
459}
460
461fn load_config_from_file(path: &Path) -> Result<ProjectConfig> {
463 let contents = std::fs::read_to_string(path)
464 .with_context(|| format!("Failed to read config file: {}", path.display()))?;
465 let config: ProjectConfig = toml::from_str(&contents)
466 .with_context(|| format!("Failed to parse config file: {}", path.display()))?;
467 Ok(config)
468}
469
470fn merge_configs(base: ProjectConfig, override_cfg: ProjectConfig) -> ProjectConfig {
472 ProjectConfig {
473 model: merge_model_section(base.model, override_cfg.model),
474 device: merge_device_section(base.device, override_cfg.device),
475 training: merge_training_section(base.training, override_cfg.training),
476 output: merge_output_section(base.output, override_cfg.output),
477 }
478}
479
480fn merge_model_section(base: ModelSection, override_cfg: ModelSection) -> ModelSection {
482 let default = ModelSection::default();
483 ModelSection {
484 flame_model_path: if override_cfg.flame_model_path != default.flame_model_path {
485 override_cfg.flame_model_path
486 } else {
487 base.flame_model_path
488 },
489 diffusion_weights_dir: if override_cfg.diffusion_weights_dir
490 != default.diffusion_weights_dir
491 {
492 override_cfg.diffusion_weights_dir
493 } else {
494 base.diffusion_weights_dir
495 },
496 }
497}
498
499fn merge_device_section(base: DeviceSection, override_cfg: DeviceSection) -> DeviceSection {
501 let default = DeviceSection::default();
502 DeviceSection {
503 backend: if override_cfg.backend != default.backend {
504 override_cfg.backend
505 } else {
506 base.backend
507 },
508 gpu_index: if override_cfg.gpu_index != default.gpu_index {
509 override_cfg.gpu_index
510 } else {
511 base.gpu_index
512 },
513 }
514}
515
516fn merge_training_section(base: TrainingSection, override_cfg: TrainingSection) -> TrainingSection {
518 let default = TrainingSection::default();
519 TrainingSection {
520 total_iterations: if override_cfg.total_iterations != default.total_iterations {
521 override_cfg.total_iterations
522 } else {
523 base.total_iterations
524 },
525 views_per_step: if override_cfg.views_per_step != default.views_per_step {
526 override_cfg.views_per_step
527 } else {
528 base.views_per_step
529 },
530 image_size: if override_cfg.image_size != default.image_size {
531 override_cfg.image_size
532 } else {
533 base.image_size
534 },
535 guidance_scale_start: if (override_cfg.guidance_scale_start - default.guidance_scale_start)
536 .abs()
537 > f32::EPSILON
538 {
539 override_cfg.guidance_scale_start
540 } else {
541 base.guidance_scale_start
542 },
543 guidance_scale_end: if (override_cfg.guidance_scale_end - default.guidance_scale_end).abs()
544 > f32::EPSILON
545 {
546 override_cfg.guidance_scale_end
547 } else {
548 base.guidance_scale_end
549 },
550 guidance_anneal_steps: if override_cfg.guidance_anneal_steps
551 != default.guidance_anneal_steps
552 {
553 override_cfg.guidance_anneal_steps
554 } else {
555 base.guidance_anneal_steps
556 },
557 num_inference_steps: if override_cfg.num_inference_steps != default.num_inference_steps {
558 override_cfg.num_inference_steps
559 } else {
560 base.num_inference_steps
561 },
562 opacity_reset_interval: if override_cfg.opacity_reset_interval
563 != default.opacity_reset_interval
564 {
565 override_cfg.opacity_reset_interval
566 } else {
567 base.opacity_reset_interval
568 },
569 init: merge_init_section(base.init, override_cfg.init),
570 optimizer: merge_optimizer_section(base.optimizer, override_cfg.optimizer),
571 density_control: merge_density_control_section(
572 base.density_control,
573 override_cfg.density_control,
574 ),
575 loss: merge_loss_section(base.loss, override_cfg.loss),
576 }
577}
578
579fn merge_init_section(base: InitSection, override_cfg: InitSection) -> InitSection {
581 let default = InitSection::default();
582 InitSection {
583 num_rigid_gaussians: if override_cfg.num_rigid_gaussians != default.num_rigid_gaussians {
584 override_cfg.num_rigid_gaussians
585 } else {
586 base.num_rigid_gaussians
587 },
588 num_flexible_gaussians: if override_cfg.num_flexible_gaussians
589 != default.num_flexible_gaussians
590 {
591 override_cfg.num_flexible_gaussians
592 } else {
593 base.num_flexible_gaussians
594 },
595 initial_scale: if (override_cfg.initial_scale - default.initial_scale).abs() > f32::EPSILON
596 {
597 override_cfg.initial_scale
598 } else {
599 base.initial_scale
600 },
601 initial_opacity: if (override_cfg.initial_opacity - default.initial_opacity).abs()
602 > f32::EPSILON
603 {
604 override_cfg.initial_opacity
605 } else {
606 base.initial_opacity
607 },
608 sh_degree: if override_cfg.sh_degree != default.sh_degree {
609 override_cfg.sh_degree
610 } else {
611 base.sh_degree
612 },
613 }
614}
615
616fn merge_optimizer_section(
618 base: OptimizerSection,
619 override_cfg: OptimizerSection,
620) -> OptimizerSection {
621 let default = OptimizerSection::default();
622 OptimizerSection {
623 position_lr: if (override_cfg.position_lr - default.position_lr).abs() > f32::EPSILON {
624 override_cfg.position_lr
625 } else {
626 base.position_lr
627 },
628 position_lr_final: if (override_cfg.position_lr_final - default.position_lr_final).abs()
629 > f32::EPSILON
630 {
631 override_cfg.position_lr_final
632 } else {
633 base.position_lr_final
634 },
635 rotation_lr: if (override_cfg.rotation_lr - default.rotation_lr).abs() > f32::EPSILON {
636 override_cfg.rotation_lr
637 } else {
638 base.rotation_lr
639 },
640 scale_lr: if (override_cfg.scale_lr - default.scale_lr).abs() > f32::EPSILON {
641 override_cfg.scale_lr
642 } else {
643 base.scale_lr
644 },
645 opacity_lr: if (override_cfg.opacity_lr - default.opacity_lr).abs() > f32::EPSILON {
646 override_cfg.opacity_lr
647 } else {
648 base.opacity_lr
649 },
650 sh_lr: if (override_cfg.sh_lr - default.sh_lr).abs() > f32::EPSILON {
651 override_cfg.sh_lr
652 } else {
653 base.sh_lr
654 },
655 offset_lr: if (override_cfg.offset_lr - default.offset_lr).abs() > f32::EPSILON {
656 override_cfg.offset_lr
657 } else {
658 base.offset_lr
659 },
660 beta1: if (override_cfg.beta1 - default.beta1).abs() > f32::EPSILON {
661 override_cfg.beta1
662 } else {
663 base.beta1
664 },
665 beta2: if (override_cfg.beta2 - default.beta2).abs() > f32::EPSILON {
666 override_cfg.beta2
667 } else {
668 base.beta2
669 },
670 epsilon: if (override_cfg.epsilon - default.epsilon).abs() > f32::EPSILON {
671 override_cfg.epsilon
672 } else {
673 base.epsilon
674 },
675 position_lr_decay_steps: if override_cfg.position_lr_decay_steps
676 != default.position_lr_decay_steps
677 {
678 override_cfg.position_lr_decay_steps
679 } else {
680 base.position_lr_decay_steps
681 },
682 }
683}
684
685fn merge_density_control_section(
687 base: DensityControlSection,
688 override_cfg: DensityControlSection,
689) -> DensityControlSection {
690 let default = DensityControlSection::default();
691 DensityControlSection {
692 interval: if override_cfg.interval != default.interval {
693 override_cfg.interval
694 } else {
695 base.interval
696 },
697 start_iteration: if override_cfg.start_iteration != default.start_iteration {
698 override_cfg.start_iteration
699 } else {
700 base.start_iteration
701 },
702 end_iteration: if override_cfg.end_iteration != default.end_iteration {
703 override_cfg.end_iteration
704 } else {
705 base.end_iteration
706 },
707 grad_threshold: if (override_cfg.grad_threshold - default.grad_threshold).abs()
708 > f32::EPSILON
709 {
710 override_cfg.grad_threshold
711 } else {
712 base.grad_threshold
713 },
714 min_opacity: if (override_cfg.min_opacity - default.min_opacity).abs() > f32::EPSILON {
715 override_cfg.min_opacity
716 } else {
717 base.min_opacity
718 },
719 max_screen_size: if (override_cfg.max_screen_size - default.max_screen_size).abs()
720 > f32::EPSILON
721 {
722 override_cfg.max_screen_size
723 } else {
724 base.max_screen_size
725 },
726 split_scale_threshold: if (override_cfg.split_scale_threshold
727 - default.split_scale_threshold)
728 .abs()
729 > f32::EPSILON
730 {
731 override_cfg.split_scale_threshold
732 } else {
733 base.split_scale_threshold
734 },
735 max_gaussians: if override_cfg.max_gaussians != default.max_gaussians {
736 override_cfg.max_gaussians
737 } else {
738 base.max_gaussians
739 },
740 }
741}
742
743fn merge_loss_section(base: LossSection, override_cfg: LossSection) -> LossSection {
745 let default = LossSection::default();
746 LossSection {
747 lambda_l1: if (override_cfg.lambda_l1 - default.lambda_l1).abs() > f32::EPSILON {
748 override_cfg.lambda_l1
749 } else {
750 base.lambda_l1
751 },
752 lambda_ssim: if (override_cfg.lambda_ssim - default.lambda_ssim).abs() > f32::EPSILON {
753 override_cfg.lambda_ssim
754 } else {
755 base.lambda_ssim
756 },
757 lambda_ms_ssim: if (override_cfg.lambda_ms_ssim - default.lambda_ms_ssim).abs()
758 > f32::EPSILON
759 {
760 override_cfg.lambda_ms_ssim
761 } else {
762 base.lambda_ms_ssim
763 },
764 lambda_lpips: if (override_cfg.lambda_lpips - default.lambda_lpips).abs() > f32::EPSILON {
765 override_cfg.lambda_lpips
766 } else {
767 base.lambda_lpips
768 },
769 lambda_position_reg: if (override_cfg.lambda_position_reg - default.lambda_position_reg)
770 .abs()
771 > f32::EPSILON
772 {
773 override_cfg.lambda_position_reg
774 } else {
775 base.lambda_position_reg
776 },
777 lambda_scale_reg: if (override_cfg.lambda_scale_reg - default.lambda_scale_reg).abs()
778 > f32::EPSILON
779 {
780 override_cfg.lambda_scale_reg
781 } else {
782 base.lambda_scale_reg
783 },
784 lambda_opacity_reg: if (override_cfg.lambda_opacity_reg - default.lambda_opacity_reg).abs()
785 > f32::EPSILON
786 {
787 override_cfg.lambda_opacity_reg
788 } else {
789 base.lambda_opacity_reg
790 },
791 lambda_normal: if (override_cfg.lambda_normal - default.lambda_normal).abs() > f32::EPSILON
792 {
793 override_cfg.lambda_normal
794 } else {
795 base.lambda_normal
796 },
797 lambda_gradient_penalty: if (override_cfg.lambda_gradient_penalty
798 - default.lambda_gradient_penalty)
799 .abs()
800 > f32::EPSILON
801 {
802 override_cfg.lambda_gradient_penalty
803 } else {
804 base.lambda_gradient_penalty
805 },
806 gradient_penalty_threshold: if (override_cfg.gradient_penalty_threshold
807 - default.gradient_penalty_threshold)
808 .abs()
809 > f32::EPSILON
810 {
811 override_cfg.gradient_penalty_threshold
812 } else {
813 base.gradient_penalty_threshold
814 },
815 }
816}
817
818fn merge_output_section(base: OutputSection, override_cfg: OutputSection) -> OutputSection {
820 let default = OutputSection::default();
821 OutputSection {
822 checkpoint_interval: if override_cfg.checkpoint_interval != default.checkpoint_interval {
823 override_cfg.checkpoint_interval
824 } else {
825 base.checkpoint_interval
826 },
827 log_interval: if override_cfg.log_interval != default.log_interval {
828 override_cfg.log_interval
829 } else {
830 base.log_interval
831 },
832 export_format: if override_cfg.export_format != default.export_format {
833 override_cfg.export_format
834 } else {
835 base.export_format
836 },
837 }
838}
839
840fn apply_env_overrides(mut config: ProjectConfig) -> Result<ProjectConfig> {
842 if let Ok(val) = env::var("OXIGAF_TOTAL_ITERATIONS") {
844 config.training.total_iterations = val
845 .parse()
846 .with_context(|| format!("Invalid OXIGAF_TOTAL_ITERATIONS: {}", val))?;
847 }
848
849 if let Ok(val) = env::var("OXIGAF_IMAGE_SIZE") {
850 config.training.image_size = val
851 .parse()
852 .with_context(|| format!("Invalid OXIGAF_IMAGE_SIZE: {}", val))?;
853 }
854
855 if let Ok(val) = env::var("OXIGAF_VIEWS_PER_STEP") {
856 config.training.views_per_step = val
857 .parse()
858 .with_context(|| format!("Invalid OXIGAF_VIEWS_PER_STEP: {}", val))?;
859 }
860
861 if let Ok(val) = env::var("OXIGAF_GUIDANCE_SCALE_START") {
862 config.training.guidance_scale_start = val
863 .parse()
864 .with_context(|| format!("Invalid OXIGAF_GUIDANCE_SCALE_START: {}", val))?;
865 }
866
867 if let Ok(val) = env::var("OXIGAF_GUIDANCE_SCALE_END") {
868 config.training.guidance_scale_end = val
869 .parse()
870 .with_context(|| format!("Invalid OXIGAF_GUIDANCE_SCALE_END: {}", val))?;
871 }
872
873 if let Ok(val) = env::var("OXIGAF_POSITION_LR") {
875 config.training.optimizer.position_lr = val
876 .parse()
877 .with_context(|| format!("Invalid OXIGAF_POSITION_LR: {}", val))?;
878 }
879
880 if let Ok(val) = env::var("OXIGAF_SCALING_LR") {
881 config.training.optimizer.scale_lr = val
882 .parse()
883 .with_context(|| format!("Invalid OXIGAF_SCALING_LR: {}", val))?;
884 }
885
886 if let Ok(val) = env::var("OXIGAF_ROTATION_LR") {
887 config.training.optimizer.rotation_lr = val
888 .parse()
889 .with_context(|| format!("Invalid OXIGAF_ROTATION_LR: {}", val))?;
890 }
891
892 if let Ok(val) = env::var("OXIGAF_OPACITY_LR") {
893 config.training.optimizer.opacity_lr = val
894 .parse()
895 .with_context(|| format!("Invalid OXIGAF_OPACITY_LR: {}", val))?;
896 }
897
898 if let Ok(val) = env::var("OXIGAF_SH_LR") {
899 config.training.optimizer.sh_lr = val
900 .parse()
901 .with_context(|| format!("Invalid OXIGAF_SH_LR: {}", val))?;
902 }
903
904 if let Ok(val) = env::var("OXIGAF_DEVICE_BACKEND") {
906 config.device.backend = val;
907 }
908
909 if let Ok(val) = env::var("OXIGAF_DEVICE_GPU_INDEX") {
910 config.device.gpu_index = val
911 .parse()
912 .with_context(|| format!("Invalid OXIGAF_DEVICE_GPU_INDEX: {}", val))?;
913 }
914
915 if let Ok(val) = env::var("OXIGAF_OUTPUT_CHECKPOINT_INTERVAL") {
917 config.output.checkpoint_interval = val
918 .parse()
919 .with_context(|| format!("Invalid OXIGAF_OUTPUT_CHECKPOINT_INTERVAL: {}", val))?;
920 }
921
922 if let Ok(val) = env::var("OXIGAF_OUTPUT_LOG_INTERVAL") {
923 config.output.log_interval = val
924 .parse()
925 .with_context(|| format!("Invalid OXIGAF_OUTPUT_LOG_INTERVAL: {}", val))?;
926 }
927
928 if let Ok(val) = env::var("OXIGAF_OUTPUT_EXPORT_FORMAT") {
929 config.output.export_format = val;
930 }
931
932 if let Ok(val) = env::var("OXIGAF_SH_DEGREE") {
934 config.training.init.sh_degree = val
935 .parse()
936 .with_context(|| format!("Invalid OXIGAF_SH_DEGREE: {}", val))?;
937 }
938
939 if let Ok(val) = env::var("OXIGAF_NUM_RIGID_GAUSSIANS") {
940 config.training.init.num_rigid_gaussians = val
941 .parse()
942 .with_context(|| format!("Invalid OXIGAF_NUM_RIGID_GAUSSIANS: {}", val))?;
943 }
944
945 if let Ok(val) = env::var("OXIGAF_NUM_FLEXIBLE_GAUSSIANS") {
946 config.training.init.num_flexible_gaussians = val
947 .parse()
948 .with_context(|| format!("Invalid OXIGAF_NUM_FLEXIBLE_GAUSSIANS: {}", val))?;
949 }
950
951 Ok(config)
952}
953
954#[allow(dead_code)]
963pub fn load_config(path: &Path) -> Result<ProjectConfig> {
964 if path.exists() {
965 let contents = std::fs::read_to_string(path)
966 .with_context(|| format!("Failed to read config file: {}", path.display()))?;
967 let config: ProjectConfig = toml::from_str(&contents)
968 .with_context(|| format!("Failed to parse config file: {}", path.display()))?;
969 config.validate()?;
970 Ok(config)
971 } else if path.ends_with("oxigaf.toml") {
972 tracing::info!(
973 "Config file not found at {}, using defaults",
974 path.display()
975 );
976 Ok(ProjectConfig::default())
977 } else {
978 anyhow::bail!("Config file not found: {}", path.display())
979 }
980}
981
982#[allow(dead_code)]
984pub fn generate_default_config() -> Result<String> {
985 let config = ProjectConfig::default();
986 toml::to_string_pretty(&config).context("Failed to serialize default config")
987}
988
989pub fn expand_tilde(path: &Path) -> PathBuf {
995 let s = path.to_string_lossy();
996 if s.starts_with("~/") || s == "~" {
997 if let Ok(home) = std::env::var("HOME") {
998 return PathBuf::from(s.replacen('~', &home, 1));
999 }
1000 }
1001 path.to_path_buf()
1002}
1003
1004#[cfg(test)]
1005mod tests {
1006 use super::*;
1007
1008 #[test]
1009 fn default_config_round_trips() -> Result<()> {
1010 let config = ProjectConfig::default();
1011 let toml_str =
1012 toml::to_string_pretty(&config).context("Failed to serialize default config")?;
1013 let parsed: ProjectConfig =
1014 toml::from_str(&toml_str).context("Failed to parse serialized config")?;
1015 assert_eq!(parsed.training.total_iterations, 15_000);
1016 assert_eq!(parsed.training.init.num_rigid_gaussians, 50_000);
1017 Ok(())
1018 }
1019
1020 #[test]
1021 fn partial_config_uses_defaults() -> Result<()> {
1022 let toml_str = r#"
1023[training]
1024total_iterations = 5000
1025"#;
1026 let config: ProjectConfig =
1027 toml::from_str(toml_str).context("Failed to parse partial config")?;
1028 assert_eq!(config.training.total_iterations, 5000);
1029 assert_eq!(config.training.views_per_step, 4);
1031 assert_eq!(config.training.init.num_rigid_gaussians, 50_000);
1032 Ok(())
1033 }
1034
1035 #[test]
1036 fn validation_catches_zero_iterations() {
1037 let mut config = ProjectConfig::default();
1038 config.training.total_iterations = 0;
1039 assert!(config.validate().is_err());
1040 }
1041
1042 #[test]
1043 fn expand_tilde_works() {
1044 std::env::set_var("HOME", "/home/test");
1046 let p = expand_tilde(Path::new("~/.cache/oxigaf"));
1047 assert_eq!(p, PathBuf::from("/home/test/.cache/oxigaf"));
1048 }
1049}