1use serde::{Deserialize, Serialize};
2use serde_with::{serde_as, DefaultOnNull};
3
4use crate::AddError;
5
6#[serde_as]
7#[derive(Debug, Clone, Serialize, Deserialize)]
8#[serde(default)]
9pub struct SimulationConfig {
10 #[serde_as(as = "DefaultOnNull")]
11 pub num_lambda: usize,
12 #[serde_as(as = "DefaultOnNull")]
13 pub lambda_min: f64,
14 #[serde_as(as = "DefaultOnNull")]
15 pub lambda_max: f64,
16 #[serde_as(as = "DefaultOnNull")]
17 pub steps_per_run: usize,
18 #[serde(default)]
19 pub multi_steps_per_run: Vec<usize>,
20 #[serde_as(as = "DefaultOnNull")]
21 pub random_seed: u64,
22 #[serde_as(as = "DefaultOnNull")]
23 pub enable_aet: bool,
24 #[serde_as(as = "DefaultOnNull")]
25 pub enable_tcp: bool,
26 #[serde_as(as = "DefaultOnNull")]
27 pub enable_rlt: bool,
28 #[serde_as(as = "DefaultOnNull")]
29 pub enable_iwlt: bool,
30}
31
32impl Default for SimulationConfig {
33 fn default() -> Self {
34 Self {
35 num_lambda: 360,
36 lambda_min: 0.0,
37 lambda_max: 1.0,
38 steps_per_run: 512,
39 multi_steps_per_run: vec![512, 5_000, 10_000, 20_000],
40 random_seed: 0xADD2_0260_0001_u64,
41 enable_aet: true,
42 enable_tcp: true,
43 enable_rlt: true,
44 enable_iwlt: true,
45 }
46 }
47}
48
49impl SimulationConfig {
50 pub fn validate(&self) -> Result<(), AddError> {
51 if self.num_lambda == 0 {
52 return Err(AddError::InvalidConfig(
53 "num_lambda must be greater than zero".to_string(),
54 ));
55 }
56
57 if self.steps_per_run == 0 {
58 return Err(AddError::InvalidConfig(
59 "steps_per_run must be greater than zero".to_string(),
60 ));
61 }
62
63 if self.multi_steps_per_run.iter().any(|&steps| steps == 0) {
64 return Err(AddError::InvalidConfig(
65 "multi_steps_per_run must contain only values greater than zero".to_string(),
66 ));
67 }
68
69 if !self.lambda_min.is_finite() || !self.lambda_max.is_finite() {
70 return Err(AddError::InvalidConfig(
71 "lambda_min and lambda_max must be finite".to_string(),
72 ));
73 }
74
75 if self.lambda_max < self.lambda_min {
76 return Err(AddError::InvalidConfig(
77 "lambda_max must be greater than or equal to lambda_min".to_string(),
78 ));
79 }
80
81 if !(self.enable_aet || self.enable_tcp || self.enable_rlt || self.enable_iwlt) {
82 return Err(AddError::InvalidConfig(
83 "at least one sub-theory must be enabled".to_string(),
84 ));
85 }
86
87 Ok(())
88 }
89
90 pub fn lambda_grid(&self) -> Vec<f64> {
91 if self.num_lambda == 1 {
92 return vec![self.lambda_min];
93 }
94
95 let span = self.lambda_max - self.lambda_min;
96 let denom = (self.num_lambda - 1) as f64;
97
98 (0..self.num_lambda)
99 .map(|idx| self.lambda_min + span * idx as f64 / denom)
100 .collect()
101 }
102
103 pub fn normalized_lambda(&self, lambda: f64) -> f64 {
104 let span = self.lambda_max - self.lambda_min;
105 if span.abs() < f64::EPSILON {
106 return 0.5;
107 }
108
109 ((lambda - self.lambda_min) / span).clamp(0.0, 1.0)
110 }
111
112 pub fn sweep_steps(&self) -> Vec<usize> {
113 if self.multi_steps_per_run.is_empty() {
114 vec![self.steps_per_run]
115 } else {
116 self.multi_steps_per_run.clone()
117 }
118 }
119}