Skip to main content

dag_ml_core/
fold.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::campaign::stable_json_fingerprint;
6use crate::error::{DagMlError, Result};
7use crate::ids::{FoldId, GroupId, SampleId};
8use crate::rng::SeedContext;
9
10#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
11pub struct FoldAssignment {
12    pub fold_id: FoldId,
13    pub train_sample_ids: Vec<SampleId>,
14    pub validation_sample_ids: Vec<SampleId>,
15    #[serde(default)]
16    pub metadata: BTreeMap<String, serde_json::Value>,
17}
18
19#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
20pub struct FoldSet {
21    pub id: String,
22    pub sample_ids: Vec<SampleId>,
23    pub folds: Vec<FoldAssignment>,
24    #[serde(default)]
25    pub sample_groups: BTreeMap<SampleId, GroupId>,
26}
27
28impl FoldSet {
29    pub fn validate(&self) -> Result<()> {
30        if self.id.trim().is_empty() {
31            return Err(DagMlError::OofValidation(
32                "fold set id is empty".to_string(),
33            ));
34        }
35        if self.sample_ids.is_empty() {
36            return Err(DagMlError::OofValidation(
37                "fold set contains no samples".to_string(),
38            ));
39        }
40        if self.folds.is_empty() {
41            return Err(DagMlError::OofValidation(
42                "fold set contains no folds".to_string(),
43            ));
44        }
45        let universe = unique_samples("fold set sample_ids", &self.sample_ids)?;
46        if !self.sample_groups.is_empty() {
47            for sample_id in self.sample_groups.keys() {
48                if !universe.contains(sample_id) {
49                    return Err(DagMlError::OofValidation(format!(
50                        "sample group map references unknown sample `{sample_id}`"
51                    )));
52                }
53            }
54            for sample_id in &self.sample_ids {
55                if !self.sample_groups.contains_key(sample_id) {
56                    return Err(DagMlError::OofValidation(format!(
57                        "sample `{sample_id}` is missing from non-empty group map"
58                    )));
59                }
60            }
61        }
62        let mut fold_ids = BTreeSet::new();
63        let mut validation_counts = self
64            .sample_ids
65            .iter()
66            .cloned()
67            .map(|sample_id| (sample_id, 0usize))
68            .collect::<BTreeMap<_, _>>();
69
70        for fold in &self.folds {
71            if !fold_ids.insert(&fold.fold_id) {
72                return Err(DagMlError::OofValidation(format!(
73                    "duplicate fold id `{}`",
74                    fold.fold_id
75                )));
76            }
77            let train = unique_samples(
78                &format!("fold `{}` train_sample_ids", fold.fold_id),
79                &fold.train_sample_ids,
80            )?;
81            let validation = unique_samples(
82                &format!("fold `{}` validation_sample_ids", fold.fold_id),
83                &fold.validation_sample_ids,
84            )?;
85            if validation.is_empty() {
86                return Err(DagMlError::OofValidation(format!(
87                    "fold `{}` has no validation samples",
88                    fold.fold_id
89                )));
90            }
91            for sample_id in train.union(&validation) {
92                if !universe.contains(sample_id) {
93                    return Err(DagMlError::OofValidation(format!(
94                        "fold `{}` references unknown sample `{}`",
95                        fold.fold_id, sample_id
96                    )));
97                }
98            }
99            let overlap = train.intersection(&validation).collect::<Vec<_>>();
100            if !overlap.is_empty() {
101                return Err(DagMlError::OofValidation(format!(
102                    "fold `{}` has train/validation overlap at sample `{}`",
103                    fold.fold_id, overlap[0]
104                )));
105            }
106            for sample_id in validation {
107                *validation_counts
108                    .get_mut(sample_id)
109                    .expect("validation sample is in universe") += 1;
110            }
111            self.validate_group_boundary(fold, &train)?;
112        }
113
114        for (sample_id, count) in validation_counts {
115            if count != 1 {
116                return Err(DagMlError::OofValidation(format!(
117                    "sample `{}` appears in validation {} time(s), expected exactly once",
118                    sample_id, count
119                )));
120            }
121        }
122
123        Ok(())
124    }
125
126    fn validate_group_boundary(
127        &self,
128        fold: &FoldAssignment,
129        train: &BTreeSet<&SampleId>,
130    ) -> Result<()> {
131        if self.sample_groups.is_empty() {
132            return Ok(());
133        }
134        let train_groups = train
135            .iter()
136            .filter_map(|sample_id| self.sample_groups.get(*sample_id))
137            .collect::<BTreeSet<_>>();
138        for sample_id in &fold.validation_sample_ids {
139            let Some(group_id) = self.sample_groups.get(sample_id) else {
140                continue;
141            };
142            if train_groups.contains(group_id) {
143                return Err(DagMlError::OofValidation(format!(
144                    "fold `{}` leaks group `{}` across train/validation",
145                    fold.fold_id, group_id
146                )));
147            }
148        }
149        Ok(())
150    }
151}
152
153pub fn fold_set_fingerprint(fold_set: &FoldSet) -> Result<String> {
154    let mut canonical = fold_set.clone();
155    canonical.validate()?;
156    canonical.sample_ids.sort();
157    canonical
158        .folds
159        .sort_by(|left, right| left.fold_id.cmp(&right.fold_id));
160    for fold in &mut canonical.folds {
161        fold.train_sample_ids.sort();
162        fold.validation_sample_ids.sort();
163    }
164
165    let mut value = serde_json::to_value(&canonical)?;
166    remove_empty_fold_set_maps(&mut value);
167    stable_json_fingerprint(&value)
168}
169
170fn remove_empty_fold_set_maps(value: &mut serde_json::Value) {
171    let Some(object) = value.as_object_mut() else {
172        return;
173    };
174    if object
175        .get("sample_groups")
176        .and_then(serde_json::Value::as_object)
177        .is_some_and(serde_json::Map::is_empty)
178    {
179        object.remove("sample_groups");
180    }
181    let Some(folds) = object
182        .get_mut("folds")
183        .and_then(serde_json::Value::as_array_mut)
184    else {
185        return;
186    };
187    for fold in folds {
188        let Some(fold_object) = fold.as_object_mut() else {
189            continue;
190        };
191        if fold_object
192            .get("metadata")
193            .and_then(serde_json::Value::as_object)
194            .is_some_and(serde_json::Map::is_empty)
195        {
196            fold_object.remove("metadata");
197        }
198    }
199}
200
201#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
202pub struct KFoldSpec {
203    pub n_splits: usize,
204    #[serde(default)]
205    pub shuffle: bool,
206    pub seed: Option<u64>,
207}
208
209impl KFoldSpec {
210    pub fn split(&self, id: impl Into<String>, samples: &[SampleId]) -> Result<FoldSet> {
211        if self.n_splits < 2 {
212            return Err(DagMlError::OofValidation(
213                "KFold requires at least two splits".to_string(),
214            ));
215        }
216        let unique = unique_samples("KFold samples", samples)?;
217        if self.n_splits > unique.len() {
218            return Err(DagMlError::OofValidation(format!(
219                "KFold n_splits={} exceeds sample count {}",
220                self.n_splits,
221                unique.len()
222            )));
223        }
224        let ordered = ordered_samples(samples, self.shuffle, self.seed.unwrap_or(0));
225        let folds = (0..self.n_splits)
226            .map(|fold_idx| {
227                let validation = ordered
228                    .iter()
229                    .enumerate()
230                    .filter_map(|(idx, sample_id)| {
231                        (idx % self.n_splits == fold_idx).then_some(sample_id.clone())
232                    })
233                    .collect::<Vec<_>>();
234                let validation_set = validation.iter().collect::<BTreeSet<_>>();
235                let train = ordered
236                    .iter()
237                    .filter(|sample_id| !validation_set.contains(sample_id))
238                    .cloned()
239                    .collect::<Vec<_>>();
240                Ok(FoldAssignment {
241                    fold_id: FoldId::new(format!("fold{fold_idx}"))?,
242                    train_sample_ids: train,
243                    validation_sample_ids: validation,
244                    metadata: BTreeMap::new(),
245                })
246            })
247            .collect::<Result<Vec<_>>>()?;
248        let fold_set = FoldSet {
249            id: id.into(),
250            sample_ids: ordered_samples(samples, false, 0),
251            folds,
252            sample_groups: BTreeMap::new(),
253        };
254        fold_set.validate()?;
255        Ok(fold_set)
256    }
257}
258
259/// Stratified K-fold: each sample is validated exactly once (OOF-safe like
260/// plain K-fold), but folds are balanced by a per-sample class label so every
261/// fold mirrors the overall class distribution. `strata` maps each sample id to
262/// its class label (identity-keyed metadata — never feature values).
263#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
264pub struct StratifiedKFoldSpec {
265    pub n_splits: usize,
266    #[serde(default)]
267    pub shuffle: bool,
268    pub seed: Option<u64>,
269}
270
271impl StratifiedKFoldSpec {
272    pub fn split(
273        &self,
274        id: impl Into<String>,
275        samples: &[SampleId],
276        strata: &BTreeMap<SampleId, String>,
277    ) -> Result<FoldSet> {
278        if self.n_splits < 2 {
279            return Err(DagMlError::OofValidation(
280                "StratifiedKFold requires at least two splits".to_string(),
281            ));
282        }
283        let unique = unique_samples("StratifiedKFold samples", samples)?;
284        if self.n_splits > unique.len() {
285            return Err(DagMlError::OofValidation(format!(
286                "StratifiedKFold n_splits={} exceeds sample count {}",
287                self.n_splits,
288                unique.len()
289            )));
290        }
291        // Group samples by class (deterministic label order), preserving the
292        // within-class order, then assign folds by GLOBAL round-robin over that
293        // class-grouped order. Each sample lands in exactly one fold (OOF) and
294        // every class is spread across folds; crucially no fold is left empty
295        // whenever KFold's `n_splits <= n_samples` invariant holds (the previous
296        // per-class counter could pile singleton classes all into fold 0).
297        let ordered = ordered_samples(samples, self.shuffle, self.seed.unwrap_or(0));
298        let mut by_label: BTreeMap<String, Vec<SampleId>> = BTreeMap::new();
299        for sample_id in &ordered {
300            let label = strata.get(sample_id).ok_or_else(|| {
301                DagMlError::OofValidation(format!(
302                    "StratifiedKFold: sample `{sample_id}` has no stratum label"
303                ))
304            })?;
305            by_label
306                .entry(label.clone())
307                .or_default()
308                .push(sample_id.clone());
309        }
310        let mut fold_of: BTreeMap<SampleId, usize> = BTreeMap::new();
311        let mut position = 0usize;
312        for members in by_label.values() {
313            for sample_id in members {
314                fold_of.insert(sample_id.clone(), position % self.n_splits);
315                position += 1;
316            }
317        }
318        let folds = (0..self.n_splits)
319            .map(|fold_idx| {
320                let validation = ordered
321                    .iter()
322                    .filter(|s| fold_of.get(*s) == Some(&fold_idx))
323                    .cloned()
324                    .collect::<Vec<_>>();
325                let train = ordered
326                    .iter()
327                    .filter(|s| fold_of.get(*s) != Some(&fold_idx))
328                    .cloned()
329                    .collect::<Vec<_>>();
330                Ok(FoldAssignment {
331                    fold_id: FoldId::new(format!("fold{fold_idx}"))?,
332                    train_sample_ids: train,
333                    validation_sample_ids: validation,
334                    metadata: BTreeMap::new(),
335                })
336            })
337            .collect::<Result<Vec<_>>>()?;
338        let fold_set = FoldSet {
339            id: id.into(),
340            sample_ids: ordered_samples(samples, false, 0),
341            folds,
342            sample_groups: BTreeMap::new(),
343        };
344        fold_set.validate()?;
345        Ok(fold_set)
346    }
347}
348
349#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
350pub struct GroupKFoldSpec {
351    pub n_splits: usize,
352}
353
354impl GroupKFoldSpec {
355    pub fn split(
356        &self,
357        id: impl Into<String>,
358        sample_groups: &BTreeMap<SampleId, GroupId>,
359    ) -> Result<FoldSet> {
360        if self.n_splits < 2 {
361            return Err(DagMlError::OofValidation(
362                "GroupKFold requires at least two splits".to_string(),
363            ));
364        }
365        if sample_groups.is_empty() {
366            return Err(DagMlError::OofValidation(
367                "GroupKFold requires sample groups".to_string(),
368            ));
369        }
370        let mut groups = BTreeMap::<GroupId, Vec<SampleId>>::new();
371        for (sample_id, group_id) in sample_groups {
372            groups
373                .entry(group_id.clone())
374                .or_default()
375                .push(sample_id.clone());
376        }
377        if self.n_splits > groups.len() {
378            return Err(DagMlError::OofValidation(format!(
379                "GroupKFold n_splits={} exceeds group count {}",
380                self.n_splits,
381                groups.len()
382            )));
383        }
384
385        let mut grouped = groups.into_iter().collect::<Vec<_>>();
386        grouped.sort_by(|(left_group, left_samples), (right_group, right_samples)| {
387            right_samples
388                .len()
389                .cmp(&left_samples.len())
390                .then_with(|| left_group.cmp(right_group))
391        });
392
393        let mut fold_validation = vec![Vec::<SampleId>::new(); self.n_splits];
394        for (_group_id, mut samples) in grouped {
395            samples.sort();
396            let fold_idx = fold_validation
397                .iter()
398                .enumerate()
399                .min_by(|(left_idx, left), (right_idx, right)| {
400                    left.len()
401                        .cmp(&right.len())
402                        .then_with(|| left_idx.cmp(right_idx))
403                })
404                .map(|(idx, _)| idx)
405                .expect("at least one fold");
406            fold_validation[fold_idx].extend(samples);
407        }
408
409        let mut sample_ids = sample_groups.keys().cloned().collect::<Vec<_>>();
410        sample_ids.sort();
411        let folds = fold_validation
412            .into_iter()
413            .enumerate()
414            .map(|(fold_idx, mut validation)| {
415                validation.sort();
416                let validation_set = validation.iter().collect::<BTreeSet<_>>();
417                let train = sample_ids
418                    .iter()
419                    .filter(|sample_id| !validation_set.contains(sample_id))
420                    .cloned()
421                    .collect::<Vec<_>>();
422                Ok(FoldAssignment {
423                    fold_id: FoldId::new(format!("fold{fold_idx}"))?,
424                    train_sample_ids: train,
425                    validation_sample_ids: validation,
426                    metadata: BTreeMap::new(),
427                })
428            })
429            .collect::<Result<Vec<_>>>()?;
430
431        let fold_set = FoldSet {
432            id: id.into(),
433            sample_ids,
434            folds,
435            sample_groups: sample_groups.clone(),
436        };
437        fold_set.validate()?;
438        Ok(fold_set)
439    }
440}
441
442/// Inner (nested) cross-validation policy.
443///
444/// Declared globally on the `CampaignSpec` and/or locally on a `NodePlan`
445/// (e.g. a finetune/tuner or branch node); the local policy overrides the global
446/// default (see [`resolve_inner_cv`]). dag-ml builds the inner `FoldSet` from each
447/// outer fold's **training** samples via [`NestedCvSpec::build_inner_fold_set`],
448/// so the inner folds are a subset of outer-train *by construction* — nested CV
449/// cannot leak outer-validation rows into inner tuning.
450#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
451#[serde(tag = "kind")]
452pub enum NestedCvSpec {
453    /// Index-based inner K-fold, built in-core from outer-train samples.
454    #[serde(rename = "kfold")]
455    KFold(KFoldSpec),
456    /// Group-aware inner K-fold, built in-core from outer-train sample groups.
457    #[serde(rename = "group_kfold")]
458    GroupKFold(GroupKFoldSpec),
459}
460
461impl NestedCvSpec {
462    /// Validate the nested-CV policy's parameters independently of any outer fold.
463    /// Mirrors the checks the splitters enforce (`n_splits >= 2`) so a malformed
464    /// declaration is rejected at plan time rather than deferred to FIT_CV.
465    pub fn validate(&self) -> Result<()> {
466        match self {
467            Self::KFold(spec) => {
468                if spec.n_splits < 2 {
469                    return Err(DagMlError::OofValidation(
470                        "inner KFold requires at least two splits".to_string(),
471                    ));
472                }
473            }
474            Self::GroupKFold(spec) => {
475                if spec.n_splits < 2 {
476                    return Err(DagMlError::OofValidation(
477                        "inner GroupKFold requires at least two splits".to_string(),
478                    ));
479                }
480            }
481        }
482        Ok(())
483    }
484
485    /// Build the inner `FoldSet` for one outer fold from its **training** samples
486    /// only. `outer_groups` is the outer `FoldSet.sample_groups` (used by
487    /// `GroupKFold`; ignored otherwise). The result is validated to lie entirely
488    /// within the outer fold's training set.
489    pub fn build_inner_fold_set(
490        &self,
491        outer: &FoldAssignment,
492        outer_groups: &BTreeMap<SampleId, GroupId>,
493    ) -> Result<FoldSet> {
494        let inner_id = format!("{}.inner", outer.fold_id);
495        let inner = match self {
496            Self::KFold(spec) => spec.split(inner_id, &outer.train_sample_ids)?,
497            Self::GroupKFold(spec) => {
498                let train = outer.train_sample_ids.iter().collect::<BTreeSet<_>>();
499                let inner_groups = outer_groups
500                    .iter()
501                    .filter(|(sample_id, _)| train.contains(sample_id))
502                    .map(|(sample_id, group_id)| (sample_id.clone(), group_id.clone()))
503                    .collect::<BTreeMap<_, _>>();
504                spec.split(inner_id, &inner_groups)?
505            }
506        };
507        validate_inner_fold_set_within_outer(&inner, outer)?;
508        Ok(inner)
509    }
510}
511
512/// Resolve the effective inner-CV policy for a node: a node-local policy
513/// overrides the campaign-global default; `None` means no nested CV.
514pub fn resolve_inner_cv<'a>(
515    node_inner_cv: Option<&'a NestedCvSpec>,
516    campaign_inner_cv: Option<&'a NestedCvSpec>,
517) -> Option<&'a NestedCvSpec> {
518    node_inner_cv.or(campaign_inner_cv)
519}
520
521/// Enforce the nested-CV invariant: every sample in `inner` — both the top-level
522/// universe and every fold's train/validation members — must be an outer-fold
523/// **training** sample (never an outer-validation sample). Holds by construction
524/// for dag-ml-built inner folds, and also validates inner folds supplied from
525/// elsewhere. Refuses with an OOF-validation error on any leaking sample.
526pub fn validate_inner_fold_set_within_outer(inner: &FoldSet, outer: &FoldAssignment) -> Result<()> {
527    // Ensure the inner fold set is structurally sound first; otherwise a malformed
528    // supplied fold set could hide a leaking sample in a fold while omitting it
529    // from `sample_ids`. After this, fold members are guaranteed ⊆ `sample_ids`.
530    inner.validate()?;
531    let train = outer.train_sample_ids.iter().collect::<BTreeSet<_>>();
532    let ensure_train = |sample_id: &SampleId| -> Result<()> {
533        if !train.contains(sample_id) {
534            return Err(DagMlError::OofValidation(format!(
535                "nested CV leakage: inner-CV sample `{sample_id}` for outer fold `{}` is not an outer training sample",
536                outer.fold_id
537            )));
538        }
539        Ok(())
540    };
541    for sample_id in &inner.sample_ids {
542        ensure_train(sample_id)?;
543    }
544    // Defence-in-depth: check every fold member directly, independent of the
545    // sample_ids / structural invariants above.
546    for fold in &inner.folds {
547        for sample_id in fold
548            .train_sample_ids
549            .iter()
550            .chain(&fold.validation_sample_ids)
551        {
552            ensure_train(sample_id)?;
553        }
554    }
555    Ok(())
556}
557
558fn unique_samples<'a>(label: &str, samples: &'a [SampleId]) -> Result<BTreeSet<&'a SampleId>> {
559    let mut seen = BTreeSet::new();
560    for sample_id in samples {
561        if !seen.insert(sample_id) {
562            return Err(DagMlError::OofValidation(format!(
563                "{label} contains duplicate sample `{sample_id}`"
564            )));
565        }
566    }
567    Ok(seen)
568}
569
570fn ordered_samples(samples: &[SampleId], shuffle: bool, seed: u64) -> Vec<SampleId> {
571    let mut ordered = samples.to_vec();
572    ordered.sort();
573    if shuffle {
574        let context = SeedContext::root(seed).child("kfold");
575        ordered.sort_by(|left, right| {
576            context
577                .derive_u64(left.as_str())
578                .cmp(&context.derive_u64(right.as_str()))
579                .then_with(|| left.cmp(right))
580        });
581    }
582    ordered
583}
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588
589    const SHARED_FOLD_SET_FINGERPRINT: &str =
590        "54d3185d6c628ef0df848828a8d8ae650222a283a78bbd3ab3bc2256f222c05c";
591
592    fn sid(value: &str) -> SampleId {
593        SampleId::new(value).unwrap()
594    }
595
596    fn gid(value: &str) -> GroupId {
597        GroupId::new(value).unwrap()
598    }
599
600    #[test]
601    fn kfold_is_deterministic_and_covers_samples_once() {
602        let samples = ["s1", "s2", "s3", "s4", "s5", "s6"]
603            .into_iter()
604            .map(sid)
605            .collect::<Vec<_>>();
606        let spec = KFoldSpec {
607            n_splits: 3,
608            shuffle: true,
609            seed: Some(42),
610        };
611
612        let left = spec.split("kfold", &samples).unwrap();
613        let right = spec.split("kfold", &samples).unwrap();
614
615        assert_eq!(left, right);
616        left.validate().unwrap();
617        for fold in &left.folds {
618            assert_eq!(fold.validation_sample_ids.len(), 2);
619            assert_eq!(fold.train_sample_ids.len(), 4);
620        }
621    }
622
623    #[test]
624    fn fold_validation_rejects_overlap() {
625        let fold_set = FoldSet {
626            id: "bad".to_string(),
627            sample_ids: vec![sid("s1"), sid("s2")],
628            folds: vec![FoldAssignment {
629                fold_id: FoldId::new("fold0").unwrap(),
630                train_sample_ids: vec![sid("s1")],
631                validation_sample_ids: vec![sid("s1")],
632                metadata: BTreeMap::new(),
633            }],
634            sample_groups: BTreeMap::new(),
635        };
636
637        assert!(fold_set.validate().is_err());
638    }
639
640    #[test]
641    fn fold_validation_rejects_partial_group_maps() {
642        let fold_set = FoldSet {
643            id: "bad-groups".to_string(),
644            sample_ids: vec![sid("s1"), sid("s2")],
645            folds: vec![FoldAssignment {
646                fold_id: FoldId::new("fold0").unwrap(),
647                train_sample_ids: vec![sid("s2")],
648                validation_sample_ids: vec![sid("s1")],
649                metadata: BTreeMap::new(),
650            }],
651            sample_groups: BTreeMap::from([(sid("s1"), gid("g1"))]),
652        };
653
654        assert!(fold_set.validate().is_err());
655    }
656
657    #[test]
658    fn fold_set_fingerprint_is_independent_of_ordering() {
659        let mut left = FoldSet {
660            id: "cv.partition".to_string(),
661            sample_ids: vec![sid("s3"), sid("s2"), sid("s1")],
662            folds: vec![
663                FoldAssignment {
664                    fold_id: FoldId::new("fold1").unwrap(),
665                    train_sample_ids: vec![sid("s2"), sid("s1")],
666                    validation_sample_ids: vec![sid("s3")],
667                    metadata: BTreeMap::new(),
668                },
669                FoldAssignment {
670                    fold_id: FoldId::new("fold0").unwrap(),
671                    train_sample_ids: vec![sid("s3")],
672                    validation_sample_ids: vec![sid("s2"), sid("s1")],
673                    metadata: BTreeMap::new(),
674                },
675            ],
676            sample_groups: BTreeMap::new(),
677        };
678        let mut right = left.clone();
679        right.sample_ids.reverse();
680        right.folds.reverse();
681        for fold in &mut right.folds {
682            fold.train_sample_ids.reverse();
683            fold.validation_sample_ids.reverse();
684        }
685
686        assert_eq!(
687            fold_set_fingerprint(&left).unwrap(),
688            fold_set_fingerprint(&right).unwrap()
689        );
690
691        left.id = "cv.partition.changed".to_string();
692        assert_ne!(
693            fold_set_fingerprint(&left).unwrap(),
694            fold_set_fingerprint(&right).unwrap()
695        );
696    }
697
698    #[test]
699    fn shared_fold_set_fixture_fingerprint_is_locked() {
700        let fixture = include_str!("../../../examples/fixtures/shared/fold_set_cv_partition.json");
701        let fold_set = serde_json::from_str::<FoldSet>(fixture).unwrap();
702
703        assert_eq!(
704            fold_set_fingerprint(&fold_set).unwrap(),
705            SHARED_FOLD_SET_FINGERPRINT
706        );
707    }
708
709    #[test]
710    fn group_kfold_keeps_groups_out_of_train_validation_overlap() {
711        let groups = BTreeMap::from([
712            (sid("s1"), gid("g1")),
713            (sid("s2"), gid("g1")),
714            (sid("s3"), gid("g2")),
715            (sid("s4"), gid("g2")),
716            (sid("s5"), gid("g3")),
717            (sid("s6"), gid("g3")),
718        ]);
719        let fold_set = GroupKFoldSpec { n_splits: 3 }
720            .split("group-kfold", &groups)
721            .unwrap();
722
723        fold_set.validate().unwrap();
724        for fold in &fold_set.folds {
725            let train_groups = fold
726                .train_sample_ids
727                .iter()
728                .map(|sample_id| groups.get(sample_id).unwrap())
729                .collect::<BTreeSet<_>>();
730            for sample_id in &fold.validation_sample_ids {
731                assert!(!train_groups.contains(groups.get(sample_id).unwrap()));
732            }
733        }
734    }
735
736    #[test]
737    fn stratified_kfold_is_oof_safe_and_balances_classes() {
738        // 8 samples, 2 classes (4 each); 2-fold stratified → each fold gets 2 of each class.
739        let samples = (0..8).map(|i| sid(&format!("s{i}"))).collect::<Vec<_>>();
740        let strata = BTreeMap::from_iter(samples.iter().enumerate().map(|(i, s)| {
741            (
742                s.clone(),
743                if i % 2 == 0 {
744                    "A".to_string()
745                } else {
746                    "B".to_string()
747                },
748            )
749        }));
750        let fold_set = StratifiedKFoldSpec {
751            n_splits: 2,
752            shuffle: false,
753            seed: Some(0),
754        }
755        .split("strat", &samples, &strata)
756        .unwrap();
757        fold_set.validate().unwrap(); // OOF: each sample validated exactly once
758        assert_eq!(fold_set.folds.len(), 2);
759        for fold in &fold_set.folds {
760            let mut counts: BTreeMap<&str, usize> = BTreeMap::new();
761            for s in &fold.validation_sample_ids {
762                *counts.entry(strata.get(s).unwrap().as_str()).or_insert(0) += 1;
763            }
764            assert_eq!(counts.get("A"), Some(&2));
765            assert_eq!(counts.get("B"), Some(&2));
766        }
767    }
768
769    #[test]
770    fn stratified_kfold_singleton_classes_leave_no_empty_fold() {
771        // Codex repro: 3 singleton classes with n_splits=3 must not pile all
772        // samples into fold0 (which FoldSet.validate rejects as an empty fold1).
773        let samples = ["s0", "s1", "s2"].into_iter().map(sid).collect::<Vec<_>>();
774        let strata = BTreeMap::from_iter([
775            (sid("s0"), "A".to_string()),
776            (sid("s1"), "B".to_string()),
777            (sid("s2"), "C".to_string()),
778        ]);
779        let fold_set = StratifiedKFoldSpec {
780            n_splits: 3,
781            shuffle: false,
782            seed: Some(0),
783        }
784        .split("strat", &samples, &strata)
785        .expect("singleton-class stratified split must succeed");
786        fold_set.validate().unwrap();
787        for fold in &fold_set.folds {
788            assert_eq!(fold.validation_sample_ids.len(), 1);
789        }
790    }
791
792    #[test]
793    fn stratified_kfold_rejects_missing_label() {
794        let samples = (0..4).map(|i| sid(&format!("s{i}"))).collect::<Vec<_>>();
795        let strata = BTreeMap::from_iter([(sid("s0"), "A".to_string())]); // incomplete
796        let err = StratifiedKFoldSpec {
797            n_splits: 2,
798            shuffle: false,
799            seed: Some(0),
800        }
801        .split("strat", &samples, &strata);
802        assert!(err.is_err());
803    }
804
805    fn outer_kfold(samples: &[SampleId]) -> FoldSet {
806        KFoldSpec {
807            n_splits: 2,
808            shuffle: false,
809            seed: Some(0),
810        }
811        .split("outer", samples)
812        .unwrap()
813    }
814
815    #[test]
816    fn nested_kfold_inner_folds_are_subset_of_outer_train() {
817        let samples = ["s1", "s2", "s3", "s4", "s5", "s6"]
818            .into_iter()
819            .map(sid)
820            .collect::<Vec<_>>();
821        let outer = outer_kfold(&samples);
822        let spec = NestedCvSpec::KFold(KFoldSpec {
823            n_splits: 2,
824            shuffle: false,
825            seed: Some(1),
826        });
827        for outer_fold in &outer.folds {
828            let inner = spec
829                .build_inner_fold_set(outer_fold, &outer.sample_groups)
830                .expect("inner fold set");
831            let outer_train = outer_fold.train_sample_ids.iter().collect::<BTreeSet<_>>();
832            // Every inner sample is an outer training sample.
833            for sample_id in &inner.sample_ids {
834                assert!(outer_train.contains(sample_id));
835            }
836            // The inner fold set is itself valid and covers exactly outer-train.
837            inner.validate().unwrap();
838            assert_eq!(
839                inner.sample_ids.iter().collect::<BTreeSet<_>>(),
840                outer_train
841            );
842        }
843    }
844
845    #[test]
846    fn nested_cv_validation_refuses_inner_sample_from_outer_validation() {
847        let samples = ["s1", "s2", "s3", "s4"]
848            .into_iter()
849            .map(sid)
850            .collect::<Vec<_>>();
851        let outer = outer_kfold(&samples);
852        let outer_fold = &outer.folds[0];
853        // A STRUCTURALLY VALID inner fold set that nonetheless includes an outer
854        // VALIDATION sample — the nested-CV boundary check must refuse it.
855        let leaking_sample = outer_fold.validation_sample_ids[0].clone();
856        let train_sample = outer_fold.train_sample_ids[0].clone();
857        let inner = FoldSet {
858            id: "leaky.inner".to_string(),
859            sample_ids: vec![train_sample.clone(), leaking_sample.clone()],
860            folds: vec![
861                FoldAssignment {
862                    fold_id: FoldId::new("if0").unwrap(),
863                    train_sample_ids: vec![leaking_sample.clone()],
864                    validation_sample_ids: vec![train_sample.clone()],
865                    metadata: BTreeMap::new(),
866                },
867                FoldAssignment {
868                    fold_id: FoldId::new("if1").unwrap(),
869                    train_sample_ids: vec![train_sample],
870                    validation_sample_ids: vec![leaking_sample],
871                    metadata: BTreeMap::new(),
872                },
873            ],
874            sample_groups: BTreeMap::new(),
875        };
876        inner
877            .validate()
878            .expect("inner fold set is structurally valid");
879        let err = validate_inner_fold_set_within_outer(&inner, outer_fold)
880            .expect_err("inner fold leaking an outer-validation sample must be refused");
881        assert!(err.to_string().contains("nested CV leakage"));
882    }
883
884    #[test]
885    fn nested_cv_validation_refuses_leak_hidden_in_fold_members() {
886        // A malformed supplied inner fold set hides an outer-validation sample in a
887        // fold's members while omitting it from the top-level `sample_ids`. It must
888        // still be refused (structural validation catches the inconsistency).
889        let samples = ["s1", "s2", "s3", "s4"]
890            .into_iter()
891            .map(sid)
892            .collect::<Vec<_>>();
893        let outer = outer_kfold(&samples);
894        let outer_fold = &outer.folds[0];
895        let leaking_sample = outer_fold.validation_sample_ids[0].clone();
896        let train_sample = outer_fold.train_sample_ids[0].clone();
897        let inner = FoldSet {
898            id: "hidden.inner".to_string(),
899            // `sample_ids` omits the leaking sample, but a fold member smuggles it in.
900            sample_ids: vec![train_sample.clone()],
901            folds: vec![FoldAssignment {
902                fold_id: FoldId::new("if0").unwrap(),
903                train_sample_ids: vec![train_sample],
904                validation_sample_ids: vec![leaking_sample],
905                metadata: BTreeMap::new(),
906            }],
907            sample_groups: BTreeMap::new(),
908        };
909        assert!(validate_inner_fold_set_within_outer(&inner, outer_fold).is_err());
910    }
911
912    #[test]
913    fn nested_cv_spec_json_shape_is_stable() {
914        let spec = NestedCvSpec::KFold(KFoldSpec {
915            n_splits: 3,
916            shuffle: false,
917            seed: Some(7),
918        });
919        let value = serde_json::to_value(&spec).unwrap();
920        assert_eq!(value["kind"], "kfold");
921        assert_eq!(value["n_splits"], 3);
922        assert_eq!(value["seed"], 7);
923        let round: NestedCvSpec = serde_json::from_value(value).unwrap();
924        assert_eq!(round, spec);
925
926        let group = NestedCvSpec::GroupKFold(GroupKFoldSpec { n_splits: 2 });
927        let gv = serde_json::to_value(&group).unwrap();
928        assert_eq!(gv["kind"], "group_kfold");
929        assert_eq!(gv["n_splits"], 2);
930        assert_eq!(serde_json::from_value::<NestedCvSpec>(gv).unwrap(), group);
931    }
932
933    #[test]
934    fn resolve_inner_cv_prefers_node_over_campaign() {
935        let node = NestedCvSpec::KFold(KFoldSpec {
936            n_splits: 3,
937            shuffle: false,
938            seed: Some(2),
939        });
940        let campaign = NestedCvSpec::KFold(KFoldSpec {
941            n_splits: 5,
942            shuffle: false,
943            seed: Some(3),
944        });
945        assert_eq!(resolve_inner_cv(Some(&node), Some(&campaign)), Some(&node));
946        assert_eq!(resolve_inner_cv(None, Some(&campaign)), Some(&campaign));
947        assert_eq!(resolve_inner_cv(Some(&node), None), Some(&node));
948        assert_eq!(resolve_inner_cv(None, None), None);
949    }
950}