1use std::collections::{BTreeMap, BTreeSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::campaign::stable_json_fingerprint;
6use crate::error::{DagMlError, Result};
7use crate::ids::{FoldId, GroupId, SampleId};
8use crate::rng::SeedContext;
9
10#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
11pub struct FoldAssignment {
12 pub fold_id: FoldId,
13 pub train_sample_ids: Vec<SampleId>,
14 pub validation_sample_ids: Vec<SampleId>,
15 #[serde(default)]
16 pub metadata: BTreeMap<String, serde_json::Value>,
17}
18
19#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
20pub struct FoldSet {
21 pub id: String,
22 pub sample_ids: Vec<SampleId>,
23 pub folds: Vec<FoldAssignment>,
24 #[serde(default)]
25 pub sample_groups: BTreeMap<SampleId, GroupId>,
26}
27
28impl FoldSet {
29 pub fn validate(&self) -> Result<()> {
30 if self.id.trim().is_empty() {
31 return Err(DagMlError::OofValidation(
32 "fold set id is empty".to_string(),
33 ));
34 }
35 if self.sample_ids.is_empty() {
36 return Err(DagMlError::OofValidation(
37 "fold set contains no samples".to_string(),
38 ));
39 }
40 if self.folds.is_empty() {
41 return Err(DagMlError::OofValidation(
42 "fold set contains no folds".to_string(),
43 ));
44 }
45 let universe = unique_samples("fold set sample_ids", &self.sample_ids)?;
46 if !self.sample_groups.is_empty() {
47 for sample_id in self.sample_groups.keys() {
48 if !universe.contains(sample_id) {
49 return Err(DagMlError::OofValidation(format!(
50 "sample group map references unknown sample `{sample_id}`"
51 )));
52 }
53 }
54 for sample_id in &self.sample_ids {
55 if !self.sample_groups.contains_key(sample_id) {
56 return Err(DagMlError::OofValidation(format!(
57 "sample `{sample_id}` is missing from non-empty group map"
58 )));
59 }
60 }
61 }
62 let mut fold_ids = BTreeSet::new();
63 let mut validation_counts = self
64 .sample_ids
65 .iter()
66 .cloned()
67 .map(|sample_id| (sample_id, 0usize))
68 .collect::<BTreeMap<_, _>>();
69
70 for fold in &self.folds {
71 if !fold_ids.insert(&fold.fold_id) {
72 return Err(DagMlError::OofValidation(format!(
73 "duplicate fold id `{}`",
74 fold.fold_id
75 )));
76 }
77 let train = unique_samples(
78 &format!("fold `{}` train_sample_ids", fold.fold_id),
79 &fold.train_sample_ids,
80 )?;
81 let validation = unique_samples(
82 &format!("fold `{}` validation_sample_ids", fold.fold_id),
83 &fold.validation_sample_ids,
84 )?;
85 if validation.is_empty() {
86 return Err(DagMlError::OofValidation(format!(
87 "fold `{}` has no validation samples",
88 fold.fold_id
89 )));
90 }
91 for sample_id in train.union(&validation) {
92 if !universe.contains(sample_id) {
93 return Err(DagMlError::OofValidation(format!(
94 "fold `{}` references unknown sample `{}`",
95 fold.fold_id, sample_id
96 )));
97 }
98 }
99 let overlap = train.intersection(&validation).collect::<Vec<_>>();
100 if !overlap.is_empty() {
101 return Err(DagMlError::OofValidation(format!(
102 "fold `{}` has train/validation overlap at sample `{}`",
103 fold.fold_id, overlap[0]
104 )));
105 }
106 for sample_id in validation {
107 *validation_counts
108 .get_mut(sample_id)
109 .expect("validation sample is in universe") += 1;
110 }
111 self.validate_group_boundary(fold, &train)?;
112 }
113
114 for (sample_id, count) in validation_counts {
115 if count != 1 {
116 return Err(DagMlError::OofValidation(format!(
117 "sample `{}` appears in validation {} time(s), expected exactly once",
118 sample_id, count
119 )));
120 }
121 }
122
123 Ok(())
124 }
125
126 fn validate_group_boundary(
127 &self,
128 fold: &FoldAssignment,
129 train: &BTreeSet<&SampleId>,
130 ) -> Result<()> {
131 if self.sample_groups.is_empty() {
132 return Ok(());
133 }
134 let train_groups = train
135 .iter()
136 .filter_map(|sample_id| self.sample_groups.get(*sample_id))
137 .collect::<BTreeSet<_>>();
138 for sample_id in &fold.validation_sample_ids {
139 let Some(group_id) = self.sample_groups.get(sample_id) else {
140 continue;
141 };
142 if train_groups.contains(group_id) {
143 return Err(DagMlError::OofValidation(format!(
144 "fold `{}` leaks group `{}` across train/validation",
145 fold.fold_id, group_id
146 )));
147 }
148 }
149 Ok(())
150 }
151}
152
153pub fn fold_set_fingerprint(fold_set: &FoldSet) -> Result<String> {
154 let mut canonical = fold_set.clone();
155 canonical.validate()?;
156 canonical.sample_ids.sort();
157 canonical
158 .folds
159 .sort_by(|left, right| left.fold_id.cmp(&right.fold_id));
160 for fold in &mut canonical.folds {
161 fold.train_sample_ids.sort();
162 fold.validation_sample_ids.sort();
163 }
164
165 let mut value = serde_json::to_value(&canonical)?;
166 remove_empty_fold_set_maps(&mut value);
167 stable_json_fingerprint(&value)
168}
169
170fn remove_empty_fold_set_maps(value: &mut serde_json::Value) {
171 let Some(object) = value.as_object_mut() else {
172 return;
173 };
174 if object
175 .get("sample_groups")
176 .and_then(serde_json::Value::as_object)
177 .is_some_and(serde_json::Map::is_empty)
178 {
179 object.remove("sample_groups");
180 }
181 let Some(folds) = object
182 .get_mut("folds")
183 .and_then(serde_json::Value::as_array_mut)
184 else {
185 return;
186 };
187 for fold in folds {
188 let Some(fold_object) = fold.as_object_mut() else {
189 continue;
190 };
191 if fold_object
192 .get("metadata")
193 .and_then(serde_json::Value::as_object)
194 .is_some_and(serde_json::Map::is_empty)
195 {
196 fold_object.remove("metadata");
197 }
198 }
199}
200
201#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
202pub struct KFoldSpec {
203 pub n_splits: usize,
204 #[serde(default)]
205 pub shuffle: bool,
206 pub seed: Option<u64>,
207}
208
209impl KFoldSpec {
210 pub fn split(&self, id: impl Into<String>, samples: &[SampleId]) -> Result<FoldSet> {
211 if self.n_splits < 2 {
212 return Err(DagMlError::OofValidation(
213 "KFold requires at least two splits".to_string(),
214 ));
215 }
216 let unique = unique_samples("KFold samples", samples)?;
217 if self.n_splits > unique.len() {
218 return Err(DagMlError::OofValidation(format!(
219 "KFold n_splits={} exceeds sample count {}",
220 self.n_splits,
221 unique.len()
222 )));
223 }
224 let ordered = ordered_samples(samples, self.shuffle, self.seed.unwrap_or(0));
225 let folds = (0..self.n_splits)
226 .map(|fold_idx| {
227 let validation = ordered
228 .iter()
229 .enumerate()
230 .filter_map(|(idx, sample_id)| {
231 (idx % self.n_splits == fold_idx).then_some(sample_id.clone())
232 })
233 .collect::<Vec<_>>();
234 let validation_set = validation.iter().collect::<BTreeSet<_>>();
235 let train = ordered
236 .iter()
237 .filter(|sample_id| !validation_set.contains(sample_id))
238 .cloned()
239 .collect::<Vec<_>>();
240 Ok(FoldAssignment {
241 fold_id: FoldId::new(format!("fold{fold_idx}"))?,
242 train_sample_ids: train,
243 validation_sample_ids: validation,
244 metadata: BTreeMap::new(),
245 })
246 })
247 .collect::<Result<Vec<_>>>()?;
248 let fold_set = FoldSet {
249 id: id.into(),
250 sample_ids: ordered_samples(samples, false, 0),
251 folds,
252 sample_groups: BTreeMap::new(),
253 };
254 fold_set.validate()?;
255 Ok(fold_set)
256 }
257}
258
259#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
264pub struct StratifiedKFoldSpec {
265 pub n_splits: usize,
266 #[serde(default)]
267 pub shuffle: bool,
268 pub seed: Option<u64>,
269}
270
271impl StratifiedKFoldSpec {
272 pub fn split(
273 &self,
274 id: impl Into<String>,
275 samples: &[SampleId],
276 strata: &BTreeMap<SampleId, String>,
277 ) -> Result<FoldSet> {
278 if self.n_splits < 2 {
279 return Err(DagMlError::OofValidation(
280 "StratifiedKFold requires at least two splits".to_string(),
281 ));
282 }
283 let unique = unique_samples("StratifiedKFold samples", samples)?;
284 if self.n_splits > unique.len() {
285 return Err(DagMlError::OofValidation(format!(
286 "StratifiedKFold n_splits={} exceeds sample count {}",
287 self.n_splits,
288 unique.len()
289 )));
290 }
291 let ordered = ordered_samples(samples, self.shuffle, self.seed.unwrap_or(0));
298 let mut by_label: BTreeMap<String, Vec<SampleId>> = BTreeMap::new();
299 for sample_id in &ordered {
300 let label = strata.get(sample_id).ok_or_else(|| {
301 DagMlError::OofValidation(format!(
302 "StratifiedKFold: sample `{sample_id}` has no stratum label"
303 ))
304 })?;
305 by_label
306 .entry(label.clone())
307 .or_default()
308 .push(sample_id.clone());
309 }
310 let mut fold_of: BTreeMap<SampleId, usize> = BTreeMap::new();
311 let mut position = 0usize;
312 for members in by_label.values() {
313 for sample_id in members {
314 fold_of.insert(sample_id.clone(), position % self.n_splits);
315 position += 1;
316 }
317 }
318 let folds = (0..self.n_splits)
319 .map(|fold_idx| {
320 let validation = ordered
321 .iter()
322 .filter(|s| fold_of.get(*s) == Some(&fold_idx))
323 .cloned()
324 .collect::<Vec<_>>();
325 let train = ordered
326 .iter()
327 .filter(|s| fold_of.get(*s) != Some(&fold_idx))
328 .cloned()
329 .collect::<Vec<_>>();
330 Ok(FoldAssignment {
331 fold_id: FoldId::new(format!("fold{fold_idx}"))?,
332 train_sample_ids: train,
333 validation_sample_ids: validation,
334 metadata: BTreeMap::new(),
335 })
336 })
337 .collect::<Result<Vec<_>>>()?;
338 let fold_set = FoldSet {
339 id: id.into(),
340 sample_ids: ordered_samples(samples, false, 0),
341 folds,
342 sample_groups: BTreeMap::new(),
343 };
344 fold_set.validate()?;
345 Ok(fold_set)
346 }
347}
348
349#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
350pub struct GroupKFoldSpec {
351 pub n_splits: usize,
352}
353
354impl GroupKFoldSpec {
355 pub fn split(
356 &self,
357 id: impl Into<String>,
358 sample_groups: &BTreeMap<SampleId, GroupId>,
359 ) -> Result<FoldSet> {
360 if self.n_splits < 2 {
361 return Err(DagMlError::OofValidation(
362 "GroupKFold requires at least two splits".to_string(),
363 ));
364 }
365 if sample_groups.is_empty() {
366 return Err(DagMlError::OofValidation(
367 "GroupKFold requires sample groups".to_string(),
368 ));
369 }
370 let mut groups = BTreeMap::<GroupId, Vec<SampleId>>::new();
371 for (sample_id, group_id) in sample_groups {
372 groups
373 .entry(group_id.clone())
374 .or_default()
375 .push(sample_id.clone());
376 }
377 if self.n_splits > groups.len() {
378 return Err(DagMlError::OofValidation(format!(
379 "GroupKFold n_splits={} exceeds group count {}",
380 self.n_splits,
381 groups.len()
382 )));
383 }
384
385 let mut grouped = groups.into_iter().collect::<Vec<_>>();
386 grouped.sort_by(|(left_group, left_samples), (right_group, right_samples)| {
387 right_samples
388 .len()
389 .cmp(&left_samples.len())
390 .then_with(|| left_group.cmp(right_group))
391 });
392
393 let mut fold_validation = vec![Vec::<SampleId>::new(); self.n_splits];
394 for (_group_id, mut samples) in grouped {
395 samples.sort();
396 let fold_idx = fold_validation
397 .iter()
398 .enumerate()
399 .min_by(|(left_idx, left), (right_idx, right)| {
400 left.len()
401 .cmp(&right.len())
402 .then_with(|| left_idx.cmp(right_idx))
403 })
404 .map(|(idx, _)| idx)
405 .expect("at least one fold");
406 fold_validation[fold_idx].extend(samples);
407 }
408
409 let mut sample_ids = sample_groups.keys().cloned().collect::<Vec<_>>();
410 sample_ids.sort();
411 let folds = fold_validation
412 .into_iter()
413 .enumerate()
414 .map(|(fold_idx, mut validation)| {
415 validation.sort();
416 let validation_set = validation.iter().collect::<BTreeSet<_>>();
417 let train = sample_ids
418 .iter()
419 .filter(|sample_id| !validation_set.contains(sample_id))
420 .cloned()
421 .collect::<Vec<_>>();
422 Ok(FoldAssignment {
423 fold_id: FoldId::new(format!("fold{fold_idx}"))?,
424 train_sample_ids: train,
425 validation_sample_ids: validation,
426 metadata: BTreeMap::new(),
427 })
428 })
429 .collect::<Result<Vec<_>>>()?;
430
431 let fold_set = FoldSet {
432 id: id.into(),
433 sample_ids,
434 folds,
435 sample_groups: sample_groups.clone(),
436 };
437 fold_set.validate()?;
438 Ok(fold_set)
439 }
440}
441
442#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
451#[serde(tag = "kind")]
452pub enum NestedCvSpec {
453 #[serde(rename = "kfold")]
455 KFold(KFoldSpec),
456 #[serde(rename = "group_kfold")]
458 GroupKFold(GroupKFoldSpec),
459}
460
461impl NestedCvSpec {
462 pub fn validate(&self) -> Result<()> {
466 match self {
467 Self::KFold(spec) => {
468 if spec.n_splits < 2 {
469 return Err(DagMlError::OofValidation(
470 "inner KFold requires at least two splits".to_string(),
471 ));
472 }
473 }
474 Self::GroupKFold(spec) => {
475 if spec.n_splits < 2 {
476 return Err(DagMlError::OofValidation(
477 "inner GroupKFold requires at least two splits".to_string(),
478 ));
479 }
480 }
481 }
482 Ok(())
483 }
484
485 pub fn build_inner_fold_set(
490 &self,
491 outer: &FoldAssignment,
492 outer_groups: &BTreeMap<SampleId, GroupId>,
493 ) -> Result<FoldSet> {
494 let inner_id = format!("{}.inner", outer.fold_id);
495 let inner = match self {
496 Self::KFold(spec) => spec.split(inner_id, &outer.train_sample_ids)?,
497 Self::GroupKFold(spec) => {
498 let train = outer.train_sample_ids.iter().collect::<BTreeSet<_>>();
499 let inner_groups = outer_groups
500 .iter()
501 .filter(|(sample_id, _)| train.contains(sample_id))
502 .map(|(sample_id, group_id)| (sample_id.clone(), group_id.clone()))
503 .collect::<BTreeMap<_, _>>();
504 spec.split(inner_id, &inner_groups)?
505 }
506 };
507 validate_inner_fold_set_within_outer(&inner, outer)?;
508 Ok(inner)
509 }
510}
511
512pub fn resolve_inner_cv<'a>(
515 node_inner_cv: Option<&'a NestedCvSpec>,
516 campaign_inner_cv: Option<&'a NestedCvSpec>,
517) -> Option<&'a NestedCvSpec> {
518 node_inner_cv.or(campaign_inner_cv)
519}
520
521pub fn validate_inner_fold_set_within_outer(inner: &FoldSet, outer: &FoldAssignment) -> Result<()> {
527 inner.validate()?;
531 let train = outer.train_sample_ids.iter().collect::<BTreeSet<_>>();
532 let ensure_train = |sample_id: &SampleId| -> Result<()> {
533 if !train.contains(sample_id) {
534 return Err(DagMlError::OofValidation(format!(
535 "nested CV leakage: inner-CV sample `{sample_id}` for outer fold `{}` is not an outer training sample",
536 outer.fold_id
537 )));
538 }
539 Ok(())
540 };
541 for sample_id in &inner.sample_ids {
542 ensure_train(sample_id)?;
543 }
544 for fold in &inner.folds {
547 for sample_id in fold
548 .train_sample_ids
549 .iter()
550 .chain(&fold.validation_sample_ids)
551 {
552 ensure_train(sample_id)?;
553 }
554 }
555 Ok(())
556}
557
558fn unique_samples<'a>(label: &str, samples: &'a [SampleId]) -> Result<BTreeSet<&'a SampleId>> {
559 let mut seen = BTreeSet::new();
560 for sample_id in samples {
561 if !seen.insert(sample_id) {
562 return Err(DagMlError::OofValidation(format!(
563 "{label} contains duplicate sample `{sample_id}`"
564 )));
565 }
566 }
567 Ok(seen)
568}
569
570fn ordered_samples(samples: &[SampleId], shuffle: bool, seed: u64) -> Vec<SampleId> {
571 let mut ordered = samples.to_vec();
572 ordered.sort();
573 if shuffle {
574 let context = SeedContext::root(seed).child("kfold");
575 ordered.sort_by(|left, right| {
576 context
577 .derive_u64(left.as_str())
578 .cmp(&context.derive_u64(right.as_str()))
579 .then_with(|| left.cmp(right))
580 });
581 }
582 ordered
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588
589 const SHARED_FOLD_SET_FINGERPRINT: &str =
590 "54d3185d6c628ef0df848828a8d8ae650222a283a78bbd3ab3bc2256f222c05c";
591
592 fn sid(value: &str) -> SampleId {
593 SampleId::new(value).unwrap()
594 }
595
596 fn gid(value: &str) -> GroupId {
597 GroupId::new(value).unwrap()
598 }
599
600 #[test]
601 fn kfold_is_deterministic_and_covers_samples_once() {
602 let samples = ["s1", "s2", "s3", "s4", "s5", "s6"]
603 .into_iter()
604 .map(sid)
605 .collect::<Vec<_>>();
606 let spec = KFoldSpec {
607 n_splits: 3,
608 shuffle: true,
609 seed: Some(42),
610 };
611
612 let left = spec.split("kfold", &samples).unwrap();
613 let right = spec.split("kfold", &samples).unwrap();
614
615 assert_eq!(left, right);
616 left.validate().unwrap();
617 for fold in &left.folds {
618 assert_eq!(fold.validation_sample_ids.len(), 2);
619 assert_eq!(fold.train_sample_ids.len(), 4);
620 }
621 }
622
623 #[test]
624 fn fold_validation_rejects_overlap() {
625 let fold_set = FoldSet {
626 id: "bad".to_string(),
627 sample_ids: vec![sid("s1"), sid("s2")],
628 folds: vec![FoldAssignment {
629 fold_id: FoldId::new("fold0").unwrap(),
630 train_sample_ids: vec![sid("s1")],
631 validation_sample_ids: vec![sid("s1")],
632 metadata: BTreeMap::new(),
633 }],
634 sample_groups: BTreeMap::new(),
635 };
636
637 assert!(fold_set.validate().is_err());
638 }
639
640 #[test]
641 fn fold_validation_rejects_partial_group_maps() {
642 let fold_set = FoldSet {
643 id: "bad-groups".to_string(),
644 sample_ids: vec![sid("s1"), sid("s2")],
645 folds: vec![FoldAssignment {
646 fold_id: FoldId::new("fold0").unwrap(),
647 train_sample_ids: vec![sid("s2")],
648 validation_sample_ids: vec![sid("s1")],
649 metadata: BTreeMap::new(),
650 }],
651 sample_groups: BTreeMap::from([(sid("s1"), gid("g1"))]),
652 };
653
654 assert!(fold_set.validate().is_err());
655 }
656
657 #[test]
658 fn fold_set_fingerprint_is_independent_of_ordering() {
659 let mut left = FoldSet {
660 id: "cv.partition".to_string(),
661 sample_ids: vec![sid("s3"), sid("s2"), sid("s1")],
662 folds: vec![
663 FoldAssignment {
664 fold_id: FoldId::new("fold1").unwrap(),
665 train_sample_ids: vec![sid("s2"), sid("s1")],
666 validation_sample_ids: vec![sid("s3")],
667 metadata: BTreeMap::new(),
668 },
669 FoldAssignment {
670 fold_id: FoldId::new("fold0").unwrap(),
671 train_sample_ids: vec![sid("s3")],
672 validation_sample_ids: vec![sid("s2"), sid("s1")],
673 metadata: BTreeMap::new(),
674 },
675 ],
676 sample_groups: BTreeMap::new(),
677 };
678 let mut right = left.clone();
679 right.sample_ids.reverse();
680 right.folds.reverse();
681 for fold in &mut right.folds {
682 fold.train_sample_ids.reverse();
683 fold.validation_sample_ids.reverse();
684 }
685
686 assert_eq!(
687 fold_set_fingerprint(&left).unwrap(),
688 fold_set_fingerprint(&right).unwrap()
689 );
690
691 left.id = "cv.partition.changed".to_string();
692 assert_ne!(
693 fold_set_fingerprint(&left).unwrap(),
694 fold_set_fingerprint(&right).unwrap()
695 );
696 }
697
698 #[test]
699 fn shared_fold_set_fixture_fingerprint_is_locked() {
700 let fixture = include_str!("../../../examples/fixtures/shared/fold_set_cv_partition.json");
701 let fold_set = serde_json::from_str::<FoldSet>(fixture).unwrap();
702
703 assert_eq!(
704 fold_set_fingerprint(&fold_set).unwrap(),
705 SHARED_FOLD_SET_FINGERPRINT
706 );
707 }
708
709 #[test]
710 fn group_kfold_keeps_groups_out_of_train_validation_overlap() {
711 let groups = BTreeMap::from([
712 (sid("s1"), gid("g1")),
713 (sid("s2"), gid("g1")),
714 (sid("s3"), gid("g2")),
715 (sid("s4"), gid("g2")),
716 (sid("s5"), gid("g3")),
717 (sid("s6"), gid("g3")),
718 ]);
719 let fold_set = GroupKFoldSpec { n_splits: 3 }
720 .split("group-kfold", &groups)
721 .unwrap();
722
723 fold_set.validate().unwrap();
724 for fold in &fold_set.folds {
725 let train_groups = fold
726 .train_sample_ids
727 .iter()
728 .map(|sample_id| groups.get(sample_id).unwrap())
729 .collect::<BTreeSet<_>>();
730 for sample_id in &fold.validation_sample_ids {
731 assert!(!train_groups.contains(groups.get(sample_id).unwrap()));
732 }
733 }
734 }
735
736 #[test]
737 fn stratified_kfold_is_oof_safe_and_balances_classes() {
738 let samples = (0..8).map(|i| sid(&format!("s{i}"))).collect::<Vec<_>>();
740 let strata = BTreeMap::from_iter(samples.iter().enumerate().map(|(i, s)| {
741 (
742 s.clone(),
743 if i % 2 == 0 {
744 "A".to_string()
745 } else {
746 "B".to_string()
747 },
748 )
749 }));
750 let fold_set = StratifiedKFoldSpec {
751 n_splits: 2,
752 shuffle: false,
753 seed: Some(0),
754 }
755 .split("strat", &samples, &strata)
756 .unwrap();
757 fold_set.validate().unwrap(); assert_eq!(fold_set.folds.len(), 2);
759 for fold in &fold_set.folds {
760 let mut counts: BTreeMap<&str, usize> = BTreeMap::new();
761 for s in &fold.validation_sample_ids {
762 *counts.entry(strata.get(s).unwrap().as_str()).or_insert(0) += 1;
763 }
764 assert_eq!(counts.get("A"), Some(&2));
765 assert_eq!(counts.get("B"), Some(&2));
766 }
767 }
768
769 #[test]
770 fn stratified_kfold_singleton_classes_leave_no_empty_fold() {
771 let samples = ["s0", "s1", "s2"].into_iter().map(sid).collect::<Vec<_>>();
774 let strata = BTreeMap::from_iter([
775 (sid("s0"), "A".to_string()),
776 (sid("s1"), "B".to_string()),
777 (sid("s2"), "C".to_string()),
778 ]);
779 let fold_set = StratifiedKFoldSpec {
780 n_splits: 3,
781 shuffle: false,
782 seed: Some(0),
783 }
784 .split("strat", &samples, &strata)
785 .expect("singleton-class stratified split must succeed");
786 fold_set.validate().unwrap();
787 for fold in &fold_set.folds {
788 assert_eq!(fold.validation_sample_ids.len(), 1);
789 }
790 }
791
792 #[test]
793 fn stratified_kfold_rejects_missing_label() {
794 let samples = (0..4).map(|i| sid(&format!("s{i}"))).collect::<Vec<_>>();
795 let strata = BTreeMap::from_iter([(sid("s0"), "A".to_string())]); let err = StratifiedKFoldSpec {
797 n_splits: 2,
798 shuffle: false,
799 seed: Some(0),
800 }
801 .split("strat", &samples, &strata);
802 assert!(err.is_err());
803 }
804
805 fn outer_kfold(samples: &[SampleId]) -> FoldSet {
806 KFoldSpec {
807 n_splits: 2,
808 shuffle: false,
809 seed: Some(0),
810 }
811 .split("outer", samples)
812 .unwrap()
813 }
814
815 #[test]
816 fn nested_kfold_inner_folds_are_subset_of_outer_train() {
817 let samples = ["s1", "s2", "s3", "s4", "s5", "s6"]
818 .into_iter()
819 .map(sid)
820 .collect::<Vec<_>>();
821 let outer = outer_kfold(&samples);
822 let spec = NestedCvSpec::KFold(KFoldSpec {
823 n_splits: 2,
824 shuffle: false,
825 seed: Some(1),
826 });
827 for outer_fold in &outer.folds {
828 let inner = spec
829 .build_inner_fold_set(outer_fold, &outer.sample_groups)
830 .expect("inner fold set");
831 let outer_train = outer_fold.train_sample_ids.iter().collect::<BTreeSet<_>>();
832 for sample_id in &inner.sample_ids {
834 assert!(outer_train.contains(sample_id));
835 }
836 inner.validate().unwrap();
838 assert_eq!(
839 inner.sample_ids.iter().collect::<BTreeSet<_>>(),
840 outer_train
841 );
842 }
843 }
844
845 #[test]
846 fn nested_cv_validation_refuses_inner_sample_from_outer_validation() {
847 let samples = ["s1", "s2", "s3", "s4"]
848 .into_iter()
849 .map(sid)
850 .collect::<Vec<_>>();
851 let outer = outer_kfold(&samples);
852 let outer_fold = &outer.folds[0];
853 let leaking_sample = outer_fold.validation_sample_ids[0].clone();
856 let train_sample = outer_fold.train_sample_ids[0].clone();
857 let inner = FoldSet {
858 id: "leaky.inner".to_string(),
859 sample_ids: vec![train_sample.clone(), leaking_sample.clone()],
860 folds: vec![
861 FoldAssignment {
862 fold_id: FoldId::new("if0").unwrap(),
863 train_sample_ids: vec![leaking_sample.clone()],
864 validation_sample_ids: vec![train_sample.clone()],
865 metadata: BTreeMap::new(),
866 },
867 FoldAssignment {
868 fold_id: FoldId::new("if1").unwrap(),
869 train_sample_ids: vec![train_sample],
870 validation_sample_ids: vec![leaking_sample],
871 metadata: BTreeMap::new(),
872 },
873 ],
874 sample_groups: BTreeMap::new(),
875 };
876 inner
877 .validate()
878 .expect("inner fold set is structurally valid");
879 let err = validate_inner_fold_set_within_outer(&inner, outer_fold)
880 .expect_err("inner fold leaking an outer-validation sample must be refused");
881 assert!(err.to_string().contains("nested CV leakage"));
882 }
883
884 #[test]
885 fn nested_cv_validation_refuses_leak_hidden_in_fold_members() {
886 let samples = ["s1", "s2", "s3", "s4"]
890 .into_iter()
891 .map(sid)
892 .collect::<Vec<_>>();
893 let outer = outer_kfold(&samples);
894 let outer_fold = &outer.folds[0];
895 let leaking_sample = outer_fold.validation_sample_ids[0].clone();
896 let train_sample = outer_fold.train_sample_ids[0].clone();
897 let inner = FoldSet {
898 id: "hidden.inner".to_string(),
899 sample_ids: vec![train_sample.clone()],
901 folds: vec![FoldAssignment {
902 fold_id: FoldId::new("if0").unwrap(),
903 train_sample_ids: vec![train_sample],
904 validation_sample_ids: vec![leaking_sample],
905 metadata: BTreeMap::new(),
906 }],
907 sample_groups: BTreeMap::new(),
908 };
909 assert!(validate_inner_fold_set_within_outer(&inner, outer_fold).is_err());
910 }
911
912 #[test]
913 fn nested_cv_spec_json_shape_is_stable() {
914 let spec = NestedCvSpec::KFold(KFoldSpec {
915 n_splits: 3,
916 shuffle: false,
917 seed: Some(7),
918 });
919 let value = serde_json::to_value(&spec).unwrap();
920 assert_eq!(value["kind"], "kfold");
921 assert_eq!(value["n_splits"], 3);
922 assert_eq!(value["seed"], 7);
923 let round: NestedCvSpec = serde_json::from_value(value).unwrap();
924 assert_eq!(round, spec);
925
926 let group = NestedCvSpec::GroupKFold(GroupKFoldSpec { n_splits: 2 });
927 let gv = serde_json::to_value(&group).unwrap();
928 assert_eq!(gv["kind"], "group_kfold");
929 assert_eq!(gv["n_splits"], 2);
930 assert_eq!(serde_json::from_value::<NestedCvSpec>(gv).unwrap(), group);
931 }
932
933 #[test]
934 fn resolve_inner_cv_prefers_node_over_campaign() {
935 let node = NestedCvSpec::KFold(KFoldSpec {
936 n_splits: 3,
937 shuffle: false,
938 seed: Some(2),
939 });
940 let campaign = NestedCvSpec::KFold(KFoldSpec {
941 n_splits: 5,
942 shuffle: false,
943 seed: Some(3),
944 });
945 assert_eq!(resolve_inner_cv(Some(&node), Some(&campaign)), Some(&node));
946 assert_eq!(resolve_inner_cv(None, Some(&campaign)), Some(&campaign));
947 assert_eq!(resolve_inner_cv(Some(&node), None), Some(&node));
948 assert_eq!(resolve_inner_cv(None, None), None);
949 }
950}