1use std::collections::{BTreeMap, BTreeSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::campaign::stable_json_fingerprint;
6use crate::error::{DagMlError, OofLeakageReport, OofLeakageViolation, Result};
7use crate::fold::FoldSet;
8use crate::ids::{FoldId, NodeId, SampleId};
9
10#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
11#[serde(rename_all = "snake_case")]
12pub enum PredictionPartition {
13 Train,
14 Validation,
15 Test,
16 Final,
17}
18
19#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
20#[serde(rename_all = "snake_case")]
21pub enum PredictionJoinKey {
22 SampleId,
23}
24
25fn default_prediction_join_key() -> PredictionJoinKey {
26 PredictionJoinKey::SampleId
27}
28
29#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
30pub struct PredictionBlock {
31 #[serde(default)]
32 pub prediction_id: Option<String>,
33 pub producer_node: NodeId,
34 pub partition: PredictionPartition,
35 pub fold_id: Option<FoldId>,
36 pub sample_ids: Vec<SampleId>,
37 pub values: Vec<Vec<f64>>,
38 #[serde(default)]
39 pub target_names: Vec<String>,
40}
41
42impl PredictionBlock {
43 pub fn validate_shape(&self) -> Result<usize> {
44 if self.sample_ids.len() != self.values.len() {
45 return Err(DagMlError::OofValidation(format!(
46 "producer `{}` has {} sample ids but {} prediction rows",
47 self.producer_node,
48 self.sample_ids.len(),
49 self.values.len()
50 )));
51 }
52 let width = self.values.first().map_or(0, Vec::len);
53 if width == 0 {
54 return Err(DagMlError::OofValidation(format!(
55 "producer `{}` emitted empty prediction rows",
56 self.producer_node
57 )));
58 }
59 if self.values.iter().any(|row| row.len() != width) {
60 return Err(DagMlError::OofValidation(format!(
61 "producer `{}` emitted ragged prediction rows",
62 self.producer_node
63 )));
64 }
65 if !self.target_names.is_empty() && self.target_names.len() != width {
66 return Err(DagMlError::OofValidation(format!(
67 "producer `{}` has {} target names for width {}",
68 self.producer_node,
69 self.target_names.len(),
70 width
71 )));
72 }
73 Ok(width)
74 }
75}
76
77#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
78pub struct OofMatrix {
79 pub sample_ids: Vec<SampleId>,
80 pub columns: Vec<String>,
81 pub values: Vec<Vec<f64>>,
82}
83
84#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
85pub struct OofCampaign {
86 pub fold_set: FoldSet,
87 pub join_policy: PredictionJoinPolicy,
88 pub requested_sample_order: Vec<SampleId>,
89 pub prediction_blocks: Vec<PredictionBlock>,
90}
91
92#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
93pub struct PredictionJoinPolicy {
94 pub node_id: NodeId,
95 #[serde(default = "default_prediction_join_key")]
96 pub join_on: PredictionJoinKey,
97 #[serde(default)]
98 pub allow_train_predictions_as_features: bool,
99 #[serde(default)]
100 pub include_partitions: Vec<PredictionPartition>,
101}
102
103#[derive(Clone, Debug)]
104struct ProducerPredictions {
105 width: usize,
106 target_names: Vec<String>,
107 by_sample: BTreeMap<SampleId, Vec<f64>>,
108}
109
110pub fn join_oof_features(
111 blocks: &[PredictionBlock],
112 required_samples: &[SampleId],
113) -> Result<OofMatrix> {
114 validate_prediction_blocks_are_oof(
115 &PredictionJoinPolicy {
116 node_id: NodeId::new("prediction_join")?,
117 join_on: PredictionJoinKey::SampleId,
118 allow_train_predictions_as_features: false,
119 include_partitions: vec![PredictionPartition::Validation],
120 },
121 blocks,
122 )?;
123 if required_samples.is_empty() {
124 return Err(DagMlError::OofValidation(
125 "required sample set is empty".to_string(),
126 ));
127 }
128
129 let required = required_samples.iter().collect::<BTreeSet<_>>();
130 if required.len() != required_samples.len() {
131 return Err(DagMlError::OofValidation(
132 "required sample set contains duplicates".to_string(),
133 ));
134 }
135
136 let mut rows = required_samples
137 .iter()
138 .cloned()
139 .map(|sample_id| (sample_id, Vec::<f64>::new()))
140 .collect::<BTreeMap<_, _>>();
141 let mut columns = Vec::new();
142
143 for block in blocks {
144 let width = block.validate_shape()?;
145 let mut seen = BTreeSet::new();
146 let mut by_sample = BTreeMap::new();
147 for (sample_id, values) in block.sample_ids.iter().zip(block.values.iter()) {
148 if !seen.insert(sample_id) {
149 return Err(DagMlError::OofValidation(format!(
150 "producer `{}` emitted duplicate prediction for sample `{}`",
151 block.producer_node, sample_id
152 )));
153 }
154 by_sample.insert(sample_id, values);
155 }
156
157 for sample_id in required_samples {
158 let values = by_sample.get(sample_id).ok_or_else(|| {
159 DagMlError::OofValidation(format!(
160 "producer `{}` is missing required sample `{}`",
161 block.producer_node, sample_id
162 ))
163 })?;
164 rows.get_mut(sample_id)
165 .expect("required sample row exists")
166 .extend(values.iter().copied());
167 }
168
169 for column_idx in 0..width {
170 let target = block
171 .target_names
172 .get(column_idx)
173 .cloned()
174 .unwrap_or_else(|| format!("p{column_idx}"));
175 columns.push(format!("{}__{target}", block.producer_node));
176 }
177 }
178
179 Ok(OofMatrix {
180 sample_ids: required_samples.to_vec(),
181 columns,
182 values: required_samples
183 .iter()
184 .map(|sample_id| rows.remove(sample_id).expect("row exists"))
185 .collect(),
186 })
187}
188
189pub fn join_oof_campaign_features(
190 policy: &PredictionJoinPolicy,
191 blocks: &[PredictionBlock],
192 required_samples: &[SampleId],
193) -> Result<OofMatrix> {
194 validate_prediction_blocks_are_oof(policy, blocks)?;
195 ensure_required_samples(required_samples)?;
196
197 let required = required_samples.iter().collect::<BTreeSet<_>>();
198 let included_partitions = effective_partitions(policy);
199 let mut producers = BTreeMap::<NodeId, ProducerPredictions>::new();
200
201 for block in blocks {
202 if !included_partitions.contains(&block.partition) {
203 continue;
204 }
205 let width = block.validate_shape()?;
206 let target_names = normalized_targets(block, width);
207 let producer = producers
208 .entry(block.producer_node.clone())
209 .or_insert_with(|| ProducerPredictions {
210 width,
211 target_names: target_names.clone(),
212 by_sample: BTreeMap::new(),
213 });
214 if producer.width != width {
215 return Err(DagMlError::OofValidation(format!(
216 "producer `{}` changed prediction width from {} to {}",
217 block.producer_node, producer.width, width
218 )));
219 }
220 if producer.target_names != target_names {
221 return Err(DagMlError::OofValidation(format!(
222 "producer `{}` changed target names across folds",
223 block.producer_node
224 )));
225 }
226
227 for (sample_id, values) in block.sample_ids.iter().zip(block.values.iter()) {
228 if !required.contains(sample_id) {
229 return Err(DagMlError::OofValidation(format!(
230 "producer `{}` emitted unexpected sample `{}`",
231 block.producer_node, sample_id
232 )));
233 }
234 if producer
235 .by_sample
236 .insert(sample_id.clone(), values.clone())
237 .is_some()
238 {
239 return Err(DagMlError::OofValidation(format!(
240 "producer `{}` emitted duplicate OOF prediction for sample `{}`",
241 block.producer_node, sample_id
242 )));
243 }
244 }
245 }
246
247 if producers.is_empty() {
248 return Err(DagMlError::OofValidation(
249 "no prediction blocks were selected for OOF join".to_string(),
250 ));
251 }
252
253 for (producer_node, producer) in &producers {
254 for sample_id in required_samples {
255 if !producer.by_sample.contains_key(sample_id) {
256 return Err(DagMlError::OofValidation(format!(
257 "producer `{producer_node}` is missing required sample `{sample_id}`"
258 )));
259 }
260 }
261 }
262
263 let producer_predictions = producers.into_iter().collect::<Vec<_>>();
264 let columns = producer_predictions
265 .iter()
266 .flat_map(|(producer_node, producer)| {
267 producer
268 .target_names
269 .iter()
270 .map(move |target| format!("{producer_node}__{target}"))
271 })
272 .collect::<Vec<_>>();
273 let values = required_samples
274 .iter()
275 .map(|sample_id| {
276 let mut row = Vec::new();
277 for (_producer_node, producer) in &producer_predictions {
278 row.extend(
279 producer
280 .by_sample
281 .get(sample_id)
282 .expect("required sample was checked")
283 .iter()
284 .copied(),
285 );
286 }
287 row
288 })
289 .collect::<Vec<_>>();
290
291 Ok(OofMatrix {
292 sample_ids: required_samples.to_vec(),
293 columns,
294 values,
295 })
296}
297
298pub fn validate_oof_campaign(campaign: &OofCampaign) -> Result<OofMatrix> {
299 campaign.fold_set.validate()?;
300 validate_requested_samples_match_fold_set(
301 &campaign.requested_sample_order,
302 &campaign.fold_set,
303 )?;
304 validate_prediction_blocks_against_folds(&campaign.fold_set, &campaign.prediction_blocks)?;
305 join_oof_campaign_features(
306 &campaign.join_policy,
307 &campaign.prediction_blocks,
308 &campaign.requested_sample_order,
309 )
310}
311
312pub fn oof_campaign_fingerprint(campaign: &OofCampaign) -> Result<String> {
313 campaign.fold_set.validate()?;
314 validate_requested_samples_match_fold_set(
315 &campaign.requested_sample_order,
316 &campaign.fold_set,
317 )?;
318 validate_prediction_blocks_against_folds(&campaign.fold_set, &campaign.prediction_blocks)?;
319 stable_json_fingerprint(campaign)
320}
321
322pub fn validate_prediction_blocks_against_folds(
323 fold_set: &FoldSet,
324 blocks: &[PredictionBlock],
325) -> Result<()> {
326 fold_set.validate()?;
327 let folds = fold_set
328 .folds
329 .iter()
330 .map(|fold| (&fold.fold_id, fold))
331 .collect::<BTreeMap<_, _>>();
332 for block in blocks {
333 block.validate_shape()?;
334 let Some(fold_id) = &block.fold_id else {
335 if matches!(
336 block.partition,
337 PredictionPartition::Train | PredictionPartition::Validation
338 ) {
339 return Err(DagMlError::OofValidation(format!(
340 "producer `{}` emitted {:?} predictions without fold_id",
341 block.producer_node, block.partition
342 )));
343 }
344 continue;
345 };
346 let fold = folds.get(fold_id).ok_or_else(|| {
347 DagMlError::OofValidation(format!(
348 "producer `{}` references unknown fold `{fold_id}`",
349 block.producer_node
350 ))
351 })?;
352 match block.partition {
353 PredictionPartition::Train => {
354 assert_exact_partition_samples(block, &fold.train_sample_ids, "train")?
355 }
356 PredictionPartition::Validation => {
357 assert_exact_partition_samples(block, &fold.validation_sample_ids, "validation")?
358 }
359 PredictionPartition::Test | PredictionPartition::Final => {}
360 }
361 }
362 Ok(())
363}
364
365pub fn validate_prediction_blocks_are_oof(
366 policy: &PredictionJoinPolicy,
367 blocks: &[PredictionBlock],
368) -> Result<()> {
369 if policy.allow_train_predictions_as_features {
370 return Ok(());
371 }
372 let violators = blocks
373 .iter()
374 .filter(|block| block.partition != PredictionPartition::Validation)
375 .map(|block| OofLeakageViolation {
376 producer_node: block.producer_node.to_string(),
377 partition: format!("{:?}", block.partition).to_lowercase(),
378 fold_id: block.fold_id.as_ref().map(ToString::to_string),
379 })
380 .collect::<Vec<_>>();
381 if violators.is_empty() {
382 Ok(())
383 } else {
384 crate::observability::emit_oof_refusal(policy.node_id.as_str(), violators.len());
385 Err(DagMlError::OofLeakage(Box::new(OofLeakageReport {
386 node_id: policy.node_id.to_string(),
387 violators,
388 allow_train_predictions_as_features: policy.allow_train_predictions_as_features,
389 remediation: "Use only OOF validation predictions as training features, or explicitly set allow_train_predictions_as_features=true for an unsafe run.".to_string(),
390 })))
391 }
392}
393
394fn validate_requested_samples_match_fold_set(
395 requested_sample_order: &[SampleId],
396 fold_set: &FoldSet,
397) -> Result<()> {
398 ensure_required_samples(requested_sample_order)?;
399 let requested = requested_sample_order.iter().collect::<BTreeSet<_>>();
400 let expected = fold_set.sample_ids.iter().collect::<BTreeSet<_>>();
401 if requested != expected {
402 return Err(DagMlError::OofValidation(
403 "requested sample order does not match fold-set sample universe".to_string(),
404 ));
405 }
406 Ok(())
407}
408
409fn assert_exact_partition_samples(
410 block: &PredictionBlock,
411 expected_samples: &[SampleId],
412 partition_name: &str,
413) -> Result<()> {
414 let actual = unique_block_samples(block)?;
415 let expected = expected_samples.iter().collect::<BTreeSet<_>>();
416 if actual != expected {
417 return Err(DagMlError::OofValidation(format!(
418 "producer `{}` fold `{}` {} predictions do not match fold {} samples",
419 block.producer_node,
420 block.fold_id.as_ref().expect("fold id exists"),
421 partition_name,
422 partition_name
423 )));
424 }
425 Ok(())
426}
427
428fn unique_block_samples(block: &PredictionBlock) -> Result<BTreeSet<&SampleId>> {
429 let mut seen = BTreeSet::new();
430 for sample_id in &block.sample_ids {
431 if !seen.insert(sample_id) {
432 return Err(DagMlError::OofValidation(format!(
433 "producer `{}` emitted duplicate prediction for sample `{sample_id}`",
434 block.producer_node
435 )));
436 }
437 }
438 Ok(seen)
439}
440
441fn ensure_required_samples(required_samples: &[SampleId]) -> Result<()> {
442 if required_samples.is_empty() {
443 return Err(DagMlError::OofValidation(
444 "required sample set is empty".to_string(),
445 ));
446 }
447 let required = required_samples.iter().collect::<BTreeSet<_>>();
448 if required.len() != required_samples.len() {
449 return Err(DagMlError::OofValidation(
450 "required sample set contains duplicates".to_string(),
451 ));
452 }
453 Ok(())
454}
455
456fn effective_partitions(policy: &PredictionJoinPolicy) -> BTreeSet<PredictionPartition> {
457 if policy.include_partitions.is_empty() {
458 BTreeSet::from([PredictionPartition::Validation])
459 } else {
460 policy.include_partitions.iter().cloned().collect()
461 }
462}
463
464fn normalized_targets(block: &PredictionBlock, width: usize) -> Vec<String> {
465 if block.target_names.is_empty() {
466 (0..width)
467 .map(|column_idx| format!("p{column_idx}"))
468 .collect()
469 } else {
470 block.target_names.clone()
471 }
472}
473
474#[cfg(test)]
475mod tests {
476 use std::time::{Duration, Instant};
477
478 use super::*;
479
480 fn sid(value: &str) -> SampleId {
481 SampleId::new(value).unwrap()
482 }
483
484 fn producer() -> NodeId {
485 NodeId::new("model:base").unwrap()
486 }
487
488 fn block(partition: PredictionPartition) -> PredictionBlock {
489 PredictionBlock {
490 prediction_id: None,
491 producer_node: producer(),
492 partition,
493 fold_id: Some(FoldId::new("fold0").unwrap()),
494 sample_ids: vec![sid("s2"), sid("s1")],
495 values: vec![vec![20.0], vec![10.0]],
496 target_names: vec!["y".to_string()],
497 }
498 }
499
500 fn campaign_block(producer_node: &str, fold_id: &str, samples: &[&str]) -> PredictionBlock {
501 PredictionBlock {
502 prediction_id: None,
503 producer_node: NodeId::new(producer_node).unwrap(),
504 partition: PredictionPartition::Validation,
505 fold_id: Some(FoldId::new(fold_id).unwrap()),
506 sample_ids: samples.iter().copied().map(sid).collect(),
507 values: samples
508 .iter()
509 .map(|sample_id| {
510 let suffix = sample_id.trim_start_matches('s').parse::<f64>().unwrap();
511 vec![suffix]
512 })
513 .collect(),
514 target_names: vec!["y".to_string()],
515 }
516 }
517
518 fn load_fixture(source: &str) -> OofCampaign {
519 serde_json::from_str(source).unwrap()
520 }
521
522 #[test]
523 fn aligns_oof_by_sample_id_not_position() {
524 let joined = join_oof_features(
525 &[block(PredictionPartition::Validation)],
526 &[sid("s1"), sid("s2")],
527 )
528 .unwrap();
529
530 assert_eq!(joined.values, vec![vec![10.0], vec![20.0]]);
531 assert_eq!(joined.columns, vec!["model:base__y"]);
532 }
533
534 #[test]
535 fn rejects_train_predictions_as_training_features() {
536 let err = join_oof_features(
537 &[block(PredictionPartition::Train)],
538 &[sid("s1"), sid("s2")],
539 )
540 .unwrap_err();
541
542 match err {
543 DagMlError::OofLeakage(report) => {
544 assert_eq!(report.violators[0].producer_node, "model:base");
545 assert_eq!(report.violators[0].partition, "train");
546 }
547 other => panic!("expected OOF leakage error, got {other:?}"),
548 }
549 }
550
551 #[test]
552 fn rejects_duplicate_samples() {
553 let mut duplicate = block(PredictionPartition::Validation);
554 duplicate.sample_ids = vec![sid("s1"), sid("s1")];
555
556 assert!(join_oof_features(&[duplicate], &[sid("s1")]).is_err());
557 }
558
559 #[test]
560 fn joins_fold_blocks_by_producer_for_campaigns() {
561 let mut b1_fold0 = campaign_block("branch:b1.model:rf", "fold0", &["s4", "s1"]);
562 b1_fold0.values = vec![vec![40.0], vec![10.0]];
563 let mut b1_fold1 = campaign_block("branch:b1.model:rf", "fold1", &["s2", "s3"]);
564 b1_fold1.values = vec![vec![20.0], vec![30.0]];
565 let mut b0_fold0 = campaign_block("branch:b0.model:pls", "fold0", &["s4", "s1"]);
566 b0_fold0.values = vec![vec![4.0], vec![1.0]];
567 let mut b0_fold1 = campaign_block("branch:b0.model:pls", "fold1", &["s2", "s3"]);
568 b0_fold1.values = vec![vec![2.0], vec![3.0]];
569
570 let joined = join_oof_campaign_features(
571 &PredictionJoinPolicy {
572 node_id: NodeId::new("merge:pred").unwrap(),
573 join_on: PredictionJoinKey::SampleId,
574 allow_train_predictions_as_features: false,
575 include_partitions: vec![PredictionPartition::Validation],
576 },
577 &[b1_fold0, b1_fold1, b0_fold0, b0_fold1],
578 &[sid("s1"), sid("s2"), sid("s3"), sid("s4")],
579 )
580 .unwrap();
581
582 assert_eq!(
583 joined.columns,
584 vec!["branch:b0.model:pls__y", "branch:b1.model:rf__y"]
585 );
586 assert_eq!(
587 joined.values,
588 vec![
589 vec![1.0, 10.0],
590 vec![2.0, 20.0],
591 vec![3.0, 30.0],
592 vec![4.0, 40.0]
593 ]
594 );
595 }
596
597 #[test]
598 fn uc6_fixture_joins_successfully() {
599 let fixture = load_fixture(include_str!(
600 "../../../examples/fixtures/oof_campaign/uc6_oof_success_predictions.json"
601 ));
602
603 let joined = validate_oof_campaign(&fixture).unwrap();
604 assert_eq!(
605 oof_campaign_fingerprint(&fixture).unwrap(),
606 oof_campaign_fingerprint(&fixture).unwrap()
607 );
608
609 assert_eq!(joined.columns.len(), 3);
610 assert_eq!(joined.values[0], vec![1.0, 10.0, 100.0]);
611 assert_eq!(joined.values[5], vec![6.0, 60.0, 600.0]);
612 }
613
614 #[test]
615 fn uc11_fixture_refuses_train_predictions() {
616 let fixture = load_fixture(include_str!(
617 "../../../examples/fixtures/oof_campaign/uc11_train_prediction_refusal.json"
618 ));
619
620 let err = validate_oof_campaign(&fixture).unwrap_err();
621
622 match err {
623 DagMlError::OofLeakage(report) => {
624 assert_eq!(report.node_id, "merge:pred");
625 assert!(!report.allow_train_predictions_as_features);
626 assert_eq!(report.violators.len(), 1);
627 assert_eq!(report.violators[0].partition, "train");
628 }
629 other => panic!("expected OOF leakage error, got {other:?}"),
630 }
631 }
632
633 #[test]
634 fn fold_validation_rejects_wrong_validation_partition_samples() {
635 let mut fixture = load_fixture(include_str!(
636 "../../../examples/fixtures/oof_campaign/uc6_oof_success_predictions.json"
637 ));
638 fixture.prediction_blocks[0].sample_ids = vec![sid("S001"), sid("S002")];
639
640 let err = validate_oof_campaign(&fixture).unwrap_err();
641
642 assert!(err
643 .to_string()
644 .contains("do not match fold validation samples"));
645 }
646
647 #[test]
648 #[ignore = "perf sanity probe; run with --release --ignored --nocapture"]
649 fn oof_join_large_campaign_under_1500ms() {
650 let sample_count = 12_000usize;
651 let producer_count = 4usize;
652 let fold_count = 6usize;
653 let required_samples = (0..sample_count)
654 .map(|sample_idx| sid(&format!("s{sample_idx:05}")))
655 .collect::<Vec<_>>();
656 let mut blocks = Vec::new();
657
658 for producer_idx in 0..producer_count {
659 for fold_idx in 0..fold_count {
660 let sample_ids = (fold_idx..sample_count)
661 .step_by(fold_count)
662 .map(|sample_idx| sid(&format!("s{sample_idx:05}")))
663 .collect::<Vec<_>>();
664 let values = (fold_idx..sample_count)
665 .step_by(fold_count)
666 .map(|sample_idx| vec![producer_idx as f64, sample_idx as f64])
667 .collect::<Vec<_>>();
668 blocks.push(PredictionBlock {
669 prediction_id: None,
670 producer_node: NodeId::new(format!("model:p{producer_idx}")).unwrap(),
671 partition: PredictionPartition::Validation,
672 fold_id: Some(FoldId::new(format!("fold:{fold_idx}")).unwrap()),
673 sample_ids,
674 values,
675 target_names: vec!["score".to_string(), "rank".to_string()],
676 });
677 }
678 }
679
680 let started = Instant::now();
681 let joined = join_oof_campaign_features(
682 &PredictionJoinPolicy {
683 node_id: NodeId::new("merge:perf").unwrap(),
684 join_on: PredictionJoinKey::SampleId,
685 allow_train_predictions_as_features: false,
686 include_partitions: vec![PredictionPartition::Validation],
687 },
688 &blocks,
689 &required_samples,
690 )
691 .unwrap();
692 let elapsed = started.elapsed();
693
694 assert_eq!(joined.sample_ids.len(), sample_count);
695 assert_eq!(joined.columns.len(), producer_count * 2);
696 assert!(
697 elapsed <= Duration::from_millis(1_500),
698 "large OOF join took {elapsed:?}"
699 );
700 }
701}