use std::collections::{BTreeMap, BTreeSet};
use serde::{Deserialize, Serialize};
use crate::campaign::stable_json_fingerprint;
use crate::error::{DagMlError, Result};
use crate::ids::{FoldId, GroupId, SampleId};
use crate::rng::SeedContext;
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct FoldAssignment {
pub fold_id: FoldId,
pub train_sample_ids: Vec<SampleId>,
pub validation_sample_ids: Vec<SampleId>,
#[serde(default)]
pub metadata: BTreeMap<String, serde_json::Value>,
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct FoldSet {
pub id: String,
pub sample_ids: Vec<SampleId>,
pub folds: Vec<FoldAssignment>,
#[serde(default)]
pub sample_groups: BTreeMap<SampleId, GroupId>,
}
impl FoldSet {
pub fn validate(&self) -> Result<()> {
if self.id.trim().is_empty() {
return Err(DagMlError::OofValidation(
"fold set id is empty".to_string(),
));
}
if self.sample_ids.is_empty() {
return Err(DagMlError::OofValidation(
"fold set contains no samples".to_string(),
));
}
if self.folds.is_empty() {
return Err(DagMlError::OofValidation(
"fold set contains no folds".to_string(),
));
}
let universe = unique_samples("fold set sample_ids", &self.sample_ids)?;
if !self.sample_groups.is_empty() {
for sample_id in self.sample_groups.keys() {
if !universe.contains(sample_id) {
return Err(DagMlError::OofValidation(format!(
"sample group map references unknown sample `{sample_id}`"
)));
}
}
for sample_id in &self.sample_ids {
if !self.sample_groups.contains_key(sample_id) {
return Err(DagMlError::OofValidation(format!(
"sample `{sample_id}` is missing from non-empty group map"
)));
}
}
}
let mut fold_ids = BTreeSet::new();
let mut validation_counts = self
.sample_ids
.iter()
.cloned()
.map(|sample_id| (sample_id, 0usize))
.collect::<BTreeMap<_, _>>();
for fold in &self.folds {
if !fold_ids.insert(&fold.fold_id) {
return Err(DagMlError::OofValidation(format!(
"duplicate fold id `{}`",
fold.fold_id
)));
}
let train = unique_samples(
&format!("fold `{}` train_sample_ids", fold.fold_id),
&fold.train_sample_ids,
)?;
let validation = unique_samples(
&format!("fold `{}` validation_sample_ids", fold.fold_id),
&fold.validation_sample_ids,
)?;
if validation.is_empty() {
return Err(DagMlError::OofValidation(format!(
"fold `{}` has no validation samples",
fold.fold_id
)));
}
for sample_id in train.union(&validation) {
if !universe.contains(sample_id) {
return Err(DagMlError::OofValidation(format!(
"fold `{}` references unknown sample `{}`",
fold.fold_id, sample_id
)));
}
}
let overlap = train.intersection(&validation).collect::<Vec<_>>();
if !overlap.is_empty() {
return Err(DagMlError::OofValidation(format!(
"fold `{}` has train/validation overlap at sample `{}`",
fold.fold_id, overlap[0]
)));
}
for sample_id in validation {
*validation_counts
.get_mut(sample_id)
.expect("validation sample is in universe") += 1;
}
self.validate_group_boundary(fold, &train)?;
}
for (sample_id, count) in validation_counts {
if count != 1 {
return Err(DagMlError::OofValidation(format!(
"sample `{}` appears in validation {} time(s), expected exactly once",
sample_id, count
)));
}
}
Ok(())
}
fn validate_group_boundary(
&self,
fold: &FoldAssignment,
train: &BTreeSet<&SampleId>,
) -> Result<()> {
if self.sample_groups.is_empty() {
return Ok(());
}
let train_groups = train
.iter()
.filter_map(|sample_id| self.sample_groups.get(*sample_id))
.collect::<BTreeSet<_>>();
for sample_id in &fold.validation_sample_ids {
let Some(group_id) = self.sample_groups.get(sample_id) else {
continue;
};
if train_groups.contains(group_id) {
return Err(DagMlError::OofValidation(format!(
"fold `{}` leaks group `{}` across train/validation",
fold.fold_id, group_id
)));
}
}
Ok(())
}
}
pub fn fold_set_fingerprint(fold_set: &FoldSet) -> Result<String> {
let mut canonical = fold_set.clone();
canonical.validate()?;
canonical.sample_ids.sort();
canonical
.folds
.sort_by(|left, right| left.fold_id.cmp(&right.fold_id));
for fold in &mut canonical.folds {
fold.train_sample_ids.sort();
fold.validation_sample_ids.sort();
}
let mut value = serde_json::to_value(&canonical)?;
remove_empty_fold_set_maps(&mut value);
stable_json_fingerprint(&value)
}
fn remove_empty_fold_set_maps(value: &mut serde_json::Value) {
let Some(object) = value.as_object_mut() else {
return;
};
if object
.get("sample_groups")
.and_then(serde_json::Value::as_object)
.is_some_and(serde_json::Map::is_empty)
{
object.remove("sample_groups");
}
let Some(folds) = object
.get_mut("folds")
.and_then(serde_json::Value::as_array_mut)
else {
return;
};
for fold in folds {
let Some(fold_object) = fold.as_object_mut() else {
continue;
};
if fold_object
.get("metadata")
.and_then(serde_json::Value::as_object)
.is_some_and(serde_json::Map::is_empty)
{
fold_object.remove("metadata");
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct KFoldSpec {
pub n_splits: usize,
#[serde(default)]
pub shuffle: bool,
pub seed: Option<u64>,
}
impl KFoldSpec {
pub fn split(&self, id: impl Into<String>, samples: &[SampleId]) -> Result<FoldSet> {
if self.n_splits < 2 {
return Err(DagMlError::OofValidation(
"KFold requires at least two splits".to_string(),
));
}
let unique = unique_samples("KFold samples", samples)?;
if self.n_splits > unique.len() {
return Err(DagMlError::OofValidation(format!(
"KFold n_splits={} exceeds sample count {}",
self.n_splits,
unique.len()
)));
}
let ordered = ordered_samples(samples, self.shuffle, self.seed.unwrap_or(0));
let folds = (0..self.n_splits)
.map(|fold_idx| {
let validation = ordered
.iter()
.enumerate()
.filter_map(|(idx, sample_id)| {
(idx % self.n_splits == fold_idx).then_some(sample_id.clone())
})
.collect::<Vec<_>>();
let validation_set = validation.iter().collect::<BTreeSet<_>>();
let train = ordered
.iter()
.filter(|sample_id| !validation_set.contains(sample_id))
.cloned()
.collect::<Vec<_>>();
Ok(FoldAssignment {
fold_id: FoldId::new(format!("fold{fold_idx}"))?,
train_sample_ids: train,
validation_sample_ids: validation,
metadata: BTreeMap::new(),
})
})
.collect::<Result<Vec<_>>>()?;
let fold_set = FoldSet {
id: id.into(),
sample_ids: ordered_samples(samples, false, 0),
folds,
sample_groups: BTreeMap::new(),
};
fold_set.validate()?;
Ok(fold_set)
}
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct StratifiedKFoldSpec {
pub n_splits: usize,
#[serde(default)]
pub shuffle: bool,
pub seed: Option<u64>,
}
impl StratifiedKFoldSpec {
pub fn split(
&self,
id: impl Into<String>,
samples: &[SampleId],
strata: &BTreeMap<SampleId, String>,
) -> Result<FoldSet> {
if self.n_splits < 2 {
return Err(DagMlError::OofValidation(
"StratifiedKFold requires at least two splits".to_string(),
));
}
let unique = unique_samples("StratifiedKFold samples", samples)?;
if self.n_splits > unique.len() {
return Err(DagMlError::OofValidation(format!(
"StratifiedKFold n_splits={} exceeds sample count {}",
self.n_splits,
unique.len()
)));
}
let ordered = ordered_samples(samples, self.shuffle, self.seed.unwrap_or(0));
let mut by_label: BTreeMap<String, Vec<SampleId>> = BTreeMap::new();
for sample_id in &ordered {
let label = strata.get(sample_id).ok_or_else(|| {
DagMlError::OofValidation(format!(
"StratifiedKFold: sample `{sample_id}` has no stratum label"
))
})?;
by_label
.entry(label.clone())
.or_default()
.push(sample_id.clone());
}
let mut fold_of: BTreeMap<SampleId, usize> = BTreeMap::new();
let mut position = 0usize;
for members in by_label.values() {
for sample_id in members {
fold_of.insert(sample_id.clone(), position % self.n_splits);
position += 1;
}
}
let folds = (0..self.n_splits)
.map(|fold_idx| {
let validation = ordered
.iter()
.filter(|s| fold_of.get(*s) == Some(&fold_idx))
.cloned()
.collect::<Vec<_>>();
let train = ordered
.iter()
.filter(|s| fold_of.get(*s) != Some(&fold_idx))
.cloned()
.collect::<Vec<_>>();
Ok(FoldAssignment {
fold_id: FoldId::new(format!("fold{fold_idx}"))?,
train_sample_ids: train,
validation_sample_ids: validation,
metadata: BTreeMap::new(),
})
})
.collect::<Result<Vec<_>>>()?;
let fold_set = FoldSet {
id: id.into(),
sample_ids: ordered_samples(samples, false, 0),
folds,
sample_groups: BTreeMap::new(),
};
fold_set.validate()?;
Ok(fold_set)
}
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct GroupKFoldSpec {
pub n_splits: usize,
}
impl GroupKFoldSpec {
pub fn split(
&self,
id: impl Into<String>,
sample_groups: &BTreeMap<SampleId, GroupId>,
) -> Result<FoldSet> {
if self.n_splits < 2 {
return Err(DagMlError::OofValidation(
"GroupKFold requires at least two splits".to_string(),
));
}
if sample_groups.is_empty() {
return Err(DagMlError::OofValidation(
"GroupKFold requires sample groups".to_string(),
));
}
let mut groups = BTreeMap::<GroupId, Vec<SampleId>>::new();
for (sample_id, group_id) in sample_groups {
groups
.entry(group_id.clone())
.or_default()
.push(sample_id.clone());
}
if self.n_splits > groups.len() {
return Err(DagMlError::OofValidation(format!(
"GroupKFold n_splits={} exceeds group count {}",
self.n_splits,
groups.len()
)));
}
let mut grouped = groups.into_iter().collect::<Vec<_>>();
grouped.sort_by(|(left_group, left_samples), (right_group, right_samples)| {
right_samples
.len()
.cmp(&left_samples.len())
.then_with(|| left_group.cmp(right_group))
});
let mut fold_validation = vec![Vec::<SampleId>::new(); self.n_splits];
for (_group_id, mut samples) in grouped {
samples.sort();
let fold_idx = fold_validation
.iter()
.enumerate()
.min_by(|(left_idx, left), (right_idx, right)| {
left.len()
.cmp(&right.len())
.then_with(|| left_idx.cmp(right_idx))
})
.map(|(idx, _)| idx)
.expect("at least one fold");
fold_validation[fold_idx].extend(samples);
}
let mut sample_ids = sample_groups.keys().cloned().collect::<Vec<_>>();
sample_ids.sort();
let folds = fold_validation
.into_iter()
.enumerate()
.map(|(fold_idx, mut validation)| {
validation.sort();
let validation_set = validation.iter().collect::<BTreeSet<_>>();
let train = sample_ids
.iter()
.filter(|sample_id| !validation_set.contains(sample_id))
.cloned()
.collect::<Vec<_>>();
Ok(FoldAssignment {
fold_id: FoldId::new(format!("fold{fold_idx}"))?,
train_sample_ids: train,
validation_sample_ids: validation,
metadata: BTreeMap::new(),
})
})
.collect::<Result<Vec<_>>>()?;
let fold_set = FoldSet {
id: id.into(),
sample_ids,
folds,
sample_groups: sample_groups.clone(),
};
fold_set.validate()?;
Ok(fold_set)
}
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(tag = "kind")]
pub enum NestedCvSpec {
#[serde(rename = "kfold")]
KFold(KFoldSpec),
#[serde(rename = "group_kfold")]
GroupKFold(GroupKFoldSpec),
}
impl NestedCvSpec {
pub fn validate(&self) -> Result<()> {
match self {
Self::KFold(spec) => {
if spec.n_splits < 2 {
return Err(DagMlError::OofValidation(
"inner KFold requires at least two splits".to_string(),
));
}
}
Self::GroupKFold(spec) => {
if spec.n_splits < 2 {
return Err(DagMlError::OofValidation(
"inner GroupKFold requires at least two splits".to_string(),
));
}
}
}
Ok(())
}
pub fn build_inner_fold_set(
&self,
outer: &FoldAssignment,
outer_groups: &BTreeMap<SampleId, GroupId>,
) -> Result<FoldSet> {
let inner_id = format!("{}.inner", outer.fold_id);
let inner = match self {
Self::KFold(spec) => spec.split(inner_id, &outer.train_sample_ids)?,
Self::GroupKFold(spec) => {
let train = outer.train_sample_ids.iter().collect::<BTreeSet<_>>();
let inner_groups = outer_groups
.iter()
.filter(|(sample_id, _)| train.contains(sample_id))
.map(|(sample_id, group_id)| (sample_id.clone(), group_id.clone()))
.collect::<BTreeMap<_, _>>();
spec.split(inner_id, &inner_groups)?
}
};
validate_inner_fold_set_within_outer(&inner, outer)?;
Ok(inner)
}
}
pub fn resolve_inner_cv<'a>(
node_inner_cv: Option<&'a NestedCvSpec>,
campaign_inner_cv: Option<&'a NestedCvSpec>,
) -> Option<&'a NestedCvSpec> {
node_inner_cv.or(campaign_inner_cv)
}
pub fn validate_inner_fold_set_within_outer(inner: &FoldSet, outer: &FoldAssignment) -> Result<()> {
inner.validate()?;
let train = outer.train_sample_ids.iter().collect::<BTreeSet<_>>();
let ensure_train = |sample_id: &SampleId| -> Result<()> {
if !train.contains(sample_id) {
return Err(DagMlError::OofValidation(format!(
"nested CV leakage: inner-CV sample `{sample_id}` for outer fold `{}` is not an outer training sample",
outer.fold_id
)));
}
Ok(())
};
for sample_id in &inner.sample_ids {
ensure_train(sample_id)?;
}
for fold in &inner.folds {
for sample_id in fold
.train_sample_ids
.iter()
.chain(&fold.validation_sample_ids)
{
ensure_train(sample_id)?;
}
}
Ok(())
}
fn unique_samples<'a>(label: &str, samples: &'a [SampleId]) -> Result<BTreeSet<&'a SampleId>> {
let mut seen = BTreeSet::new();
for sample_id in samples {
if !seen.insert(sample_id) {
return Err(DagMlError::OofValidation(format!(
"{label} contains duplicate sample `{sample_id}`"
)));
}
}
Ok(seen)
}
fn ordered_samples(samples: &[SampleId], shuffle: bool, seed: u64) -> Vec<SampleId> {
let mut ordered = samples.to_vec();
ordered.sort();
if shuffle {
let context = SeedContext::root(seed).child("kfold");
ordered.sort_by(|left, right| {
context
.derive_u64(left.as_str())
.cmp(&context.derive_u64(right.as_str()))
.then_with(|| left.cmp(right))
});
}
ordered
}
#[cfg(test)]
mod tests {
use super::*;
const SHARED_FOLD_SET_FINGERPRINT: &str =
"54d3185d6c628ef0df848828a8d8ae650222a283a78bbd3ab3bc2256f222c05c";
fn sid(value: &str) -> SampleId {
SampleId::new(value).unwrap()
}
fn gid(value: &str) -> GroupId {
GroupId::new(value).unwrap()
}
#[test]
fn kfold_is_deterministic_and_covers_samples_once() {
let samples = ["s1", "s2", "s3", "s4", "s5", "s6"]
.into_iter()
.map(sid)
.collect::<Vec<_>>();
let spec = KFoldSpec {
n_splits: 3,
shuffle: true,
seed: Some(42),
};
let left = spec.split("kfold", &samples).unwrap();
let right = spec.split("kfold", &samples).unwrap();
assert_eq!(left, right);
left.validate().unwrap();
for fold in &left.folds {
assert_eq!(fold.validation_sample_ids.len(), 2);
assert_eq!(fold.train_sample_ids.len(), 4);
}
}
#[test]
fn fold_validation_rejects_overlap() {
let fold_set = FoldSet {
id: "bad".to_string(),
sample_ids: vec![sid("s1"), sid("s2")],
folds: vec![FoldAssignment {
fold_id: FoldId::new("fold0").unwrap(),
train_sample_ids: vec![sid("s1")],
validation_sample_ids: vec![sid("s1")],
metadata: BTreeMap::new(),
}],
sample_groups: BTreeMap::new(),
};
assert!(fold_set.validate().is_err());
}
#[test]
fn fold_validation_rejects_partial_group_maps() {
let fold_set = FoldSet {
id: "bad-groups".to_string(),
sample_ids: vec![sid("s1"), sid("s2")],
folds: vec![FoldAssignment {
fold_id: FoldId::new("fold0").unwrap(),
train_sample_ids: vec![sid("s2")],
validation_sample_ids: vec![sid("s1")],
metadata: BTreeMap::new(),
}],
sample_groups: BTreeMap::from([(sid("s1"), gid("g1"))]),
};
assert!(fold_set.validate().is_err());
}
#[test]
fn fold_set_fingerprint_is_independent_of_ordering() {
let mut left = FoldSet {
id: "cv.partition".to_string(),
sample_ids: vec![sid("s3"), sid("s2"), sid("s1")],
folds: vec![
FoldAssignment {
fold_id: FoldId::new("fold1").unwrap(),
train_sample_ids: vec![sid("s2"), sid("s1")],
validation_sample_ids: vec![sid("s3")],
metadata: BTreeMap::new(),
},
FoldAssignment {
fold_id: FoldId::new("fold0").unwrap(),
train_sample_ids: vec![sid("s3")],
validation_sample_ids: vec![sid("s2"), sid("s1")],
metadata: BTreeMap::new(),
},
],
sample_groups: BTreeMap::new(),
};
let mut right = left.clone();
right.sample_ids.reverse();
right.folds.reverse();
for fold in &mut right.folds {
fold.train_sample_ids.reverse();
fold.validation_sample_ids.reverse();
}
assert_eq!(
fold_set_fingerprint(&left).unwrap(),
fold_set_fingerprint(&right).unwrap()
);
left.id = "cv.partition.changed".to_string();
assert_ne!(
fold_set_fingerprint(&left).unwrap(),
fold_set_fingerprint(&right).unwrap()
);
}
#[test]
fn shared_fold_set_fixture_fingerprint_is_locked() {
let fixture = include_str!("../../../examples/fixtures/shared/fold_set_cv_partition.json");
let fold_set = serde_json::from_str::<FoldSet>(fixture).unwrap();
assert_eq!(
fold_set_fingerprint(&fold_set).unwrap(),
SHARED_FOLD_SET_FINGERPRINT
);
}
#[test]
fn group_kfold_keeps_groups_out_of_train_validation_overlap() {
let groups = BTreeMap::from([
(sid("s1"), gid("g1")),
(sid("s2"), gid("g1")),
(sid("s3"), gid("g2")),
(sid("s4"), gid("g2")),
(sid("s5"), gid("g3")),
(sid("s6"), gid("g3")),
]);
let fold_set = GroupKFoldSpec { n_splits: 3 }
.split("group-kfold", &groups)
.unwrap();
fold_set.validate().unwrap();
for fold in &fold_set.folds {
let train_groups = fold
.train_sample_ids
.iter()
.map(|sample_id| groups.get(sample_id).unwrap())
.collect::<BTreeSet<_>>();
for sample_id in &fold.validation_sample_ids {
assert!(!train_groups.contains(groups.get(sample_id).unwrap()));
}
}
}
#[test]
fn stratified_kfold_is_oof_safe_and_balances_classes() {
let samples = (0..8).map(|i| sid(&format!("s{i}"))).collect::<Vec<_>>();
let strata = BTreeMap::from_iter(samples.iter().enumerate().map(|(i, s)| {
(
s.clone(),
if i % 2 == 0 {
"A".to_string()
} else {
"B".to_string()
},
)
}));
let fold_set = StratifiedKFoldSpec {
n_splits: 2,
shuffle: false,
seed: Some(0),
}
.split("strat", &samples, &strata)
.unwrap();
fold_set.validate().unwrap(); assert_eq!(fold_set.folds.len(), 2);
for fold in &fold_set.folds {
let mut counts: BTreeMap<&str, usize> = BTreeMap::new();
for s in &fold.validation_sample_ids {
*counts.entry(strata.get(s).unwrap().as_str()).or_insert(0) += 1;
}
assert_eq!(counts.get("A"), Some(&2));
assert_eq!(counts.get("B"), Some(&2));
}
}
#[test]
fn stratified_kfold_singleton_classes_leave_no_empty_fold() {
let samples = ["s0", "s1", "s2"].into_iter().map(sid).collect::<Vec<_>>();
let strata = BTreeMap::from_iter([
(sid("s0"), "A".to_string()),
(sid("s1"), "B".to_string()),
(sid("s2"), "C".to_string()),
]);
let fold_set = StratifiedKFoldSpec {
n_splits: 3,
shuffle: false,
seed: Some(0),
}
.split("strat", &samples, &strata)
.expect("singleton-class stratified split must succeed");
fold_set.validate().unwrap();
for fold in &fold_set.folds {
assert_eq!(fold.validation_sample_ids.len(), 1);
}
}
#[test]
fn stratified_kfold_rejects_missing_label() {
let samples = (0..4).map(|i| sid(&format!("s{i}"))).collect::<Vec<_>>();
let strata = BTreeMap::from_iter([(sid("s0"), "A".to_string())]); let err = StratifiedKFoldSpec {
n_splits: 2,
shuffle: false,
seed: Some(0),
}
.split("strat", &samples, &strata);
assert!(err.is_err());
}
fn outer_kfold(samples: &[SampleId]) -> FoldSet {
KFoldSpec {
n_splits: 2,
shuffle: false,
seed: Some(0),
}
.split("outer", samples)
.unwrap()
}
#[test]
fn nested_kfold_inner_folds_are_subset_of_outer_train() {
let samples = ["s1", "s2", "s3", "s4", "s5", "s6"]
.into_iter()
.map(sid)
.collect::<Vec<_>>();
let outer = outer_kfold(&samples);
let spec = NestedCvSpec::KFold(KFoldSpec {
n_splits: 2,
shuffle: false,
seed: Some(1),
});
for outer_fold in &outer.folds {
let inner = spec
.build_inner_fold_set(outer_fold, &outer.sample_groups)
.expect("inner fold set");
let outer_train = outer_fold.train_sample_ids.iter().collect::<BTreeSet<_>>();
for sample_id in &inner.sample_ids {
assert!(outer_train.contains(sample_id));
}
inner.validate().unwrap();
assert_eq!(
inner.sample_ids.iter().collect::<BTreeSet<_>>(),
outer_train
);
}
}
#[test]
fn nested_cv_validation_refuses_inner_sample_from_outer_validation() {
let samples = ["s1", "s2", "s3", "s4"]
.into_iter()
.map(sid)
.collect::<Vec<_>>();
let outer = outer_kfold(&samples);
let outer_fold = &outer.folds[0];
let leaking_sample = outer_fold.validation_sample_ids[0].clone();
let train_sample = outer_fold.train_sample_ids[0].clone();
let inner = FoldSet {
id: "leaky.inner".to_string(),
sample_ids: vec![train_sample.clone(), leaking_sample.clone()],
folds: vec![
FoldAssignment {
fold_id: FoldId::new("if0").unwrap(),
train_sample_ids: vec![leaking_sample.clone()],
validation_sample_ids: vec![train_sample.clone()],
metadata: BTreeMap::new(),
},
FoldAssignment {
fold_id: FoldId::new("if1").unwrap(),
train_sample_ids: vec![train_sample],
validation_sample_ids: vec![leaking_sample],
metadata: BTreeMap::new(),
},
],
sample_groups: BTreeMap::new(),
};
inner
.validate()
.expect("inner fold set is structurally valid");
let err = validate_inner_fold_set_within_outer(&inner, outer_fold)
.expect_err("inner fold leaking an outer-validation sample must be refused");
assert!(err.to_string().contains("nested CV leakage"));
}
#[test]
fn nested_cv_validation_refuses_leak_hidden_in_fold_members() {
let samples = ["s1", "s2", "s3", "s4"]
.into_iter()
.map(sid)
.collect::<Vec<_>>();
let outer = outer_kfold(&samples);
let outer_fold = &outer.folds[0];
let leaking_sample = outer_fold.validation_sample_ids[0].clone();
let train_sample = outer_fold.train_sample_ids[0].clone();
let inner = FoldSet {
id: "hidden.inner".to_string(),
sample_ids: vec![train_sample.clone()],
folds: vec![FoldAssignment {
fold_id: FoldId::new("if0").unwrap(),
train_sample_ids: vec![train_sample],
validation_sample_ids: vec![leaking_sample],
metadata: BTreeMap::new(),
}],
sample_groups: BTreeMap::new(),
};
assert!(validate_inner_fold_set_within_outer(&inner, outer_fold).is_err());
}
#[test]
fn nested_cv_spec_json_shape_is_stable() {
let spec = NestedCvSpec::KFold(KFoldSpec {
n_splits: 3,
shuffle: false,
seed: Some(7),
});
let value = serde_json::to_value(&spec).unwrap();
assert_eq!(value["kind"], "kfold");
assert_eq!(value["n_splits"], 3);
assert_eq!(value["seed"], 7);
let round: NestedCvSpec = serde_json::from_value(value).unwrap();
assert_eq!(round, spec);
let group = NestedCvSpec::GroupKFold(GroupKFoldSpec { n_splits: 2 });
let gv = serde_json::to_value(&group).unwrap();
assert_eq!(gv["kind"], "group_kfold");
assert_eq!(gv["n_splits"], 2);
assert_eq!(serde_json::from_value::<NestedCvSpec>(gv).unwrap(), group);
}
#[test]
fn resolve_inner_cv_prefers_node_over_campaign() {
let node = NestedCvSpec::KFold(KFoldSpec {
n_splits: 3,
shuffle: false,
seed: Some(2),
});
let campaign = NestedCvSpec::KFold(KFoldSpec {
n_splits: 5,
shuffle: false,
seed: Some(3),
});
assert_eq!(resolve_inner_cv(Some(&node), Some(&campaign)), Some(&node));
assert_eq!(resolve_inner_cv(None, Some(&campaign)), Some(&campaign));
assert_eq!(resolve_inner_cv(Some(&node), None), Some(&node));
assert_eq!(resolve_inner_cv(None, None), None);
}
}