Skip to main content

dsfb_srd/
config.rs

1use std::fmt::Write;
2use std::path::{Path, PathBuf};
3
4pub const CRATE_NAME: &str = "dsfb-srd";
5pub const CRATE_VERSION: &str = env!("CARGO_PKG_VERSION");
6
7#[derive(Clone, Debug)]
8pub struct SimulationConfig {
9    pub n_events: usize,
10    pub n_channels: usize,
11    pub causal_window: usize,
12    pub tau_steps: usize,
13    pub shock_start: usize,
14    pub shock_end: usize,
15    pub beta: f64,
16    pub envelope_decay: f64,
17}
18
19impl Default for SimulationConfig {
20    fn default() -> Self {
21        Self {
22            n_events: 2_000,
23            n_channels: 4,
24            causal_window: 24,
25            tau_steps: 401,
26            shock_start: 800,
27            shock_end: 1_200,
28            beta: 4.0,
29            envelope_decay: 0.97,
30        }
31    }
32}
33
34impl SimulationConfig {
35    pub fn validate(&self) -> Result<(), String> {
36        if self.n_events < 2 {
37            return Err("n_events must be at least 2".to_string());
38        }
39        if self.n_channels == 0 {
40            return Err("n_channels must be at least 1".to_string());
41        }
42        if self.causal_window == 0 {
43            return Err("causal_window must be at least 1".to_string());
44        }
45        if self.tau_steps < 2 {
46            return Err("tau_steps must be at least 2".to_string());
47        }
48        if self.shock_start >= self.shock_end {
49            return Err("shock_start must be less than shock_end".to_string());
50        }
51        if self.shock_end > self.n_events {
52            return Err("shock_end must be less than or equal to n_events".to_string());
53        }
54        if !(self.beta.is_finite() && self.beta > 0.0) {
55            return Err("beta must be a finite value greater than 0".to_string());
56        }
57        if !(self.envelope_decay.is_finite() && (0.0..=1.0).contains(&self.envelope_decay)) {
58            return Err("envelope_decay must be finite and lie in [0, 1]".to_string());
59        }
60        Ok(())
61    }
62
63    pub fn from_args<I, S>(args: I) -> Result<Self, String>
64    where
65        I: IntoIterator<Item = S>,
66        S: Into<String>,
67    {
68        let mut config = Self::default();
69        let mut args = args.into_iter().map(Into::into).peekable();
70
71        while let Some(arg) = args.next() {
72            let (flag, inline_value) = match arg.split_once('=') {
73                Some((flag, value)) => (flag.to_string(), Some(value.to_string())),
74                None => (arg, None),
75            };
76
77            let value =
78                |name: &str, inline_value: Option<String>, args: &mut std::iter::Peekable<_>| {
79                    inline_value
80                        .or_else(|| args.next())
81                        .ok_or_else(|| format!("missing value for {name}"))
82                };
83
84            match flag.as_str() {
85                "--n-events" => {
86                    config.n_events =
87                        parse_usize("--n-events", value("--n-events", inline_value, &mut args)?)?;
88                }
89                "--n-channels" => {
90                    config.n_channels = parse_usize(
91                        "--n-channels",
92                        value("--n-channels", inline_value, &mut args)?,
93                    )?;
94                }
95                "--causal-window" => {
96                    config.causal_window = parse_usize(
97                        "--causal-window",
98                        value("--causal-window", inline_value, &mut args)?,
99                    )?;
100                }
101                "--tau-steps" => {
102                    config.tau_steps = parse_usize(
103                        "--tau-steps",
104                        value("--tau-steps", inline_value, &mut args)?,
105                    )?;
106                }
107                "--shock-start" => {
108                    config.shock_start = parse_usize(
109                        "--shock-start",
110                        value("--shock-start", inline_value, &mut args)?,
111                    )?;
112                }
113                "--shock-end" => {
114                    config.shock_end = parse_usize(
115                        "--shock-end",
116                        value("--shock-end", inline_value, &mut args)?,
117                    )?;
118                }
119                "--beta" => {
120                    config.beta = parse_f64("--beta", value("--beta", inline_value, &mut args)?)?;
121                }
122                "--envelope-decay" => {
123                    config.envelope_decay = parse_f64(
124                        "--envelope-decay",
125                        value("--envelope-decay", inline_value, &mut args)?,
126                    )?;
127                }
128                _ => {
129                    return Err(format!(
130                        "unrecognized argument `{flag}`\n\n{}",
131                        Self::usage(
132                            "cargo run --manifest-path crates/dsfb-srd/Cargo.toml --release --bin dsfb-srd-generate --"
133                        )
134                    ));
135                }
136            }
137        }
138
139        config.validate()?;
140        Ok(config)
141    }
142
143    pub fn usage(program: &str) -> String {
144        format!(
145            "Usage:\n  {program} [options]\n\nOptions:\n  \
146--n-events <usize>\n  \
147--n-channels <usize>\n  \
148--causal-window <usize>\n  \
149--tau-steps <usize>\n  \
150--shock-start <usize>\n  \
151--shock-end <usize>\n  \
152--beta <f64>\n  \
153--envelope-decay <f64>\n  \
154--help"
155        )
156    }
157
158    pub fn tau_thresholds(&self) -> Vec<f64> {
159        let denominator = (self.tau_steps - 1) as f64;
160        (0..self.tau_steps)
161            .map(|index| index as f64 / denominator)
162            .collect()
163    }
164
165    pub fn scaled_for_n_events(&self, n_events: usize) -> Self {
166        let mut scaled = self.clone();
167        let scale = n_events as f64 / self.n_events as f64;
168        scaled.n_events = n_events;
169        scaled.shock_start = scale_index(self.shock_start, scale, n_events.saturating_sub(1));
170        scaled.shock_end = scale_index(self.shock_end, scale, n_events);
171        if scaled.shock_end <= scaled.shock_start {
172            scaled.shock_end = (scaled.shock_start + 1).min(n_events);
173        }
174        scaled
175    }
176
177    pub fn canonical_json(&self) -> String {
178        format!(
179            concat!(
180                "{{",
181                "\"crate\":\"{}\",",
182                "\"version\":\"{}\",",
183                "\"n_events\":{},",
184                "\"n_channels\":{},",
185                "\"causal_window\":{},",
186                "\"tau_steps\":{},",
187                "\"shock_start\":{},",
188                "\"shock_end\":{},",
189                "\"beta\":{},",
190                "\"envelope_decay\":{}",
191                "}}"
192            ),
193            CRATE_NAME,
194            CRATE_VERSION,
195            self.n_events,
196            self.n_channels,
197            self.causal_window,
198            self.tau_steps,
199            self.shock_start,
200            self.shock_end,
201            canonical_float(self.beta),
202            canonical_float(self.envelope_decay),
203        )
204    }
205
206    pub fn config_hash(&self) -> String {
207        sha256_hex(self.canonical_json().as_bytes())
208    }
209
210    pub fn repo_root() -> PathBuf {
211        Path::new(env!("CARGO_MANIFEST_DIR"))
212            .ancestors()
213            .nth(2)
214            .map(Path::to_path_buf)
215            .expect("dsfb-srd must live under <repo>/crates/dsfb-srd")
216    }
217
218    pub fn output_root() -> PathBuf {
219        Self::repo_root().join("output-dsfb-srd")
220    }
221}
222
223pub fn compute_run_id(config: &SimulationConfig) -> String {
224    let config_hash = config.config_hash();
225    config_hash.chars().take(32).collect()
226}
227
228fn parse_usize(name: &str, raw: String) -> Result<usize, String> {
229    raw.parse::<usize>()
230        .map_err(|_| format!("invalid integer for {name}: `{raw}`"))
231}
232
233fn parse_f64(name: &str, raw: String) -> Result<f64, String> {
234    raw.parse::<f64>()
235        .map_err(|_| format!("invalid floating-point value for {name}: `{raw}`"))
236}
237
238fn scale_index(index: usize, scale: f64, upper_bound: usize) -> usize {
239    ((index as f64 * scale).round() as usize).min(upper_bound)
240}
241
242fn canonical_float(value: f64) -> String {
243    format!("{value:.12}")
244}
245
246fn sha256_hex(bytes: &[u8]) -> String {
247    let digest = sha256(bytes);
248    let mut output = String::with_capacity(64);
249    for byte in digest {
250        let _ = write!(&mut output, "{byte:02x}");
251    }
252    output
253}
254
255fn sha256(bytes: &[u8]) -> [u8; 32] {
256    const INITIAL_STATE: [u32; 8] = [
257        0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab,
258        0x5be0cd19,
259    ];
260
261    const ROUND_CONSTANTS: [u32; 64] = [
262        0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4,
263        0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe,
264        0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f,
265        0x4a7484aa, 0x5cb0a9dc, 0x76f988da, 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7,
266        0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc,
267        0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b,
268        0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, 0x19a4c116,
269        0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
270        0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7,
271        0xc67178f2,
272    ];
273
274    let bit_len = (bytes.len() as u64) * 8;
275    let mut padded = bytes.to_vec();
276    padded.push(0x80);
277    while (padded.len() + 8) % 64 != 0 {
278        padded.push(0);
279    }
280    padded.extend_from_slice(&bit_len.to_be_bytes());
281
282    let mut state = INITIAL_STATE;
283
284    for chunk in padded.chunks_exact(64) {
285        let mut words = [0u32; 64];
286        for (index, slot) in words.iter_mut().take(16).enumerate() {
287            let start = index * 4;
288            *slot = u32::from_be_bytes([
289                chunk[start],
290                chunk[start + 1],
291                chunk[start + 2],
292                chunk[start + 3],
293            ]);
294        }
295        for index in 16..64 {
296            let s0 = words[index - 15].rotate_right(7)
297                ^ words[index - 15].rotate_right(18)
298                ^ (words[index - 15] >> 3);
299            let s1 = words[index - 2].rotate_right(17)
300                ^ words[index - 2].rotate_right(19)
301                ^ (words[index - 2] >> 10);
302            words[index] = words[index - 16]
303                .wrapping_add(s0)
304                .wrapping_add(words[index - 7])
305                .wrapping_add(s1);
306        }
307
308        let mut a = state[0];
309        let mut b = state[1];
310        let mut c = state[2];
311        let mut d = state[3];
312        let mut e = state[4];
313        let mut f = state[5];
314        let mut g = state[6];
315        let mut h = state[7];
316
317        for index in 0..64 {
318            let s1 = e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25);
319            let ch = (e & f) ^ ((!e) & g);
320            let temp1 = h
321                .wrapping_add(s1)
322                .wrapping_add(ch)
323                .wrapping_add(ROUND_CONSTANTS[index])
324                .wrapping_add(words[index]);
325            let s0 = a.rotate_right(2) ^ a.rotate_right(13) ^ a.rotate_right(22);
326            let maj = (a & b) ^ (a & c) ^ (b & c);
327            let temp2 = s0.wrapping_add(maj);
328
329            h = g;
330            g = f;
331            f = e;
332            e = d.wrapping_add(temp1);
333            d = c;
334            c = b;
335            b = a;
336            a = temp1.wrapping_add(temp2);
337        }
338
339        state[0] = state[0].wrapping_add(a);
340        state[1] = state[1].wrapping_add(b);
341        state[2] = state[2].wrapping_add(c);
342        state[3] = state[3].wrapping_add(d);
343        state[4] = state[4].wrapping_add(e);
344        state[5] = state[5].wrapping_add(f);
345        state[6] = state[6].wrapping_add(g);
346        state[7] = state[7].wrapping_add(h);
347    }
348
349    let mut digest = [0u8; 32];
350    for (index, word) in state.iter().enumerate() {
351        let bytes = word.to_be_bytes();
352        let start = index * 4;
353        digest[start..start + 4].copy_from_slice(&bytes);
354    }
355    digest
356}
357
358#[cfg(test)]
359mod tests {
360    use super::{compute_run_id, sha256_hex, SimulationConfig};
361
362    #[test]
363    fn sha256_matches_known_digest() {
364        assert_eq!(
365            sha256_hex(b"abc"),
366            "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"
367        );
368    }
369
370    #[test]
371    fn run_id_is_deterministic() {
372        let config = SimulationConfig::default();
373        assert_eq!(compute_run_id(&config), compute_run_id(&config));
374    }
375}