Skip to main content

dsfb_add/
config.rs

1use serde::{Deserialize, Serialize};
2use serde_with::{serde_as, DefaultOnNull};
3
4use crate::AddError;
5
6#[serde_as]
7#[derive(Debug, Clone, Serialize, Deserialize)]
8#[serde(default)]
9pub struct SimulationConfig {
10    #[serde_as(as = "DefaultOnNull")]
11    pub num_lambda: usize,
12    #[serde_as(as = "DefaultOnNull")]
13    pub lambda_min: f64,
14    #[serde_as(as = "DefaultOnNull")]
15    pub lambda_max: f64,
16    #[serde_as(as = "DefaultOnNull")]
17    pub steps_per_run: usize,
18    #[serde(default)]
19    pub multi_steps_per_run: Vec<usize>,
20    #[serde_as(as = "DefaultOnNull")]
21    pub random_seed: u64,
22    #[serde_as(as = "DefaultOnNull")]
23    pub enable_aet: bool,
24    #[serde_as(as = "DefaultOnNull")]
25    pub enable_tcp: bool,
26    #[serde_as(as = "DefaultOnNull")]
27    pub enable_rlt: bool,
28    #[serde_as(as = "DefaultOnNull")]
29    pub enable_iwlt: bool,
30}
31
32impl Default for SimulationConfig {
33    fn default() -> Self {
34        Self {
35            num_lambda: 360,
36            lambda_min: 0.0,
37            lambda_max: 1.0,
38            steps_per_run: 512,
39            multi_steps_per_run: vec![512, 5_000, 10_000, 20_000],
40            random_seed: 0xADD2_0260_0001_u64,
41            enable_aet: true,
42            enable_tcp: true,
43            enable_rlt: true,
44            enable_iwlt: true,
45        }
46    }
47}
48
49impl SimulationConfig {
50    pub fn validate(&self) -> Result<(), AddError> {
51        if self.num_lambda == 0 {
52            return Err(AddError::InvalidConfig(
53                "num_lambda must be greater than zero".to_string(),
54            ));
55        }
56
57        if self.steps_per_run == 0 {
58            return Err(AddError::InvalidConfig(
59                "steps_per_run must be greater than zero".to_string(),
60            ));
61        }
62
63        if self.multi_steps_per_run.iter().any(|&steps| steps == 0) {
64            return Err(AddError::InvalidConfig(
65                "multi_steps_per_run must contain only values greater than zero".to_string(),
66            ));
67        }
68
69        if !self.lambda_min.is_finite() || !self.lambda_max.is_finite() {
70            return Err(AddError::InvalidConfig(
71                "lambda_min and lambda_max must be finite".to_string(),
72            ));
73        }
74
75        if self.lambda_max < self.lambda_min {
76            return Err(AddError::InvalidConfig(
77                "lambda_max must be greater than or equal to lambda_min".to_string(),
78            ));
79        }
80
81        if !(self.enable_aet || self.enable_tcp || self.enable_rlt || self.enable_iwlt) {
82            return Err(AddError::InvalidConfig(
83                "at least one sub-theory must be enabled".to_string(),
84            ));
85        }
86
87        Ok(())
88    }
89
90    pub fn lambda_grid(&self) -> Vec<f64> {
91        if self.num_lambda == 1 {
92            return vec![self.lambda_min];
93        }
94
95        let span = self.lambda_max - self.lambda_min;
96        let denom = (self.num_lambda - 1) as f64;
97
98        (0..self.num_lambda)
99            .map(|idx| self.lambda_min + span * idx as f64 / denom)
100            .collect()
101    }
102
103    pub fn normalized_lambda(&self, lambda: f64) -> f64 {
104        let span = self.lambda_max - self.lambda_min;
105        if span.abs() < f64::EPSILON {
106            return 0.5;
107        }
108
109        ((lambda - self.lambda_min) / span).clamp(0.0, 1.0)
110    }
111
112    pub fn sweep_steps(&self) -> Vec<usize> {
113        if self.multi_steps_per_run.is_empty() {
114            vec![self.steps_per_run]
115        } else {
116            self.multi_steps_per_run.clone()
117        }
118    }
119}