Skip to main content

dag_ml_core/
bundle.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::aggregation::{AggregatedPredictionBlock, PredictionUnitId};
6use crate::campaign::stable_json_fingerprint;
7use crate::data::{
8    ExternalDataPlanEnvelope, RepresentationCompatibilityReport, RepresentationReplayManifest,
9};
10use crate::error::{DagMlError, Result};
11use crate::ids::{BundleId, ControllerId, FoldId, NodeId, SampleId, VariantId};
12use crate::oof::{PredictionBlock, PredictionPartition};
13use crate::phase::Phase;
14use crate::plan::ExecutionPlan;
15use crate::policy::PredictionLevel;
16use crate::runtime::ArtifactRef;
17use crate::selection::SelectionDecision;
18
19pub const EXECUTION_BUNDLE_SCHEMA_VERSION: u32 = 1;
20pub const PREDICTION_CACHE_PAYLOAD_SCHEMA_VERSION: u32 = 1;
21pub const BUNDLE_PREDICTION_CACHE_FORMAT: &str = "dag-ml-json-prediction-blocks-v1";
22
23pub const MIN_READABLE_EXECUTION_BUNDLE_SCHEMA_VERSION: u32 = 1;
24pub const MIN_WRITABLE_EXECUTION_BUNDLE_SCHEMA_VERSION: u32 = 1;
25pub const MIN_READABLE_PREDICTION_CACHE_PAYLOAD_SCHEMA_VERSION: u32 = 1;
26pub const MIN_WRITABLE_PREDICTION_CACHE_PAYLOAD_SCHEMA_VERSION: u32 = 1;
27
28fn default_execution_bundle_schema_version() -> u32 {
29    EXECUTION_BUNDLE_SCHEMA_VERSION
30}
31
32fn default_prediction_cache_payload_schema_version() -> u32 {
33    PREDICTION_CACHE_PAYLOAD_SCHEMA_VERSION
34}
35
36fn default_prediction_level() -> PredictionLevel {
37    PredictionLevel::Sample
38}
39
40#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
41pub struct SchemaMigrationPolicy {
42    pub artifact: String,
43    pub current_version: u32,
44    pub min_readable_version: u32,
45    pub min_writable_version: u32,
46    #[serde(default)]
47    pub automatic_migrations: BTreeMap<u32, u32>,
48}
49
50impl SchemaMigrationPolicy {
51    pub fn validate(&self) -> Result<()> {
52        validate_non_empty("schema migration artifact", &self.artifact)?;
53        if self.current_version == 0
54            || self.min_readable_version == 0
55            || self.min_writable_version == 0
56        {
57            return Err(DagMlError::RuntimeValidation(format!(
58                "schema migration policy `{}` has zero version boundary",
59                self.artifact
60            )));
61        }
62        if self.min_readable_version > self.current_version {
63            return Err(DagMlError::RuntimeValidation(format!(
64                "schema migration policy `{}` min_readable_version exceeds current_version",
65                self.artifact
66            )));
67        }
68        if self.min_writable_version > self.current_version {
69            return Err(DagMlError::RuntimeValidation(format!(
70                "schema migration policy `{}` min_writable_version exceeds current_version",
71                self.artifact
72            )));
73        }
74        for (from, to) in &self.automatic_migrations {
75            if *from == 0 || *to == 0 {
76                return Err(DagMlError::RuntimeValidation(format!(
77                    "schema migration policy `{}` contains a zero migration version",
78                    self.artifact
79                )));
80            }
81            if from == to {
82                return Err(DagMlError::RuntimeValidation(format!(
83                    "schema migration policy `{}` contains a no-op migration {from}->{to}",
84                    self.artifact
85                )));
86            }
87            if *to > self.current_version {
88                return Err(DagMlError::RuntimeValidation(format!(
89                    "schema migration policy `{}` migrates to unsupported future version {to}",
90                    self.artifact
91                )));
92            }
93        }
94        Ok(())
95    }
96
97    pub fn validate_read_version(&self, version: u32, owner: &str) -> Result<()> {
98        self.validate()?;
99        if version < self.min_readable_version {
100            return Err(DagMlError::RuntimeValidation(format!(
101                "{owner} uses schema_version {version}, below minimum readable {} for {}",
102                self.min_readable_version, self.artifact
103            )));
104        }
105        if version > self.current_version {
106            return Err(DagMlError::RuntimeValidation(format!(
107                "{owner} uses future schema_version {version}, current readable {} for {}",
108                self.current_version, self.artifact
109            )));
110        }
111        if version != self.current_version && !self.automatic_migrations.contains_key(&version) {
112            return Err(DagMlError::RuntimeValidation(format!(
113                "{owner} uses schema_version {version}, but {} declares no automatic migration to current version {}",
114                self.artifact, self.current_version
115            )));
116        }
117        Ok(())
118    }
119}
120
121pub fn execution_bundle_schema_migration_policy() -> SchemaMigrationPolicy {
122    SchemaMigrationPolicy {
123        artifact: "execution_bundle".to_string(),
124        current_version: EXECUTION_BUNDLE_SCHEMA_VERSION,
125        min_readable_version: MIN_READABLE_EXECUTION_BUNDLE_SCHEMA_VERSION,
126        min_writable_version: MIN_WRITABLE_EXECUTION_BUNDLE_SCHEMA_VERSION,
127        automatic_migrations: BTreeMap::new(),
128    }
129}
130
131pub fn prediction_cache_payload_schema_migration_policy() -> SchemaMigrationPolicy {
132    SchemaMigrationPolicy {
133        artifact: "prediction_cache_payload".to_string(),
134        current_version: PREDICTION_CACHE_PAYLOAD_SCHEMA_VERSION,
135        min_readable_version: MIN_READABLE_PREDICTION_CACHE_PAYLOAD_SCHEMA_VERSION,
136        min_writable_version: MIN_WRITABLE_PREDICTION_CACHE_PAYLOAD_SCHEMA_VERSION,
137        automatic_migrations: BTreeMap::new(),
138    }
139}
140
141#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
142pub struct BundleDataRequirement {
143    pub node_id: NodeId,
144    pub input_name: String,
145    pub schema_fingerprint: String,
146    pub plan_fingerprint: String,
147    #[serde(default)]
148    pub relation_fingerprint: Option<String>,
149    pub output_representation: String,
150    #[serde(default)]
151    pub feature_set_id: Option<String>,
152    #[serde(default, skip_serializing_if = "Option::is_none")]
153    pub representation_replay_manifest: Option<RepresentationReplayManifest>,
154    #[serde(default, skip_serializing_if = "Option::is_none")]
155    pub representation_compatibility: Option<RepresentationCompatibilityReport>,
156}
157
158impl BundleDataRequirement {
159    pub fn key(&self) -> String {
160        format!("{}.{}", self.node_id, self.input_name)
161    }
162
163    fn matches_plan_requirement(&self, expected: &Self) -> bool {
164        self.node_id == expected.node_id
165            && self.input_name == expected.input_name
166            && self.schema_fingerprint == expected.schema_fingerprint
167            && self.plan_fingerprint == expected.plan_fingerprint
168            && self.relation_fingerprint == expected.relation_fingerprint
169            && self.output_representation == expected.output_representation
170            && self.feature_set_id == expected.feature_set_id
171    }
172
173    pub fn validate(&self) -> Result<()> {
174        if self.input_name.trim().is_empty() {
175            return Err(DagMlError::CampaignValidation(format!(
176                "bundle data requirement for `{}` has empty input_name",
177                self.node_id
178            )));
179        }
180        validate_fingerprint("schema", &self.schema_fingerprint)?;
181        validate_fingerprint("plan", &self.plan_fingerprint)?;
182        if let Some(relation_fingerprint) = &self.relation_fingerprint {
183            validate_fingerprint("relation", relation_fingerprint)?;
184        }
185        if let Some(replay_manifest) = &self.representation_replay_manifest {
186            replay_manifest.validate()?;
187            if let (Some(requirement), Some(manifest)) = (
188                self.relation_fingerprint.as_deref(),
189                replay_manifest.relation_fingerprint.as_deref(),
190            ) {
191                if requirement != manifest {
192                    return Err(DagMlError::CampaignValidation(format!(
193                        "bundle data requirement `{}` relation_fingerprint does not match representation replay manifest",
194                        self.key()
195                    )));
196                }
197            }
198        }
199        if let Some(report) = &self.representation_compatibility {
200            report.validate()?;
201        }
202        if self.output_representation.trim().is_empty() {
203            return Err(DagMlError::CampaignValidation(format!(
204                "bundle data requirement `{}` has empty output representation",
205                self.key()
206            )));
207        }
208        if let Some(feature_set_id) = &self.feature_set_id {
209            if feature_set_id.trim().is_empty() {
210                return Err(DagMlError::CampaignValidation(format!(
211                    "bundle data requirement `{}` has empty feature_set_id",
212                    self.key()
213                )));
214            }
215        }
216        Ok(())
217    }
218}
219
220#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
221pub struct BundlePredictionRequirement {
222    pub producer_node: NodeId,
223    pub source_port: String,
224    pub consumer_node: NodeId,
225    pub target_port: String,
226    pub partition: PredictionPartition,
227    #[serde(default = "default_prediction_level")]
228    pub prediction_level: PredictionLevel,
229    #[serde(default)]
230    pub fold_ids: Vec<FoldId>,
231    #[serde(default, skip_serializing_if = "Vec::is_empty")]
232    pub unit_ids: Vec<PredictionUnitId>,
233    #[serde(default)]
234    pub sample_ids: Vec<SampleId>,
235    pub prediction_width: usize,
236    pub target_names: Vec<String>,
237}
238
239impl BundlePredictionRequirement {
240    pub fn key(&self) -> String {
241        bundle_prediction_requirement_key(
242            &self.producer_node,
243            &self.source_port,
244            &self.consumer_node,
245            &self.target_port,
246        )
247    }
248
249    pub fn validate(&self) -> Result<()> {
250        validate_non_empty("source_port", &self.source_port)?;
251        validate_non_empty("target_port", &self.target_port)?;
252        if self.partition != PredictionPartition::Validation {
253            return Err(DagMlError::RuntimeValidation(format!(
254                "bundle prediction requirement `{}` must use validation OOF predictions",
255                self.key()
256            )));
257        }
258        validate_unique_ids("fold id", &self.fold_ids)?;
259        validate_prediction_requirement_units(self)?;
260        if self.prediction_width == 0 {
261            return Err(DagMlError::RuntimeValidation(format!(
262                "bundle prediction requirement `{}` has zero prediction width",
263                self.key()
264            )));
265        }
266        if self.target_names.len() != self.prediction_width {
267            return Err(DagMlError::RuntimeValidation(format!(
268                "bundle prediction requirement `{}` target name count does not match prediction width",
269                self.key()
270            )));
271        }
272        for target_name in &self.target_names {
273            validate_non_empty("target_name", target_name)?;
274        }
275        Ok(())
276    }
277}
278
279pub fn bundle_prediction_requirement_key(
280    producer_node: &NodeId,
281    source_port: &str,
282    consumer_node: &NodeId,
283    target_port: &str,
284) -> String {
285    format!("{producer_node}.{source_port}->{consumer_node}.{target_port}")
286}
287
288#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
289pub struct BundlePredictionBlockCacheRecord {
290    #[serde(default)]
291    pub prediction_id: Option<String>,
292    #[serde(default)]
293    pub fold_id: Option<FoldId>,
294    #[serde(default = "default_prediction_level")]
295    pub prediction_level: PredictionLevel,
296    pub row_count: usize,
297    #[serde(default, skip_serializing_if = "Vec::is_empty")]
298    pub unit_ids: Vec<PredictionUnitId>,
299    #[serde(default)]
300    pub sample_ids: Vec<SampleId>,
301    pub content_fingerprint: String,
302}
303
304impl BundlePredictionBlockCacheRecord {
305    pub fn validate(&self) -> Result<()> {
306        if let Some(prediction_id) = &self.prediction_id {
307            validate_non_empty("prediction_id", prediction_id)?;
308        }
309        if self.row_count == 0 {
310            return Err(DagMlError::RuntimeValidation(
311                "prediction block cache record has zero rows".to_string(),
312            ));
313        }
314        validate_prediction_cache_block_record_units(self)?;
315        validate_fingerprint("prediction block cache content", &self.content_fingerprint)
316    }
317}
318
319#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
320pub struct BundlePredictionCacheRecord {
321    pub requirement_key: String,
322    pub cache_id: String,
323    pub format: String,
324    pub partition: PredictionPartition,
325    #[serde(default = "default_prediction_level")]
326    pub prediction_level: PredictionLevel,
327    #[serde(default)]
328    pub fold_ids: Vec<FoldId>,
329    #[serde(default, skip_serializing_if = "Vec::is_empty")]
330    pub unit_ids: Vec<PredictionUnitId>,
331    #[serde(default)]
332    pub sample_ids: Vec<SampleId>,
333    pub prediction_width: usize,
334    pub target_names: Vec<String>,
335    pub block_count: usize,
336    pub row_count: usize,
337    pub content_fingerprint: String,
338    #[serde(default)]
339    pub blocks: Vec<BundlePredictionBlockCacheRecord>,
340}
341
342impl BundlePredictionCacheRecord {
343    pub fn validate(&self) -> Result<()> {
344        validate_non_empty("requirement_key", &self.requirement_key)?;
345        validate_non_empty("cache_id", &self.cache_id)?;
346        validate_non_empty("format", &self.format)?;
347        if self.format != BUNDLE_PREDICTION_CACHE_FORMAT {
348            return Err(DagMlError::RuntimeValidation(format!(
349                "prediction cache `{}` uses unsupported format `{}`",
350                self.cache_id, self.format
351            )));
352        }
353        if self.partition != PredictionPartition::Validation {
354            return Err(DagMlError::RuntimeValidation(format!(
355                "prediction cache `{}` must cache validation OOF predictions",
356                self.cache_id
357            )));
358        }
359        validate_unique_ids("fold id", &self.fold_ids)?;
360        validate_prediction_cache_record_units(self)?;
361        if self.prediction_width == 0 {
362            return Err(DagMlError::RuntimeValidation(format!(
363                "prediction cache `{}` has zero prediction width",
364                self.cache_id
365            )));
366        }
367        if self.target_names.len() != self.prediction_width {
368            return Err(DagMlError::RuntimeValidation(format!(
369                "prediction cache `{}` target name count does not match prediction width",
370                self.cache_id
371            )));
372        }
373        for target_name in &self.target_names {
374            validate_non_empty("target_name", target_name)?;
375        }
376        if self.block_count == 0 || self.block_count != self.blocks.len() {
377            return Err(DagMlError::RuntimeValidation(format!(
378                "prediction cache `{}` block_count does not match block records",
379                self.cache_id
380            )));
381        }
382        validate_prediction_cache_record_blocks(self)?;
383        validate_fingerprint("prediction cache content", &self.content_fingerprint)?;
384        Ok(())
385    }
386}
387
388fn validate_prediction_requirement_units(requirement: &BundlePredictionRequirement) -> Result<()> {
389    match requirement.prediction_level {
390        PredictionLevel::Observation => Err(DagMlError::RuntimeValidation(format!(
391            "bundle prediction requirement `{}` cannot replay observation-level caches; aggregate to sample first",
392            requirement.key()
393        ))),
394        PredictionLevel::Sample => {
395            validate_unique_ids("sample id", &requirement.sample_ids)?;
396            if requirement.sample_ids.is_empty() {
397                return Err(DagMlError::RuntimeValidation(format!(
398                    "bundle prediction requirement `{}` has no sample ids",
399                    requirement.key()
400                )));
401            }
402            if !requirement.unit_ids.is_empty()
403                && requirement.unit_ids != sample_prediction_units(&requirement.sample_ids)
404            {
405                return Err(DagMlError::RuntimeValidation(format!(
406                    "bundle prediction requirement `{}` sample ids do not match unit ids",
407                    requirement.key()
408                )));
409            }
410            Ok(())
411        }
412        PredictionLevel::Target | PredictionLevel::Group => {
413            if !requirement.sample_ids.is_empty() {
414                return Err(DagMlError::RuntimeValidation(format!(
415                    "bundle prediction requirement `{}` uses {:?} unit ids but also carries sample ids",
416                    requirement.key(),
417                    requirement.prediction_level
418                )));
419            }
420            validate_prediction_units(
421                "bundle prediction requirement unit",
422                requirement.prediction_level,
423                &requirement.unit_ids,
424            )?;
425            if requirement.unit_ids.is_empty() {
426                return Err(DagMlError::RuntimeValidation(format!(
427                    "bundle prediction requirement `{}` has no unit ids",
428                    requirement.key()
429                )));
430            }
431            Ok(())
432        }
433    }
434}
435
436fn validate_prediction_cache_block_record_units(
437    block: &BundlePredictionBlockCacheRecord,
438) -> Result<()> {
439    match block.prediction_level {
440        PredictionLevel::Observation => Err(DagMlError::RuntimeValidation(
441            "prediction block cache record cannot use observation-level predictions".to_string(),
442        )),
443        PredictionLevel::Sample => {
444            validate_unique_ids("sample id", &block.sample_ids)?;
445            if block.row_count != block.sample_ids.len() {
446                return Err(DagMlError::RuntimeValidation(format!(
447                    "prediction block cache record row_count {} does not match {} sample ids",
448                    block.row_count,
449                    block.sample_ids.len()
450                )));
451            }
452            if !block.unit_ids.is_empty()
453                && block.unit_ids != sample_prediction_units(&block.sample_ids)
454            {
455                return Err(DagMlError::RuntimeValidation(
456                    "prediction block cache record sample ids do not match unit ids".to_string(),
457                ));
458            }
459            Ok(())
460        }
461        PredictionLevel::Target | PredictionLevel::Group => {
462            if !block.sample_ids.is_empty() {
463                return Err(DagMlError::RuntimeValidation(format!(
464                    "prediction block cache record uses {:?} unit ids but also carries sample ids",
465                    block.prediction_level
466                )));
467            }
468            validate_prediction_units(
469                "prediction block cache record unit",
470                block.prediction_level,
471                &block.unit_ids,
472            )?;
473            if block.row_count != block.unit_ids.len() {
474                return Err(DagMlError::RuntimeValidation(format!(
475                    "prediction block cache record row_count {} does not match {} unit ids",
476                    block.row_count,
477                    block.unit_ids.len()
478                )));
479            }
480            Ok(())
481        }
482    }
483}
484
485fn validate_prediction_cache_record_units(cache: &BundlePredictionCacheRecord) -> Result<()> {
486    match cache.prediction_level {
487        PredictionLevel::Observation => Err(DagMlError::RuntimeValidation(format!(
488            "prediction cache `{}` cannot use observation-level predictions",
489            cache.cache_id
490        ))),
491        PredictionLevel::Sample => {
492            validate_unique_ids("sample id", &cache.sample_ids)?;
493            if cache.row_count != cache.sample_ids.len() {
494                return Err(DagMlError::RuntimeValidation(format!(
495                    "prediction cache `{}` row_count does not match unique sample ids",
496                    cache.cache_id
497                )));
498            }
499            if !cache.unit_ids.is_empty()
500                && cache.unit_ids != sample_prediction_units(&cache.sample_ids)
501            {
502                return Err(DagMlError::RuntimeValidation(format!(
503                    "prediction cache `{}` sample ids do not match unit ids",
504                    cache.cache_id
505                )));
506            }
507            Ok(())
508        }
509        PredictionLevel::Target | PredictionLevel::Group => {
510            if !cache.sample_ids.is_empty() {
511                return Err(DagMlError::RuntimeValidation(format!(
512                    "prediction cache `{}` uses {:?} unit ids but also carries sample ids",
513                    cache.cache_id, cache.prediction_level
514                )));
515            }
516            validate_prediction_units(
517                "prediction cache unit",
518                cache.prediction_level,
519                &cache.unit_ids,
520            )?;
521            if cache.row_count != cache.unit_ids.len() {
522                return Err(DagMlError::RuntimeValidation(format!(
523                    "prediction cache `{}` row_count does not match unique unit ids",
524                    cache.cache_id
525                )));
526            }
527            Ok(())
528        }
529    }
530}
531
532fn validate_prediction_cache_record_blocks(cache: &BundlePredictionCacheRecord) -> Result<()> {
533    let mut row_count = 0usize;
534    let mut samples = BTreeSet::new();
535    let mut units = BTreeSet::new();
536    for block in &cache.blocks {
537        block.validate()?;
538        if block.prediction_level != cache.prediction_level {
539            return Err(DagMlError::RuntimeValidation(format!(
540                "prediction cache `{}` mixes block prediction levels",
541                cache.cache_id
542            )));
543        }
544        row_count += block.row_count;
545        match cache.prediction_level {
546            PredictionLevel::Sample => {
547                for sample_id in &block.sample_ids {
548                    if !samples.insert(sample_id.clone()) {
549                        return Err(DagMlError::RuntimeValidation(format!(
550                            "prediction cache `{}` contains duplicate sample `{sample_id}`",
551                            cache.cache_id
552                        )));
553                    }
554                }
555            }
556            PredictionLevel::Target | PredictionLevel::Group => {
557                for unit_id in &block.unit_ids {
558                    if !units.insert(unit_id.clone()) {
559                        return Err(DagMlError::RuntimeValidation(format!(
560                            "prediction cache `{}` contains duplicate unit `{unit_id}`",
561                            cache.cache_id
562                        )));
563                    }
564                }
565            }
566            PredictionLevel::Observation => {
567                unreachable!("record unit validation rejects observation")
568            }
569        }
570    }
571    if cache.row_count == 0 || cache.row_count != row_count {
572        return Err(DagMlError::RuntimeValidation(format!(
573            "prediction cache `{}` row_count does not match block records",
574            cache.cache_id
575        )));
576    }
577    if cache.prediction_level == PredictionLevel::Sample {
578        let expected = cache.sample_ids.iter().cloned().collect::<BTreeSet<_>>();
579        if samples != expected {
580            return Err(DagMlError::RuntimeValidation(format!(
581                "prediction cache `{}` block samples do not match cache sample ids",
582                cache.cache_id
583            )));
584        }
585    } else {
586        let expected = cache.unit_ids.iter().cloned().collect::<BTreeSet<_>>();
587        if units != expected {
588            return Err(DagMlError::RuntimeValidation(format!(
589                "prediction cache `{}` block units do not match cache unit ids",
590                cache.cache_id
591            )));
592        }
593    }
594    Ok(())
595}
596
597fn validate_prediction_cache_payload_blocks(
598    payload: &BundlePredictionCachePayload,
599) -> Result<usize> {
600    match payload.prediction_level {
601        PredictionLevel::Observation => Err(DagMlError::RuntimeValidation(format!(
602            "prediction cache payload `{}` cannot use observation-level predictions",
603            payload.cache_id
604        ))),
605        PredictionLevel::Sample => validate_sample_prediction_cache_payload_blocks(payload),
606        PredictionLevel::Target | PredictionLevel::Group => {
607            validate_aggregated_prediction_cache_payload_blocks(payload)
608        }
609    }
610}
611
612fn validate_sample_prediction_cache_payload_blocks(
613    payload: &BundlePredictionCachePayload,
614) -> Result<usize> {
615    let mut row_count = 0usize;
616    let mut sample_ids = BTreeSet::new();
617    for block in &payload.blocks {
618        block.validate_shape()?;
619        if block.partition != payload.partition {
620            return Err(DagMlError::RuntimeValidation(format!(
621                "prediction cache payload `{}` contains a block from partition {:?}",
622                payload.cache_id, block.partition
623            )));
624        }
625        for sample_id in &block.sample_ids {
626            if !sample_ids.insert(sample_id) {
627                return Err(DagMlError::RuntimeValidation(format!(
628                    "prediction cache payload `{}` contains duplicate sample `{}`",
629                    payload.cache_id, sample_id
630                )));
631            }
632        }
633        row_count += block.sample_ids.len();
634    }
635    Ok(row_count)
636}
637
638fn validate_aggregated_prediction_cache_payload_blocks(
639    payload: &BundlePredictionCachePayload,
640) -> Result<usize> {
641    let mut row_count = 0usize;
642    let mut unit_ids = BTreeSet::new();
643    for block in &payload.aggregated_blocks {
644        block.validate_shape()?;
645        if block.partition != payload.partition {
646            return Err(DagMlError::RuntimeValidation(format!(
647                "prediction cache payload `{}` contains an aggregated block from partition {:?}",
648                payload.cache_id, block.partition
649            )));
650        }
651        if block.level != payload.prediction_level {
652            return Err(DagMlError::RuntimeValidation(format!(
653                "prediction cache payload `{}` contains {:?} block inside {:?} payload",
654                payload.cache_id, block.level, payload.prediction_level
655            )));
656        }
657        for unit_id in &block.unit_ids {
658            if !unit_ids.insert(unit_id) {
659                return Err(DagMlError::RuntimeValidation(format!(
660                    "prediction cache payload `{}` contains duplicate unit `{unit_id}`",
661                    payload.cache_id
662                )));
663            }
664        }
665        row_count += block.unit_ids.len();
666    }
667    Ok(row_count)
668}
669
670#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
671pub struct BundlePredictionCachePayload {
672    pub requirement_key: String,
673    pub cache_id: String,
674    pub format: String,
675    pub partition: PredictionPartition,
676    #[serde(default = "default_prediction_level")]
677    pub prediction_level: PredictionLevel,
678    pub block_count: usize,
679    pub row_count: usize,
680    pub content_fingerprint: String,
681    #[serde(default)]
682    pub blocks: Vec<PredictionBlock>,
683    #[serde(default, skip_serializing_if = "Vec::is_empty")]
684    pub aggregated_blocks: Vec<AggregatedPredictionBlock>,
685}
686
687impl BundlePredictionCachePayload {
688    pub fn validate(&self) -> Result<()> {
689        validate_non_empty("requirement_key", &self.requirement_key)?;
690        validate_non_empty("cache_id", &self.cache_id)?;
691        validate_non_empty("format", &self.format)?;
692        if self.format != BUNDLE_PREDICTION_CACHE_FORMAT {
693            return Err(DagMlError::RuntimeValidation(format!(
694                "prediction cache payload `{}` uses unsupported format `{}`",
695                self.cache_id, self.format
696            )));
697        }
698        if self.partition != PredictionPartition::Validation {
699            return Err(DagMlError::RuntimeValidation(format!(
700                "prediction cache payload `{}` must cache validation OOF predictions",
701                self.cache_id
702            )));
703        }
704        let expected_block_count = if self.prediction_level == PredictionLevel::Sample {
705            if !self.aggregated_blocks.is_empty() {
706                return Err(DagMlError::RuntimeValidation(format!(
707                    "prediction cache payload `{}` mixes sample and aggregated blocks",
708                    self.cache_id
709                )));
710            }
711            self.blocks.len()
712        } else {
713            if !self.blocks.is_empty() {
714                return Err(DagMlError::RuntimeValidation(format!(
715                    "prediction cache payload `{}` mixes aggregated and sample blocks",
716                    self.cache_id
717                )));
718            }
719            self.aggregated_blocks.len()
720        };
721        if self.block_count == 0 || self.block_count != expected_block_count {
722            return Err(DagMlError::RuntimeValidation(format!(
723                "prediction cache payload `{}` block_count does not match blocks",
724                self.cache_id
725            )));
726        }
727        let row_count = validate_prediction_cache_payload_blocks(self)?;
728        if self.row_count == 0 || self.row_count != row_count {
729            return Err(DagMlError::RuntimeValidation(format!(
730                "prediction cache payload `{}` row_count does not match blocks",
731                self.cache_id
732            )));
733        }
734        validate_fingerprint(
735            "prediction cache payload content",
736            &self.content_fingerprint,
737        )?;
738        let actual_fingerprint = if self.prediction_level == PredictionLevel::Sample {
739            stable_json_fingerprint(&self.blocks)?
740        } else {
741            stable_json_fingerprint(&self.aggregated_blocks)?
742        };
743        if actual_fingerprint != self.content_fingerprint {
744            return Err(DagMlError::RuntimeValidation(format!(
745                "prediction cache payload `{}` content fingerprint does not match blocks",
746                self.cache_id
747            )));
748        }
749        Ok(())
750    }
751}
752
753#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
754pub struct BundlePredictionCachePayloadSet {
755    pub bundle_id: BundleId,
756    #[serde(default = "default_prediction_cache_payload_schema_version")]
757    pub schema_version: u32,
758    #[serde(default)]
759    pub caches: Vec<BundlePredictionCachePayload>,
760}
761
762impl BundlePredictionCachePayloadSet {
763    pub fn validate(&self) -> Result<()> {
764        prediction_cache_payload_schema_migration_policy().validate_read_version(
765            self.schema_version,
766            &format!(
767                "prediction cache payload set for bundle `{}`",
768                self.bundle_id
769            ),
770        )?;
771        let mut requirement_keys = BTreeSet::new();
772        let mut cache_ids = BTreeSet::new();
773        for payload in &self.caches {
774            payload.validate()?;
775            if !requirement_keys.insert(payload.requirement_key.as_str()) {
776                return Err(DagMlError::RuntimeValidation(format!(
777                    "prediction cache payload set for bundle `{}` has duplicate requirement `{}`",
778                    self.bundle_id, payload.requirement_key
779                )));
780            }
781            if !cache_ids.insert(payload.cache_id.as_str()) {
782                return Err(DagMlError::RuntimeValidation(format!(
783                    "prediction cache payload set for bundle `{}` has duplicate cache id `{}`",
784                    self.bundle_id, payload.cache_id
785                )));
786            }
787        }
788        Ok(())
789    }
790
791    pub fn validate_against_bundle(&self, bundle: &ExecutionBundle) -> Result<()> {
792        self.validate()?;
793        bundle.validate()?;
794        if self.bundle_id != bundle.bundle_id {
795            return Err(DagMlError::RuntimeValidation(format!(
796                "prediction cache payload set bundle `{}` does not match bundle `{}`",
797                self.bundle_id, bundle.bundle_id
798            )));
799        }
800        if self.caches.len() != bundle.prediction_caches.len() {
801            return Err(DagMlError::RuntimeValidation(format!(
802                "prediction cache payload set for bundle `{}` has {} payload(s) for {} cache record(s)",
803                self.bundle_id,
804                self.caches.len(),
805                bundle.prediction_caches.len()
806            )));
807        }
808        let records_by_requirement = bundle
809            .prediction_caches
810            .iter()
811            .map(|record| (record.requirement_key.as_str(), record))
812            .collect::<BTreeMap<_, _>>();
813        let payloads_by_requirement = self
814            .caches
815            .iter()
816            .map(|payload| (payload.requirement_key.as_str(), payload))
817            .collect::<BTreeMap<_, _>>();
818        for (requirement_key, record) in records_by_requirement {
819            let payload = payloads_by_requirement
820                .get(requirement_key)
821                .ok_or_else(|| {
822                    DagMlError::RuntimeValidation(format!(
823                        "prediction cache payload set for bundle `{}` is missing requirement `{}`",
824                        self.bundle_id, requirement_key
825                    ))
826                })?;
827            validate_prediction_cache_payload_matches_record(payload, record)?;
828        }
829        for requirement_key in payloads_by_requirement.keys() {
830            if !bundle
831                .prediction_caches
832                .iter()
833                .any(|record| record.requirement_key.as_str() == *requirement_key)
834            {
835                return Err(DagMlError::RuntimeValidation(format!(
836                    "prediction cache payload set for bundle `{}` contains unknown requirement `{}`",
837                    self.bundle_id, requirement_key
838                )));
839            }
840        }
841        Ok(())
842    }
843}
844
845#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
846pub struct RefitArtifactRecord {
847    pub node_id: NodeId,
848    pub controller_id: ControllerId,
849    pub artifact: ArtifactRef,
850    pub params_fingerprint: String,
851    #[serde(default)]
852    pub data_requirement_keys: Vec<String>,
853    #[serde(default)]
854    pub prediction_requirement_keys: Vec<String>,
855}
856
857impl RefitArtifactRecord {
858    pub fn validate(&self) -> Result<()> {
859        self.artifact.validate()?;
860        if self.artifact.id.as_str().is_empty() {
861            return Err(DagMlError::RuntimeValidation(format!(
862                "refit artifact for `{}` has empty artifact id",
863                self.node_id
864            )));
865        }
866        if self.artifact.kind.trim().is_empty() {
867            return Err(DagMlError::RuntimeValidation(format!(
868                "refit artifact `{}` has empty artifact kind",
869                self.artifact.id
870            )));
871        }
872        if self.artifact.controller_id != self.controller_id {
873            return Err(DagMlError::RuntimeValidation(format!(
874                "refit artifact `{}` controller `{}` does not match record controller `{}`",
875                self.artifact.id, self.artifact.controller_id, self.controller_id
876            )));
877        }
878        validate_fingerprint("params", &self.params_fingerprint)?;
879        let mut seen_keys = BTreeSet::new();
880        for key in &self.data_requirement_keys {
881            if key.trim().is_empty() {
882                return Err(DagMlError::RuntimeValidation(format!(
883                    "refit artifact `{}` has empty data requirement key",
884                    self.artifact.id
885                )));
886            }
887            if !seen_keys.insert(key.as_str()) {
888                return Err(DagMlError::RuntimeValidation(format!(
889                    "refit artifact `{}` has duplicate data requirement key `{key}`",
890                    self.artifact.id
891                )));
892            }
893        }
894        let mut seen_prediction_keys = BTreeSet::new();
895        for key in &self.prediction_requirement_keys {
896            if key.trim().is_empty() {
897                return Err(DagMlError::RuntimeValidation(format!(
898                    "refit artifact `{}` has empty prediction requirement key",
899                    self.artifact.id
900                )));
901            }
902            if !seen_prediction_keys.insert(key.as_str()) {
903                return Err(DagMlError::RuntimeValidation(format!(
904                    "refit artifact `{}` has duplicate prediction requirement key `{key}`",
905                    self.artifact.id
906                )));
907            }
908        }
909        Ok(())
910    }
911}
912
913#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
914pub struct ExecutionBundle {
915    pub bundle_id: BundleId,
916    #[serde(default = "default_execution_bundle_schema_version")]
917    pub schema_version: u32,
918    pub plan_id: String,
919    pub graph_fingerprint: String,
920    pub campaign_fingerprint: String,
921    pub controller_fingerprint: String,
922    #[serde(default)]
923    pub selected_variant_id: Option<VariantId>,
924    #[serde(default)]
925    pub selections: BTreeMap<String, SelectionDecision>,
926    #[serde(default)]
927    pub refit_artifacts: Vec<RefitArtifactRecord>,
928    #[serde(default)]
929    pub prediction_requirements: Vec<BundlePredictionRequirement>,
930    #[serde(default)]
931    pub prediction_caches: Vec<BundlePredictionCacheRecord>,
932    #[serde(default)]
933    pub data_requirements: Vec<BundleDataRequirement>,
934    #[serde(default)]
935    pub unsafe_flags: BTreeSet<String>,
936    #[serde(default)]
937    pub metadata: BTreeMap<String, serde_json::Value>,
938}
939
940impl ExecutionBundle {
941    pub fn validate(&self) -> Result<()> {
942        execution_bundle_schema_migration_policy()
943            .validate_read_version(self.schema_version, &format!("bundle `{}`", self.bundle_id))?;
944        if self.plan_id.trim().is_empty() {
945            return Err(DagMlError::RuntimeValidation(format!(
946                "bundle `{}` has empty plan_id",
947                self.bundle_id
948            )));
949        }
950        validate_fingerprint("graph", &self.graph_fingerprint)?;
951        validate_fingerprint("campaign", &self.campaign_fingerprint)?;
952        validate_fingerprint("controller", &self.controller_fingerprint)?;
953        for (key, decision) in &self.selections {
954            if key.trim().is_empty() {
955                return Err(DagMlError::RuntimeValidation(format!(
956                    "bundle `{}` contains empty selection key",
957                    self.bundle_id
958                )));
959            }
960            decision.validate()?;
961        }
962        let mut data_keys = BTreeMap::new();
963        for requirement in &self.data_requirements {
964            requirement.validate()?;
965            let key = requirement.key();
966            if data_keys.insert(key.clone(), requirement).is_some() {
967                return Err(DagMlError::RuntimeValidation(format!(
968                    "bundle `{}` has duplicate data requirement `{}`",
969                    self.bundle_id, key
970                )));
971            }
972        }
973        let mut prediction_keys = BTreeMap::new();
974        for requirement in &self.prediction_requirements {
975            requirement.validate()?;
976            let key = requirement.key();
977            if prediction_keys.insert(key.clone(), requirement).is_some() {
978                return Err(DagMlError::RuntimeValidation(format!(
979                    "bundle `{}` has duplicate prediction requirement `{}`",
980                    self.bundle_id, key
981                )));
982            }
983        }
984        let mut prediction_cache_keys = BTreeMap::new();
985        for cache in &self.prediction_caches {
986            cache.validate()?;
987            let requirement = prediction_keys.get(&cache.requirement_key).ok_or_else(|| {
988                DagMlError::RuntimeValidation(format!(
989                    "prediction cache `{}` references unknown prediction requirement `{}`",
990                    cache.cache_id, cache.requirement_key
991                ))
992            })?;
993            validate_prediction_cache_matches_requirement(cache, requirement)?;
994            if prediction_cache_keys
995                .insert(cache.requirement_key.clone(), cache)
996                .is_some()
997            {
998                return Err(DagMlError::RuntimeValidation(format!(
999                    "bundle `{}` has duplicate prediction cache for requirement `{}`",
1000                    self.bundle_id, cache.requirement_key
1001                )));
1002            }
1003        }
1004        for artifact in &self.refit_artifacts {
1005            artifact.validate()?;
1006            for key in &artifact.data_requirement_keys {
1007                match data_keys.get(key) {
1008                    Some(requirement) if requirement.node_id == artifact.node_id => {}
1009                    Some(requirement) => {
1010                        return Err(DagMlError::RuntimeValidation(format!(
1011                            "refit artifact `{}` for `{}` references data requirement `{key}` owned by `{}`",
1012                            artifact.artifact.id, artifact.node_id, requirement.node_id
1013                        )));
1014                    }
1015                    None => {
1016                        return Err(DagMlError::RuntimeValidation(format!(
1017                            "refit artifact `{}` references unknown data requirement `{key}`",
1018                            artifact.artifact.id
1019                        )));
1020                    }
1021                }
1022            }
1023            for key in &artifact.prediction_requirement_keys {
1024                match prediction_keys.get(key) {
1025                    Some(requirement) if requirement.consumer_node == artifact.node_id => {}
1026                    Some(requirement) => {
1027                        return Err(DagMlError::RuntimeValidation(format!(
1028                            "refit artifact `{}` for `{}` references prediction requirement `{key}` consumed by `{}`",
1029                            artifact.artifact.id, artifact.node_id, requirement.consumer_node
1030                        )));
1031                    }
1032                    None => {
1033                        return Err(DagMlError::RuntimeValidation(format!(
1034                            "refit artifact `{}` references unknown prediction requirement `{key}`",
1035                            artifact.artifact.id
1036                        )));
1037                    }
1038                }
1039                if !prediction_cache_keys.contains_key(key) {
1040                    return Err(DagMlError::RuntimeValidation(format!(
1041                        "refit artifact `{}` references prediction requirement `{key}` without a prediction cache record",
1042                        artifact.artifact.id
1043                    )));
1044                }
1045            }
1046        }
1047        for unsafe_flag in &self.unsafe_flags {
1048            if unsafe_flag.trim().is_empty() {
1049                return Err(DagMlError::RuntimeValidation(format!(
1050                    "bundle `{}` contains an empty unsafe flag",
1051                    self.bundle_id
1052                )));
1053            }
1054        }
1055        Ok(())
1056    }
1057
1058    pub fn validate_against_plan(&self, plan: &ExecutionPlan) -> Result<()> {
1059        self.validate()?;
1060        plan.validate()?;
1061        if self.plan_id != plan.id {
1062            return Err(DagMlError::RuntimeValidation(format!(
1063                "bundle `{}` plan_id `{}` does not match plan `{}`",
1064                self.bundle_id, self.plan_id, plan.id
1065            )));
1066        }
1067        if self.graph_fingerprint != plan.graph_fingerprint
1068            || self.campaign_fingerprint != plan.campaign_fingerprint
1069            || self.controller_fingerprint != plan.controller_fingerprint
1070        {
1071            return Err(DagMlError::RuntimeValidation(format!(
1072                "bundle `{}` fingerprints do not match execution plan",
1073                self.bundle_id
1074            )));
1075        }
1076        let selected_variant = match &self.selected_variant_id {
1077            Some(selected_variant_id) => Some(
1078                plan.variants
1079                    .iter()
1080                    .find(|variant| &variant.variant_id == selected_variant_id)
1081                    .ok_or_else(|| {
1082                        DagMlError::RuntimeValidation(format!(
1083                            "bundle `{}` selected unknown variant `{selected_variant_id}`",
1084                            self.bundle_id
1085                        ))
1086                    })?,
1087            ),
1088            None => None,
1089        };
1090        self.validate_selections_against_plan(plan)?;
1091        let expected_requirements = collect_data_requirements(plan)?;
1092        let expected_by_key = expected_requirements
1093            .iter()
1094            .map(|requirement| (requirement.key(), requirement))
1095            .collect::<BTreeMap<_, _>>();
1096        if self.data_requirements.len() != expected_by_key.len() {
1097            return Err(DagMlError::RuntimeValidation(format!(
1098                "bundle `{}` data requirement count does not match execution plan",
1099                self.bundle_id
1100            )));
1101        }
1102        for requirement in &self.data_requirements {
1103            let key = requirement.key();
1104            let expected = expected_by_key.get(&key).ok_or_else(|| {
1105                DagMlError::RuntimeValidation(format!(
1106                    "bundle `{}` data requirement `{key}` does not exist in execution plan",
1107                    self.bundle_id
1108                ))
1109            })?;
1110            if !requirement.matches_plan_requirement(expected) {
1111                return Err(DagMlError::RuntimeValidation(format!(
1112                    "bundle `{}` data requirement `{key}` does not match execution plan",
1113                    self.bundle_id
1114                )));
1115            }
1116        }
1117        for artifact in &self.refit_artifacts {
1118            let node_plan = plan.node_plans.get(&artifact.node_id).ok_or_else(|| {
1119                DagMlError::RuntimeValidation(format!(
1120                    "bundle `{}` artifact references unknown node `{}`",
1121                    self.bundle_id, artifact.node_id
1122                ))
1123            })?;
1124            if artifact.controller_id != node_plan.controller_id {
1125                return Err(DagMlError::RuntimeValidation(format!(
1126                    "bundle `{}` artifact controller for `{}` does not match plan",
1127                    self.bundle_id, artifact.node_id
1128                )));
1129            }
1130            let expected_params_fingerprint =
1131                expected_refit_artifact_params_fingerprint(node_plan, selected_variant)?;
1132            if artifact.params_fingerprint != expected_params_fingerprint {
1133                return Err(DagMlError::RuntimeValidation(format!(
1134                    "bundle `{}` artifact params for `{}` do not match plan",
1135                    self.bundle_id, artifact.node_id
1136                )));
1137            }
1138        }
1139        for requirement in &self.prediction_requirements {
1140            let edge = plan
1141                .graph_plan
1142                .graph
1143                .edges
1144                .iter()
1145                .find(|edge| {
1146                    edge.source.node_id == requirement.producer_node
1147                    && edge.source.port_name == requirement.source_port
1148                    && edge.target.node_id == requirement.consumer_node
1149                    && edge.target.port_name == requirement.target_port
1150                    && edge.contract.requires_oof
1151                })
1152                .ok_or_else(|| {
1153                    DagMlError::RuntimeValidation(format!(
1154                        "bundle `{}` prediction requirement `{}` does not match an OOF edge in the plan",
1155                        self.bundle_id,
1156                        requirement.key()
1157                    ))
1158                })?;
1159            let cache = self
1160                .prediction_caches
1161                .iter()
1162                .find(|cache| cache.requirement_key == requirement.key());
1163            validate_prediction_requirement_against_plan(self, plan, edge, requirement, cache)?;
1164        }
1165        Ok(())
1166    }
1167
1168    fn validate_selections_against_plan(&self, plan: &ExecutionPlan) -> Result<()> {
1169        if self.selections.is_empty() {
1170            return Ok(());
1171        }
1172        let artifact_node_ids = self
1173            .refit_artifacts
1174            .iter()
1175            .map(|artifact| artifact.node_id.clone())
1176            .collect::<BTreeSet<_>>();
1177        let required_metric_level = plan.campaign.aggregation_policy.selection_metric_level;
1178        for (selection_key, decision) in &self.selections {
1179            match decision.metric_level {
1180                Some(metric_level) if metric_level == required_metric_level => {}
1181                Some(metric_level) => {
1182                    return Err(DagMlError::RuntimeValidation(format!(
1183                        "bundle `{}` selection `{selection_key}` metric_level {:?} does not match campaign selection_metric_level {:?}",
1184                        self.bundle_id, metric_level, required_metric_level
1185                    )));
1186                }
1187                None => {
1188                    return Err(DagMlError::RuntimeValidation(format!(
1189                        "bundle `{}` selection `{selection_key}` is missing metric_level for campaign selection_metric_level {:?}",
1190                        self.bundle_id, required_metric_level
1191                    )));
1192                }
1193            }
1194            let selected_candidate_id = decision.selected_candidate_id.as_str();
1195            if let Ok(selected_node_id) = NodeId::new(selected_candidate_id) {
1196                if let Some(node_plan) = plan.node_plans.get(&selected_node_id) {
1197                    if node_plan.supported_phases.contains(&Phase::Refit)
1198                        && !artifact_node_ids.contains(&node_plan.node_id)
1199                    {
1200                        return Err(DagMlError::RuntimeValidation(format!(
1201                            "bundle `{}` selection `{selection_key}` chose refittable node `{}` without a matching refit artifact",
1202                            self.bundle_id, node_plan.node_id
1203                        )));
1204                    }
1205                    continue;
1206                }
1207            }
1208            if VariantId::new(selected_candidate_id).is_ok()
1209                && plan
1210                    .variants
1211                    .iter()
1212                    .any(|variant| variant.variant_id.as_str() == selected_candidate_id)
1213            {
1214                continue;
1215            }
1216            return Err(DagMlError::RuntimeValidation(format!(
1217                "bundle `{}` selection `{selection_key}` chose unknown candidate `{selected_candidate_id}` for plan `{}`",
1218                self.bundle_id, plan.id
1219            )));
1220        }
1221        Ok(())
1222    }
1223
1224    pub fn validate_replay_envelopes(
1225        &self,
1226        envelopes: &BTreeMap<String, ExternalDataPlanEnvelope>,
1227    ) -> Result<()> {
1228        self.validate()?;
1229        for requirement in &self.data_requirements {
1230            let key = requirement.key();
1231            let envelope = envelopes.get(&key).ok_or_else(|| {
1232                DagMlError::RuntimeValidation(format!(
1233                    "replay is missing external data envelope for `{key}`"
1234                ))
1235            })?;
1236            envelope.validate()?;
1237            if requirement.schema_fingerprint != envelope.schema_fingerprint
1238                || requirement.plan_fingerprint != envelope.plan_fingerprint
1239                || requirement.relation_fingerprint != envelope.relation_fingerprint
1240            {
1241                return Err(DagMlError::RuntimeValidation(format!(
1242                    "replay envelope for `{key}` does not match bundle data requirement"
1243                )));
1244            }
1245        }
1246        Ok(())
1247    }
1248}
1249
1250fn expected_refit_artifact_params_fingerprint(
1251    node_plan: &crate::plan::NodePlan,
1252    selected_variant: Option<&crate::generation::VariantPlan>,
1253) -> Result<String> {
1254    let Some(variant) = selected_variant else {
1255        return Ok(node_plan.params_fingerprint.clone());
1256    };
1257    let effective_params =
1258        variant.effective_params_for_node(&node_plan.node_id, &node_plan.params)?;
1259    stable_json_fingerprint(&effective_params)
1260}
1261
1262fn validate_prediction_requirement_against_plan(
1263    bundle: &ExecutionBundle,
1264    plan: &ExecutionPlan,
1265    edge: &crate::graph::EdgeSpec,
1266    requirement: &BundlePredictionRequirement,
1267    cache: Option<&BundlePredictionCacheRecord>,
1268) -> Result<()> {
1269    if !edge.contract.requires_fold_alignment {
1270        return Ok(());
1271    }
1272    let fold_set = plan.fold_set.as_ref().ok_or_else(|| {
1273        DagMlError::RuntimeValidation(format!(
1274            "bundle `{}` prediction requirement `{}` needs fold alignment but plan `{}` has no fold set",
1275            bundle.bundle_id,
1276            requirement.key(),
1277            plan.id
1278        ))
1279    })?;
1280    let expected_fold_ids = fold_set
1281        .folds
1282        .iter()
1283        .map(|fold| fold.fold_id.clone())
1284        .collect::<BTreeSet<_>>();
1285    let requirement_fold_ids = requirement
1286        .fold_ids
1287        .iter()
1288        .cloned()
1289        .collect::<BTreeSet<_>>();
1290    if requirement_fold_ids != expected_fold_ids {
1291        return Err(DagMlError::RuntimeValidation(format!(
1292            "bundle `{}` prediction requirement `{}` fold ids do not match plan fold set",
1293            bundle.bundle_id,
1294            requirement.key()
1295        )));
1296    }
1297    if requirement.prediction_level != PredictionLevel::Sample {
1298        if let Some(cache) = cache {
1299            validate_aggregated_prediction_cache_blocks_match_requirement(
1300                bundle,
1301                requirement,
1302                cache,
1303            )?;
1304        }
1305        return Ok(());
1306    }
1307    let expected_sample_ids = fold_set.sample_ids.iter().cloned().collect::<BTreeSet<_>>();
1308    let requirement_sample_ids = requirement
1309        .sample_ids
1310        .iter()
1311        .cloned()
1312        .collect::<BTreeSet<_>>();
1313    if requirement_sample_ids != expected_sample_ids {
1314        return Err(DagMlError::RuntimeValidation(format!(
1315            "bundle `{}` prediction requirement `{}` sample ids do not match plan fold set",
1316            bundle.bundle_id,
1317            requirement.key()
1318        )));
1319    }
1320    if let Some(cache) = cache {
1321        validate_prediction_cache_blocks_match_fold_set(bundle, requirement, cache, fold_set)?;
1322    }
1323    Ok(())
1324}
1325
1326fn validate_prediction_cache_blocks_match_fold_set(
1327    bundle: &ExecutionBundle,
1328    requirement: &BundlePredictionRequirement,
1329    cache: &BundlePredictionCacheRecord,
1330    fold_set: &crate::fold::FoldSet,
1331) -> Result<()> {
1332    let folds = fold_set
1333        .folds
1334        .iter()
1335        .map(|fold| (&fold.fold_id, fold))
1336        .collect::<BTreeMap<_, _>>();
1337    let expected_fold_ids = fold_set
1338        .folds
1339        .iter()
1340        .map(|fold| fold.fold_id.clone())
1341        .collect::<BTreeSet<_>>();
1342    let mut covered_fold_ids = BTreeSet::new();
1343    let mut covered_sample_ids = BTreeSet::new();
1344    for block in &cache.blocks {
1345        let fold_id = block.fold_id.as_ref().ok_or_else(|| {
1346            DagMlError::RuntimeValidation(format!(
1347                "bundle `{}` prediction cache `{}` has an OOF block without a fold id",
1348                bundle.bundle_id, cache.cache_id
1349            ))
1350        })?;
1351        covered_fold_ids.insert(fold_id.clone());
1352        let fold = folds.get(fold_id).ok_or_else(|| {
1353            DagMlError::RuntimeValidation(format!(
1354                "bundle `{}` prediction cache `{}` references unknown fold `{fold_id}`",
1355                bundle.bundle_id, cache.cache_id
1356            ))
1357        })?;
1358        let block_samples = block.sample_ids.iter().cloned().collect::<BTreeSet<_>>();
1359        let expected_samples = fold
1360            .validation_sample_ids
1361            .iter()
1362            .cloned()
1363            .collect::<BTreeSet<_>>();
1364        if block_samples != expected_samples {
1365            return Err(DagMlError::RuntimeValidation(format!(
1366                "bundle `{}` prediction cache `{}` block for fold `{fold_id}` does not match validation samples for requirement `{}`",
1367                bundle.bundle_id,
1368                cache.cache_id,
1369                requirement.key()
1370            )));
1371        }
1372        for sample_id in block_samples {
1373            if !covered_sample_ids.insert(sample_id.clone()) {
1374                return Err(DagMlError::RuntimeValidation(format!(
1375                    "bundle `{}` prediction cache `{}` has duplicate OOF sample `{sample_id}`",
1376                    bundle.bundle_id, cache.cache_id
1377                )));
1378            }
1379        }
1380    }
1381    if covered_fold_ids != expected_fold_ids {
1382        return Err(DagMlError::RuntimeValidation(format!(
1383            "bundle `{}` prediction cache `{}` does not cover all folds for requirement `{}`",
1384            bundle.bundle_id,
1385            cache.cache_id,
1386            requirement.key()
1387        )));
1388    }
1389    let expected_sample_ids = fold_set.sample_ids.iter().cloned().collect::<BTreeSet<_>>();
1390    if covered_sample_ids != expected_sample_ids {
1391        return Err(DagMlError::RuntimeValidation(format!(
1392            "bundle `{}` prediction cache `{}` does not cover the full OOF sample universe for requirement `{}`",
1393            bundle.bundle_id,
1394            cache.cache_id,
1395            requirement.key()
1396        )));
1397    }
1398    Ok(())
1399}
1400
1401fn validate_aggregated_prediction_cache_blocks_match_requirement(
1402    bundle: &ExecutionBundle,
1403    requirement: &BundlePredictionRequirement,
1404    cache: &BundlePredictionCacheRecord,
1405) -> Result<()> {
1406    let mut covered_fold_ids = BTreeSet::new();
1407    let mut covered_unit_ids = BTreeSet::new();
1408    for block in &cache.blocks {
1409        if block.prediction_level != requirement.prediction_level {
1410            return Err(DagMlError::RuntimeValidation(format!(
1411                "bundle `{}` prediction cache `{}` block level does not match requirement `{}`",
1412                bundle.bundle_id,
1413                cache.cache_id,
1414                requirement.key()
1415            )));
1416        }
1417        if let Some(fold_id) = &block.fold_id {
1418            covered_fold_ids.insert(fold_id.clone());
1419        }
1420        for unit_id in &block.unit_ids {
1421            if !covered_unit_ids.insert(unit_id.clone()) {
1422                return Err(DagMlError::RuntimeValidation(format!(
1423                    "bundle `{}` prediction cache `{}` has duplicate aggregated unit `{unit_id}`",
1424                    bundle.bundle_id, cache.cache_id
1425                )));
1426            }
1427        }
1428    }
1429    let expected_fold_ids = requirement
1430        .fold_ids
1431        .iter()
1432        .cloned()
1433        .collect::<BTreeSet<_>>();
1434    if covered_fold_ids != expected_fold_ids {
1435        return Err(DagMlError::RuntimeValidation(format!(
1436            "bundle `{}` prediction cache `{}` does not cover all folds for aggregated requirement `{}`",
1437            bundle.bundle_id,
1438            cache.cache_id,
1439            requirement.key()
1440        )));
1441    }
1442    let expected_unit_ids = requirement
1443        .unit_ids
1444        .iter()
1445        .cloned()
1446        .collect::<BTreeSet<_>>();
1447    if covered_unit_ids != expected_unit_ids {
1448        return Err(DagMlError::RuntimeValidation(format!(
1449            "bundle `{}` prediction cache `{}` does not cover all units for aggregated requirement `{}`",
1450            bundle.bundle_id,
1451            cache.cache_id,
1452            requirement.key()
1453        )));
1454    }
1455    Ok(())
1456}
1457
1458pub fn build_execution_bundle(
1459    bundle_id: BundleId,
1460    plan: &ExecutionPlan,
1461    selected_variant_id: Option<VariantId>,
1462    selections: BTreeMap<String, SelectionDecision>,
1463    refit_artifacts: Vec<RefitArtifactRecord>,
1464) -> Result<ExecutionBundle> {
1465    build_execution_bundle_with_prediction_requirements(
1466        bundle_id,
1467        plan,
1468        selected_variant_id,
1469        selections,
1470        refit_artifacts,
1471        Vec::new(),
1472    )
1473}
1474
1475pub fn build_execution_bundle_with_prediction_requirements(
1476    bundle_id: BundleId,
1477    plan: &ExecutionPlan,
1478    selected_variant_id: Option<VariantId>,
1479    selections: BTreeMap<String, SelectionDecision>,
1480    refit_artifacts: Vec<RefitArtifactRecord>,
1481    prediction_requirements: Vec<BundlePredictionRequirement>,
1482) -> Result<ExecutionBundle> {
1483    build_execution_bundle_with_prediction_contracts(
1484        bundle_id,
1485        plan,
1486        selected_variant_id,
1487        selections,
1488        refit_artifacts,
1489        prediction_requirements,
1490        Vec::new(),
1491    )
1492}
1493
1494pub fn build_execution_bundle_with_prediction_contracts(
1495    bundle_id: BundleId,
1496    plan: &ExecutionPlan,
1497    selected_variant_id: Option<VariantId>,
1498    selections: BTreeMap<String, SelectionDecision>,
1499    refit_artifacts: Vec<RefitArtifactRecord>,
1500    prediction_requirements: Vec<BundlePredictionRequirement>,
1501    prediction_caches: Vec<BundlePredictionCacheRecord>,
1502) -> Result<ExecutionBundle> {
1503    plan.validate()?;
1504    let bundle = ExecutionBundle {
1505        bundle_id,
1506        schema_version: EXECUTION_BUNDLE_SCHEMA_VERSION,
1507        plan_id: plan.id.clone(),
1508        graph_fingerprint: plan.graph_fingerprint.clone(),
1509        campaign_fingerprint: plan.campaign_fingerprint.clone(),
1510        controller_fingerprint: plan.controller_fingerprint.clone(),
1511        selected_variant_id,
1512        selections,
1513        refit_artifacts,
1514        prediction_requirements,
1515        prediction_caches,
1516        data_requirements: collect_data_requirements(plan)?,
1517        unsafe_flags: BTreeSet::new(),
1518        metadata: BTreeMap::new(),
1519    };
1520    bundle.validate_against_plan(plan)?;
1521    Ok(bundle)
1522}
1523
1524fn collect_data_requirements(plan: &ExecutionPlan) -> Result<Vec<BundleDataRequirement>> {
1525    let mut requirements = Vec::new();
1526    for node_plan in plan.node_plans.values() {
1527        for binding in &node_plan.data_bindings {
1528            requirements.push(BundleDataRequirement {
1529                node_id: node_plan.node_id.clone(),
1530                input_name: binding.input_name.clone(),
1531                schema_fingerprint: binding.schema_fingerprint.clone(),
1532                plan_fingerprint: binding.plan_fingerprint.clone(),
1533                relation_fingerprint: binding.relation_fingerprint.clone(),
1534                output_representation: binding.output_representation.clone(),
1535                feature_set_id: binding.feature_set_id.clone(),
1536                representation_replay_manifest: None,
1537                representation_compatibility: None,
1538            });
1539        }
1540    }
1541    requirements.sort_by_key(BundleDataRequirement::key);
1542    for requirement in &requirements {
1543        requirement.validate()?;
1544    }
1545    Ok(requirements)
1546}
1547
1548pub fn build_prediction_cache_record(
1549    requirement: &BundlePredictionRequirement,
1550    blocks: &[PredictionBlock],
1551) -> Result<BundlePredictionCacheRecord> {
1552    let selected = select_prediction_cache_blocks(requirement, blocks)?;
1553    build_prediction_cache_record_from_selected(requirement, &selected)
1554}
1555
1556pub fn build_prediction_cache_payload(
1557    requirement: &BundlePredictionRequirement,
1558    blocks: &[PredictionBlock],
1559) -> Result<BundlePredictionCachePayload> {
1560    let selected = select_prediction_cache_blocks(requirement, blocks)?;
1561    let payload = BundlePredictionCachePayload {
1562        requirement_key: requirement.key(),
1563        cache_id: format!("prediction-cache:{}", requirement.key()),
1564        format: BUNDLE_PREDICTION_CACHE_FORMAT.to_string(),
1565        partition: requirement.partition.clone(),
1566        prediction_level: requirement.prediction_level,
1567        block_count: selected.len(),
1568        row_count: selected.iter().map(|block| block.sample_ids.len()).sum(),
1569        content_fingerprint: stable_json_fingerprint(&selected)?,
1570        blocks: selected,
1571        aggregated_blocks: Vec::new(),
1572    };
1573    payload.validate()?;
1574    let record = build_prediction_cache_record(requirement, &payload.blocks)?;
1575    validate_prediction_cache_payload_matches_record(&payload, &record)?;
1576    Ok(payload)
1577}
1578
1579pub fn build_aggregated_prediction_cache_record(
1580    requirement: &BundlePredictionRequirement,
1581    blocks: &[AggregatedPredictionBlock],
1582) -> Result<BundlePredictionCacheRecord> {
1583    let selected = select_aggregated_prediction_cache_blocks(requirement, blocks)?;
1584    build_aggregated_prediction_cache_record_from_selected(requirement, &selected)
1585}
1586
1587pub fn build_aggregated_prediction_cache_payload(
1588    requirement: &BundlePredictionRequirement,
1589    blocks: &[AggregatedPredictionBlock],
1590) -> Result<BundlePredictionCachePayload> {
1591    let selected = select_aggregated_prediction_cache_blocks(requirement, blocks)?;
1592    let payload = BundlePredictionCachePayload {
1593        requirement_key: requirement.key(),
1594        cache_id: format!("prediction-cache:{}", requirement.key()),
1595        format: BUNDLE_PREDICTION_CACHE_FORMAT.to_string(),
1596        partition: requirement.partition.clone(),
1597        prediction_level: requirement.prediction_level,
1598        block_count: selected.len(),
1599        row_count: selected.iter().map(|block| block.unit_ids.len()).sum(),
1600        content_fingerprint: stable_json_fingerprint(&selected)?,
1601        blocks: Vec::new(),
1602        aggregated_blocks: selected,
1603    };
1604    payload.validate()?;
1605    let record = build_aggregated_prediction_cache_record(requirement, &payload.aggregated_blocks)?;
1606    validate_prediction_cache_payload_matches_record(&payload, &record)?;
1607    Ok(payload)
1608}
1609
1610pub fn validate_prediction_cache_payload_matches_record(
1611    payload: &BundlePredictionCachePayload,
1612    record: &BundlePredictionCacheRecord,
1613) -> Result<()> {
1614    payload.validate()?;
1615    record.validate()?;
1616    if payload.requirement_key != record.requirement_key
1617        || payload.cache_id != record.cache_id
1618        || payload.format != record.format
1619        || payload.partition != record.partition
1620        || payload.prediction_level != record.prediction_level
1621        || payload.block_count != record.block_count
1622        || payload.row_count != record.row_count
1623        || payload.content_fingerprint != record.content_fingerprint
1624    {
1625        return Err(DagMlError::RuntimeValidation(format!(
1626            "prediction cache payload `{}` does not match cache record `{}`",
1627            payload.cache_id, record.cache_id
1628        )));
1629    }
1630    let block_records = if payload.prediction_level == PredictionLevel::Sample {
1631        payload
1632            .blocks
1633            .iter()
1634            .map(|block| {
1635                Ok(BundlePredictionBlockCacheRecord {
1636                    prediction_id: block.prediction_id.clone(),
1637                    fold_id: block.fold_id.clone(),
1638                    prediction_level: PredictionLevel::Sample,
1639                    row_count: block.sample_ids.len(),
1640                    unit_ids: Vec::new(),
1641                    sample_ids: block.sample_ids.clone(),
1642                    content_fingerprint: stable_json_fingerprint(block)?,
1643                })
1644            })
1645            .collect::<Result<Vec<_>>>()?
1646    } else {
1647        payload
1648            .aggregated_blocks
1649            .iter()
1650            .map(|block| {
1651                Ok(BundlePredictionBlockCacheRecord {
1652                    prediction_id: block.prediction_id.clone(),
1653                    fold_id: block.fold_id.clone(),
1654                    prediction_level: block.level,
1655                    row_count: block.unit_ids.len(),
1656                    unit_ids: block.unit_ids.clone(),
1657                    sample_ids: Vec::new(),
1658                    content_fingerprint: stable_json_fingerprint(block)?,
1659                })
1660            })
1661            .collect::<Result<Vec<_>>>()?
1662    };
1663    if block_records != record.blocks {
1664        return Err(DagMlError::RuntimeValidation(format!(
1665            "prediction cache payload `{}` block fingerprints do not match cache record",
1666            payload.cache_id
1667        )));
1668    }
1669    Ok(())
1670}
1671
1672fn select_prediction_cache_blocks(
1673    requirement: &BundlePredictionRequirement,
1674    blocks: &[PredictionBlock],
1675) -> Result<Vec<PredictionBlock>> {
1676    requirement.validate()?;
1677    let mut selected = blocks
1678        .iter()
1679        .filter(|block| {
1680            block.producer_node == requirement.producer_node
1681                && block.partition == requirement.partition
1682        })
1683        .cloned()
1684        .collect::<Vec<_>>();
1685    if selected.is_empty() {
1686        return Err(DagMlError::RuntimeValidation(format!(
1687            "prediction cache requirement `{}` has no matching prediction blocks",
1688            requirement.key()
1689        )));
1690    }
1691    selected.sort_by(|left, right| {
1692        (
1693            left.fold_id.as_ref().map(ToString::to_string),
1694            left.prediction_id.clone(),
1695        )
1696            .cmp(&(
1697                right.fold_id.as_ref().map(ToString::to_string),
1698                right.prediction_id.clone(),
1699            ))
1700    });
1701    Ok(selected)
1702}
1703
1704fn select_aggregated_prediction_cache_blocks(
1705    requirement: &BundlePredictionRequirement,
1706    blocks: &[AggregatedPredictionBlock],
1707) -> Result<Vec<AggregatedPredictionBlock>> {
1708    requirement.validate()?;
1709    if requirement.prediction_level == PredictionLevel::Sample {
1710        return Err(DagMlError::RuntimeValidation(format!(
1711            "aggregated prediction cache requirement `{}` must use target or group level",
1712            requirement.key()
1713        )));
1714    }
1715    let mut selected = blocks
1716        .iter()
1717        .filter(|block| {
1718            block.producer_node == requirement.producer_node
1719                && block.partition == requirement.partition
1720                && block.level == requirement.prediction_level
1721        })
1722        .cloned()
1723        .collect::<Vec<_>>();
1724    if selected.is_empty() {
1725        return Err(DagMlError::RuntimeValidation(format!(
1726            "aggregated prediction cache requirement `{}` has no matching prediction blocks",
1727            requirement.key()
1728        )));
1729    }
1730    selected.sort_by(|left, right| {
1731        (
1732            left.fold_id.as_ref().map(ToString::to_string),
1733            left.prediction_id.clone(),
1734        )
1735            .cmp(&(
1736                right.fold_id.as_ref().map(ToString::to_string),
1737                right.prediction_id.clone(),
1738            ))
1739    });
1740    Ok(selected)
1741}
1742
1743fn build_prediction_cache_record_from_selected(
1744    requirement: &BundlePredictionRequirement,
1745    selected: &[PredictionBlock],
1746) -> Result<BundlePredictionCacheRecord> {
1747    requirement.validate()?;
1748    if selected.is_empty() {
1749        return Err(DagMlError::RuntimeValidation(format!(
1750            "prediction cache requirement `{}` has no matching prediction blocks",
1751            requirement.key()
1752        )));
1753    }
1754    let mut fold_ids = BTreeSet::new();
1755    let mut sample_ids = BTreeSet::new();
1756    let mut target_names: Option<Vec<String>> = None;
1757    let mut prediction_width: Option<usize> = None;
1758    let mut row_count = 0usize;
1759    let mut block_records = Vec::new();
1760    for block in selected {
1761        if block.producer_node != requirement.producer_node
1762            || block.partition != requirement.partition
1763        {
1764            return Err(DagMlError::RuntimeValidation(format!(
1765                "prediction cache `{}` contains a block outside the requirement scope",
1766                requirement.key()
1767            )));
1768        }
1769        let width = block.validate_shape()?;
1770        if prediction_width.is_some_and(|expected| expected != width) {
1771            return Err(DagMlError::RuntimeValidation(format!(
1772                "prediction cache `{}` has inconsistent prediction width",
1773                requirement.key()
1774            )));
1775        }
1776        prediction_width = Some(width);
1777        let block_target_names = normalized_prediction_targets(block, width);
1778        if target_names
1779            .as_ref()
1780            .is_some_and(|expected| expected != &block_target_names)
1781        {
1782            return Err(DagMlError::RuntimeValidation(format!(
1783                "prediction cache `{}` has inconsistent target names",
1784                requirement.key()
1785            )));
1786        }
1787        target_names = Some(block_target_names);
1788        if let Some(fold_id) = &block.fold_id {
1789            fold_ids.insert(fold_id.clone());
1790        }
1791        sample_ids.extend(block.sample_ids.iter().cloned());
1792        row_count += block.sample_ids.len();
1793        block_records.push(BundlePredictionBlockCacheRecord {
1794            prediction_id: block.prediction_id.clone(),
1795            fold_id: block.fold_id.clone(),
1796            prediction_level: PredictionLevel::Sample,
1797            row_count: block.sample_ids.len(),
1798            unit_ids: Vec::new(),
1799            sample_ids: block.sample_ids.clone(),
1800            content_fingerprint: stable_json_fingerprint(block)?,
1801        });
1802    }
1803
1804    let record = BundlePredictionCacheRecord {
1805        requirement_key: requirement.key(),
1806        cache_id: format!("prediction-cache:{}", requirement.key()),
1807        format: BUNDLE_PREDICTION_CACHE_FORMAT.to_string(),
1808        partition: requirement.partition.clone(),
1809        prediction_level: requirement.prediction_level,
1810        fold_ids: fold_ids.into_iter().collect(),
1811        unit_ids: requirement.unit_ids.clone(),
1812        sample_ids: sample_ids.into_iter().collect(),
1813        prediction_width: prediction_width.unwrap_or_default(),
1814        target_names: target_names.unwrap_or_default(),
1815        block_count: block_records.len(),
1816        row_count,
1817        content_fingerprint: stable_json_fingerprint(selected)?,
1818        blocks: block_records,
1819    };
1820    validate_prediction_cache_matches_requirement(&record, requirement)?;
1821    record.validate()?;
1822    Ok(record)
1823}
1824
1825fn build_aggregated_prediction_cache_record_from_selected(
1826    requirement: &BundlePredictionRequirement,
1827    selected: &[AggregatedPredictionBlock],
1828) -> Result<BundlePredictionCacheRecord> {
1829    requirement.validate()?;
1830    if requirement.prediction_level == PredictionLevel::Sample {
1831        return Err(DagMlError::RuntimeValidation(format!(
1832            "aggregated prediction cache requirement `{}` must use target or group level",
1833            requirement.key()
1834        )));
1835    }
1836    if selected.is_empty() {
1837        return Err(DagMlError::RuntimeValidation(format!(
1838            "aggregated prediction cache requirement `{}` has no matching prediction blocks",
1839            requirement.key()
1840        )));
1841    }
1842    let mut fold_ids = BTreeSet::new();
1843    let mut unit_ids = BTreeSet::new();
1844    let mut target_names: Option<Vec<String>> = None;
1845    let mut prediction_width: Option<usize> = None;
1846    let mut row_count = 0usize;
1847    let mut block_records = Vec::new();
1848    for block in selected {
1849        if block.producer_node != requirement.producer_node
1850            || block.partition != requirement.partition
1851            || block.level != requirement.prediction_level
1852        {
1853            return Err(DagMlError::RuntimeValidation(format!(
1854                "aggregated prediction cache `{}` contains a block outside the requirement scope",
1855                requirement.key()
1856            )));
1857        }
1858        let width = block.validate_shape()?;
1859        if prediction_width.is_some_and(|expected| expected != width) {
1860            return Err(DagMlError::RuntimeValidation(format!(
1861                "aggregated prediction cache `{}` has inconsistent prediction width",
1862                requirement.key()
1863            )));
1864        }
1865        prediction_width = Some(width);
1866        let block_target_names = normalized_aggregated_prediction_targets(block, width);
1867        if target_names
1868            .as_ref()
1869            .is_some_and(|expected| expected != &block_target_names)
1870        {
1871            return Err(DagMlError::RuntimeValidation(format!(
1872                "aggregated prediction cache `{}` has inconsistent target names",
1873                requirement.key()
1874            )));
1875        }
1876        target_names = Some(block_target_names);
1877        if let Some(fold_id) = &block.fold_id {
1878            fold_ids.insert(fold_id.clone());
1879        }
1880        unit_ids.extend(block.unit_ids.iter().cloned());
1881        row_count += block.unit_ids.len();
1882        block_records.push(BundlePredictionBlockCacheRecord {
1883            prediction_id: block.prediction_id.clone(),
1884            fold_id: block.fold_id.clone(),
1885            prediction_level: block.level,
1886            row_count: block.unit_ids.len(),
1887            unit_ids: block.unit_ids.clone(),
1888            sample_ids: Vec::new(),
1889            content_fingerprint: stable_json_fingerprint(block)?,
1890        });
1891    }
1892
1893    let record = BundlePredictionCacheRecord {
1894        requirement_key: requirement.key(),
1895        cache_id: format!("prediction-cache:{}", requirement.key()),
1896        format: BUNDLE_PREDICTION_CACHE_FORMAT.to_string(),
1897        partition: requirement.partition.clone(),
1898        prediction_level: requirement.prediction_level,
1899        fold_ids: fold_ids.into_iter().collect(),
1900        unit_ids: unit_ids.into_iter().collect(),
1901        sample_ids: Vec::new(),
1902        prediction_width: prediction_width.unwrap_or_default(),
1903        target_names: target_names.unwrap_or_default(),
1904        block_count: block_records.len(),
1905        row_count,
1906        content_fingerprint: stable_json_fingerprint(selected)?,
1907        blocks: block_records,
1908    };
1909    validate_prediction_cache_matches_requirement(&record, requirement)?;
1910    record.validate()?;
1911    Ok(record)
1912}
1913
1914fn validate_prediction_cache_matches_requirement(
1915    cache: &BundlePredictionCacheRecord,
1916    requirement: &BundlePredictionRequirement,
1917) -> Result<()> {
1918    if cache.requirement_key != requirement.key()
1919        || cache.partition != requirement.partition
1920        || cache.prediction_level != requirement.prediction_level
1921        || cache.fold_ids != requirement.fold_ids
1922        || cache.unit_ids != requirement.unit_ids
1923        || cache.sample_ids != requirement.sample_ids
1924        || cache.prediction_width != requirement.prediction_width
1925        || cache.target_names != requirement.target_names
1926    {
1927        return Err(DagMlError::RuntimeValidation(format!(
1928            "prediction cache `{}` does not match requirement `{}`",
1929            cache.cache_id,
1930            requirement.key()
1931        )));
1932    }
1933    Ok(())
1934}
1935
1936fn normalized_prediction_targets(block: &PredictionBlock, width: usize) -> Vec<String> {
1937    if block.target_names.is_empty() {
1938        (0..width).map(|index| format!("p{index}")).collect()
1939    } else {
1940        block.target_names.clone()
1941    }
1942}
1943
1944fn normalized_aggregated_prediction_targets(
1945    block: &AggregatedPredictionBlock,
1946    width: usize,
1947) -> Vec<String> {
1948    if block.target_names.is_empty() {
1949        (0..width).map(|index| format!("p{index}")).collect()
1950    } else {
1951        block.target_names.clone()
1952    }
1953}
1954
1955fn sample_prediction_units(sample_ids: &[SampleId]) -> Vec<PredictionUnitId> {
1956    sample_ids
1957        .iter()
1958        .cloned()
1959        .map(PredictionUnitId::Sample)
1960        .collect()
1961}
1962
1963fn validate_prediction_units(
1964    label: &str,
1965    expected_level: PredictionLevel,
1966    unit_ids: &[PredictionUnitId],
1967) -> Result<()> {
1968    validate_unique_ids(label, unit_ids)?;
1969    for unit_id in unit_ids {
1970        if unit_id.level() != expected_level {
1971            return Err(DagMlError::RuntimeValidation(format!(
1972                "{label} `{unit_id}` does not match prediction level {:?}",
1973                expected_level
1974            )));
1975        }
1976    }
1977    Ok(())
1978}
1979
1980fn validate_fingerprint(label: &str, value: &str) -> Result<()> {
1981    if value.len() != 64 || !value.bytes().all(|byte| byte.is_ascii_hexdigit()) {
1982        return Err(DagMlError::RuntimeValidation(format!(
1983            "{label} fingerprint must be a 64-character hex digest"
1984        )));
1985    }
1986    Ok(())
1987}
1988
1989fn validate_non_empty(label: &str, value: &str) -> Result<()> {
1990    if value.trim().is_empty() {
1991        return Err(DagMlError::RuntimeValidation(format!("{label} is empty")));
1992    }
1993    Ok(())
1994}
1995
1996fn validate_unique_ids<T>(label: &str, values: &[T]) -> Result<()>
1997where
1998    T: Ord + ToString,
1999{
2000    let mut seen = BTreeSet::new();
2001    for value in values {
2002        if !seen.insert(value) {
2003            return Err(DagMlError::RuntimeValidation(format!(
2004                "duplicate {label} `{}`",
2005                value.to_string()
2006            )));
2007        }
2008    }
2009    Ok(())
2010}
2011
2012#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
2013pub struct ReplayPhaseRequest {
2014    pub bundle_id: BundleId,
2015    pub phase: Phase,
2016    #[serde(default)]
2017    pub data_envelope_keys: Vec<String>,
2018}
2019
2020impl ReplayPhaseRequest {
2021    pub fn validate_for_bundle(&self, bundle: &ExecutionBundle) -> Result<()> {
2022        self.validate_for_bundle_with_prediction_cache_store(bundle, false)
2023    }
2024
2025    pub fn validate_for_bundle_with_prediction_cache_store(
2026        &self,
2027        bundle: &ExecutionBundle,
2028        prediction_cache_available: bool,
2029    ) -> Result<()> {
2030        self.validate_for_bundle_internal(bundle, prediction_cache_available)
2031    }
2032
2033    pub fn validate_for_bundle_with_prediction_cache_payloads(
2034        &self,
2035        bundle: &ExecutionBundle,
2036        prediction_cache_payloads: Option<&BundlePredictionCachePayloadSet>,
2037    ) -> Result<()> {
2038        if let Some(payloads) = prediction_cache_payloads {
2039            payloads.validate_against_bundle(bundle)?;
2040        }
2041        self.validate_for_bundle_internal(bundle, prediction_cache_payloads.is_some())
2042    }
2043
2044    fn validate_for_bundle_internal(
2045        &self,
2046        bundle: &ExecutionBundle,
2047        prediction_cache_available: bool,
2048    ) -> Result<()> {
2049        bundle.validate()?;
2050        if self.bundle_id != bundle.bundle_id {
2051            return Err(DagMlError::RuntimeValidation(format!(
2052                "replay request bundle `{}` does not match bundle `{}`",
2053                self.bundle_id, bundle.bundle_id
2054            )));
2055        }
2056        if !matches!(self.phase, Phase::Predict | Phase::Explain | Phase::Refit) {
2057            return Err(DagMlError::RuntimeValidation(format!(
2058                "bundle replay phase {:?} is not supported",
2059                self.phase
2060            )));
2061        }
2062        if self.phase == Phase::Refit && !bundle.prediction_requirements.is_empty() {
2063            if prediction_cache_available {
2064                return self.validate_data_envelope_keys(bundle);
2065            }
2066            return Err(DagMlError::RuntimeValidation(format!(
2067                "bundle `{}` cannot replay REFIT because it depends on {} OOF prediction requirement(s) but stores only prediction cache manifests",
2068                bundle.bundle_id,
2069                bundle.prediction_requirements.len()
2070            )));
2071        }
2072        self.validate_data_envelope_keys(bundle)
2073    }
2074
2075    fn validate_data_envelope_keys(&self, bundle: &ExecutionBundle) -> Result<()> {
2076        let expected = bundle
2077            .data_requirements
2078            .iter()
2079            .map(BundleDataRequirement::key)
2080            .collect::<BTreeSet<_>>();
2081        let mut requested = BTreeSet::new();
2082        for key in &self.data_envelope_keys {
2083            if key.trim().is_empty() {
2084                return Err(DagMlError::RuntimeValidation(
2085                    "replay request contains an empty data envelope key".to_string(),
2086                ));
2087            }
2088            if !requested.insert(key.as_str()) {
2089                return Err(DagMlError::RuntimeValidation(format!(
2090                    "replay request contains duplicate data envelope key `{key}`"
2091                )));
2092            }
2093            if !expected.contains(key.as_str()) {
2094                return Err(DagMlError::RuntimeValidation(format!(
2095                    "replay request references unknown data envelope key `{key}`"
2096                )));
2097            }
2098        }
2099        for requirement in &bundle.data_requirements {
2100            let key = requirement.key();
2101            if !requested.contains(key.as_str()) {
2102                return Err(DagMlError::RuntimeValidation(format!(
2103                    "replay request is missing data envelope key `{key}`"
2104                )));
2105            }
2106        }
2107        Ok(())
2108    }
2109}
2110
2111#[cfg(test)]
2112mod tests {
2113    use super::*;
2114    use crate::controller::{ControllerManifest, ControllerRegistry};
2115    use crate::data::{
2116        AggregateRepresentation, RepresentationCardinality, RepresentationCompatibilityOutcome,
2117        RepresentationCompatibilityReport, RepresentationMissingSourcePolicy, RepresentationPlan,
2118        RepresentationReplayManifest,
2119    };
2120    use crate::dsl::{compile_pipeline_dsl_with_generation, PipelineDslSpec};
2121    use crate::graph::GraphSpec;
2122    use crate::ids::{ArtifactId, FoldId, SampleId, TargetId};
2123    use crate::plan::{build_execution_plan, CampaignSpec};
2124    use crate::relation::EntityUnitLevel;
2125    use crate::selection::{
2126        select_candidate, CandidateScore, MetricObjective, SelectionMetric, SelectionPolicy,
2127    };
2128
2129    fn plan() -> ExecutionPlan {
2130        let graph: GraphSpec =
2131            serde_json::from_str(include_str!("../../../examples/minimal_graph.json")).unwrap();
2132        let campaign: CampaignSpec = serde_json::from_str(include_str!(
2133            "../../../examples/campaign_oof_generation.json"
2134        ))
2135        .unwrap();
2136        let manifests: Vec<ControllerManifest> =
2137            serde_json::from_str(include_str!("../../../examples/controller_manifests.json"))
2138                .unwrap();
2139        let mut registry = ControllerRegistry::new();
2140        for manifest in manifests {
2141            registry.register(manifest).unwrap();
2142        }
2143        build_execution_plan("plan:bundle", graph, campaign, &registry).unwrap()
2144    }
2145
2146    fn branch_merge_plan() -> ExecutionPlan {
2147        let graph: GraphSpec = serde_json::from_str(include_str!(
2148            "../../../examples/branch_merge_oof_graph.json"
2149        ))
2150        .unwrap();
2151        let campaign: CampaignSpec = serde_json::from_str(include_str!(
2152            "../../../examples/campaign_branch_merge_oof.json"
2153        ))
2154        .unwrap();
2155        let manifests: Vec<ControllerManifest> =
2156            serde_json::from_str(include_str!("../../../examples/controller_manifests.json"))
2157                .unwrap();
2158        let mut registry = ControllerRegistry::new();
2159        for manifest in manifests {
2160            registry.register(manifest).unwrap();
2161        }
2162        build_execution_plan("plan:branch.merge.bundle", graph, campaign, &registry).unwrap()
2163    }
2164
2165    fn executable_dsl_plan() -> ExecutionPlan {
2166        let spec: PipelineDslSpec = serde_json::from_str(include_str!(
2167            "../../../examples/pipeline_dsl_branch_merge_executable.json"
2168        ))
2169        .unwrap();
2170        let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
2171        let manifests: Vec<ControllerManifest> =
2172            serde_json::from_str(include_str!("../../../examples/controller_manifests.json"))
2173                .unwrap();
2174        let mut registry = ControllerRegistry::new();
2175        for manifest in manifests {
2176            registry.register(manifest).unwrap();
2177        }
2178        build_execution_plan(
2179            "plan:dsl.branch.merge.bundle",
2180            compiled.graph,
2181            compiled.campaign_template,
2182            &registry,
2183        )
2184        .unwrap()
2185    }
2186
2187    fn branch_merge_selection_decisions() -> BTreeMap<String, SelectionDecision> {
2188        serde_json::from_str(include_str!(
2189            "../../../examples/fixtures/bundle/selection_decisions_branch_merge.json"
2190        ))
2191        .unwrap()
2192    }
2193
2194    fn refit_artifact(
2195        plan: &ExecutionPlan,
2196        node_id: &str,
2197        data_requirement_keys: Vec<String>,
2198        prediction_requirement_keys: Vec<String>,
2199    ) -> RefitArtifactRecord {
2200        let node_id = NodeId::new(node_id).unwrap();
2201        let node_plan = plan.node_plans.get(&node_id).unwrap();
2202        RefitArtifactRecord {
2203            node_id: node_plan.node_id.clone(),
2204            controller_id: node_plan.controller_id.clone(),
2205            artifact: ArtifactRef {
2206                id: ArtifactId::new(format!("artifact:{}:refit", node_plan.node_id)).unwrap(),
2207                kind: "mock_model".to_string(),
2208                controller_id: node_plan.controller_id.clone(),
2209                backend: None,
2210                uri: None,
2211                content_fingerprint: None,
2212                size_bytes: Some(128),
2213                plugin: None,
2214                plugin_version: None,
2215            },
2216            params_fingerprint: node_plan.params_fingerprint.clone(),
2217            data_requirement_keys,
2218            prediction_requirement_keys,
2219        }
2220    }
2221
2222    fn branch_merge_samples() -> Vec<SampleId> {
2223        vec![
2224            SampleId::new("sample:1").unwrap(),
2225            SampleId::new("sample:2").unwrap(),
2226            SampleId::new("sample:3").unwrap(),
2227            SampleId::new("sample:4").unwrap(),
2228        ]
2229    }
2230
2231    fn branch_merge_requirement(
2232        producer_node: &str,
2233        target_port: &str,
2234    ) -> BundlePredictionRequirement {
2235        BundlePredictionRequirement {
2236            producer_node: NodeId::new(producer_node).unwrap(),
2237            source_port: "oof".to_string(),
2238            consumer_node: NodeId::new("merge:stack.pred_plus_original.meta:ridge").unwrap(),
2239            target_port: target_port.to_string(),
2240            partition: PredictionPartition::Validation,
2241            prediction_level: PredictionLevel::Sample,
2242            fold_ids: vec![
2243                FoldId::new("fold:0").unwrap(),
2244                FoldId::new("fold:1").unwrap(),
2245            ],
2246            unit_ids: Vec::new(),
2247            sample_ids: branch_merge_samples(),
2248            prediction_width: 1,
2249            target_names: vec!["y".to_string()],
2250        }
2251    }
2252
2253    fn branch_merge_prediction_blocks(producer_node: &str, offset: f64) -> Vec<PredictionBlock> {
2254        let producer_node = NodeId::new(producer_node).unwrap();
2255        let samples = branch_merge_samples();
2256        vec![
2257            PredictionBlock {
2258                prediction_id: Some(format!("prediction:{producer_node}:fold0")),
2259                producer_node: producer_node.clone(),
2260                partition: PredictionPartition::Validation,
2261                fold_id: Some(FoldId::new("fold:0").unwrap()),
2262                sample_ids: samples[0..2].to_vec(),
2263                values: vec![vec![offset + 0.1], vec![offset + 0.2]],
2264                target_names: vec!["y".to_string()],
2265            },
2266            PredictionBlock {
2267                prediction_id: Some(format!("prediction:{producer_node}:fold1")),
2268                producer_node,
2269                partition: PredictionPartition::Validation,
2270                fold_id: Some(FoldId::new("fold:1").unwrap()),
2271                sample_ids: samples[2..4].to_vec(),
2272                values: vec![vec![offset + 0.3], vec![offset + 0.4]],
2273                target_names: vec!["y".to_string()],
2274            },
2275        ]
2276    }
2277
2278    fn decision() -> SelectionDecision {
2279        select_candidate(
2280            &SelectionPolicy {
2281                id: "select:merge".to_string(),
2282                metric: SelectionMetric {
2283                    name: "rmse".to_string(),
2284                    objective: MetricObjective::Minimize,
2285                },
2286                required_metric_level: Some(crate::policy::PredictionLevel::Sample),
2287                require_finite: true,
2288                evaluation_scope: None,
2289                refit_slot_plan: None,
2290                stacking_fit_contract: None,
2291                reduction_id: None,
2292            },
2293            &[
2294                CandidateScore {
2295                    candidate_id: "model:base".to_string(),
2296                    metrics: BTreeMap::from([("rmse".to_string(), 1.0)]),
2297                    metadata: BTreeMap::from([(
2298                        "metric_level".to_string(),
2299                        serde_json::Value::String("sample".to_string()),
2300                    )]),
2301                },
2302                CandidateScore {
2303                    candidate_id: "model:other".to_string(),
2304                    metrics: BTreeMap::from([("rmse".to_string(), 2.0)]),
2305                    metadata: BTreeMap::from([(
2306                        "metric_level".to_string(),
2307                        serde_json::Value::String("sample".to_string()),
2308                    )]),
2309                },
2310            ],
2311        )
2312        .unwrap()
2313    }
2314
2315    fn selected_model_base_decision() -> SelectionDecision {
2316        decision()
2317    }
2318
2319    fn model_base_refit_artifact(plan: &ExecutionPlan) -> RefitArtifactRecord {
2320        let model_plan = plan
2321            .node_plans
2322            .get(&NodeId::new("model:base").unwrap())
2323            .unwrap();
2324        RefitArtifactRecord {
2325            node_id: model_plan.node_id.clone(),
2326            controller_id: model_plan.controller_id.clone(),
2327            artifact: ArtifactRef {
2328                id: ArtifactId::new("artifact:model:base:refit").unwrap(),
2329                kind: "sklearn_pickle".to_string(),
2330                controller_id: model_plan.controller_id.clone(),
2331                backend: None,
2332                uri: None,
2333                content_fingerprint: None,
2334                size_bytes: Some(128),
2335                plugin: None,
2336                plugin_version: None,
2337            },
2338            params_fingerprint: model_plan.params_fingerprint.clone(),
2339            data_requirement_keys: vec!["model:base.x".to_string()],
2340            prediction_requirement_keys: Vec::new(),
2341        }
2342    }
2343
2344    #[test]
2345    fn builds_bundle_from_execution_plan() {
2346        let plan = plan();
2347        let artifact = model_base_refit_artifact(&plan);
2348
2349        let bundle = build_execution_bundle(
2350            BundleId::new("bundle:demo").unwrap(),
2351            &plan,
2352            Some(plan.variants[0].variant_id.clone()),
2353            BTreeMap::from([("merge".to_string(), decision())]),
2354            vec![artifact],
2355        )
2356        .unwrap();
2357
2358        bundle.validate_against_plan(&plan).unwrap();
2359        assert_eq!(bundle.data_requirements.len(), 1);
2360    }
2361
2362    #[test]
2363    fn bundle_data_requirements_accept_d7_replay_contracts() {
2364        let plan = plan();
2365        let artifact = model_base_refit_artifact(&plan);
2366        let mut bundle = build_execution_bundle(
2367            BundleId::new("bundle:d7.replay").unwrap(),
2368            &plan,
2369            Some(plan.variants[0].variant_id.clone()),
2370            BTreeMap::from([("merge".to_string(), decision())]),
2371            vec![artifact],
2372        )
2373        .unwrap();
2374        let relation_fingerprint = bundle.data_requirements[0]
2375            .relation_fingerprint
2376            .clone()
2377            .unwrap_or_else(|| "a".repeat(64));
2378        bundle.data_requirements[0].representation_replay_manifest =
2379            Some(RepresentationReplayManifest {
2380                manifest_id: "repr:d7.bundle".to_string(),
2381                representation_plan: RepresentationPlan::Aggregate(AggregateRepresentation {
2382                    input_unit_level: EntityUnitLevel::Observation,
2383                    output_unit_level: EntityUnitLevel::PhysicalSample,
2384                    reducer_id: None,
2385                    method: Some("mean".to_string()),
2386                    cardinality: RepresentationCardinality::ManyToOne,
2387                }),
2388                combination_plan: None,
2389                output_unit_level: EntityUnitLevel::PhysicalSample,
2390                output_representation: Some("tabular_numeric".to_string()),
2391                relation_fingerprint: Some(relation_fingerprint.clone()),
2392                feature_schema_fingerprint: Some("b".repeat(64)),
2393                final_reduction_id: None,
2394                sample_observation_mapping: Vec::new(),
2395                combo_selection: Vec::new(),
2396                qc_policy_refs: Vec::new(),
2397                outlier_policy_refs: Vec::new(),
2398                missing_source_policy: None,
2399                missing_repetition_policy: None,
2400                prediction_representation: None,
2401                final_output_unit_level: Some(EntityUnitLevel::PhysicalSample),
2402                train_compatibility: None,
2403                predict_compatibility: None,
2404                metadata: BTreeMap::new(),
2405            });
2406        bundle.data_requirements[0].representation_compatibility =
2407            Some(RepresentationCompatibilityReport {
2408                policy: RepresentationMissingSourcePolicy::Strict,
2409                outcome: RepresentationCompatibilityOutcome::Compatible,
2410                fallback_used: None,
2411                warning_severity: None,
2412                affected_source_count: 0,
2413                affected_repetition_count: 0,
2414                affected_sample_count: 0,
2415                train_relation_fingerprint: Some(relation_fingerprint),
2416                predict_relation_fingerprint: None,
2417                train_unit_count: Some(2),
2418                predict_unit_count: Some(2),
2419                fixed_width_required: false,
2420                final_reducer_stabilizes_output: true,
2421                cartesian_combo_count_changed: false,
2422                late_fusion_branch_delta: false,
2423                messages: Vec::new(),
2424                metadata: BTreeMap::new(),
2425            });
2426        bundle.validate_against_plan(&plan).unwrap();
2427
2428        bundle.data_requirements[0]
2429            .representation_replay_manifest
2430            .as_mut()
2431            .unwrap()
2432            .relation_fingerprint = Some("c".repeat(64));
2433        if bundle.data_requirements[0].relation_fingerprint.is_some() {
2434            assert!(bundle.validate().is_err());
2435        }
2436    }
2437
2438    #[test]
2439    fn d9_negative_prediction_cache_refuses_missing_aggregated_unit_ids() {
2440        let cache = BundlePredictionCacheRecord {
2441            requirement_key: "model:base.oof->model:meta.pred".to_string(),
2442            cache_id: "prediction-cache:d9.missing-units".to_string(),
2443            format: BUNDLE_PREDICTION_CACHE_FORMAT.to_string(),
2444            partition: PredictionPartition::Validation,
2445            prediction_level: PredictionLevel::Target,
2446            fold_ids: vec![FoldId::new("fold:0").unwrap()],
2447            unit_ids: Vec::new(),
2448            sample_ids: Vec::new(),
2449            prediction_width: 1,
2450            target_names: vec!["y".to_string()],
2451            block_count: 1,
2452            row_count: 1,
2453            content_fingerprint: "d".repeat(64),
2454            blocks: vec![BundlePredictionBlockCacheRecord {
2455                prediction_id: Some("prediction:d9.target.fold0".to_string()),
2456                fold_id: Some(FoldId::new("fold:0").unwrap()),
2457                prediction_level: PredictionLevel::Target,
2458                row_count: 1,
2459                unit_ids: vec![PredictionUnitId::Target(TargetId::new("target:a").unwrap())],
2460                sample_ids: Vec::new(),
2461                content_fingerprint: "e".repeat(64),
2462            }],
2463        };
2464
2465        let error = cache.validate().unwrap_err().to_string();
2466        assert!(
2467            error.contains("row_count does not match unique unit ids"),
2468            "unexpected D9 missing-unit-id cache error: {error}"
2469        );
2470    }
2471
2472    #[test]
2473    fn refit_artifact_validation_checks_portable_artifact_metadata() {
2474        let plan = plan();
2475        let mut artifact = model_base_refit_artifact(&plan);
2476        artifact.artifact.backend = Some(crate::runtime::ArtifactBackend::Joblib);
2477        artifact.artifact.uri = Some("artifacts/model.joblib".to_string());
2478        artifact.artifact.content_fingerprint = Some("c".repeat(64));
2479        artifact.artifact.plugin = Some("dagml.sklearn".to_string());
2480        artifact.artifact.plugin_version = Some("1.0.0".to_string());
2481        artifact.validate().unwrap();
2482
2483        artifact.artifact.content_fingerprint = Some("short".to_string());
2484        assert!(artifact
2485            .validate()
2486            .unwrap_err()
2487            .to_string()
2488            .contains("artifact content fingerprint"));
2489    }
2490
2491    #[test]
2492    fn bundle_selections_must_match_plan_and_refit_artifacts() {
2493        let plan = plan();
2494        let artifact = model_base_refit_artifact(&plan);
2495        let valid = build_execution_bundle(
2496            BundleId::new("bundle:selected.model").unwrap(),
2497            &plan,
2498            Some(plan.variants[0].variant_id.clone()),
2499            BTreeMap::from([("model".to_string(), selected_model_base_decision())]),
2500            vec![artifact.clone()],
2501        )
2502        .unwrap();
2503        valid.validate_against_plan(&plan).unwrap();
2504
2505        assert!(build_execution_bundle(
2506            BundleId::new("bundle:selected.model.missing.artifact").unwrap(),
2507            &plan,
2508            Some(plan.variants[0].variant_id.clone()),
2509            BTreeMap::from([("model".to_string(), selected_model_base_decision())]),
2510            Vec::new(),
2511        )
2512        .is_err());
2513
2514        let mut missing_level = selected_model_base_decision();
2515        missing_level.metric_level = None;
2516        assert!(build_execution_bundle(
2517            BundleId::new("bundle:selected.missing.level").unwrap(),
2518            &plan,
2519            Some(plan.variants[0].variant_id.clone()),
2520            BTreeMap::from([("model".to_string(), missing_level)]),
2521            vec![artifact.clone()],
2522        )
2523        .is_err());
2524
2525        let mut wrong_level = selected_model_base_decision();
2526        wrong_level.metric_level = Some(crate::policy::PredictionLevel::Target);
2527        assert!(build_execution_bundle(
2528            BundleId::new("bundle:selected.wrong.level").unwrap(),
2529            &plan,
2530            Some(plan.variants[0].variant_id.clone()),
2531            BTreeMap::from([("model".to_string(), wrong_level)]),
2532            vec![artifact.clone()],
2533        )
2534        .is_err());
2535
2536        let mut unknown = selected_model_base_decision();
2537        unknown.selected_candidate_id = "model:missing".to_string();
2538        unknown.ranked_candidates[0].candidate_id = "model:missing".to_string();
2539        assert!(build_execution_bundle(
2540            BundleId::new("bundle:selected.unknown").unwrap(),
2541            &plan,
2542            Some(plan.variants[0].variant_id.clone()),
2543            BTreeMap::from([("model".to_string(), unknown)]),
2544            vec![artifact],
2545        )
2546        .is_err());
2547    }
2548
2549    #[test]
2550    fn bundle_artifact_params_follow_selected_generation_variant() {
2551        let plan = executable_dsl_plan();
2552        let selected_variant = &plan.variants[0];
2553        let node_plan = plan
2554            .node_plans
2555            .get(&NodeId::new("branch:b0.model:ridge").unwrap())
2556            .unwrap();
2557        let effective_params = selected_variant
2558            .effective_params_for_node(&node_plan.node_id, &node_plan.params)
2559            .unwrap();
2560        let effective_fingerprint = stable_json_fingerprint(&effective_params).unwrap();
2561        assert_ne!(effective_fingerprint, node_plan.params_fingerprint);
2562
2563        let artifact = RefitArtifactRecord {
2564            node_id: node_plan.node_id.clone(),
2565            controller_id: node_plan.controller_id.clone(),
2566            artifact: ArtifactRef {
2567                id: ArtifactId::new("artifact:branch:b0.model:ridge:refit").unwrap(),
2568                kind: "mock_model".to_string(),
2569                controller_id: node_plan.controller_id.clone(),
2570                backend: None,
2571                uri: None,
2572                content_fingerprint: None,
2573                size_bytes: Some(128),
2574                plugin: None,
2575                plugin_version: None,
2576            },
2577            params_fingerprint: effective_fingerprint,
2578            data_requirement_keys: vec!["branch:b0.model:ridge.x".to_string()],
2579            prediction_requirement_keys: Vec::new(),
2580        };
2581
2582        build_execution_bundle(
2583            BundleId::new("bundle:dsl.variant.params").unwrap(),
2584            &plan,
2585            Some(selected_variant.variant_id.clone()),
2586            BTreeMap::new(),
2587            vec![artifact.clone()],
2588        )
2589        .unwrap();
2590
2591        let mut stale_artifact = artifact;
2592        stale_artifact.params_fingerprint = node_plan.params_fingerprint.clone();
2593        let error = build_execution_bundle(
2594            BundleId::new("bundle:dsl.variant.params.stale").unwrap(),
2595            &plan,
2596            Some(selected_variant.variant_id.clone()),
2597            BTreeMap::new(),
2598            vec![stale_artifact],
2599        )
2600        .unwrap_err();
2601        assert!(format!("{error}").contains("artifact params"));
2602    }
2603
2604    #[test]
2605    fn branch_merge_bundle_links_selected_refits_and_fold_aligned_oof_caches() {
2606        let plan = branch_merge_plan();
2607        let b0_requirement = branch_merge_requirement("branch:b0.model:ridge", "b0_oof");
2608        let b1_requirement = branch_merge_requirement("branch:b1.model:rf", "b1_oof");
2609        let b0_cache = build_prediction_cache_record(
2610            &b0_requirement,
2611            &branch_merge_prediction_blocks("branch:b0.model:ridge", 0.0),
2612        )
2613        .unwrap();
2614        let b1_cache = build_prediction_cache_record(
2615            &b1_requirement,
2616            &branch_merge_prediction_blocks("branch:b1.model:rf", 1.0),
2617        )
2618        .unwrap();
2619        let b0_artifact = refit_artifact(
2620            &plan,
2621            "branch:b0.model:ridge",
2622            vec!["branch:b0.model:ridge.x".to_string()],
2623            Vec::new(),
2624        );
2625        let b1_artifact = refit_artifact(
2626            &plan,
2627            "branch:b1.model:rf",
2628            vec!["branch:b1.model:rf.x".to_string()],
2629            Vec::new(),
2630        );
2631        let merge_artifact = refit_artifact(
2632            &plan,
2633            "merge:stack.pred_plus_original.meta:ridge",
2634            vec!["merge:stack.pred_plus_original.meta:ridge.x_original".to_string()],
2635            vec![b0_requirement.key(), b1_requirement.key()],
2636        );
2637
2638        let bundle = build_execution_bundle_with_prediction_contracts(
2639            BundleId::new("bundle:branch.merge.selected.refit").unwrap(),
2640            &plan,
2641            Some(plan.variants[0].variant_id.clone()),
2642            branch_merge_selection_decisions(),
2643            vec![
2644                b0_artifact.clone(),
2645                b1_artifact.clone(),
2646                merge_artifact.clone(),
2647            ],
2648            vec![b0_requirement.clone(), b1_requirement.clone()],
2649            vec![b0_cache.clone(), b1_cache.clone()],
2650        )
2651        .unwrap();
2652        bundle.validate_against_plan(&plan).unwrap();
2653        assert_eq!(bundle.selections.len(), 3);
2654        assert_eq!(bundle.prediction_requirements.len(), 2);
2655        assert_eq!(
2656            bundle.refit_artifacts[2].data_requirement_keys,
2657            vec!["merge:stack.pred_plus_original.meta:ridge.x_original"]
2658        );
2659        assert_eq!(
2660            bundle.refit_artifacts[2].prediction_requirement_keys,
2661            vec![
2662                "branch:b0.model:ridge.oof->merge:stack.pred_plus_original.meta:ridge.b0_oof",
2663                "branch:b1.model:rf.oof->merge:stack.pred_plus_original.meta:ridge.b1_oof",
2664            ]
2665        );
2666
2667        assert!(build_execution_bundle_with_prediction_contracts(
2668            BundleId::new("bundle:branch.merge.missing.branch.refit").unwrap(),
2669            &plan,
2670            Some(plan.variants[0].variant_id.clone()),
2671            branch_merge_selection_decisions(),
2672            vec![b0_artifact.clone(), merge_artifact.clone()],
2673            vec![b0_requirement.clone(), b1_requirement.clone()],
2674            vec![b0_cache.clone(), b1_cache.clone()],
2675        )
2676        .is_err());
2677
2678        let mut misaligned_cache = b0_cache;
2679        misaligned_cache.blocks[0].sample_ids = vec![
2680            SampleId::new("sample:1").unwrap(),
2681            SampleId::new("sample:3").unwrap(),
2682        ];
2683        misaligned_cache.blocks[1].sample_ids = vec![
2684            SampleId::new("sample:2").unwrap(),
2685            SampleId::new("sample:4").unwrap(),
2686        ];
2687        let error = build_execution_bundle_with_prediction_contracts(
2688            BundleId::new("bundle:branch.merge.misaligned.oof.cache").unwrap(),
2689            &plan,
2690            Some(plan.variants[0].variant_id.clone()),
2691            branch_merge_selection_decisions(),
2692            vec![b0_artifact, b1_artifact, merge_artifact],
2693            vec![b0_requirement, b1_requirement],
2694            vec![misaligned_cache, b1_cache],
2695        )
2696        .unwrap_err()
2697        .to_string();
2698        assert!(
2699            error.contains("does not match validation samples"),
2700            "unexpected fold-alignment error: {error}"
2701        );
2702    }
2703
2704    #[test]
2705    fn prediction_requirements_are_typed_and_validate_against_oof_edges() {
2706        let plan = branch_merge_plan();
2707        let meta_plan = plan
2708            .node_plans
2709            .get(&NodeId::new("merge:stack.pred_plus_original.meta:ridge").unwrap())
2710            .unwrap();
2711        let producer_node = NodeId::new("branch:b0.model:ridge").unwrap();
2712        let fold0 = FoldId::new("fold:0").unwrap();
2713        let fold1 = FoldId::new("fold:1").unwrap();
2714        let samples = [
2715            SampleId::new("sample:1").unwrap(),
2716            SampleId::new("sample:2").unwrap(),
2717            SampleId::new("sample:3").unwrap(),
2718            SampleId::new("sample:4").unwrap(),
2719        ];
2720        let requirement = BundlePredictionRequirement {
2721            producer_node: producer_node.clone(),
2722            source_port: "oof".to_string(),
2723            consumer_node: meta_plan.node_id.clone(),
2724            target_port: "b0_oof".to_string(),
2725            partition: PredictionPartition::Validation,
2726            prediction_level: PredictionLevel::Sample,
2727            fold_ids: vec![fold0.clone(), fold1.clone()],
2728            unit_ids: Vec::new(),
2729            sample_ids: samples.to_vec(),
2730            prediction_width: 1,
2731            target_names: vec!["y".to_string()],
2732        };
2733        let prediction_blocks = vec![
2734            PredictionBlock {
2735                prediction_id: Some("prediction:branch:b0.fold0".to_string()),
2736                producer_node: producer_node.clone(),
2737                partition: PredictionPartition::Validation,
2738                fold_id: Some(fold0),
2739                sample_ids: samples[0..2].to_vec(),
2740                values: vec![vec![0.1], vec![0.2]],
2741                target_names: vec!["y".to_string()],
2742            },
2743            PredictionBlock {
2744                prediction_id: Some("prediction:branch:b0.fold1".to_string()),
2745                producer_node: producer_node.clone(),
2746                partition: PredictionPartition::Validation,
2747                fold_id: Some(fold1),
2748                sample_ids: samples[2..4].to_vec(),
2749                values: vec![vec![0.3], vec![0.4]],
2750                target_names: vec!["y".to_string()],
2751            },
2752        ];
2753        let cache = build_prediction_cache_record(&requirement, &prediction_blocks).unwrap();
2754        let payload = build_prediction_cache_payload(&requirement, &prediction_blocks).unwrap();
2755        assert_eq!(cache.prediction_level, PredictionLevel::Sample);
2756        assert_eq!(payload.prediction_level, PredictionLevel::Sample);
2757        assert!(cache
2758            .blocks
2759            .iter()
2760            .all(|block| block.prediction_level == PredictionLevel::Sample));
2761        validate_prediction_cache_payload_matches_record(&payload, &cache).unwrap();
2762        let mut wrong_level_requirement = requirement.clone();
2763        wrong_level_requirement.prediction_level = PredictionLevel::Target;
2764        assert!(wrong_level_requirement.validate().is_err());
2765        let mut wrong_level_cache = cache.clone();
2766        wrong_level_cache.prediction_level = PredictionLevel::Target;
2767        assert!(wrong_level_cache.validate().is_err());
2768        let mut wrong_level_payload = payload.clone();
2769        wrong_level_payload.prediction_level = PredictionLevel::Target;
2770        assert!(wrong_level_payload.validate().is_err());
2771        let prediction_key = requirement.key();
2772        let artifact = RefitArtifactRecord {
2773            node_id: meta_plan.node_id.clone(),
2774            controller_id: meta_plan.controller_id.clone(),
2775            artifact: ArtifactRef {
2776                id: ArtifactId::new("artifact:merge:stack.pred_plus_original.meta:ridge:refit")
2777                    .unwrap(),
2778                kind: "mock_model".to_string(),
2779                controller_id: meta_plan.controller_id.clone(),
2780                backend: None,
2781                uri: None,
2782                content_fingerprint: None,
2783                size_bytes: Some(128),
2784                plugin: None,
2785                plugin_version: None,
2786            },
2787            params_fingerprint: meta_plan.params_fingerprint.clone(),
2788            data_requirement_keys: vec![
2789                "merge:stack.pred_plus_original.meta:ridge.x_original".to_string()
2790            ],
2791            prediction_requirement_keys: vec![prediction_key],
2792        };
2793
2794        assert!(build_execution_bundle(
2795            BundleId::new("bundle:missing.prediction.requirement").unwrap(),
2796            &plan,
2797            Some(plan.variants[0].variant_id.clone()),
2798            BTreeMap::new(),
2799            vec![artifact.clone()],
2800        )
2801        .is_err());
2802
2803        assert!(build_execution_bundle_with_prediction_requirements(
2804            BundleId::new("bundle:typed.prediction.requirement.without.cache").unwrap(),
2805            &plan,
2806            Some(plan.variants[0].variant_id.clone()),
2807            BTreeMap::new(),
2808            vec![artifact.clone()],
2809            vec![requirement.clone()],
2810        )
2811        .is_err());
2812
2813        let bundle = build_execution_bundle_with_prediction_contracts(
2814            BundleId::new("bundle:typed.prediction.requirement").unwrap(),
2815            &plan,
2816            Some(plan.variants[0].variant_id.clone()),
2817            BTreeMap::new(),
2818            vec![artifact],
2819            vec![requirement],
2820            vec![cache],
2821        )
2822        .unwrap();
2823        bundle.validate_against_plan(&plan).unwrap();
2824        assert_eq!(bundle.prediction_requirements.len(), 1);
2825        assert_eq!(bundle.prediction_caches.len(), 1);
2826        assert_eq!(
2827            bundle.refit_artifacts[0].prediction_requirement_keys,
2828            vec!["branch:b0.model:ridge.oof->merge:stack.pred_plus_original.meta:ridge.b0_oof"]
2829        );
2830        let payload_set = BundlePredictionCachePayloadSet {
2831            bundle_id: bundle.bundle_id.clone(),
2832            schema_version: PREDICTION_CACHE_PAYLOAD_SCHEMA_VERSION,
2833            caches: vec![payload],
2834        };
2835        payload_set.validate_against_bundle(&bundle).unwrap();
2836        let refit_replay_request = ReplayPhaseRequest {
2837            bundle_id: bundle.bundle_id.clone(),
2838            phase: Phase::Refit,
2839            data_envelope_keys: bundle
2840                .data_requirements
2841                .iter()
2842                .map(BundleDataRequirement::key)
2843                .collect(),
2844        };
2845        refit_replay_request
2846            .validate_for_bundle_with_prediction_cache_payloads(&bundle, Some(&payload_set))
2847            .unwrap();
2848        let mut tampered_payload_set = payload_set.clone();
2849        tampered_payload_set.caches[0].blocks[0].values[0][0] = 99.0;
2850        assert!(tampered_payload_set
2851            .validate_against_bundle(&bundle)
2852            .is_err());
2853        let mut missing_payload_set = payload_set.clone();
2854        missing_payload_set.caches.clear();
2855        assert!(missing_payload_set
2856            .validate_against_bundle(&bundle)
2857            .is_err());
2858        assert!(refit_replay_request.validate_for_bundle(&bundle).is_err());
2859
2860        let mut wrong_data_owner = bundle.clone();
2861        wrong_data_owner.refit_artifacts[0].data_requirement_keys =
2862            vec!["branch:b0.model:ridge.x".to_string()];
2863        assert!(wrong_data_owner.validate().is_err());
2864
2865        let mut wrong_prediction_consumer = bundle;
2866        wrong_prediction_consumer.refit_artifacts[0].node_id =
2867            NodeId::new("branch:b0.model:ridge").unwrap();
2868        wrong_prediction_consumer.refit_artifacts[0]
2869            .data_requirement_keys
2870            .clear();
2871        assert!(wrong_prediction_consumer.validate().is_err());
2872    }
2873
2874    #[test]
2875    fn aggregated_prediction_cache_contracts_preserve_unit_ids() {
2876        let plan = branch_merge_plan();
2877        let producer_node = NodeId::new("branch:b0.model:ridge").unwrap();
2878        let consumer_node = NodeId::new("merge:stack.pred_plus_original.meta:ridge").unwrap();
2879        let fold0 = FoldId::new("fold:0").unwrap();
2880        let fold1 = FoldId::new("fold:1").unwrap();
2881        let target_a = PredictionUnitId::Target(TargetId::new("target:a").unwrap());
2882        let target_b = PredictionUnitId::Target(TargetId::new("target:b").unwrap());
2883        let requirement = BundlePredictionRequirement {
2884            producer_node: producer_node.clone(),
2885            source_port: "oof".to_string(),
2886            consumer_node: consumer_node.clone(),
2887            target_port: "b0_oof".to_string(),
2888            partition: PredictionPartition::Validation,
2889            prediction_level: PredictionLevel::Target,
2890            fold_ids: vec![fold0.clone(), fold1.clone()],
2891            unit_ids: vec![target_a.clone(), target_b.clone()],
2892            sample_ids: Vec::new(),
2893            prediction_width: 1,
2894            target_names: vec!["y".to_string()],
2895        };
2896        let aggregated_blocks = vec![
2897            AggregatedPredictionBlock {
2898                prediction_id: Some("prediction:branch:b0.target.fold0".to_string()),
2899                producer_node: producer_node.clone(),
2900                partition: PredictionPartition::Validation,
2901                fold_id: Some(fold0),
2902                level: PredictionLevel::Target,
2903                unit_ids: vec![target_a],
2904                values: vec![vec![0.15]],
2905                target_names: vec!["y".to_string()],
2906            },
2907            AggregatedPredictionBlock {
2908                prediction_id: Some("prediction:branch:b0.target.fold1".to_string()),
2909                producer_node,
2910                partition: PredictionPartition::Validation,
2911                fold_id: Some(fold1),
2912                level: PredictionLevel::Target,
2913                unit_ids: vec![target_b],
2914                values: vec![vec![0.35]],
2915                target_names: vec!["y".to_string()],
2916            },
2917        ];
2918
2919        let cache =
2920            build_aggregated_prediction_cache_record(&requirement, &aggregated_blocks).unwrap();
2921        let payload =
2922            build_aggregated_prediction_cache_payload(&requirement, &aggregated_blocks).unwrap();
2923        assert_eq!(cache.prediction_level, PredictionLevel::Target);
2924        assert_eq!(cache.unit_ids, requirement.unit_ids);
2925        assert!(cache.sample_ids.is_empty());
2926        assert!(payload.blocks.is_empty());
2927        assert_eq!(payload.aggregated_blocks.len(), 2);
2928        validate_prediction_cache_payload_matches_record(&payload, &cache).unwrap();
2929
2930        let artifact = refit_artifact(
2931            &plan,
2932            "merge:stack.pred_plus_original.meta:ridge",
2933            vec!["merge:stack.pred_plus_original.meta:ridge.x_original".to_string()],
2934            vec![requirement.key()],
2935        );
2936        let bundle = build_execution_bundle_with_prediction_contracts(
2937            BundleId::new("bundle:target.prediction.requirement").unwrap(),
2938            &plan,
2939            Some(plan.variants[0].variant_id.clone()),
2940            BTreeMap::new(),
2941            vec![artifact],
2942            vec![requirement],
2943            vec![cache],
2944        )
2945        .unwrap();
2946        bundle.validate_against_plan(&plan).unwrap();
2947
2948        let mut tampered_payload = payload;
2949        tampered_payload.aggregated_blocks[0].unit_ids =
2950            vec![PredictionUnitId::Target(TargetId::new("target:z").unwrap())];
2951        assert!(validate_prediction_cache_payload_matches_record(
2952            &tampered_payload,
2953            &bundle.prediction_caches[0]
2954        )
2955        .is_err());
2956    }
2957
2958    #[test]
2959    fn replay_envelopes_must_match_bundle_requirements() {
2960        let plan = plan();
2961        let bundle = build_execution_bundle(
2962            BundleId::new("bundle:demo").unwrap(),
2963            &plan,
2964            None,
2965            BTreeMap::new(),
2966            Vec::new(),
2967        )
2968        .unwrap();
2969        let envelope: ExternalDataPlanEnvelope = serde_json::from_str(include_str!(
2970            "../../../examples/fixtures/data/coordinator_data_plan_envelope_sample12.json"
2971        ))
2972        .unwrap();
2973
2974        bundle
2975            .validate_replay_envelopes(&BTreeMap::from([(
2976                "model:base.x".to_string(),
2977                envelope.clone(),
2978            )]))
2979            .unwrap();
2980
2981        let mut mismatched = envelope;
2982        mismatched.schema_fingerprint = "0".repeat(64);
2983        assert!(bundle
2984            .validate_replay_envelopes(&BTreeMap::from([("model:base.x".to_string(), mismatched,)]))
2985            .is_err());
2986    }
2987
2988    #[test]
2989    fn rejects_unsupported_bundle_schema_version() {
2990        let mut bundle = build_execution_bundle(
2991            BundleId::new("bundle:demo").unwrap(),
2992            &plan(),
2993            None,
2994            BTreeMap::new(),
2995            Vec::new(),
2996        )
2997        .unwrap();
2998        bundle.schema_version = EXECUTION_BUNDLE_SCHEMA_VERSION + 1;
2999
3000        assert!(bundle.validate().is_err());
3001
3002        bundle.schema_version = 0;
3003        assert!(bundle.validate().is_err());
3004    }
3005
3006    #[test]
3007    fn schema_migration_policy_is_explicit_and_refuses_implicit_migrations() {
3008        let bundle_policy = execution_bundle_schema_migration_policy();
3009        assert_eq!(
3010            bundle_policy.current_version,
3011            EXECUTION_BUNDLE_SCHEMA_VERSION
3012        );
3013        assert_eq!(
3014            bundle_policy.min_readable_version,
3015            MIN_READABLE_EXECUTION_BUNDLE_SCHEMA_VERSION
3016        );
3017        assert!(bundle_policy.automatic_migrations.is_empty());
3018        bundle_policy
3019            .validate_read_version(EXECUTION_BUNDLE_SCHEMA_VERSION, "bundle `current`")
3020            .unwrap();
3021        assert!(bundle_policy
3022            .validate_read_version(EXECUTION_BUNDLE_SCHEMA_VERSION + 1, "bundle `future`")
3023            .is_err());
3024        assert!(bundle_policy
3025            .validate_read_version(0, "bundle `zero`")
3026            .is_err());
3027
3028        let mut future_policy = SchemaMigrationPolicy {
3029            artifact: "execution_bundle".to_string(),
3030            current_version: 2,
3031            min_readable_version: 1,
3032            min_writable_version: 2,
3033            automatic_migrations: BTreeMap::new(),
3034        };
3035        assert!(future_policy
3036            .validate_read_version(1, "bundle `old-without-migration`")
3037            .is_err());
3038        future_policy.automatic_migrations.insert(1, 2);
3039        future_policy
3040            .validate_read_version(1, "bundle `old-with-migration`")
3041            .unwrap();
3042    }
3043
3044    #[test]
3045    fn prediction_cache_payload_schema_policy_rejects_unsupported_versions() {
3046        let policy = prediction_cache_payload_schema_migration_policy();
3047        assert_eq!(
3048            policy.current_version,
3049            PREDICTION_CACHE_PAYLOAD_SCHEMA_VERSION
3050        );
3051        assert!(policy.automatic_migrations.is_empty());
3052
3053        let mut payload_set = BundlePredictionCachePayloadSet {
3054            bundle_id: BundleId::new("bundle:payload.schema").unwrap(),
3055            schema_version: PREDICTION_CACHE_PAYLOAD_SCHEMA_VERSION,
3056            caches: Vec::new(),
3057        };
3058        payload_set.validate().unwrap();
3059
3060        payload_set.schema_version = PREDICTION_CACHE_PAYLOAD_SCHEMA_VERSION + 1;
3061        assert!(payload_set.validate().is_err());
3062
3063        payload_set.schema_version = 0;
3064        assert!(payload_set.validate().is_err());
3065    }
3066
3067    #[test]
3068    fn replay_request_requires_predict_explain_or_refit_phase() {
3069        let bundle = build_execution_bundle(
3070            BundleId::new("bundle:demo").unwrap(),
3071            &plan(),
3072            None,
3073            BTreeMap::new(),
3074            Vec::new(),
3075        )
3076        .unwrap();
3077
3078        ReplayPhaseRequest {
3079            bundle_id: bundle.bundle_id.clone(),
3080            phase: Phase::Predict,
3081            data_envelope_keys: vec!["model:base.x".to_string()],
3082        }
3083        .validate_for_bundle(&bundle)
3084        .unwrap();
3085        ReplayPhaseRequest {
3086            bundle_id: bundle.bundle_id.clone(),
3087            phase: Phase::Refit,
3088            data_envelope_keys: vec!["model:base.x".to_string()],
3089        }
3090        .validate_for_bundle(&bundle)
3091        .unwrap();
3092        assert!(ReplayPhaseRequest {
3093            bundle_id: bundle.bundle_id.clone(),
3094            phase: Phase::FitCv,
3095            data_envelope_keys: vec!["model:base.x".to_string()],
3096        }
3097        .validate_for_bundle(&bundle)
3098        .is_err());
3099        assert!(ReplayPhaseRequest {
3100            bundle_id: bundle.bundle_id.clone(),
3101            phase: Phase::Predict,
3102            data_envelope_keys: vec!["model:base.x".to_string(), "model:base.x".to_string()],
3103        }
3104        .validate_for_bundle(&bundle)
3105        .is_err());
3106        assert!(ReplayPhaseRequest {
3107            bundle_id: bundle.bundle_id.clone(),
3108            phase: Phase::Predict,
3109            data_envelope_keys: vec!["model:base.y".to_string()],
3110        }
3111        .validate_for_bundle(&bundle)
3112        .is_err());
3113    }
3114}