Skip to main content

dag_ml_core/
oof.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::campaign::stable_json_fingerprint;
6use crate::error::{DagMlError, OofLeakageReport, OofLeakageViolation, Result};
7use crate::fold::FoldSet;
8use crate::ids::{FoldId, NodeId, SampleId};
9
10#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
11#[serde(rename_all = "snake_case")]
12pub enum PredictionPartition {
13    Train,
14    Validation,
15    Test,
16    Final,
17}
18
19#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
20#[serde(rename_all = "snake_case")]
21pub enum PredictionJoinKey {
22    SampleId,
23}
24
25fn default_prediction_join_key() -> PredictionJoinKey {
26    PredictionJoinKey::SampleId
27}
28
29#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
30pub struct PredictionBlock {
31    #[serde(default)]
32    pub prediction_id: Option<String>,
33    pub producer_node: NodeId,
34    pub partition: PredictionPartition,
35    pub fold_id: Option<FoldId>,
36    pub sample_ids: Vec<SampleId>,
37    pub values: Vec<Vec<f64>>,
38    #[serde(default)]
39    pub target_names: Vec<String>,
40}
41
42impl PredictionBlock {
43    pub fn validate_shape(&self) -> Result<usize> {
44        if self.sample_ids.len() != self.values.len() {
45            return Err(DagMlError::OofValidation(format!(
46                "producer `{}` has {} sample ids but {} prediction rows",
47                self.producer_node,
48                self.sample_ids.len(),
49                self.values.len()
50            )));
51        }
52        let width = self.values.first().map_or(0, Vec::len);
53        if width == 0 {
54            return Err(DagMlError::OofValidation(format!(
55                "producer `{}` emitted empty prediction rows",
56                self.producer_node
57            )));
58        }
59        if self.values.iter().any(|row| row.len() != width) {
60            return Err(DagMlError::OofValidation(format!(
61                "producer `{}` emitted ragged prediction rows",
62                self.producer_node
63            )));
64        }
65        if !self.target_names.is_empty() && self.target_names.len() != width {
66            return Err(DagMlError::OofValidation(format!(
67                "producer `{}` has {} target names for width {}",
68                self.producer_node,
69                self.target_names.len(),
70                width
71            )));
72        }
73        Ok(width)
74    }
75}
76
77#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
78pub struct OofMatrix {
79    pub sample_ids: Vec<SampleId>,
80    pub columns: Vec<String>,
81    pub values: Vec<Vec<f64>>,
82}
83
84#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
85pub struct OofCampaign {
86    pub fold_set: FoldSet,
87    pub join_policy: PredictionJoinPolicy,
88    pub requested_sample_order: Vec<SampleId>,
89    pub prediction_blocks: Vec<PredictionBlock>,
90}
91
92#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
93pub struct PredictionJoinPolicy {
94    pub node_id: NodeId,
95    #[serde(default = "default_prediction_join_key")]
96    pub join_on: PredictionJoinKey,
97    #[serde(default)]
98    pub allow_train_predictions_as_features: bool,
99    #[serde(default)]
100    pub include_partitions: Vec<PredictionPartition>,
101}
102
103#[derive(Clone, Debug)]
104struct ProducerPredictions {
105    width: usize,
106    target_names: Vec<String>,
107    by_sample: BTreeMap<SampleId, Vec<f64>>,
108}
109
110pub fn join_oof_features(
111    blocks: &[PredictionBlock],
112    required_samples: &[SampleId],
113) -> Result<OofMatrix> {
114    validate_prediction_blocks_are_oof(
115        &PredictionJoinPolicy {
116            node_id: NodeId::new("prediction_join")?,
117            join_on: PredictionJoinKey::SampleId,
118            allow_train_predictions_as_features: false,
119            include_partitions: vec![PredictionPartition::Validation],
120        },
121        blocks,
122    )?;
123    if required_samples.is_empty() {
124        return Err(DagMlError::OofValidation(
125            "required sample set is empty".to_string(),
126        ));
127    }
128
129    let required = required_samples.iter().collect::<BTreeSet<_>>();
130    if required.len() != required_samples.len() {
131        return Err(DagMlError::OofValidation(
132            "required sample set contains duplicates".to_string(),
133        ));
134    }
135
136    let mut rows = required_samples
137        .iter()
138        .cloned()
139        .map(|sample_id| (sample_id, Vec::<f64>::new()))
140        .collect::<BTreeMap<_, _>>();
141    let mut columns = Vec::new();
142
143    for block in blocks {
144        let width = block.validate_shape()?;
145        let mut seen = BTreeSet::new();
146        let mut by_sample = BTreeMap::new();
147        for (sample_id, values) in block.sample_ids.iter().zip(block.values.iter()) {
148            if !seen.insert(sample_id) {
149                return Err(DagMlError::OofValidation(format!(
150                    "producer `{}` emitted duplicate prediction for sample `{}`",
151                    block.producer_node, sample_id
152                )));
153            }
154            by_sample.insert(sample_id, values);
155        }
156
157        for sample_id in required_samples {
158            let values = by_sample.get(sample_id).ok_or_else(|| {
159                DagMlError::OofValidation(format!(
160                    "producer `{}` is missing required sample `{}`",
161                    block.producer_node, sample_id
162                ))
163            })?;
164            rows.get_mut(sample_id)
165                .expect("required sample row exists")
166                .extend(values.iter().copied());
167        }
168
169        for column_idx in 0..width {
170            let target = block
171                .target_names
172                .get(column_idx)
173                .cloned()
174                .unwrap_or_else(|| format!("p{column_idx}"));
175            columns.push(format!("{}__{target}", block.producer_node));
176        }
177    }
178
179    Ok(OofMatrix {
180        sample_ids: required_samples.to_vec(),
181        columns,
182        values: required_samples
183            .iter()
184            .map(|sample_id| rows.remove(sample_id).expect("row exists"))
185            .collect(),
186    })
187}
188
189pub fn join_oof_campaign_features(
190    policy: &PredictionJoinPolicy,
191    blocks: &[PredictionBlock],
192    required_samples: &[SampleId],
193) -> Result<OofMatrix> {
194    validate_prediction_blocks_are_oof(policy, blocks)?;
195    ensure_required_samples(required_samples)?;
196
197    let required = required_samples.iter().collect::<BTreeSet<_>>();
198    let included_partitions = effective_partitions(policy);
199    let mut producers = BTreeMap::<NodeId, ProducerPredictions>::new();
200
201    for block in blocks {
202        if !included_partitions.contains(&block.partition) {
203            continue;
204        }
205        let width = block.validate_shape()?;
206        let target_names = normalized_targets(block, width);
207        let producer = producers
208            .entry(block.producer_node.clone())
209            .or_insert_with(|| ProducerPredictions {
210                width,
211                target_names: target_names.clone(),
212                by_sample: BTreeMap::new(),
213            });
214        if producer.width != width {
215            return Err(DagMlError::OofValidation(format!(
216                "producer `{}` changed prediction width from {} to {}",
217                block.producer_node, producer.width, width
218            )));
219        }
220        if producer.target_names != target_names {
221            return Err(DagMlError::OofValidation(format!(
222                "producer `{}` changed target names across folds",
223                block.producer_node
224            )));
225        }
226
227        for (sample_id, values) in block.sample_ids.iter().zip(block.values.iter()) {
228            if !required.contains(sample_id) {
229                return Err(DagMlError::OofValidation(format!(
230                    "producer `{}` emitted unexpected sample `{}`",
231                    block.producer_node, sample_id
232                )));
233            }
234            if producer
235                .by_sample
236                .insert(sample_id.clone(), values.clone())
237                .is_some()
238            {
239                return Err(DagMlError::OofValidation(format!(
240                    "producer `{}` emitted duplicate OOF prediction for sample `{}`",
241                    block.producer_node, sample_id
242                )));
243            }
244        }
245    }
246
247    if producers.is_empty() {
248        return Err(DagMlError::OofValidation(
249            "no prediction blocks were selected for OOF join".to_string(),
250        ));
251    }
252
253    for (producer_node, producer) in &producers {
254        for sample_id in required_samples {
255            if !producer.by_sample.contains_key(sample_id) {
256                return Err(DagMlError::OofValidation(format!(
257                    "producer `{producer_node}` is missing required sample `{sample_id}`"
258                )));
259            }
260        }
261    }
262
263    let producer_predictions = producers.into_iter().collect::<Vec<_>>();
264    let columns = producer_predictions
265        .iter()
266        .flat_map(|(producer_node, producer)| {
267            producer
268                .target_names
269                .iter()
270                .map(move |target| format!("{producer_node}__{target}"))
271        })
272        .collect::<Vec<_>>();
273    let values = required_samples
274        .iter()
275        .map(|sample_id| {
276            let mut row = Vec::new();
277            for (_producer_node, producer) in &producer_predictions {
278                row.extend(
279                    producer
280                        .by_sample
281                        .get(sample_id)
282                        .expect("required sample was checked")
283                        .iter()
284                        .copied(),
285                );
286            }
287            row
288        })
289        .collect::<Vec<_>>();
290
291    Ok(OofMatrix {
292        sample_ids: required_samples.to_vec(),
293        columns,
294        values,
295    })
296}
297
298pub fn validate_oof_campaign(campaign: &OofCampaign) -> Result<OofMatrix> {
299    campaign.fold_set.validate()?;
300    validate_requested_samples_match_fold_set(
301        &campaign.requested_sample_order,
302        &campaign.fold_set,
303    )?;
304    validate_prediction_blocks_against_folds(&campaign.fold_set, &campaign.prediction_blocks)?;
305    join_oof_campaign_features(
306        &campaign.join_policy,
307        &campaign.prediction_blocks,
308        &campaign.requested_sample_order,
309    )
310}
311
312pub fn oof_campaign_fingerprint(campaign: &OofCampaign) -> Result<String> {
313    campaign.fold_set.validate()?;
314    validate_requested_samples_match_fold_set(
315        &campaign.requested_sample_order,
316        &campaign.fold_set,
317    )?;
318    validate_prediction_blocks_against_folds(&campaign.fold_set, &campaign.prediction_blocks)?;
319    stable_json_fingerprint(campaign)
320}
321
322pub fn validate_prediction_blocks_against_folds(
323    fold_set: &FoldSet,
324    blocks: &[PredictionBlock],
325) -> Result<()> {
326    fold_set.validate()?;
327    let folds = fold_set
328        .folds
329        .iter()
330        .map(|fold| (&fold.fold_id, fold))
331        .collect::<BTreeMap<_, _>>();
332    for block in blocks {
333        block.validate_shape()?;
334        let Some(fold_id) = &block.fold_id else {
335            if matches!(
336                block.partition,
337                PredictionPartition::Train | PredictionPartition::Validation
338            ) {
339                return Err(DagMlError::OofValidation(format!(
340                    "producer `{}` emitted {:?} predictions without fold_id",
341                    block.producer_node, block.partition
342                )));
343            }
344            continue;
345        };
346        let fold = folds.get(fold_id).ok_or_else(|| {
347            DagMlError::OofValidation(format!(
348                "producer `{}` references unknown fold `{fold_id}`",
349                block.producer_node
350            ))
351        })?;
352        match block.partition {
353            PredictionPartition::Train => {
354                assert_exact_partition_samples(block, &fold.train_sample_ids, "train")?
355            }
356            PredictionPartition::Validation => {
357                assert_exact_partition_samples(block, &fold.validation_sample_ids, "validation")?
358            }
359            PredictionPartition::Test | PredictionPartition::Final => {}
360        }
361    }
362    Ok(())
363}
364
365pub fn validate_prediction_blocks_are_oof(
366    policy: &PredictionJoinPolicy,
367    blocks: &[PredictionBlock],
368) -> Result<()> {
369    if policy.allow_train_predictions_as_features {
370        return Ok(());
371    }
372    let violators = blocks
373        .iter()
374        .filter(|block| block.partition != PredictionPartition::Validation)
375        .map(|block| OofLeakageViolation {
376            producer_node: block.producer_node.to_string(),
377            partition: format!("{:?}", block.partition).to_lowercase(),
378            fold_id: block.fold_id.as_ref().map(ToString::to_string),
379        })
380        .collect::<Vec<_>>();
381    if violators.is_empty() {
382        Ok(())
383    } else {
384        crate::observability::emit_oof_refusal(policy.node_id.as_str(), violators.len());
385        Err(DagMlError::OofLeakage(Box::new(OofLeakageReport {
386            node_id: policy.node_id.to_string(),
387            violators,
388            allow_train_predictions_as_features: policy.allow_train_predictions_as_features,
389            remediation: "Use only OOF validation predictions as training features, or explicitly set allow_train_predictions_as_features=true for an unsafe run.".to_string(),
390        })))
391    }
392}
393
394fn validate_requested_samples_match_fold_set(
395    requested_sample_order: &[SampleId],
396    fold_set: &FoldSet,
397) -> Result<()> {
398    ensure_required_samples(requested_sample_order)?;
399    let requested = requested_sample_order.iter().collect::<BTreeSet<_>>();
400    let expected = fold_set.sample_ids.iter().collect::<BTreeSet<_>>();
401    if requested != expected {
402        return Err(DagMlError::OofValidation(
403            "requested sample order does not match fold-set sample universe".to_string(),
404        ));
405    }
406    Ok(())
407}
408
409fn assert_exact_partition_samples(
410    block: &PredictionBlock,
411    expected_samples: &[SampleId],
412    partition_name: &str,
413) -> Result<()> {
414    let actual = unique_block_samples(block)?;
415    let expected = expected_samples.iter().collect::<BTreeSet<_>>();
416    if actual != expected {
417        return Err(DagMlError::OofValidation(format!(
418            "producer `{}` fold `{}` {} predictions do not match fold {} samples",
419            block.producer_node,
420            block.fold_id.as_ref().expect("fold id exists"),
421            partition_name,
422            partition_name
423        )));
424    }
425    Ok(())
426}
427
428fn unique_block_samples(block: &PredictionBlock) -> Result<BTreeSet<&SampleId>> {
429    let mut seen = BTreeSet::new();
430    for sample_id in &block.sample_ids {
431        if !seen.insert(sample_id) {
432            return Err(DagMlError::OofValidation(format!(
433                "producer `{}` emitted duplicate prediction for sample `{sample_id}`",
434                block.producer_node
435            )));
436        }
437    }
438    Ok(seen)
439}
440
441fn ensure_required_samples(required_samples: &[SampleId]) -> Result<()> {
442    if required_samples.is_empty() {
443        return Err(DagMlError::OofValidation(
444            "required sample set is empty".to_string(),
445        ));
446    }
447    let required = required_samples.iter().collect::<BTreeSet<_>>();
448    if required.len() != required_samples.len() {
449        return Err(DagMlError::OofValidation(
450            "required sample set contains duplicates".to_string(),
451        ));
452    }
453    Ok(())
454}
455
456fn effective_partitions(policy: &PredictionJoinPolicy) -> BTreeSet<PredictionPartition> {
457    if policy.include_partitions.is_empty() {
458        BTreeSet::from([PredictionPartition::Validation])
459    } else {
460        policy.include_partitions.iter().cloned().collect()
461    }
462}
463
464fn normalized_targets(block: &PredictionBlock, width: usize) -> Vec<String> {
465    if block.target_names.is_empty() {
466        (0..width)
467            .map(|column_idx| format!("p{column_idx}"))
468            .collect()
469    } else {
470        block.target_names.clone()
471    }
472}
473
474#[cfg(test)]
475mod tests {
476    use std::time::{Duration, Instant};
477
478    use super::*;
479
480    fn sid(value: &str) -> SampleId {
481        SampleId::new(value).unwrap()
482    }
483
484    fn producer() -> NodeId {
485        NodeId::new("model:base").unwrap()
486    }
487
488    fn block(partition: PredictionPartition) -> PredictionBlock {
489        PredictionBlock {
490            prediction_id: None,
491            producer_node: producer(),
492            partition,
493            fold_id: Some(FoldId::new("fold0").unwrap()),
494            sample_ids: vec![sid("s2"), sid("s1")],
495            values: vec![vec![20.0], vec![10.0]],
496            target_names: vec!["y".to_string()],
497        }
498    }
499
500    fn campaign_block(producer_node: &str, fold_id: &str, samples: &[&str]) -> PredictionBlock {
501        PredictionBlock {
502            prediction_id: None,
503            producer_node: NodeId::new(producer_node).unwrap(),
504            partition: PredictionPartition::Validation,
505            fold_id: Some(FoldId::new(fold_id).unwrap()),
506            sample_ids: samples.iter().copied().map(sid).collect(),
507            values: samples
508                .iter()
509                .map(|sample_id| {
510                    let suffix = sample_id.trim_start_matches('s').parse::<f64>().unwrap();
511                    vec![suffix]
512                })
513                .collect(),
514            target_names: vec!["y".to_string()],
515        }
516    }
517
518    fn load_fixture(source: &str) -> OofCampaign {
519        serde_json::from_str(source).unwrap()
520    }
521
522    #[test]
523    fn aligns_oof_by_sample_id_not_position() {
524        let joined = join_oof_features(
525            &[block(PredictionPartition::Validation)],
526            &[sid("s1"), sid("s2")],
527        )
528        .unwrap();
529
530        assert_eq!(joined.values, vec![vec![10.0], vec![20.0]]);
531        assert_eq!(joined.columns, vec!["model:base__y"]);
532    }
533
534    #[test]
535    fn rejects_train_predictions_as_training_features() {
536        let err = join_oof_features(
537            &[block(PredictionPartition::Train)],
538            &[sid("s1"), sid("s2")],
539        )
540        .unwrap_err();
541
542        match err {
543            DagMlError::OofLeakage(report) => {
544                assert_eq!(report.violators[0].producer_node, "model:base");
545                assert_eq!(report.violators[0].partition, "train");
546            }
547            other => panic!("expected OOF leakage error, got {other:?}"),
548        }
549    }
550
551    #[test]
552    fn rejects_duplicate_samples() {
553        let mut duplicate = block(PredictionPartition::Validation);
554        duplicate.sample_ids = vec![sid("s1"), sid("s1")];
555
556        assert!(join_oof_features(&[duplicate], &[sid("s1")]).is_err());
557    }
558
559    #[test]
560    fn joins_fold_blocks_by_producer_for_campaigns() {
561        let mut b1_fold0 = campaign_block("branch:b1.model:rf", "fold0", &["s4", "s1"]);
562        b1_fold0.values = vec![vec![40.0], vec![10.0]];
563        let mut b1_fold1 = campaign_block("branch:b1.model:rf", "fold1", &["s2", "s3"]);
564        b1_fold1.values = vec![vec![20.0], vec![30.0]];
565        let mut b0_fold0 = campaign_block("branch:b0.model:pls", "fold0", &["s4", "s1"]);
566        b0_fold0.values = vec![vec![4.0], vec![1.0]];
567        let mut b0_fold1 = campaign_block("branch:b0.model:pls", "fold1", &["s2", "s3"]);
568        b0_fold1.values = vec![vec![2.0], vec![3.0]];
569
570        let joined = join_oof_campaign_features(
571            &PredictionJoinPolicy {
572                node_id: NodeId::new("merge:pred").unwrap(),
573                join_on: PredictionJoinKey::SampleId,
574                allow_train_predictions_as_features: false,
575                include_partitions: vec![PredictionPartition::Validation],
576            },
577            &[b1_fold0, b1_fold1, b0_fold0, b0_fold1],
578            &[sid("s1"), sid("s2"), sid("s3"), sid("s4")],
579        )
580        .unwrap();
581
582        assert_eq!(
583            joined.columns,
584            vec!["branch:b0.model:pls__y", "branch:b1.model:rf__y"]
585        );
586        assert_eq!(
587            joined.values,
588            vec![
589                vec![1.0, 10.0],
590                vec![2.0, 20.0],
591                vec![3.0, 30.0],
592                vec![4.0, 40.0]
593            ]
594        );
595    }
596
597    #[test]
598    fn uc6_fixture_joins_successfully() {
599        let fixture = load_fixture(include_str!(
600            "../../../examples/fixtures/oof_campaign/uc6_oof_success_predictions.json"
601        ));
602
603        let joined = validate_oof_campaign(&fixture).unwrap();
604        assert_eq!(
605            oof_campaign_fingerprint(&fixture).unwrap(),
606            oof_campaign_fingerprint(&fixture).unwrap()
607        );
608
609        assert_eq!(joined.columns.len(), 3);
610        assert_eq!(joined.values[0], vec![1.0, 10.0, 100.0]);
611        assert_eq!(joined.values[5], vec![6.0, 60.0, 600.0]);
612    }
613
614    #[test]
615    fn uc11_fixture_refuses_train_predictions() {
616        let fixture = load_fixture(include_str!(
617            "../../../examples/fixtures/oof_campaign/uc11_train_prediction_refusal.json"
618        ));
619
620        let err = validate_oof_campaign(&fixture).unwrap_err();
621
622        match err {
623            DagMlError::OofLeakage(report) => {
624                assert_eq!(report.node_id, "merge:pred");
625                assert!(!report.allow_train_predictions_as_features);
626                assert_eq!(report.violators.len(), 1);
627                assert_eq!(report.violators[0].partition, "train");
628            }
629            other => panic!("expected OOF leakage error, got {other:?}"),
630        }
631    }
632
633    #[test]
634    fn fold_validation_rejects_wrong_validation_partition_samples() {
635        let mut fixture = load_fixture(include_str!(
636            "../../../examples/fixtures/oof_campaign/uc6_oof_success_predictions.json"
637        ));
638        fixture.prediction_blocks[0].sample_ids = vec![sid("S001"), sid("S002")];
639
640        let err = validate_oof_campaign(&fixture).unwrap_err();
641
642        assert!(err
643            .to_string()
644            .contains("do not match fold validation samples"));
645    }
646
647    #[test]
648    #[ignore = "perf sanity probe; run with --release --ignored --nocapture"]
649    fn oof_join_large_campaign_under_1500ms() {
650        let sample_count = 12_000usize;
651        let producer_count = 4usize;
652        let fold_count = 6usize;
653        let required_samples = (0..sample_count)
654            .map(|sample_idx| sid(&format!("s{sample_idx:05}")))
655            .collect::<Vec<_>>();
656        let mut blocks = Vec::new();
657
658        for producer_idx in 0..producer_count {
659            for fold_idx in 0..fold_count {
660                let sample_ids = (fold_idx..sample_count)
661                    .step_by(fold_count)
662                    .map(|sample_idx| sid(&format!("s{sample_idx:05}")))
663                    .collect::<Vec<_>>();
664                let values = (fold_idx..sample_count)
665                    .step_by(fold_count)
666                    .map(|sample_idx| vec![producer_idx as f64, sample_idx as f64])
667                    .collect::<Vec<_>>();
668                blocks.push(PredictionBlock {
669                    prediction_id: None,
670                    producer_node: NodeId::new(format!("model:p{producer_idx}")).unwrap(),
671                    partition: PredictionPartition::Validation,
672                    fold_id: Some(FoldId::new(format!("fold:{fold_idx}")).unwrap()),
673                    sample_ids,
674                    values,
675                    target_names: vec!["score".to_string(), "rank".to_string()],
676                });
677            }
678        }
679
680        let started = Instant::now();
681        let joined = join_oof_campaign_features(
682            &PredictionJoinPolicy {
683                node_id: NodeId::new("merge:perf").unwrap(),
684                join_on: PredictionJoinKey::SampleId,
685                allow_train_predictions_as_features: false,
686                include_partitions: vec![PredictionPartition::Validation],
687            },
688            &blocks,
689            &required_samples,
690        )
691        .unwrap();
692        let elapsed = started.elapsed();
693
694        assert_eq!(joined.sample_ids.len(), sample_count);
695        assert_eq!(joined.columns.len(), producer_count * 2);
696        assert!(
697            elapsed <= Duration::from_millis(1_500),
698            "large OOF join took {elapsed:?}"
699        );
700    }
701}