Skip to main content

dag_ml_core/
campaign.rs

1use serde::{Deserialize, Serialize};
2use sha2::{Digest, Sha256};
3
4use crate::error::Result;
5use crate::fold::FoldSet;
6
7#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
8pub struct CampaignFingerprintSpec {
9    pub graph_id: String,
10    pub root_seed: u64,
11    pub splitter: serde_json::Value,
12    pub fold_set: FoldSet,
13}
14
15pub fn campaign_fingerprint(spec: &CampaignFingerprintSpec) -> Result<String> {
16    spec.fold_set.validate()?;
17    stable_json_fingerprint(spec)
18}
19
20pub(crate) fn stable_json_fingerprint<T: Serialize + ?Sized>(value: &T) -> Result<String> {
21    let json = serde_json::to_vec(value)?;
22    let digest = Sha256::digest(json);
23    Ok(to_hex(&digest))
24}
25
26fn to_hex(bytes: &[u8]) -> String {
27    let mut out = String::with_capacity(bytes.len() * 2);
28    for byte in bytes {
29        use std::fmt::Write;
30        write!(&mut out, "{byte:02x}").expect("writing to string cannot fail");
31    }
32    out
33}
34
35#[cfg(test)]
36mod tests {
37    use serde_json::json;
38
39    use crate::fold::KFoldSpec;
40    use crate::ids::SampleId;
41
42    use super::*;
43
44    fn sid(value: &str) -> SampleId {
45        SampleId::new(value).unwrap()
46    }
47
48    #[test]
49    fn campaign_fingerprint_is_stable_and_sensitive() {
50        let samples = ["s1", "s2", "s3", "s4"]
51            .into_iter()
52            .map(sid)
53            .collect::<Vec<_>>();
54        let fold_set = KFoldSpec {
55            n_splits: 2,
56            shuffle: true,
57            seed: Some(9),
58        }
59        .split("outer", &samples)
60        .unwrap();
61        let mut spec = CampaignFingerprintSpec {
62            graph_id: "g".to_string(),
63            root_seed: 9,
64            splitter: json!({"kind": "kfold", "n_splits": 2}),
65            fold_set,
66        };
67
68        let left = campaign_fingerprint(&spec).unwrap();
69        let right = campaign_fingerprint(&spec).unwrap();
70        assert_eq!(left, right);
71
72        spec.root_seed = 10;
73        assert_ne!(left, campaign_fingerprint(&spec).unwrap());
74    }
75}