Skip to main content

dag_ml_core/
rng.rs

1use serde::{Deserialize, Serialize};
2use sha2::{Digest, Sha256};
3
4#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
5pub struct SeedContext {
6    pub root_seed: u64,
7    pub path: Vec<String>,
8}
9
10impl SeedContext {
11    pub fn root(root_seed: u64) -> Self {
12        Self {
13            root_seed,
14            path: Vec::new(),
15        }
16    }
17
18    pub fn child(&self, label: impl Into<String>) -> Self {
19        let mut next = self.clone();
20        next.path.push(label.into());
21        next
22    }
23
24    pub fn derive_u64(&self, label: impl AsRef<str>) -> u64 {
25        let mut hasher = Sha256::new();
26        hasher.update(self.root_seed.to_le_bytes());
27        for part in &self.path {
28            hasher.update([0]);
29            hasher.update(part.as_bytes());
30        }
31        hasher.update([0xff]);
32        hasher.update(label.as_ref().as_bytes());
33
34        let digest = hasher.finalize();
35        let mut bytes = [0u8; 8];
36        bytes.copy_from_slice(&digest[..8]);
37        u64::from_le_bytes(bytes)
38    }
39}
40
41#[cfg(test)]
42mod tests {
43    use super::*;
44
45    #[test]
46    fn derives_stable_streams() {
47        let a = SeedContext::root(7).child("node:model").child("fold:1");
48        let b = SeedContext::root(7).child("node:model").child("fold:1");
49
50        assert_eq!(a.derive_u64("split"), b.derive_u64("split"));
51        assert_ne!(a.derive_u64("split"), a.derive_u64("bootstrap"));
52    }
53}