1use std::collections::{BTreeMap, BTreeSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::aggregation::{AggregatedPredictionBlock, PredictionUnitId};
6use crate::campaign::stable_json_fingerprint;
7use crate::data::{
8 ExternalDataPlanEnvelope, RepresentationCompatibilityReport, RepresentationReplayManifest,
9};
10use crate::error::{DagMlError, Result};
11use crate::ids::{BundleId, ControllerId, FoldId, NodeId, SampleId, VariantId};
12use crate::oof::{PredictionBlock, PredictionPartition};
13use crate::phase::Phase;
14use crate::plan::ExecutionPlan;
15use crate::policy::PredictionLevel;
16use crate::runtime::ArtifactRef;
17use crate::selection::SelectionDecision;
18
19pub const EXECUTION_BUNDLE_SCHEMA_VERSION: u32 = 1;
20pub const PREDICTION_CACHE_PAYLOAD_SCHEMA_VERSION: u32 = 1;
21pub const BUNDLE_PREDICTION_CACHE_FORMAT: &str = "dag-ml-json-prediction-blocks-v1";
22
23pub const MIN_READABLE_EXECUTION_BUNDLE_SCHEMA_VERSION: u32 = 1;
24pub const MIN_WRITABLE_EXECUTION_BUNDLE_SCHEMA_VERSION: u32 = 1;
25pub const MIN_READABLE_PREDICTION_CACHE_PAYLOAD_SCHEMA_VERSION: u32 = 1;
26pub const MIN_WRITABLE_PREDICTION_CACHE_PAYLOAD_SCHEMA_VERSION: u32 = 1;
27
28fn default_execution_bundle_schema_version() -> u32 {
29 EXECUTION_BUNDLE_SCHEMA_VERSION
30}
31
32fn default_prediction_cache_payload_schema_version() -> u32 {
33 PREDICTION_CACHE_PAYLOAD_SCHEMA_VERSION
34}
35
36fn default_prediction_level() -> PredictionLevel {
37 PredictionLevel::Sample
38}
39
40#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
41pub struct SchemaMigrationPolicy {
42 pub artifact: String,
43 pub current_version: u32,
44 pub min_readable_version: u32,
45 pub min_writable_version: u32,
46 #[serde(default)]
47 pub automatic_migrations: BTreeMap<u32, u32>,
48}
49
50impl SchemaMigrationPolicy {
51 pub fn validate(&self) -> Result<()> {
52 validate_non_empty("schema migration artifact", &self.artifact)?;
53 if self.current_version == 0
54 || self.min_readable_version == 0
55 || self.min_writable_version == 0
56 {
57 return Err(DagMlError::RuntimeValidation(format!(
58 "schema migration policy `{}` has zero version boundary",
59 self.artifact
60 )));
61 }
62 if self.min_readable_version > self.current_version {
63 return Err(DagMlError::RuntimeValidation(format!(
64 "schema migration policy `{}` min_readable_version exceeds current_version",
65 self.artifact
66 )));
67 }
68 if self.min_writable_version > self.current_version {
69 return Err(DagMlError::RuntimeValidation(format!(
70 "schema migration policy `{}` min_writable_version exceeds current_version",
71 self.artifact
72 )));
73 }
74 for (from, to) in &self.automatic_migrations {
75 if *from == 0 || *to == 0 {
76 return Err(DagMlError::RuntimeValidation(format!(
77 "schema migration policy `{}` contains a zero migration version",
78 self.artifact
79 )));
80 }
81 if from == to {
82 return Err(DagMlError::RuntimeValidation(format!(
83 "schema migration policy `{}` contains a no-op migration {from}->{to}",
84 self.artifact
85 )));
86 }
87 if *to > self.current_version {
88 return Err(DagMlError::RuntimeValidation(format!(
89 "schema migration policy `{}` migrates to unsupported future version {to}",
90 self.artifact
91 )));
92 }
93 }
94 Ok(())
95 }
96
97 pub fn validate_read_version(&self, version: u32, owner: &str) -> Result<()> {
98 self.validate()?;
99 if version < self.min_readable_version {
100 return Err(DagMlError::RuntimeValidation(format!(
101 "{owner} uses schema_version {version}, below minimum readable {} for {}",
102 self.min_readable_version, self.artifact
103 )));
104 }
105 if version > self.current_version {
106 return Err(DagMlError::RuntimeValidation(format!(
107 "{owner} uses future schema_version {version}, current readable {} for {}",
108 self.current_version, self.artifact
109 )));
110 }
111 if version != self.current_version && !self.automatic_migrations.contains_key(&version) {
112 return Err(DagMlError::RuntimeValidation(format!(
113 "{owner} uses schema_version {version}, but {} declares no automatic migration to current version {}",
114 self.artifact, self.current_version
115 )));
116 }
117 Ok(())
118 }
119}
120
121pub fn execution_bundle_schema_migration_policy() -> SchemaMigrationPolicy {
122 SchemaMigrationPolicy {
123 artifact: "execution_bundle".to_string(),
124 current_version: EXECUTION_BUNDLE_SCHEMA_VERSION,
125 min_readable_version: MIN_READABLE_EXECUTION_BUNDLE_SCHEMA_VERSION,
126 min_writable_version: MIN_WRITABLE_EXECUTION_BUNDLE_SCHEMA_VERSION,
127 automatic_migrations: BTreeMap::new(),
128 }
129}
130
131pub fn prediction_cache_payload_schema_migration_policy() -> SchemaMigrationPolicy {
132 SchemaMigrationPolicy {
133 artifact: "prediction_cache_payload".to_string(),
134 current_version: PREDICTION_CACHE_PAYLOAD_SCHEMA_VERSION,
135 min_readable_version: MIN_READABLE_PREDICTION_CACHE_PAYLOAD_SCHEMA_VERSION,
136 min_writable_version: MIN_WRITABLE_PREDICTION_CACHE_PAYLOAD_SCHEMA_VERSION,
137 automatic_migrations: BTreeMap::new(),
138 }
139}
140
141#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
142pub struct BundleDataRequirement {
143 pub node_id: NodeId,
144 pub input_name: String,
145 pub schema_fingerprint: String,
146 pub plan_fingerprint: String,
147 #[serde(default)]
148 pub relation_fingerprint: Option<String>,
149 pub output_representation: String,
150 #[serde(default)]
151 pub feature_set_id: Option<String>,
152 #[serde(default, skip_serializing_if = "Option::is_none")]
153 pub representation_replay_manifest: Option<RepresentationReplayManifest>,
154 #[serde(default, skip_serializing_if = "Option::is_none")]
155 pub representation_compatibility: Option<RepresentationCompatibilityReport>,
156}
157
158impl BundleDataRequirement {
159 pub fn key(&self) -> String {
160 format!("{}.{}", self.node_id, self.input_name)
161 }
162
163 fn matches_plan_requirement(&self, expected: &Self) -> bool {
164 self.node_id == expected.node_id
165 && self.input_name == expected.input_name
166 && self.schema_fingerprint == expected.schema_fingerprint
167 && self.plan_fingerprint == expected.plan_fingerprint
168 && self.relation_fingerprint == expected.relation_fingerprint
169 && self.output_representation == expected.output_representation
170 && self.feature_set_id == expected.feature_set_id
171 }
172
173 pub fn validate(&self) -> Result<()> {
174 if self.input_name.trim().is_empty() {
175 return Err(DagMlError::CampaignValidation(format!(
176 "bundle data requirement for `{}` has empty input_name",
177 self.node_id
178 )));
179 }
180 validate_fingerprint("schema", &self.schema_fingerprint)?;
181 validate_fingerprint("plan", &self.plan_fingerprint)?;
182 if let Some(relation_fingerprint) = &self.relation_fingerprint {
183 validate_fingerprint("relation", relation_fingerprint)?;
184 }
185 if let Some(replay_manifest) = &self.representation_replay_manifest {
186 replay_manifest.validate()?;
187 if let (Some(requirement), Some(manifest)) = (
188 self.relation_fingerprint.as_deref(),
189 replay_manifest.relation_fingerprint.as_deref(),
190 ) {
191 if requirement != manifest {
192 return Err(DagMlError::CampaignValidation(format!(
193 "bundle data requirement `{}` relation_fingerprint does not match representation replay manifest",
194 self.key()
195 )));
196 }
197 }
198 }
199 if let Some(report) = &self.representation_compatibility {
200 report.validate()?;
201 }
202 if self.output_representation.trim().is_empty() {
203 return Err(DagMlError::CampaignValidation(format!(
204 "bundle data requirement `{}` has empty output representation",
205 self.key()
206 )));
207 }
208 if let Some(feature_set_id) = &self.feature_set_id {
209 if feature_set_id.trim().is_empty() {
210 return Err(DagMlError::CampaignValidation(format!(
211 "bundle data requirement `{}` has empty feature_set_id",
212 self.key()
213 )));
214 }
215 }
216 Ok(())
217 }
218}
219
220#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
221pub struct BundlePredictionRequirement {
222 pub producer_node: NodeId,
223 pub source_port: String,
224 pub consumer_node: NodeId,
225 pub target_port: String,
226 pub partition: PredictionPartition,
227 #[serde(default = "default_prediction_level")]
228 pub prediction_level: PredictionLevel,
229 #[serde(default)]
230 pub fold_ids: Vec<FoldId>,
231 #[serde(default, skip_serializing_if = "Vec::is_empty")]
232 pub unit_ids: Vec<PredictionUnitId>,
233 #[serde(default)]
234 pub sample_ids: Vec<SampleId>,
235 pub prediction_width: usize,
236 pub target_names: Vec<String>,
237}
238
239impl BundlePredictionRequirement {
240 pub fn key(&self) -> String {
241 bundle_prediction_requirement_key(
242 &self.producer_node,
243 &self.source_port,
244 &self.consumer_node,
245 &self.target_port,
246 )
247 }
248
249 pub fn validate(&self) -> Result<()> {
250 validate_non_empty("source_port", &self.source_port)?;
251 validate_non_empty("target_port", &self.target_port)?;
252 if self.partition != PredictionPartition::Validation {
253 return Err(DagMlError::RuntimeValidation(format!(
254 "bundle prediction requirement `{}` must use validation OOF predictions",
255 self.key()
256 )));
257 }
258 validate_unique_ids("fold id", &self.fold_ids)?;
259 validate_prediction_requirement_units(self)?;
260 if self.prediction_width == 0 {
261 return Err(DagMlError::RuntimeValidation(format!(
262 "bundle prediction requirement `{}` has zero prediction width",
263 self.key()
264 )));
265 }
266 if self.target_names.len() != self.prediction_width {
267 return Err(DagMlError::RuntimeValidation(format!(
268 "bundle prediction requirement `{}` target name count does not match prediction width",
269 self.key()
270 )));
271 }
272 for target_name in &self.target_names {
273 validate_non_empty("target_name", target_name)?;
274 }
275 Ok(())
276 }
277}
278
279pub fn bundle_prediction_requirement_key(
280 producer_node: &NodeId,
281 source_port: &str,
282 consumer_node: &NodeId,
283 target_port: &str,
284) -> String {
285 format!("{producer_node}.{source_port}->{consumer_node}.{target_port}")
286}
287
288#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
289pub struct BundlePredictionBlockCacheRecord {
290 #[serde(default)]
291 pub prediction_id: Option<String>,
292 #[serde(default)]
293 pub fold_id: Option<FoldId>,
294 #[serde(default = "default_prediction_level")]
295 pub prediction_level: PredictionLevel,
296 pub row_count: usize,
297 #[serde(default, skip_serializing_if = "Vec::is_empty")]
298 pub unit_ids: Vec<PredictionUnitId>,
299 #[serde(default)]
300 pub sample_ids: Vec<SampleId>,
301 pub content_fingerprint: String,
302}
303
304impl BundlePredictionBlockCacheRecord {
305 pub fn validate(&self) -> Result<()> {
306 if let Some(prediction_id) = &self.prediction_id {
307 validate_non_empty("prediction_id", prediction_id)?;
308 }
309 if self.row_count == 0 {
310 return Err(DagMlError::RuntimeValidation(
311 "prediction block cache record has zero rows".to_string(),
312 ));
313 }
314 validate_prediction_cache_block_record_units(self)?;
315 validate_fingerprint("prediction block cache content", &self.content_fingerprint)
316 }
317}
318
319#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
320pub struct BundlePredictionCacheRecord {
321 pub requirement_key: String,
322 pub cache_id: String,
323 pub format: String,
324 pub partition: PredictionPartition,
325 #[serde(default = "default_prediction_level")]
326 pub prediction_level: PredictionLevel,
327 #[serde(default)]
328 pub fold_ids: Vec<FoldId>,
329 #[serde(default, skip_serializing_if = "Vec::is_empty")]
330 pub unit_ids: Vec<PredictionUnitId>,
331 #[serde(default)]
332 pub sample_ids: Vec<SampleId>,
333 pub prediction_width: usize,
334 pub target_names: Vec<String>,
335 pub block_count: usize,
336 pub row_count: usize,
337 pub content_fingerprint: String,
338 #[serde(default)]
339 pub blocks: Vec<BundlePredictionBlockCacheRecord>,
340}
341
342impl BundlePredictionCacheRecord {
343 pub fn validate(&self) -> Result<()> {
344 validate_non_empty("requirement_key", &self.requirement_key)?;
345 validate_non_empty("cache_id", &self.cache_id)?;
346 validate_non_empty("format", &self.format)?;
347 if self.format != BUNDLE_PREDICTION_CACHE_FORMAT {
348 return Err(DagMlError::RuntimeValidation(format!(
349 "prediction cache `{}` uses unsupported format `{}`",
350 self.cache_id, self.format
351 )));
352 }
353 if self.partition != PredictionPartition::Validation {
354 return Err(DagMlError::RuntimeValidation(format!(
355 "prediction cache `{}` must cache validation OOF predictions",
356 self.cache_id
357 )));
358 }
359 validate_unique_ids("fold id", &self.fold_ids)?;
360 validate_prediction_cache_record_units(self)?;
361 if self.prediction_width == 0 {
362 return Err(DagMlError::RuntimeValidation(format!(
363 "prediction cache `{}` has zero prediction width",
364 self.cache_id
365 )));
366 }
367 if self.target_names.len() != self.prediction_width {
368 return Err(DagMlError::RuntimeValidation(format!(
369 "prediction cache `{}` target name count does not match prediction width",
370 self.cache_id
371 )));
372 }
373 for target_name in &self.target_names {
374 validate_non_empty("target_name", target_name)?;
375 }
376 if self.block_count == 0 || self.block_count != self.blocks.len() {
377 return Err(DagMlError::RuntimeValidation(format!(
378 "prediction cache `{}` block_count does not match block records",
379 self.cache_id
380 )));
381 }
382 validate_prediction_cache_record_blocks(self)?;
383 validate_fingerprint("prediction cache content", &self.content_fingerprint)?;
384 Ok(())
385 }
386}
387
388fn validate_prediction_requirement_units(requirement: &BundlePredictionRequirement) -> Result<()> {
389 match requirement.prediction_level {
390 PredictionLevel::Observation => Err(DagMlError::RuntimeValidation(format!(
391 "bundle prediction requirement `{}` cannot replay observation-level caches; aggregate to sample first",
392 requirement.key()
393 ))),
394 PredictionLevel::Sample => {
395 validate_unique_ids("sample id", &requirement.sample_ids)?;
396 if requirement.sample_ids.is_empty() {
397 return Err(DagMlError::RuntimeValidation(format!(
398 "bundle prediction requirement `{}` has no sample ids",
399 requirement.key()
400 )));
401 }
402 if !requirement.unit_ids.is_empty()
403 && requirement.unit_ids != sample_prediction_units(&requirement.sample_ids)
404 {
405 return Err(DagMlError::RuntimeValidation(format!(
406 "bundle prediction requirement `{}` sample ids do not match unit ids",
407 requirement.key()
408 )));
409 }
410 Ok(())
411 }
412 PredictionLevel::Target | PredictionLevel::Group => {
413 if !requirement.sample_ids.is_empty() {
414 return Err(DagMlError::RuntimeValidation(format!(
415 "bundle prediction requirement `{}` uses {:?} unit ids but also carries sample ids",
416 requirement.key(),
417 requirement.prediction_level
418 )));
419 }
420 validate_prediction_units(
421 "bundle prediction requirement unit",
422 requirement.prediction_level,
423 &requirement.unit_ids,
424 )?;
425 if requirement.unit_ids.is_empty() {
426 return Err(DagMlError::RuntimeValidation(format!(
427 "bundle prediction requirement `{}` has no unit ids",
428 requirement.key()
429 )));
430 }
431 Ok(())
432 }
433 }
434}
435
436fn validate_prediction_cache_block_record_units(
437 block: &BundlePredictionBlockCacheRecord,
438) -> Result<()> {
439 match block.prediction_level {
440 PredictionLevel::Observation => Err(DagMlError::RuntimeValidation(
441 "prediction block cache record cannot use observation-level predictions".to_string(),
442 )),
443 PredictionLevel::Sample => {
444 validate_unique_ids("sample id", &block.sample_ids)?;
445 if block.row_count != block.sample_ids.len() {
446 return Err(DagMlError::RuntimeValidation(format!(
447 "prediction block cache record row_count {} does not match {} sample ids",
448 block.row_count,
449 block.sample_ids.len()
450 )));
451 }
452 if !block.unit_ids.is_empty()
453 && block.unit_ids != sample_prediction_units(&block.sample_ids)
454 {
455 return Err(DagMlError::RuntimeValidation(
456 "prediction block cache record sample ids do not match unit ids".to_string(),
457 ));
458 }
459 Ok(())
460 }
461 PredictionLevel::Target | PredictionLevel::Group => {
462 if !block.sample_ids.is_empty() {
463 return Err(DagMlError::RuntimeValidation(format!(
464 "prediction block cache record uses {:?} unit ids but also carries sample ids",
465 block.prediction_level
466 )));
467 }
468 validate_prediction_units(
469 "prediction block cache record unit",
470 block.prediction_level,
471 &block.unit_ids,
472 )?;
473 if block.row_count != block.unit_ids.len() {
474 return Err(DagMlError::RuntimeValidation(format!(
475 "prediction block cache record row_count {} does not match {} unit ids",
476 block.row_count,
477 block.unit_ids.len()
478 )));
479 }
480 Ok(())
481 }
482 }
483}
484
485fn validate_prediction_cache_record_units(cache: &BundlePredictionCacheRecord) -> Result<()> {
486 match cache.prediction_level {
487 PredictionLevel::Observation => Err(DagMlError::RuntimeValidation(format!(
488 "prediction cache `{}` cannot use observation-level predictions",
489 cache.cache_id
490 ))),
491 PredictionLevel::Sample => {
492 validate_unique_ids("sample id", &cache.sample_ids)?;
493 if cache.row_count != cache.sample_ids.len() {
494 return Err(DagMlError::RuntimeValidation(format!(
495 "prediction cache `{}` row_count does not match unique sample ids",
496 cache.cache_id
497 )));
498 }
499 if !cache.unit_ids.is_empty()
500 && cache.unit_ids != sample_prediction_units(&cache.sample_ids)
501 {
502 return Err(DagMlError::RuntimeValidation(format!(
503 "prediction cache `{}` sample ids do not match unit ids",
504 cache.cache_id
505 )));
506 }
507 Ok(())
508 }
509 PredictionLevel::Target | PredictionLevel::Group => {
510 if !cache.sample_ids.is_empty() {
511 return Err(DagMlError::RuntimeValidation(format!(
512 "prediction cache `{}` uses {:?} unit ids but also carries sample ids",
513 cache.cache_id, cache.prediction_level
514 )));
515 }
516 validate_prediction_units(
517 "prediction cache unit",
518 cache.prediction_level,
519 &cache.unit_ids,
520 )?;
521 if cache.row_count != cache.unit_ids.len() {
522 return Err(DagMlError::RuntimeValidation(format!(
523 "prediction cache `{}` row_count does not match unique unit ids",
524 cache.cache_id
525 )));
526 }
527 Ok(())
528 }
529 }
530}
531
532fn validate_prediction_cache_record_blocks(cache: &BundlePredictionCacheRecord) -> Result<()> {
533 let mut row_count = 0usize;
534 let mut samples = BTreeSet::new();
535 let mut units = BTreeSet::new();
536 for block in &cache.blocks {
537 block.validate()?;
538 if block.prediction_level != cache.prediction_level {
539 return Err(DagMlError::RuntimeValidation(format!(
540 "prediction cache `{}` mixes block prediction levels",
541 cache.cache_id
542 )));
543 }
544 row_count += block.row_count;
545 match cache.prediction_level {
546 PredictionLevel::Sample => {
547 for sample_id in &block.sample_ids {
548 if !samples.insert(sample_id.clone()) {
549 return Err(DagMlError::RuntimeValidation(format!(
550 "prediction cache `{}` contains duplicate sample `{sample_id}`",
551 cache.cache_id
552 )));
553 }
554 }
555 }
556 PredictionLevel::Target | PredictionLevel::Group => {
557 for unit_id in &block.unit_ids {
558 if !units.insert(unit_id.clone()) {
559 return Err(DagMlError::RuntimeValidation(format!(
560 "prediction cache `{}` contains duplicate unit `{unit_id}`",
561 cache.cache_id
562 )));
563 }
564 }
565 }
566 PredictionLevel::Observation => {
567 unreachable!("record unit validation rejects observation")
568 }
569 }
570 }
571 if cache.row_count == 0 || cache.row_count != row_count {
572 return Err(DagMlError::RuntimeValidation(format!(
573 "prediction cache `{}` row_count does not match block records",
574 cache.cache_id
575 )));
576 }
577 if cache.prediction_level == PredictionLevel::Sample {
578 let expected = cache.sample_ids.iter().cloned().collect::<BTreeSet<_>>();
579 if samples != expected {
580 return Err(DagMlError::RuntimeValidation(format!(
581 "prediction cache `{}` block samples do not match cache sample ids",
582 cache.cache_id
583 )));
584 }
585 } else {
586 let expected = cache.unit_ids.iter().cloned().collect::<BTreeSet<_>>();
587 if units != expected {
588 return Err(DagMlError::RuntimeValidation(format!(
589 "prediction cache `{}` block units do not match cache unit ids",
590 cache.cache_id
591 )));
592 }
593 }
594 Ok(())
595}
596
597fn validate_prediction_cache_payload_blocks(
598 payload: &BundlePredictionCachePayload,
599) -> Result<usize> {
600 match payload.prediction_level {
601 PredictionLevel::Observation => Err(DagMlError::RuntimeValidation(format!(
602 "prediction cache payload `{}` cannot use observation-level predictions",
603 payload.cache_id
604 ))),
605 PredictionLevel::Sample => validate_sample_prediction_cache_payload_blocks(payload),
606 PredictionLevel::Target | PredictionLevel::Group => {
607 validate_aggregated_prediction_cache_payload_blocks(payload)
608 }
609 }
610}
611
612fn validate_sample_prediction_cache_payload_blocks(
613 payload: &BundlePredictionCachePayload,
614) -> Result<usize> {
615 let mut row_count = 0usize;
616 let mut sample_ids = BTreeSet::new();
617 for block in &payload.blocks {
618 block.validate_shape()?;
619 if block.partition != payload.partition {
620 return Err(DagMlError::RuntimeValidation(format!(
621 "prediction cache payload `{}` contains a block from partition {:?}",
622 payload.cache_id, block.partition
623 )));
624 }
625 for sample_id in &block.sample_ids {
626 if !sample_ids.insert(sample_id) {
627 return Err(DagMlError::RuntimeValidation(format!(
628 "prediction cache payload `{}` contains duplicate sample `{}`",
629 payload.cache_id, sample_id
630 )));
631 }
632 }
633 row_count += block.sample_ids.len();
634 }
635 Ok(row_count)
636}
637
638fn validate_aggregated_prediction_cache_payload_blocks(
639 payload: &BundlePredictionCachePayload,
640) -> Result<usize> {
641 let mut row_count = 0usize;
642 let mut unit_ids = BTreeSet::new();
643 for block in &payload.aggregated_blocks {
644 block.validate_shape()?;
645 if block.partition != payload.partition {
646 return Err(DagMlError::RuntimeValidation(format!(
647 "prediction cache payload `{}` contains an aggregated block from partition {:?}",
648 payload.cache_id, block.partition
649 )));
650 }
651 if block.level != payload.prediction_level {
652 return Err(DagMlError::RuntimeValidation(format!(
653 "prediction cache payload `{}` contains {:?} block inside {:?} payload",
654 payload.cache_id, block.level, payload.prediction_level
655 )));
656 }
657 for unit_id in &block.unit_ids {
658 if !unit_ids.insert(unit_id) {
659 return Err(DagMlError::RuntimeValidation(format!(
660 "prediction cache payload `{}` contains duplicate unit `{unit_id}`",
661 payload.cache_id
662 )));
663 }
664 }
665 row_count += block.unit_ids.len();
666 }
667 Ok(row_count)
668}
669
670#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
671pub struct BundlePredictionCachePayload {
672 pub requirement_key: String,
673 pub cache_id: String,
674 pub format: String,
675 pub partition: PredictionPartition,
676 #[serde(default = "default_prediction_level")]
677 pub prediction_level: PredictionLevel,
678 pub block_count: usize,
679 pub row_count: usize,
680 pub content_fingerprint: String,
681 #[serde(default)]
682 pub blocks: Vec<PredictionBlock>,
683 #[serde(default, skip_serializing_if = "Vec::is_empty")]
684 pub aggregated_blocks: Vec<AggregatedPredictionBlock>,
685}
686
687impl BundlePredictionCachePayload {
688 pub fn validate(&self) -> Result<()> {
689 validate_non_empty("requirement_key", &self.requirement_key)?;
690 validate_non_empty("cache_id", &self.cache_id)?;
691 validate_non_empty("format", &self.format)?;
692 if self.format != BUNDLE_PREDICTION_CACHE_FORMAT {
693 return Err(DagMlError::RuntimeValidation(format!(
694 "prediction cache payload `{}` uses unsupported format `{}`",
695 self.cache_id, self.format
696 )));
697 }
698 if self.partition != PredictionPartition::Validation {
699 return Err(DagMlError::RuntimeValidation(format!(
700 "prediction cache payload `{}` must cache validation OOF predictions",
701 self.cache_id
702 )));
703 }
704 let expected_block_count = if self.prediction_level == PredictionLevel::Sample {
705 if !self.aggregated_blocks.is_empty() {
706 return Err(DagMlError::RuntimeValidation(format!(
707 "prediction cache payload `{}` mixes sample and aggregated blocks",
708 self.cache_id
709 )));
710 }
711 self.blocks.len()
712 } else {
713 if !self.blocks.is_empty() {
714 return Err(DagMlError::RuntimeValidation(format!(
715 "prediction cache payload `{}` mixes aggregated and sample blocks",
716 self.cache_id
717 )));
718 }
719 self.aggregated_blocks.len()
720 };
721 if self.block_count == 0 || self.block_count != expected_block_count {
722 return Err(DagMlError::RuntimeValidation(format!(
723 "prediction cache payload `{}` block_count does not match blocks",
724 self.cache_id
725 )));
726 }
727 let row_count = validate_prediction_cache_payload_blocks(self)?;
728 if self.row_count == 0 || self.row_count != row_count {
729 return Err(DagMlError::RuntimeValidation(format!(
730 "prediction cache payload `{}` row_count does not match blocks",
731 self.cache_id
732 )));
733 }
734 validate_fingerprint(
735 "prediction cache payload content",
736 &self.content_fingerprint,
737 )?;
738 let actual_fingerprint = if self.prediction_level == PredictionLevel::Sample {
739 stable_json_fingerprint(&self.blocks)?
740 } else {
741 stable_json_fingerprint(&self.aggregated_blocks)?
742 };
743 if actual_fingerprint != self.content_fingerprint {
744 return Err(DagMlError::RuntimeValidation(format!(
745 "prediction cache payload `{}` content fingerprint does not match blocks",
746 self.cache_id
747 )));
748 }
749 Ok(())
750 }
751}
752
753#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
754pub struct BundlePredictionCachePayloadSet {
755 pub bundle_id: BundleId,
756 #[serde(default = "default_prediction_cache_payload_schema_version")]
757 pub schema_version: u32,
758 #[serde(default)]
759 pub caches: Vec<BundlePredictionCachePayload>,
760}
761
762impl BundlePredictionCachePayloadSet {
763 pub fn validate(&self) -> Result<()> {
764 prediction_cache_payload_schema_migration_policy().validate_read_version(
765 self.schema_version,
766 &format!(
767 "prediction cache payload set for bundle `{}`",
768 self.bundle_id
769 ),
770 )?;
771 let mut requirement_keys = BTreeSet::new();
772 let mut cache_ids = BTreeSet::new();
773 for payload in &self.caches {
774 payload.validate()?;
775 if !requirement_keys.insert(payload.requirement_key.as_str()) {
776 return Err(DagMlError::RuntimeValidation(format!(
777 "prediction cache payload set for bundle `{}` has duplicate requirement `{}`",
778 self.bundle_id, payload.requirement_key
779 )));
780 }
781 if !cache_ids.insert(payload.cache_id.as_str()) {
782 return Err(DagMlError::RuntimeValidation(format!(
783 "prediction cache payload set for bundle `{}` has duplicate cache id `{}`",
784 self.bundle_id, payload.cache_id
785 )));
786 }
787 }
788 Ok(())
789 }
790
791 pub fn validate_against_bundle(&self, bundle: &ExecutionBundle) -> Result<()> {
792 self.validate()?;
793 bundle.validate()?;
794 if self.bundle_id != bundle.bundle_id {
795 return Err(DagMlError::RuntimeValidation(format!(
796 "prediction cache payload set bundle `{}` does not match bundle `{}`",
797 self.bundle_id, bundle.bundle_id
798 )));
799 }
800 if self.caches.len() != bundle.prediction_caches.len() {
801 return Err(DagMlError::RuntimeValidation(format!(
802 "prediction cache payload set for bundle `{}` has {} payload(s) for {} cache record(s)",
803 self.bundle_id,
804 self.caches.len(),
805 bundle.prediction_caches.len()
806 )));
807 }
808 let records_by_requirement = bundle
809 .prediction_caches
810 .iter()
811 .map(|record| (record.requirement_key.as_str(), record))
812 .collect::<BTreeMap<_, _>>();
813 let payloads_by_requirement = self
814 .caches
815 .iter()
816 .map(|payload| (payload.requirement_key.as_str(), payload))
817 .collect::<BTreeMap<_, _>>();
818 for (requirement_key, record) in records_by_requirement {
819 let payload = payloads_by_requirement
820 .get(requirement_key)
821 .ok_or_else(|| {
822 DagMlError::RuntimeValidation(format!(
823 "prediction cache payload set for bundle `{}` is missing requirement `{}`",
824 self.bundle_id, requirement_key
825 ))
826 })?;
827 validate_prediction_cache_payload_matches_record(payload, record)?;
828 }
829 for requirement_key in payloads_by_requirement.keys() {
830 if !bundle
831 .prediction_caches
832 .iter()
833 .any(|record| record.requirement_key.as_str() == *requirement_key)
834 {
835 return Err(DagMlError::RuntimeValidation(format!(
836 "prediction cache payload set for bundle `{}` contains unknown requirement `{}`",
837 self.bundle_id, requirement_key
838 )));
839 }
840 }
841 Ok(())
842 }
843}
844
845#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
846pub struct RefitArtifactRecord {
847 pub node_id: NodeId,
848 pub controller_id: ControllerId,
849 pub artifact: ArtifactRef,
850 pub params_fingerprint: String,
851 #[serde(default)]
852 pub data_requirement_keys: Vec<String>,
853 #[serde(default)]
854 pub prediction_requirement_keys: Vec<String>,
855}
856
857impl RefitArtifactRecord {
858 pub fn validate(&self) -> Result<()> {
859 self.artifact.validate()?;
860 if self.artifact.id.as_str().is_empty() {
861 return Err(DagMlError::RuntimeValidation(format!(
862 "refit artifact for `{}` has empty artifact id",
863 self.node_id
864 )));
865 }
866 if self.artifact.kind.trim().is_empty() {
867 return Err(DagMlError::RuntimeValidation(format!(
868 "refit artifact `{}` has empty artifact kind",
869 self.artifact.id
870 )));
871 }
872 if self.artifact.controller_id != self.controller_id {
873 return Err(DagMlError::RuntimeValidation(format!(
874 "refit artifact `{}` controller `{}` does not match record controller `{}`",
875 self.artifact.id, self.artifact.controller_id, self.controller_id
876 )));
877 }
878 validate_fingerprint("params", &self.params_fingerprint)?;
879 let mut seen_keys = BTreeSet::new();
880 for key in &self.data_requirement_keys {
881 if key.trim().is_empty() {
882 return Err(DagMlError::RuntimeValidation(format!(
883 "refit artifact `{}` has empty data requirement key",
884 self.artifact.id
885 )));
886 }
887 if !seen_keys.insert(key.as_str()) {
888 return Err(DagMlError::RuntimeValidation(format!(
889 "refit artifact `{}` has duplicate data requirement key `{key}`",
890 self.artifact.id
891 )));
892 }
893 }
894 let mut seen_prediction_keys = BTreeSet::new();
895 for key in &self.prediction_requirement_keys {
896 if key.trim().is_empty() {
897 return Err(DagMlError::RuntimeValidation(format!(
898 "refit artifact `{}` has empty prediction requirement key",
899 self.artifact.id
900 )));
901 }
902 if !seen_prediction_keys.insert(key.as_str()) {
903 return Err(DagMlError::RuntimeValidation(format!(
904 "refit artifact `{}` has duplicate prediction requirement key `{key}`",
905 self.artifact.id
906 )));
907 }
908 }
909 Ok(())
910 }
911}
912
913#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
914pub struct ExecutionBundle {
915 pub bundle_id: BundleId,
916 #[serde(default = "default_execution_bundle_schema_version")]
917 pub schema_version: u32,
918 pub plan_id: String,
919 pub graph_fingerprint: String,
920 pub campaign_fingerprint: String,
921 pub controller_fingerprint: String,
922 #[serde(default)]
923 pub selected_variant_id: Option<VariantId>,
924 #[serde(default)]
925 pub selections: BTreeMap<String, SelectionDecision>,
926 #[serde(default)]
927 pub refit_artifacts: Vec<RefitArtifactRecord>,
928 #[serde(default)]
929 pub prediction_requirements: Vec<BundlePredictionRequirement>,
930 #[serde(default)]
931 pub prediction_caches: Vec<BundlePredictionCacheRecord>,
932 #[serde(default)]
933 pub data_requirements: Vec<BundleDataRequirement>,
934 #[serde(default)]
935 pub unsafe_flags: BTreeSet<String>,
936 #[serde(default)]
937 pub metadata: BTreeMap<String, serde_json::Value>,
938}
939
940impl ExecutionBundle {
941 pub fn validate(&self) -> Result<()> {
942 execution_bundle_schema_migration_policy()
943 .validate_read_version(self.schema_version, &format!("bundle `{}`", self.bundle_id))?;
944 if self.plan_id.trim().is_empty() {
945 return Err(DagMlError::RuntimeValidation(format!(
946 "bundle `{}` has empty plan_id",
947 self.bundle_id
948 )));
949 }
950 validate_fingerprint("graph", &self.graph_fingerprint)?;
951 validate_fingerprint("campaign", &self.campaign_fingerprint)?;
952 validate_fingerprint("controller", &self.controller_fingerprint)?;
953 for (key, decision) in &self.selections {
954 if key.trim().is_empty() {
955 return Err(DagMlError::RuntimeValidation(format!(
956 "bundle `{}` contains empty selection key",
957 self.bundle_id
958 )));
959 }
960 decision.validate()?;
961 }
962 let mut data_keys = BTreeMap::new();
963 for requirement in &self.data_requirements {
964 requirement.validate()?;
965 let key = requirement.key();
966 if data_keys.insert(key.clone(), requirement).is_some() {
967 return Err(DagMlError::RuntimeValidation(format!(
968 "bundle `{}` has duplicate data requirement `{}`",
969 self.bundle_id, key
970 )));
971 }
972 }
973 let mut prediction_keys = BTreeMap::new();
974 for requirement in &self.prediction_requirements {
975 requirement.validate()?;
976 let key = requirement.key();
977 if prediction_keys.insert(key.clone(), requirement).is_some() {
978 return Err(DagMlError::RuntimeValidation(format!(
979 "bundle `{}` has duplicate prediction requirement `{}`",
980 self.bundle_id, key
981 )));
982 }
983 }
984 let mut prediction_cache_keys = BTreeMap::new();
985 for cache in &self.prediction_caches {
986 cache.validate()?;
987 let requirement = prediction_keys.get(&cache.requirement_key).ok_or_else(|| {
988 DagMlError::RuntimeValidation(format!(
989 "prediction cache `{}` references unknown prediction requirement `{}`",
990 cache.cache_id, cache.requirement_key
991 ))
992 })?;
993 validate_prediction_cache_matches_requirement(cache, requirement)?;
994 if prediction_cache_keys
995 .insert(cache.requirement_key.clone(), cache)
996 .is_some()
997 {
998 return Err(DagMlError::RuntimeValidation(format!(
999 "bundle `{}` has duplicate prediction cache for requirement `{}`",
1000 self.bundle_id, cache.requirement_key
1001 )));
1002 }
1003 }
1004 for artifact in &self.refit_artifacts {
1005 artifact.validate()?;
1006 for key in &artifact.data_requirement_keys {
1007 match data_keys.get(key) {
1008 Some(requirement) if requirement.node_id == artifact.node_id => {}
1009 Some(requirement) => {
1010 return Err(DagMlError::RuntimeValidation(format!(
1011 "refit artifact `{}` for `{}` references data requirement `{key}` owned by `{}`",
1012 artifact.artifact.id, artifact.node_id, requirement.node_id
1013 )));
1014 }
1015 None => {
1016 return Err(DagMlError::RuntimeValidation(format!(
1017 "refit artifact `{}` references unknown data requirement `{key}`",
1018 artifact.artifact.id
1019 )));
1020 }
1021 }
1022 }
1023 for key in &artifact.prediction_requirement_keys {
1024 match prediction_keys.get(key) {
1025 Some(requirement) if requirement.consumer_node == artifact.node_id => {}
1026 Some(requirement) => {
1027 return Err(DagMlError::RuntimeValidation(format!(
1028 "refit artifact `{}` for `{}` references prediction requirement `{key}` consumed by `{}`",
1029 artifact.artifact.id, artifact.node_id, requirement.consumer_node
1030 )));
1031 }
1032 None => {
1033 return Err(DagMlError::RuntimeValidation(format!(
1034 "refit artifact `{}` references unknown prediction requirement `{key}`",
1035 artifact.artifact.id
1036 )));
1037 }
1038 }
1039 if !prediction_cache_keys.contains_key(key) {
1040 return Err(DagMlError::RuntimeValidation(format!(
1041 "refit artifact `{}` references prediction requirement `{key}` without a prediction cache record",
1042 artifact.artifact.id
1043 )));
1044 }
1045 }
1046 }
1047 for unsafe_flag in &self.unsafe_flags {
1048 if unsafe_flag.trim().is_empty() {
1049 return Err(DagMlError::RuntimeValidation(format!(
1050 "bundle `{}` contains an empty unsafe flag",
1051 self.bundle_id
1052 )));
1053 }
1054 }
1055 Ok(())
1056 }
1057
1058 pub fn validate_against_plan(&self, plan: &ExecutionPlan) -> Result<()> {
1059 self.validate()?;
1060 plan.validate()?;
1061 if self.plan_id != plan.id {
1062 return Err(DagMlError::RuntimeValidation(format!(
1063 "bundle `{}` plan_id `{}` does not match plan `{}`",
1064 self.bundle_id, self.plan_id, plan.id
1065 )));
1066 }
1067 if self.graph_fingerprint != plan.graph_fingerprint
1068 || self.campaign_fingerprint != plan.campaign_fingerprint
1069 || self.controller_fingerprint != plan.controller_fingerprint
1070 {
1071 return Err(DagMlError::RuntimeValidation(format!(
1072 "bundle `{}` fingerprints do not match execution plan",
1073 self.bundle_id
1074 )));
1075 }
1076 let selected_variant = match &self.selected_variant_id {
1077 Some(selected_variant_id) => Some(
1078 plan.variants
1079 .iter()
1080 .find(|variant| &variant.variant_id == selected_variant_id)
1081 .ok_or_else(|| {
1082 DagMlError::RuntimeValidation(format!(
1083 "bundle `{}` selected unknown variant `{selected_variant_id}`",
1084 self.bundle_id
1085 ))
1086 })?,
1087 ),
1088 None => None,
1089 };
1090 self.validate_selections_against_plan(plan)?;
1091 let expected_requirements = collect_data_requirements(plan)?;
1092 let expected_by_key = expected_requirements
1093 .iter()
1094 .map(|requirement| (requirement.key(), requirement))
1095 .collect::<BTreeMap<_, _>>();
1096 if self.data_requirements.len() != expected_by_key.len() {
1097 return Err(DagMlError::RuntimeValidation(format!(
1098 "bundle `{}` data requirement count does not match execution plan",
1099 self.bundle_id
1100 )));
1101 }
1102 for requirement in &self.data_requirements {
1103 let key = requirement.key();
1104 let expected = expected_by_key.get(&key).ok_or_else(|| {
1105 DagMlError::RuntimeValidation(format!(
1106 "bundle `{}` data requirement `{key}` does not exist in execution plan",
1107 self.bundle_id
1108 ))
1109 })?;
1110 if !requirement.matches_plan_requirement(expected) {
1111 return Err(DagMlError::RuntimeValidation(format!(
1112 "bundle `{}` data requirement `{key}` does not match execution plan",
1113 self.bundle_id
1114 )));
1115 }
1116 }
1117 for artifact in &self.refit_artifacts {
1118 let node_plan = plan.node_plans.get(&artifact.node_id).ok_or_else(|| {
1119 DagMlError::RuntimeValidation(format!(
1120 "bundle `{}` artifact references unknown node `{}`",
1121 self.bundle_id, artifact.node_id
1122 ))
1123 })?;
1124 if artifact.controller_id != node_plan.controller_id {
1125 return Err(DagMlError::RuntimeValidation(format!(
1126 "bundle `{}` artifact controller for `{}` does not match plan",
1127 self.bundle_id, artifact.node_id
1128 )));
1129 }
1130 let expected_params_fingerprint =
1131 expected_refit_artifact_params_fingerprint(node_plan, selected_variant)?;
1132 if artifact.params_fingerprint != expected_params_fingerprint {
1133 return Err(DagMlError::RuntimeValidation(format!(
1134 "bundle `{}` artifact params for `{}` do not match plan",
1135 self.bundle_id, artifact.node_id
1136 )));
1137 }
1138 }
1139 for requirement in &self.prediction_requirements {
1140 let edge = plan
1141 .graph_plan
1142 .graph
1143 .edges
1144 .iter()
1145 .find(|edge| {
1146 edge.source.node_id == requirement.producer_node
1147 && edge.source.port_name == requirement.source_port
1148 && edge.target.node_id == requirement.consumer_node
1149 && edge.target.port_name == requirement.target_port
1150 && edge.contract.requires_oof
1151 })
1152 .ok_or_else(|| {
1153 DagMlError::RuntimeValidation(format!(
1154 "bundle `{}` prediction requirement `{}` does not match an OOF edge in the plan",
1155 self.bundle_id,
1156 requirement.key()
1157 ))
1158 })?;
1159 let cache = self
1160 .prediction_caches
1161 .iter()
1162 .find(|cache| cache.requirement_key == requirement.key());
1163 validate_prediction_requirement_against_plan(self, plan, edge, requirement, cache)?;
1164 }
1165 Ok(())
1166 }
1167
1168 fn validate_selections_against_plan(&self, plan: &ExecutionPlan) -> Result<()> {
1169 if self.selections.is_empty() {
1170 return Ok(());
1171 }
1172 let artifact_node_ids = self
1173 .refit_artifacts
1174 .iter()
1175 .map(|artifact| artifact.node_id.clone())
1176 .collect::<BTreeSet<_>>();
1177 let required_metric_level = plan.campaign.aggregation_policy.selection_metric_level;
1178 for (selection_key, decision) in &self.selections {
1179 match decision.metric_level {
1180 Some(metric_level) if metric_level == required_metric_level => {}
1181 Some(metric_level) => {
1182 return Err(DagMlError::RuntimeValidation(format!(
1183 "bundle `{}` selection `{selection_key}` metric_level {:?} does not match campaign selection_metric_level {:?}",
1184 self.bundle_id, metric_level, required_metric_level
1185 )));
1186 }
1187 None => {
1188 return Err(DagMlError::RuntimeValidation(format!(
1189 "bundle `{}` selection `{selection_key}` is missing metric_level for campaign selection_metric_level {:?}",
1190 self.bundle_id, required_metric_level
1191 )));
1192 }
1193 }
1194 let selected_candidate_id = decision.selected_candidate_id.as_str();
1195 if let Ok(selected_node_id) = NodeId::new(selected_candidate_id) {
1196 if let Some(node_plan) = plan.node_plans.get(&selected_node_id) {
1197 if node_plan.supported_phases.contains(&Phase::Refit)
1198 && !artifact_node_ids.contains(&node_plan.node_id)
1199 {
1200 return Err(DagMlError::RuntimeValidation(format!(
1201 "bundle `{}` selection `{selection_key}` chose refittable node `{}` without a matching refit artifact",
1202 self.bundle_id, node_plan.node_id
1203 )));
1204 }
1205 continue;
1206 }
1207 }
1208 if VariantId::new(selected_candidate_id).is_ok()
1209 && plan
1210 .variants
1211 .iter()
1212 .any(|variant| variant.variant_id.as_str() == selected_candidate_id)
1213 {
1214 continue;
1215 }
1216 return Err(DagMlError::RuntimeValidation(format!(
1217 "bundle `{}` selection `{selection_key}` chose unknown candidate `{selected_candidate_id}` for plan `{}`",
1218 self.bundle_id, plan.id
1219 )));
1220 }
1221 Ok(())
1222 }
1223
1224 pub fn validate_replay_envelopes(
1225 &self,
1226 envelopes: &BTreeMap<String, ExternalDataPlanEnvelope>,
1227 ) -> Result<()> {
1228 self.validate()?;
1229 for requirement in &self.data_requirements {
1230 let key = requirement.key();
1231 let envelope = envelopes.get(&key).ok_or_else(|| {
1232 DagMlError::RuntimeValidation(format!(
1233 "replay is missing external data envelope for `{key}`"
1234 ))
1235 })?;
1236 envelope.validate()?;
1237 if requirement.schema_fingerprint != envelope.schema_fingerprint
1238 || requirement.plan_fingerprint != envelope.plan_fingerprint
1239 || requirement.relation_fingerprint != envelope.relation_fingerprint
1240 {
1241 return Err(DagMlError::RuntimeValidation(format!(
1242 "replay envelope for `{key}` does not match bundle data requirement"
1243 )));
1244 }
1245 }
1246 Ok(())
1247 }
1248}
1249
1250fn expected_refit_artifact_params_fingerprint(
1251 node_plan: &crate::plan::NodePlan,
1252 selected_variant: Option<&crate::generation::VariantPlan>,
1253) -> Result<String> {
1254 let Some(variant) = selected_variant else {
1255 return Ok(node_plan.params_fingerprint.clone());
1256 };
1257 let effective_params =
1258 variant.effective_params_for_node(&node_plan.node_id, &node_plan.params)?;
1259 stable_json_fingerprint(&effective_params)
1260}
1261
1262fn validate_prediction_requirement_against_plan(
1263 bundle: &ExecutionBundle,
1264 plan: &ExecutionPlan,
1265 edge: &crate::graph::EdgeSpec,
1266 requirement: &BundlePredictionRequirement,
1267 cache: Option<&BundlePredictionCacheRecord>,
1268) -> Result<()> {
1269 if !edge.contract.requires_fold_alignment {
1270 return Ok(());
1271 }
1272 let fold_set = plan.fold_set.as_ref().ok_or_else(|| {
1273 DagMlError::RuntimeValidation(format!(
1274 "bundle `{}` prediction requirement `{}` needs fold alignment but plan `{}` has no fold set",
1275 bundle.bundle_id,
1276 requirement.key(),
1277 plan.id
1278 ))
1279 })?;
1280 let expected_fold_ids = fold_set
1281 .folds
1282 .iter()
1283 .map(|fold| fold.fold_id.clone())
1284 .collect::<BTreeSet<_>>();
1285 let requirement_fold_ids = requirement
1286 .fold_ids
1287 .iter()
1288 .cloned()
1289 .collect::<BTreeSet<_>>();
1290 if requirement_fold_ids != expected_fold_ids {
1291 return Err(DagMlError::RuntimeValidation(format!(
1292 "bundle `{}` prediction requirement `{}` fold ids do not match plan fold set",
1293 bundle.bundle_id,
1294 requirement.key()
1295 )));
1296 }
1297 if requirement.prediction_level != PredictionLevel::Sample {
1298 if let Some(cache) = cache {
1299 validate_aggregated_prediction_cache_blocks_match_requirement(
1300 bundle,
1301 requirement,
1302 cache,
1303 )?;
1304 }
1305 return Ok(());
1306 }
1307 let expected_sample_ids = fold_set.sample_ids.iter().cloned().collect::<BTreeSet<_>>();
1308 let requirement_sample_ids = requirement
1309 .sample_ids
1310 .iter()
1311 .cloned()
1312 .collect::<BTreeSet<_>>();
1313 if requirement_sample_ids != expected_sample_ids {
1314 return Err(DagMlError::RuntimeValidation(format!(
1315 "bundle `{}` prediction requirement `{}` sample ids do not match plan fold set",
1316 bundle.bundle_id,
1317 requirement.key()
1318 )));
1319 }
1320 if let Some(cache) = cache {
1321 validate_prediction_cache_blocks_match_fold_set(bundle, requirement, cache, fold_set)?;
1322 }
1323 Ok(())
1324}
1325
1326fn validate_prediction_cache_blocks_match_fold_set(
1327 bundle: &ExecutionBundle,
1328 requirement: &BundlePredictionRequirement,
1329 cache: &BundlePredictionCacheRecord,
1330 fold_set: &crate::fold::FoldSet,
1331) -> Result<()> {
1332 let folds = fold_set
1333 .folds
1334 .iter()
1335 .map(|fold| (&fold.fold_id, fold))
1336 .collect::<BTreeMap<_, _>>();
1337 let expected_fold_ids = fold_set
1338 .folds
1339 .iter()
1340 .map(|fold| fold.fold_id.clone())
1341 .collect::<BTreeSet<_>>();
1342 let mut covered_fold_ids = BTreeSet::new();
1343 let mut covered_sample_ids = BTreeSet::new();
1344 for block in &cache.blocks {
1345 let fold_id = block.fold_id.as_ref().ok_or_else(|| {
1346 DagMlError::RuntimeValidation(format!(
1347 "bundle `{}` prediction cache `{}` has an OOF block without a fold id",
1348 bundle.bundle_id, cache.cache_id
1349 ))
1350 })?;
1351 covered_fold_ids.insert(fold_id.clone());
1352 let fold = folds.get(fold_id).ok_or_else(|| {
1353 DagMlError::RuntimeValidation(format!(
1354 "bundle `{}` prediction cache `{}` references unknown fold `{fold_id}`",
1355 bundle.bundle_id, cache.cache_id
1356 ))
1357 })?;
1358 let block_samples = block.sample_ids.iter().cloned().collect::<BTreeSet<_>>();
1359 let expected_samples = fold
1360 .validation_sample_ids
1361 .iter()
1362 .cloned()
1363 .collect::<BTreeSet<_>>();
1364 if block_samples != expected_samples {
1365 return Err(DagMlError::RuntimeValidation(format!(
1366 "bundle `{}` prediction cache `{}` block for fold `{fold_id}` does not match validation samples for requirement `{}`",
1367 bundle.bundle_id,
1368 cache.cache_id,
1369 requirement.key()
1370 )));
1371 }
1372 for sample_id in block_samples {
1373 if !covered_sample_ids.insert(sample_id.clone()) {
1374 return Err(DagMlError::RuntimeValidation(format!(
1375 "bundle `{}` prediction cache `{}` has duplicate OOF sample `{sample_id}`",
1376 bundle.bundle_id, cache.cache_id
1377 )));
1378 }
1379 }
1380 }
1381 if covered_fold_ids != expected_fold_ids {
1382 return Err(DagMlError::RuntimeValidation(format!(
1383 "bundle `{}` prediction cache `{}` does not cover all folds for requirement `{}`",
1384 bundle.bundle_id,
1385 cache.cache_id,
1386 requirement.key()
1387 )));
1388 }
1389 let expected_sample_ids = fold_set.sample_ids.iter().cloned().collect::<BTreeSet<_>>();
1390 if covered_sample_ids != expected_sample_ids {
1391 return Err(DagMlError::RuntimeValidation(format!(
1392 "bundle `{}` prediction cache `{}` does not cover the full OOF sample universe for requirement `{}`",
1393 bundle.bundle_id,
1394 cache.cache_id,
1395 requirement.key()
1396 )));
1397 }
1398 Ok(())
1399}
1400
1401fn validate_aggregated_prediction_cache_blocks_match_requirement(
1402 bundle: &ExecutionBundle,
1403 requirement: &BundlePredictionRequirement,
1404 cache: &BundlePredictionCacheRecord,
1405) -> Result<()> {
1406 let mut covered_fold_ids = BTreeSet::new();
1407 let mut covered_unit_ids = BTreeSet::new();
1408 for block in &cache.blocks {
1409 if block.prediction_level != requirement.prediction_level {
1410 return Err(DagMlError::RuntimeValidation(format!(
1411 "bundle `{}` prediction cache `{}` block level does not match requirement `{}`",
1412 bundle.bundle_id,
1413 cache.cache_id,
1414 requirement.key()
1415 )));
1416 }
1417 if let Some(fold_id) = &block.fold_id {
1418 covered_fold_ids.insert(fold_id.clone());
1419 }
1420 for unit_id in &block.unit_ids {
1421 if !covered_unit_ids.insert(unit_id.clone()) {
1422 return Err(DagMlError::RuntimeValidation(format!(
1423 "bundle `{}` prediction cache `{}` has duplicate aggregated unit `{unit_id}`",
1424 bundle.bundle_id, cache.cache_id
1425 )));
1426 }
1427 }
1428 }
1429 let expected_fold_ids = requirement
1430 .fold_ids
1431 .iter()
1432 .cloned()
1433 .collect::<BTreeSet<_>>();
1434 if covered_fold_ids != expected_fold_ids {
1435 return Err(DagMlError::RuntimeValidation(format!(
1436 "bundle `{}` prediction cache `{}` does not cover all folds for aggregated requirement `{}`",
1437 bundle.bundle_id,
1438 cache.cache_id,
1439 requirement.key()
1440 )));
1441 }
1442 let expected_unit_ids = requirement
1443 .unit_ids
1444 .iter()
1445 .cloned()
1446 .collect::<BTreeSet<_>>();
1447 if covered_unit_ids != expected_unit_ids {
1448 return Err(DagMlError::RuntimeValidation(format!(
1449 "bundle `{}` prediction cache `{}` does not cover all units for aggregated requirement `{}`",
1450 bundle.bundle_id,
1451 cache.cache_id,
1452 requirement.key()
1453 )));
1454 }
1455 Ok(())
1456}
1457
1458pub fn build_execution_bundle(
1459 bundle_id: BundleId,
1460 plan: &ExecutionPlan,
1461 selected_variant_id: Option<VariantId>,
1462 selections: BTreeMap<String, SelectionDecision>,
1463 refit_artifacts: Vec<RefitArtifactRecord>,
1464) -> Result<ExecutionBundle> {
1465 build_execution_bundle_with_prediction_requirements(
1466 bundle_id,
1467 plan,
1468 selected_variant_id,
1469 selections,
1470 refit_artifacts,
1471 Vec::new(),
1472 )
1473}
1474
1475pub fn build_execution_bundle_with_prediction_requirements(
1476 bundle_id: BundleId,
1477 plan: &ExecutionPlan,
1478 selected_variant_id: Option<VariantId>,
1479 selections: BTreeMap<String, SelectionDecision>,
1480 refit_artifacts: Vec<RefitArtifactRecord>,
1481 prediction_requirements: Vec<BundlePredictionRequirement>,
1482) -> Result<ExecutionBundle> {
1483 build_execution_bundle_with_prediction_contracts(
1484 bundle_id,
1485 plan,
1486 selected_variant_id,
1487 selections,
1488 refit_artifacts,
1489 prediction_requirements,
1490 Vec::new(),
1491 )
1492}
1493
1494pub fn build_execution_bundle_with_prediction_contracts(
1495 bundle_id: BundleId,
1496 plan: &ExecutionPlan,
1497 selected_variant_id: Option<VariantId>,
1498 selections: BTreeMap<String, SelectionDecision>,
1499 refit_artifacts: Vec<RefitArtifactRecord>,
1500 prediction_requirements: Vec<BundlePredictionRequirement>,
1501 prediction_caches: Vec<BundlePredictionCacheRecord>,
1502) -> Result<ExecutionBundle> {
1503 plan.validate()?;
1504 let bundle = ExecutionBundle {
1505 bundle_id,
1506 schema_version: EXECUTION_BUNDLE_SCHEMA_VERSION,
1507 plan_id: plan.id.clone(),
1508 graph_fingerprint: plan.graph_fingerprint.clone(),
1509 campaign_fingerprint: plan.campaign_fingerprint.clone(),
1510 controller_fingerprint: plan.controller_fingerprint.clone(),
1511 selected_variant_id,
1512 selections,
1513 refit_artifacts,
1514 prediction_requirements,
1515 prediction_caches,
1516 data_requirements: collect_data_requirements(plan)?,
1517 unsafe_flags: BTreeSet::new(),
1518 metadata: BTreeMap::new(),
1519 };
1520 bundle.validate_against_plan(plan)?;
1521 Ok(bundle)
1522}
1523
1524fn collect_data_requirements(plan: &ExecutionPlan) -> Result<Vec<BundleDataRequirement>> {
1525 let mut requirements = Vec::new();
1526 for node_plan in plan.node_plans.values() {
1527 for binding in &node_plan.data_bindings {
1528 requirements.push(BundleDataRequirement {
1529 node_id: node_plan.node_id.clone(),
1530 input_name: binding.input_name.clone(),
1531 schema_fingerprint: binding.schema_fingerprint.clone(),
1532 plan_fingerprint: binding.plan_fingerprint.clone(),
1533 relation_fingerprint: binding.relation_fingerprint.clone(),
1534 output_representation: binding.output_representation.clone(),
1535 feature_set_id: binding.feature_set_id.clone(),
1536 representation_replay_manifest: None,
1537 representation_compatibility: None,
1538 });
1539 }
1540 }
1541 requirements.sort_by_key(BundleDataRequirement::key);
1542 for requirement in &requirements {
1543 requirement.validate()?;
1544 }
1545 Ok(requirements)
1546}
1547
1548pub fn build_prediction_cache_record(
1549 requirement: &BundlePredictionRequirement,
1550 blocks: &[PredictionBlock],
1551) -> Result<BundlePredictionCacheRecord> {
1552 let selected = select_prediction_cache_blocks(requirement, blocks)?;
1553 build_prediction_cache_record_from_selected(requirement, &selected)
1554}
1555
1556pub fn build_prediction_cache_payload(
1557 requirement: &BundlePredictionRequirement,
1558 blocks: &[PredictionBlock],
1559) -> Result<BundlePredictionCachePayload> {
1560 let selected = select_prediction_cache_blocks(requirement, blocks)?;
1561 let payload = BundlePredictionCachePayload {
1562 requirement_key: requirement.key(),
1563 cache_id: format!("prediction-cache:{}", requirement.key()),
1564 format: BUNDLE_PREDICTION_CACHE_FORMAT.to_string(),
1565 partition: requirement.partition.clone(),
1566 prediction_level: requirement.prediction_level,
1567 block_count: selected.len(),
1568 row_count: selected.iter().map(|block| block.sample_ids.len()).sum(),
1569 content_fingerprint: stable_json_fingerprint(&selected)?,
1570 blocks: selected,
1571 aggregated_blocks: Vec::new(),
1572 };
1573 payload.validate()?;
1574 let record = build_prediction_cache_record(requirement, &payload.blocks)?;
1575 validate_prediction_cache_payload_matches_record(&payload, &record)?;
1576 Ok(payload)
1577}
1578
1579pub fn build_aggregated_prediction_cache_record(
1580 requirement: &BundlePredictionRequirement,
1581 blocks: &[AggregatedPredictionBlock],
1582) -> Result<BundlePredictionCacheRecord> {
1583 let selected = select_aggregated_prediction_cache_blocks(requirement, blocks)?;
1584 build_aggregated_prediction_cache_record_from_selected(requirement, &selected)
1585}
1586
1587pub fn build_aggregated_prediction_cache_payload(
1588 requirement: &BundlePredictionRequirement,
1589 blocks: &[AggregatedPredictionBlock],
1590) -> Result<BundlePredictionCachePayload> {
1591 let selected = select_aggregated_prediction_cache_blocks(requirement, blocks)?;
1592 let payload = BundlePredictionCachePayload {
1593 requirement_key: requirement.key(),
1594 cache_id: format!("prediction-cache:{}", requirement.key()),
1595 format: BUNDLE_PREDICTION_CACHE_FORMAT.to_string(),
1596 partition: requirement.partition.clone(),
1597 prediction_level: requirement.prediction_level,
1598 block_count: selected.len(),
1599 row_count: selected.iter().map(|block| block.unit_ids.len()).sum(),
1600 content_fingerprint: stable_json_fingerprint(&selected)?,
1601 blocks: Vec::new(),
1602 aggregated_blocks: selected,
1603 };
1604 payload.validate()?;
1605 let record = build_aggregated_prediction_cache_record(requirement, &payload.aggregated_blocks)?;
1606 validate_prediction_cache_payload_matches_record(&payload, &record)?;
1607 Ok(payload)
1608}
1609
1610pub fn validate_prediction_cache_payload_matches_record(
1611 payload: &BundlePredictionCachePayload,
1612 record: &BundlePredictionCacheRecord,
1613) -> Result<()> {
1614 payload.validate()?;
1615 record.validate()?;
1616 if payload.requirement_key != record.requirement_key
1617 || payload.cache_id != record.cache_id
1618 || payload.format != record.format
1619 || payload.partition != record.partition
1620 || payload.prediction_level != record.prediction_level
1621 || payload.block_count != record.block_count
1622 || payload.row_count != record.row_count
1623 || payload.content_fingerprint != record.content_fingerprint
1624 {
1625 return Err(DagMlError::RuntimeValidation(format!(
1626 "prediction cache payload `{}` does not match cache record `{}`",
1627 payload.cache_id, record.cache_id
1628 )));
1629 }
1630 let block_records = if payload.prediction_level == PredictionLevel::Sample {
1631 payload
1632 .blocks
1633 .iter()
1634 .map(|block| {
1635 Ok(BundlePredictionBlockCacheRecord {
1636 prediction_id: block.prediction_id.clone(),
1637 fold_id: block.fold_id.clone(),
1638 prediction_level: PredictionLevel::Sample,
1639 row_count: block.sample_ids.len(),
1640 unit_ids: Vec::new(),
1641 sample_ids: block.sample_ids.clone(),
1642 content_fingerprint: stable_json_fingerprint(block)?,
1643 })
1644 })
1645 .collect::<Result<Vec<_>>>()?
1646 } else {
1647 payload
1648 .aggregated_blocks
1649 .iter()
1650 .map(|block| {
1651 Ok(BundlePredictionBlockCacheRecord {
1652 prediction_id: block.prediction_id.clone(),
1653 fold_id: block.fold_id.clone(),
1654 prediction_level: block.level,
1655 row_count: block.unit_ids.len(),
1656 unit_ids: block.unit_ids.clone(),
1657 sample_ids: Vec::new(),
1658 content_fingerprint: stable_json_fingerprint(block)?,
1659 })
1660 })
1661 .collect::<Result<Vec<_>>>()?
1662 };
1663 if block_records != record.blocks {
1664 return Err(DagMlError::RuntimeValidation(format!(
1665 "prediction cache payload `{}` block fingerprints do not match cache record",
1666 payload.cache_id
1667 )));
1668 }
1669 Ok(())
1670}
1671
1672fn select_prediction_cache_blocks(
1673 requirement: &BundlePredictionRequirement,
1674 blocks: &[PredictionBlock],
1675) -> Result<Vec<PredictionBlock>> {
1676 requirement.validate()?;
1677 let mut selected = blocks
1678 .iter()
1679 .filter(|block| {
1680 block.producer_node == requirement.producer_node
1681 && block.partition == requirement.partition
1682 })
1683 .cloned()
1684 .collect::<Vec<_>>();
1685 if selected.is_empty() {
1686 return Err(DagMlError::RuntimeValidation(format!(
1687 "prediction cache requirement `{}` has no matching prediction blocks",
1688 requirement.key()
1689 )));
1690 }
1691 selected.sort_by(|left, right| {
1692 (
1693 left.fold_id.as_ref().map(ToString::to_string),
1694 left.prediction_id.clone(),
1695 )
1696 .cmp(&(
1697 right.fold_id.as_ref().map(ToString::to_string),
1698 right.prediction_id.clone(),
1699 ))
1700 });
1701 Ok(selected)
1702}
1703
1704fn select_aggregated_prediction_cache_blocks(
1705 requirement: &BundlePredictionRequirement,
1706 blocks: &[AggregatedPredictionBlock],
1707) -> Result<Vec<AggregatedPredictionBlock>> {
1708 requirement.validate()?;
1709 if requirement.prediction_level == PredictionLevel::Sample {
1710 return Err(DagMlError::RuntimeValidation(format!(
1711 "aggregated prediction cache requirement `{}` must use target or group level",
1712 requirement.key()
1713 )));
1714 }
1715 let mut selected = blocks
1716 .iter()
1717 .filter(|block| {
1718 block.producer_node == requirement.producer_node
1719 && block.partition == requirement.partition
1720 && block.level == requirement.prediction_level
1721 })
1722 .cloned()
1723 .collect::<Vec<_>>();
1724 if selected.is_empty() {
1725 return Err(DagMlError::RuntimeValidation(format!(
1726 "aggregated prediction cache requirement `{}` has no matching prediction blocks",
1727 requirement.key()
1728 )));
1729 }
1730 selected.sort_by(|left, right| {
1731 (
1732 left.fold_id.as_ref().map(ToString::to_string),
1733 left.prediction_id.clone(),
1734 )
1735 .cmp(&(
1736 right.fold_id.as_ref().map(ToString::to_string),
1737 right.prediction_id.clone(),
1738 ))
1739 });
1740 Ok(selected)
1741}
1742
1743fn build_prediction_cache_record_from_selected(
1744 requirement: &BundlePredictionRequirement,
1745 selected: &[PredictionBlock],
1746) -> Result<BundlePredictionCacheRecord> {
1747 requirement.validate()?;
1748 if selected.is_empty() {
1749 return Err(DagMlError::RuntimeValidation(format!(
1750 "prediction cache requirement `{}` has no matching prediction blocks",
1751 requirement.key()
1752 )));
1753 }
1754 let mut fold_ids = BTreeSet::new();
1755 let mut sample_ids = BTreeSet::new();
1756 let mut target_names: Option<Vec<String>> = None;
1757 let mut prediction_width: Option<usize> = None;
1758 let mut row_count = 0usize;
1759 let mut block_records = Vec::new();
1760 for block in selected {
1761 if block.producer_node != requirement.producer_node
1762 || block.partition != requirement.partition
1763 {
1764 return Err(DagMlError::RuntimeValidation(format!(
1765 "prediction cache `{}` contains a block outside the requirement scope",
1766 requirement.key()
1767 )));
1768 }
1769 let width = block.validate_shape()?;
1770 if prediction_width.is_some_and(|expected| expected != width) {
1771 return Err(DagMlError::RuntimeValidation(format!(
1772 "prediction cache `{}` has inconsistent prediction width",
1773 requirement.key()
1774 )));
1775 }
1776 prediction_width = Some(width);
1777 let block_target_names = normalized_prediction_targets(block, width);
1778 if target_names
1779 .as_ref()
1780 .is_some_and(|expected| expected != &block_target_names)
1781 {
1782 return Err(DagMlError::RuntimeValidation(format!(
1783 "prediction cache `{}` has inconsistent target names",
1784 requirement.key()
1785 )));
1786 }
1787 target_names = Some(block_target_names);
1788 if let Some(fold_id) = &block.fold_id {
1789 fold_ids.insert(fold_id.clone());
1790 }
1791 sample_ids.extend(block.sample_ids.iter().cloned());
1792 row_count += block.sample_ids.len();
1793 block_records.push(BundlePredictionBlockCacheRecord {
1794 prediction_id: block.prediction_id.clone(),
1795 fold_id: block.fold_id.clone(),
1796 prediction_level: PredictionLevel::Sample,
1797 row_count: block.sample_ids.len(),
1798 unit_ids: Vec::new(),
1799 sample_ids: block.sample_ids.clone(),
1800 content_fingerprint: stable_json_fingerprint(block)?,
1801 });
1802 }
1803
1804 let record = BundlePredictionCacheRecord {
1805 requirement_key: requirement.key(),
1806 cache_id: format!("prediction-cache:{}", requirement.key()),
1807 format: BUNDLE_PREDICTION_CACHE_FORMAT.to_string(),
1808 partition: requirement.partition.clone(),
1809 prediction_level: requirement.prediction_level,
1810 fold_ids: fold_ids.into_iter().collect(),
1811 unit_ids: requirement.unit_ids.clone(),
1812 sample_ids: sample_ids.into_iter().collect(),
1813 prediction_width: prediction_width.unwrap_or_default(),
1814 target_names: target_names.unwrap_or_default(),
1815 block_count: block_records.len(),
1816 row_count,
1817 content_fingerprint: stable_json_fingerprint(selected)?,
1818 blocks: block_records,
1819 };
1820 validate_prediction_cache_matches_requirement(&record, requirement)?;
1821 record.validate()?;
1822 Ok(record)
1823}
1824
1825fn build_aggregated_prediction_cache_record_from_selected(
1826 requirement: &BundlePredictionRequirement,
1827 selected: &[AggregatedPredictionBlock],
1828) -> Result<BundlePredictionCacheRecord> {
1829 requirement.validate()?;
1830 if requirement.prediction_level == PredictionLevel::Sample {
1831 return Err(DagMlError::RuntimeValidation(format!(
1832 "aggregated prediction cache requirement `{}` must use target or group level",
1833 requirement.key()
1834 )));
1835 }
1836 if selected.is_empty() {
1837 return Err(DagMlError::RuntimeValidation(format!(
1838 "aggregated prediction cache requirement `{}` has no matching prediction blocks",
1839 requirement.key()
1840 )));
1841 }
1842 let mut fold_ids = BTreeSet::new();
1843 let mut unit_ids = BTreeSet::new();
1844 let mut target_names: Option<Vec<String>> = None;
1845 let mut prediction_width: Option<usize> = None;
1846 let mut row_count = 0usize;
1847 let mut block_records = Vec::new();
1848 for block in selected {
1849 if block.producer_node != requirement.producer_node
1850 || block.partition != requirement.partition
1851 || block.level != requirement.prediction_level
1852 {
1853 return Err(DagMlError::RuntimeValidation(format!(
1854 "aggregated prediction cache `{}` contains a block outside the requirement scope",
1855 requirement.key()
1856 )));
1857 }
1858 let width = block.validate_shape()?;
1859 if prediction_width.is_some_and(|expected| expected != width) {
1860 return Err(DagMlError::RuntimeValidation(format!(
1861 "aggregated prediction cache `{}` has inconsistent prediction width",
1862 requirement.key()
1863 )));
1864 }
1865 prediction_width = Some(width);
1866 let block_target_names = normalized_aggregated_prediction_targets(block, width);
1867 if target_names
1868 .as_ref()
1869 .is_some_and(|expected| expected != &block_target_names)
1870 {
1871 return Err(DagMlError::RuntimeValidation(format!(
1872 "aggregated prediction cache `{}` has inconsistent target names",
1873 requirement.key()
1874 )));
1875 }
1876 target_names = Some(block_target_names);
1877 if let Some(fold_id) = &block.fold_id {
1878 fold_ids.insert(fold_id.clone());
1879 }
1880 unit_ids.extend(block.unit_ids.iter().cloned());
1881 row_count += block.unit_ids.len();
1882 block_records.push(BundlePredictionBlockCacheRecord {
1883 prediction_id: block.prediction_id.clone(),
1884 fold_id: block.fold_id.clone(),
1885 prediction_level: block.level,
1886 row_count: block.unit_ids.len(),
1887 unit_ids: block.unit_ids.clone(),
1888 sample_ids: Vec::new(),
1889 content_fingerprint: stable_json_fingerprint(block)?,
1890 });
1891 }
1892
1893 let record = BundlePredictionCacheRecord {
1894 requirement_key: requirement.key(),
1895 cache_id: format!("prediction-cache:{}", requirement.key()),
1896 format: BUNDLE_PREDICTION_CACHE_FORMAT.to_string(),
1897 partition: requirement.partition.clone(),
1898 prediction_level: requirement.prediction_level,
1899 fold_ids: fold_ids.into_iter().collect(),
1900 unit_ids: unit_ids.into_iter().collect(),
1901 sample_ids: Vec::new(),
1902 prediction_width: prediction_width.unwrap_or_default(),
1903 target_names: target_names.unwrap_or_default(),
1904 block_count: block_records.len(),
1905 row_count,
1906 content_fingerprint: stable_json_fingerprint(selected)?,
1907 blocks: block_records,
1908 };
1909 validate_prediction_cache_matches_requirement(&record, requirement)?;
1910 record.validate()?;
1911 Ok(record)
1912}
1913
1914fn validate_prediction_cache_matches_requirement(
1915 cache: &BundlePredictionCacheRecord,
1916 requirement: &BundlePredictionRequirement,
1917) -> Result<()> {
1918 if cache.requirement_key != requirement.key()
1919 || cache.partition != requirement.partition
1920 || cache.prediction_level != requirement.prediction_level
1921 || cache.fold_ids != requirement.fold_ids
1922 || cache.unit_ids != requirement.unit_ids
1923 || cache.sample_ids != requirement.sample_ids
1924 || cache.prediction_width != requirement.prediction_width
1925 || cache.target_names != requirement.target_names
1926 {
1927 return Err(DagMlError::RuntimeValidation(format!(
1928 "prediction cache `{}` does not match requirement `{}`",
1929 cache.cache_id,
1930 requirement.key()
1931 )));
1932 }
1933 Ok(())
1934}
1935
1936fn normalized_prediction_targets(block: &PredictionBlock, width: usize) -> Vec<String> {
1937 if block.target_names.is_empty() {
1938 (0..width).map(|index| format!("p{index}")).collect()
1939 } else {
1940 block.target_names.clone()
1941 }
1942}
1943
1944fn normalized_aggregated_prediction_targets(
1945 block: &AggregatedPredictionBlock,
1946 width: usize,
1947) -> Vec<String> {
1948 if block.target_names.is_empty() {
1949 (0..width).map(|index| format!("p{index}")).collect()
1950 } else {
1951 block.target_names.clone()
1952 }
1953}
1954
1955fn sample_prediction_units(sample_ids: &[SampleId]) -> Vec<PredictionUnitId> {
1956 sample_ids
1957 .iter()
1958 .cloned()
1959 .map(PredictionUnitId::Sample)
1960 .collect()
1961}
1962
1963fn validate_prediction_units(
1964 label: &str,
1965 expected_level: PredictionLevel,
1966 unit_ids: &[PredictionUnitId],
1967) -> Result<()> {
1968 validate_unique_ids(label, unit_ids)?;
1969 for unit_id in unit_ids {
1970 if unit_id.level() != expected_level {
1971 return Err(DagMlError::RuntimeValidation(format!(
1972 "{label} `{unit_id}` does not match prediction level {:?}",
1973 expected_level
1974 )));
1975 }
1976 }
1977 Ok(())
1978}
1979
1980fn validate_fingerprint(label: &str, value: &str) -> Result<()> {
1981 if value.len() != 64 || !value.bytes().all(|byte| byte.is_ascii_hexdigit()) {
1982 return Err(DagMlError::RuntimeValidation(format!(
1983 "{label} fingerprint must be a 64-character hex digest"
1984 )));
1985 }
1986 Ok(())
1987}
1988
1989fn validate_non_empty(label: &str, value: &str) -> Result<()> {
1990 if value.trim().is_empty() {
1991 return Err(DagMlError::RuntimeValidation(format!("{label} is empty")));
1992 }
1993 Ok(())
1994}
1995
1996fn validate_unique_ids<T>(label: &str, values: &[T]) -> Result<()>
1997where
1998 T: Ord + ToString,
1999{
2000 let mut seen = BTreeSet::new();
2001 for value in values {
2002 if !seen.insert(value) {
2003 return Err(DagMlError::RuntimeValidation(format!(
2004 "duplicate {label} `{}`",
2005 value.to_string()
2006 )));
2007 }
2008 }
2009 Ok(())
2010}
2011
2012#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
2013pub struct ReplayPhaseRequest {
2014 pub bundle_id: BundleId,
2015 pub phase: Phase,
2016 #[serde(default)]
2017 pub data_envelope_keys: Vec<String>,
2018}
2019
2020impl ReplayPhaseRequest {
2021 pub fn validate_for_bundle(&self, bundle: &ExecutionBundle) -> Result<()> {
2022 self.validate_for_bundle_with_prediction_cache_store(bundle, false)
2023 }
2024
2025 pub fn validate_for_bundle_with_prediction_cache_store(
2026 &self,
2027 bundle: &ExecutionBundle,
2028 prediction_cache_available: bool,
2029 ) -> Result<()> {
2030 self.validate_for_bundle_internal(bundle, prediction_cache_available)
2031 }
2032
2033 pub fn validate_for_bundle_with_prediction_cache_payloads(
2034 &self,
2035 bundle: &ExecutionBundle,
2036 prediction_cache_payloads: Option<&BundlePredictionCachePayloadSet>,
2037 ) -> Result<()> {
2038 if let Some(payloads) = prediction_cache_payloads {
2039 payloads.validate_against_bundle(bundle)?;
2040 }
2041 self.validate_for_bundle_internal(bundle, prediction_cache_payloads.is_some())
2042 }
2043
2044 fn validate_for_bundle_internal(
2045 &self,
2046 bundle: &ExecutionBundle,
2047 prediction_cache_available: bool,
2048 ) -> Result<()> {
2049 bundle.validate()?;
2050 if self.bundle_id != bundle.bundle_id {
2051 return Err(DagMlError::RuntimeValidation(format!(
2052 "replay request bundle `{}` does not match bundle `{}`",
2053 self.bundle_id, bundle.bundle_id
2054 )));
2055 }
2056 if !matches!(self.phase, Phase::Predict | Phase::Explain | Phase::Refit) {
2057 return Err(DagMlError::RuntimeValidation(format!(
2058 "bundle replay phase {:?} is not supported",
2059 self.phase
2060 )));
2061 }
2062 if self.phase == Phase::Refit && !bundle.prediction_requirements.is_empty() {
2063 if prediction_cache_available {
2064 return self.validate_data_envelope_keys(bundle);
2065 }
2066 return Err(DagMlError::RuntimeValidation(format!(
2067 "bundle `{}` cannot replay REFIT because it depends on {} OOF prediction requirement(s) but stores only prediction cache manifests",
2068 bundle.bundle_id,
2069 bundle.prediction_requirements.len()
2070 )));
2071 }
2072 self.validate_data_envelope_keys(bundle)
2073 }
2074
2075 fn validate_data_envelope_keys(&self, bundle: &ExecutionBundle) -> Result<()> {
2076 let expected = bundle
2077 .data_requirements
2078 .iter()
2079 .map(BundleDataRequirement::key)
2080 .collect::<BTreeSet<_>>();
2081 let mut requested = BTreeSet::new();
2082 for key in &self.data_envelope_keys {
2083 if key.trim().is_empty() {
2084 return Err(DagMlError::RuntimeValidation(
2085 "replay request contains an empty data envelope key".to_string(),
2086 ));
2087 }
2088 if !requested.insert(key.as_str()) {
2089 return Err(DagMlError::RuntimeValidation(format!(
2090 "replay request contains duplicate data envelope key `{key}`"
2091 )));
2092 }
2093 if !expected.contains(key.as_str()) {
2094 return Err(DagMlError::RuntimeValidation(format!(
2095 "replay request references unknown data envelope key `{key}`"
2096 )));
2097 }
2098 }
2099 for requirement in &bundle.data_requirements {
2100 let key = requirement.key();
2101 if !requested.contains(key.as_str()) {
2102 return Err(DagMlError::RuntimeValidation(format!(
2103 "replay request is missing data envelope key `{key}`"
2104 )));
2105 }
2106 }
2107 Ok(())
2108 }
2109}
2110
2111#[cfg(test)]
2112mod tests {
2113 use super::*;
2114 use crate::controller::{ControllerManifest, ControllerRegistry};
2115 use crate::data::{
2116 AggregateRepresentation, RepresentationCardinality, RepresentationCompatibilityOutcome,
2117 RepresentationCompatibilityReport, RepresentationMissingSourcePolicy, RepresentationPlan,
2118 RepresentationReplayManifest,
2119 };
2120 use crate::dsl::{compile_pipeline_dsl_with_generation, PipelineDslSpec};
2121 use crate::graph::GraphSpec;
2122 use crate::ids::{ArtifactId, FoldId, SampleId, TargetId};
2123 use crate::plan::{build_execution_plan, CampaignSpec};
2124 use crate::relation::EntityUnitLevel;
2125 use crate::selection::{
2126 select_candidate, CandidateScore, MetricObjective, SelectionMetric, SelectionPolicy,
2127 };
2128
2129 fn plan() -> ExecutionPlan {
2130 let graph: GraphSpec =
2131 serde_json::from_str(include_str!("../../../examples/minimal_graph.json")).unwrap();
2132 let campaign: CampaignSpec = serde_json::from_str(include_str!(
2133 "../../../examples/campaign_oof_generation.json"
2134 ))
2135 .unwrap();
2136 let manifests: Vec<ControllerManifest> =
2137 serde_json::from_str(include_str!("../../../examples/controller_manifests.json"))
2138 .unwrap();
2139 let mut registry = ControllerRegistry::new();
2140 for manifest in manifests {
2141 registry.register(manifest).unwrap();
2142 }
2143 build_execution_plan("plan:bundle", graph, campaign, ®istry).unwrap()
2144 }
2145
2146 fn branch_merge_plan() -> ExecutionPlan {
2147 let graph: GraphSpec = serde_json::from_str(include_str!(
2148 "../../../examples/branch_merge_oof_graph.json"
2149 ))
2150 .unwrap();
2151 let campaign: CampaignSpec = serde_json::from_str(include_str!(
2152 "../../../examples/campaign_branch_merge_oof.json"
2153 ))
2154 .unwrap();
2155 let manifests: Vec<ControllerManifest> =
2156 serde_json::from_str(include_str!("../../../examples/controller_manifests.json"))
2157 .unwrap();
2158 let mut registry = ControllerRegistry::new();
2159 for manifest in manifests {
2160 registry.register(manifest).unwrap();
2161 }
2162 build_execution_plan("plan:branch.merge.bundle", graph, campaign, ®istry).unwrap()
2163 }
2164
2165 fn executable_dsl_plan() -> ExecutionPlan {
2166 let spec: PipelineDslSpec = serde_json::from_str(include_str!(
2167 "../../../examples/pipeline_dsl_branch_merge_executable.json"
2168 ))
2169 .unwrap();
2170 let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
2171 let manifests: Vec<ControllerManifest> =
2172 serde_json::from_str(include_str!("../../../examples/controller_manifests.json"))
2173 .unwrap();
2174 let mut registry = ControllerRegistry::new();
2175 for manifest in manifests {
2176 registry.register(manifest).unwrap();
2177 }
2178 build_execution_plan(
2179 "plan:dsl.branch.merge.bundle",
2180 compiled.graph,
2181 compiled.campaign_template,
2182 ®istry,
2183 )
2184 .unwrap()
2185 }
2186
2187 fn branch_merge_selection_decisions() -> BTreeMap<String, SelectionDecision> {
2188 serde_json::from_str(include_str!(
2189 "../../../examples/fixtures/bundle/selection_decisions_branch_merge.json"
2190 ))
2191 .unwrap()
2192 }
2193
2194 fn refit_artifact(
2195 plan: &ExecutionPlan,
2196 node_id: &str,
2197 data_requirement_keys: Vec<String>,
2198 prediction_requirement_keys: Vec<String>,
2199 ) -> RefitArtifactRecord {
2200 let node_id = NodeId::new(node_id).unwrap();
2201 let node_plan = plan.node_plans.get(&node_id).unwrap();
2202 RefitArtifactRecord {
2203 node_id: node_plan.node_id.clone(),
2204 controller_id: node_plan.controller_id.clone(),
2205 artifact: ArtifactRef {
2206 id: ArtifactId::new(format!("artifact:{}:refit", node_plan.node_id)).unwrap(),
2207 kind: "mock_model".to_string(),
2208 controller_id: node_plan.controller_id.clone(),
2209 backend: None,
2210 uri: None,
2211 content_fingerprint: None,
2212 size_bytes: Some(128),
2213 plugin: None,
2214 plugin_version: None,
2215 },
2216 params_fingerprint: node_plan.params_fingerprint.clone(),
2217 data_requirement_keys,
2218 prediction_requirement_keys,
2219 }
2220 }
2221
2222 fn branch_merge_samples() -> Vec<SampleId> {
2223 vec![
2224 SampleId::new("sample:1").unwrap(),
2225 SampleId::new("sample:2").unwrap(),
2226 SampleId::new("sample:3").unwrap(),
2227 SampleId::new("sample:4").unwrap(),
2228 ]
2229 }
2230
2231 fn branch_merge_requirement(
2232 producer_node: &str,
2233 target_port: &str,
2234 ) -> BundlePredictionRequirement {
2235 BundlePredictionRequirement {
2236 producer_node: NodeId::new(producer_node).unwrap(),
2237 source_port: "oof".to_string(),
2238 consumer_node: NodeId::new("merge:stack.pred_plus_original.meta:ridge").unwrap(),
2239 target_port: target_port.to_string(),
2240 partition: PredictionPartition::Validation,
2241 prediction_level: PredictionLevel::Sample,
2242 fold_ids: vec![
2243 FoldId::new("fold:0").unwrap(),
2244 FoldId::new("fold:1").unwrap(),
2245 ],
2246 unit_ids: Vec::new(),
2247 sample_ids: branch_merge_samples(),
2248 prediction_width: 1,
2249 target_names: vec!["y".to_string()],
2250 }
2251 }
2252
2253 fn branch_merge_prediction_blocks(producer_node: &str, offset: f64) -> Vec<PredictionBlock> {
2254 let producer_node = NodeId::new(producer_node).unwrap();
2255 let samples = branch_merge_samples();
2256 vec![
2257 PredictionBlock {
2258 prediction_id: Some(format!("prediction:{producer_node}:fold0")),
2259 producer_node: producer_node.clone(),
2260 partition: PredictionPartition::Validation,
2261 fold_id: Some(FoldId::new("fold:0").unwrap()),
2262 sample_ids: samples[0..2].to_vec(),
2263 values: vec![vec![offset + 0.1], vec![offset + 0.2]],
2264 target_names: vec!["y".to_string()],
2265 },
2266 PredictionBlock {
2267 prediction_id: Some(format!("prediction:{producer_node}:fold1")),
2268 producer_node,
2269 partition: PredictionPartition::Validation,
2270 fold_id: Some(FoldId::new("fold:1").unwrap()),
2271 sample_ids: samples[2..4].to_vec(),
2272 values: vec![vec![offset + 0.3], vec![offset + 0.4]],
2273 target_names: vec!["y".to_string()],
2274 },
2275 ]
2276 }
2277
2278 fn decision() -> SelectionDecision {
2279 select_candidate(
2280 &SelectionPolicy {
2281 id: "select:merge".to_string(),
2282 metric: SelectionMetric {
2283 name: "rmse".to_string(),
2284 objective: MetricObjective::Minimize,
2285 },
2286 required_metric_level: Some(crate::policy::PredictionLevel::Sample),
2287 require_finite: true,
2288 evaluation_scope: None,
2289 refit_slot_plan: None,
2290 stacking_fit_contract: None,
2291 reduction_id: None,
2292 },
2293 &[
2294 CandidateScore {
2295 candidate_id: "model:base".to_string(),
2296 metrics: BTreeMap::from([("rmse".to_string(), 1.0)]),
2297 metadata: BTreeMap::from([(
2298 "metric_level".to_string(),
2299 serde_json::Value::String("sample".to_string()),
2300 )]),
2301 },
2302 CandidateScore {
2303 candidate_id: "model:other".to_string(),
2304 metrics: BTreeMap::from([("rmse".to_string(), 2.0)]),
2305 metadata: BTreeMap::from([(
2306 "metric_level".to_string(),
2307 serde_json::Value::String("sample".to_string()),
2308 )]),
2309 },
2310 ],
2311 )
2312 .unwrap()
2313 }
2314
2315 fn selected_model_base_decision() -> SelectionDecision {
2316 decision()
2317 }
2318
2319 fn model_base_refit_artifact(plan: &ExecutionPlan) -> RefitArtifactRecord {
2320 let model_plan = plan
2321 .node_plans
2322 .get(&NodeId::new("model:base").unwrap())
2323 .unwrap();
2324 RefitArtifactRecord {
2325 node_id: model_plan.node_id.clone(),
2326 controller_id: model_plan.controller_id.clone(),
2327 artifact: ArtifactRef {
2328 id: ArtifactId::new("artifact:model:base:refit").unwrap(),
2329 kind: "sklearn_pickle".to_string(),
2330 controller_id: model_plan.controller_id.clone(),
2331 backend: None,
2332 uri: None,
2333 content_fingerprint: None,
2334 size_bytes: Some(128),
2335 plugin: None,
2336 plugin_version: None,
2337 },
2338 params_fingerprint: model_plan.params_fingerprint.clone(),
2339 data_requirement_keys: vec!["model:base.x".to_string()],
2340 prediction_requirement_keys: Vec::new(),
2341 }
2342 }
2343
2344 #[test]
2345 fn builds_bundle_from_execution_plan() {
2346 let plan = plan();
2347 let artifact = model_base_refit_artifact(&plan);
2348
2349 let bundle = build_execution_bundle(
2350 BundleId::new("bundle:demo").unwrap(),
2351 &plan,
2352 Some(plan.variants[0].variant_id.clone()),
2353 BTreeMap::from([("merge".to_string(), decision())]),
2354 vec![artifact],
2355 )
2356 .unwrap();
2357
2358 bundle.validate_against_plan(&plan).unwrap();
2359 assert_eq!(bundle.data_requirements.len(), 1);
2360 }
2361
2362 #[test]
2363 fn bundle_data_requirements_accept_d7_replay_contracts() {
2364 let plan = plan();
2365 let artifact = model_base_refit_artifact(&plan);
2366 let mut bundle = build_execution_bundle(
2367 BundleId::new("bundle:d7.replay").unwrap(),
2368 &plan,
2369 Some(plan.variants[0].variant_id.clone()),
2370 BTreeMap::from([("merge".to_string(), decision())]),
2371 vec![artifact],
2372 )
2373 .unwrap();
2374 let relation_fingerprint = bundle.data_requirements[0]
2375 .relation_fingerprint
2376 .clone()
2377 .unwrap_or_else(|| "a".repeat(64));
2378 bundle.data_requirements[0].representation_replay_manifest =
2379 Some(RepresentationReplayManifest {
2380 manifest_id: "repr:d7.bundle".to_string(),
2381 representation_plan: RepresentationPlan::Aggregate(AggregateRepresentation {
2382 input_unit_level: EntityUnitLevel::Observation,
2383 output_unit_level: EntityUnitLevel::PhysicalSample,
2384 reducer_id: None,
2385 method: Some("mean".to_string()),
2386 cardinality: RepresentationCardinality::ManyToOne,
2387 }),
2388 combination_plan: None,
2389 output_unit_level: EntityUnitLevel::PhysicalSample,
2390 output_representation: Some("tabular_numeric".to_string()),
2391 relation_fingerprint: Some(relation_fingerprint.clone()),
2392 feature_schema_fingerprint: Some("b".repeat(64)),
2393 final_reduction_id: None,
2394 sample_observation_mapping: Vec::new(),
2395 combo_selection: Vec::new(),
2396 qc_policy_refs: Vec::new(),
2397 outlier_policy_refs: Vec::new(),
2398 missing_source_policy: None,
2399 missing_repetition_policy: None,
2400 prediction_representation: None,
2401 final_output_unit_level: Some(EntityUnitLevel::PhysicalSample),
2402 train_compatibility: None,
2403 predict_compatibility: None,
2404 metadata: BTreeMap::new(),
2405 });
2406 bundle.data_requirements[0].representation_compatibility =
2407 Some(RepresentationCompatibilityReport {
2408 policy: RepresentationMissingSourcePolicy::Strict,
2409 outcome: RepresentationCompatibilityOutcome::Compatible,
2410 fallback_used: None,
2411 warning_severity: None,
2412 affected_source_count: 0,
2413 affected_repetition_count: 0,
2414 affected_sample_count: 0,
2415 train_relation_fingerprint: Some(relation_fingerprint),
2416 predict_relation_fingerprint: None,
2417 train_unit_count: Some(2),
2418 predict_unit_count: Some(2),
2419 fixed_width_required: false,
2420 final_reducer_stabilizes_output: true,
2421 cartesian_combo_count_changed: false,
2422 late_fusion_branch_delta: false,
2423 messages: Vec::new(),
2424 metadata: BTreeMap::new(),
2425 });
2426 bundle.validate_against_plan(&plan).unwrap();
2427
2428 bundle.data_requirements[0]
2429 .representation_replay_manifest
2430 .as_mut()
2431 .unwrap()
2432 .relation_fingerprint = Some("c".repeat(64));
2433 if bundle.data_requirements[0].relation_fingerprint.is_some() {
2434 assert!(bundle.validate().is_err());
2435 }
2436 }
2437
2438 #[test]
2439 fn d9_negative_prediction_cache_refuses_missing_aggregated_unit_ids() {
2440 let cache = BundlePredictionCacheRecord {
2441 requirement_key: "model:base.oof->model:meta.pred".to_string(),
2442 cache_id: "prediction-cache:d9.missing-units".to_string(),
2443 format: BUNDLE_PREDICTION_CACHE_FORMAT.to_string(),
2444 partition: PredictionPartition::Validation,
2445 prediction_level: PredictionLevel::Target,
2446 fold_ids: vec![FoldId::new("fold:0").unwrap()],
2447 unit_ids: Vec::new(),
2448 sample_ids: Vec::new(),
2449 prediction_width: 1,
2450 target_names: vec!["y".to_string()],
2451 block_count: 1,
2452 row_count: 1,
2453 content_fingerprint: "d".repeat(64),
2454 blocks: vec![BundlePredictionBlockCacheRecord {
2455 prediction_id: Some("prediction:d9.target.fold0".to_string()),
2456 fold_id: Some(FoldId::new("fold:0").unwrap()),
2457 prediction_level: PredictionLevel::Target,
2458 row_count: 1,
2459 unit_ids: vec![PredictionUnitId::Target(TargetId::new("target:a").unwrap())],
2460 sample_ids: Vec::new(),
2461 content_fingerprint: "e".repeat(64),
2462 }],
2463 };
2464
2465 let error = cache.validate().unwrap_err().to_string();
2466 assert!(
2467 error.contains("row_count does not match unique unit ids"),
2468 "unexpected D9 missing-unit-id cache error: {error}"
2469 );
2470 }
2471
2472 #[test]
2473 fn refit_artifact_validation_checks_portable_artifact_metadata() {
2474 let plan = plan();
2475 let mut artifact = model_base_refit_artifact(&plan);
2476 artifact.artifact.backend = Some(crate::runtime::ArtifactBackend::Joblib);
2477 artifact.artifact.uri = Some("artifacts/model.joblib".to_string());
2478 artifact.artifact.content_fingerprint = Some("c".repeat(64));
2479 artifact.artifact.plugin = Some("dagml.sklearn".to_string());
2480 artifact.artifact.plugin_version = Some("1.0.0".to_string());
2481 artifact.validate().unwrap();
2482
2483 artifact.artifact.content_fingerprint = Some("short".to_string());
2484 assert!(artifact
2485 .validate()
2486 .unwrap_err()
2487 .to_string()
2488 .contains("artifact content fingerprint"));
2489 }
2490
2491 #[test]
2492 fn bundle_selections_must_match_plan_and_refit_artifacts() {
2493 let plan = plan();
2494 let artifact = model_base_refit_artifact(&plan);
2495 let valid = build_execution_bundle(
2496 BundleId::new("bundle:selected.model").unwrap(),
2497 &plan,
2498 Some(plan.variants[0].variant_id.clone()),
2499 BTreeMap::from([("model".to_string(), selected_model_base_decision())]),
2500 vec![artifact.clone()],
2501 )
2502 .unwrap();
2503 valid.validate_against_plan(&plan).unwrap();
2504
2505 assert!(build_execution_bundle(
2506 BundleId::new("bundle:selected.model.missing.artifact").unwrap(),
2507 &plan,
2508 Some(plan.variants[0].variant_id.clone()),
2509 BTreeMap::from([("model".to_string(), selected_model_base_decision())]),
2510 Vec::new(),
2511 )
2512 .is_err());
2513
2514 let mut missing_level = selected_model_base_decision();
2515 missing_level.metric_level = None;
2516 assert!(build_execution_bundle(
2517 BundleId::new("bundle:selected.missing.level").unwrap(),
2518 &plan,
2519 Some(plan.variants[0].variant_id.clone()),
2520 BTreeMap::from([("model".to_string(), missing_level)]),
2521 vec![artifact.clone()],
2522 )
2523 .is_err());
2524
2525 let mut wrong_level = selected_model_base_decision();
2526 wrong_level.metric_level = Some(crate::policy::PredictionLevel::Target);
2527 assert!(build_execution_bundle(
2528 BundleId::new("bundle:selected.wrong.level").unwrap(),
2529 &plan,
2530 Some(plan.variants[0].variant_id.clone()),
2531 BTreeMap::from([("model".to_string(), wrong_level)]),
2532 vec![artifact.clone()],
2533 )
2534 .is_err());
2535
2536 let mut unknown = selected_model_base_decision();
2537 unknown.selected_candidate_id = "model:missing".to_string();
2538 unknown.ranked_candidates[0].candidate_id = "model:missing".to_string();
2539 assert!(build_execution_bundle(
2540 BundleId::new("bundle:selected.unknown").unwrap(),
2541 &plan,
2542 Some(plan.variants[0].variant_id.clone()),
2543 BTreeMap::from([("model".to_string(), unknown)]),
2544 vec![artifact],
2545 )
2546 .is_err());
2547 }
2548
2549 #[test]
2550 fn bundle_artifact_params_follow_selected_generation_variant() {
2551 let plan = executable_dsl_plan();
2552 let selected_variant = &plan.variants[0];
2553 let node_plan = plan
2554 .node_plans
2555 .get(&NodeId::new("branch:b0.model:ridge").unwrap())
2556 .unwrap();
2557 let effective_params = selected_variant
2558 .effective_params_for_node(&node_plan.node_id, &node_plan.params)
2559 .unwrap();
2560 let effective_fingerprint = stable_json_fingerprint(&effective_params).unwrap();
2561 assert_ne!(effective_fingerprint, node_plan.params_fingerprint);
2562
2563 let artifact = RefitArtifactRecord {
2564 node_id: node_plan.node_id.clone(),
2565 controller_id: node_plan.controller_id.clone(),
2566 artifact: ArtifactRef {
2567 id: ArtifactId::new("artifact:branch:b0.model:ridge:refit").unwrap(),
2568 kind: "mock_model".to_string(),
2569 controller_id: node_plan.controller_id.clone(),
2570 backend: None,
2571 uri: None,
2572 content_fingerprint: None,
2573 size_bytes: Some(128),
2574 plugin: None,
2575 plugin_version: None,
2576 },
2577 params_fingerprint: effective_fingerprint,
2578 data_requirement_keys: vec!["branch:b0.model:ridge.x".to_string()],
2579 prediction_requirement_keys: Vec::new(),
2580 };
2581
2582 build_execution_bundle(
2583 BundleId::new("bundle:dsl.variant.params").unwrap(),
2584 &plan,
2585 Some(selected_variant.variant_id.clone()),
2586 BTreeMap::new(),
2587 vec![artifact.clone()],
2588 )
2589 .unwrap();
2590
2591 let mut stale_artifact = artifact;
2592 stale_artifact.params_fingerprint = node_plan.params_fingerprint.clone();
2593 let error = build_execution_bundle(
2594 BundleId::new("bundle:dsl.variant.params.stale").unwrap(),
2595 &plan,
2596 Some(selected_variant.variant_id.clone()),
2597 BTreeMap::new(),
2598 vec![stale_artifact],
2599 )
2600 .unwrap_err();
2601 assert!(format!("{error}").contains("artifact params"));
2602 }
2603
2604 #[test]
2605 fn branch_merge_bundle_links_selected_refits_and_fold_aligned_oof_caches() {
2606 let plan = branch_merge_plan();
2607 let b0_requirement = branch_merge_requirement("branch:b0.model:ridge", "b0_oof");
2608 let b1_requirement = branch_merge_requirement("branch:b1.model:rf", "b1_oof");
2609 let b0_cache = build_prediction_cache_record(
2610 &b0_requirement,
2611 &branch_merge_prediction_blocks("branch:b0.model:ridge", 0.0),
2612 )
2613 .unwrap();
2614 let b1_cache = build_prediction_cache_record(
2615 &b1_requirement,
2616 &branch_merge_prediction_blocks("branch:b1.model:rf", 1.0),
2617 )
2618 .unwrap();
2619 let b0_artifact = refit_artifact(
2620 &plan,
2621 "branch:b0.model:ridge",
2622 vec!["branch:b0.model:ridge.x".to_string()],
2623 Vec::new(),
2624 );
2625 let b1_artifact = refit_artifact(
2626 &plan,
2627 "branch:b1.model:rf",
2628 vec!["branch:b1.model:rf.x".to_string()],
2629 Vec::new(),
2630 );
2631 let merge_artifact = refit_artifact(
2632 &plan,
2633 "merge:stack.pred_plus_original.meta:ridge",
2634 vec!["merge:stack.pred_plus_original.meta:ridge.x_original".to_string()],
2635 vec![b0_requirement.key(), b1_requirement.key()],
2636 );
2637
2638 let bundle = build_execution_bundle_with_prediction_contracts(
2639 BundleId::new("bundle:branch.merge.selected.refit").unwrap(),
2640 &plan,
2641 Some(plan.variants[0].variant_id.clone()),
2642 branch_merge_selection_decisions(),
2643 vec![
2644 b0_artifact.clone(),
2645 b1_artifact.clone(),
2646 merge_artifact.clone(),
2647 ],
2648 vec![b0_requirement.clone(), b1_requirement.clone()],
2649 vec![b0_cache.clone(), b1_cache.clone()],
2650 )
2651 .unwrap();
2652 bundle.validate_against_plan(&plan).unwrap();
2653 assert_eq!(bundle.selections.len(), 3);
2654 assert_eq!(bundle.prediction_requirements.len(), 2);
2655 assert_eq!(
2656 bundle.refit_artifacts[2].data_requirement_keys,
2657 vec!["merge:stack.pred_plus_original.meta:ridge.x_original"]
2658 );
2659 assert_eq!(
2660 bundle.refit_artifacts[2].prediction_requirement_keys,
2661 vec![
2662 "branch:b0.model:ridge.oof->merge:stack.pred_plus_original.meta:ridge.b0_oof",
2663 "branch:b1.model:rf.oof->merge:stack.pred_plus_original.meta:ridge.b1_oof",
2664 ]
2665 );
2666
2667 assert!(build_execution_bundle_with_prediction_contracts(
2668 BundleId::new("bundle:branch.merge.missing.branch.refit").unwrap(),
2669 &plan,
2670 Some(plan.variants[0].variant_id.clone()),
2671 branch_merge_selection_decisions(),
2672 vec![b0_artifact.clone(), merge_artifact.clone()],
2673 vec![b0_requirement.clone(), b1_requirement.clone()],
2674 vec![b0_cache.clone(), b1_cache.clone()],
2675 )
2676 .is_err());
2677
2678 let mut misaligned_cache = b0_cache;
2679 misaligned_cache.blocks[0].sample_ids = vec![
2680 SampleId::new("sample:1").unwrap(),
2681 SampleId::new("sample:3").unwrap(),
2682 ];
2683 misaligned_cache.blocks[1].sample_ids = vec![
2684 SampleId::new("sample:2").unwrap(),
2685 SampleId::new("sample:4").unwrap(),
2686 ];
2687 let error = build_execution_bundle_with_prediction_contracts(
2688 BundleId::new("bundle:branch.merge.misaligned.oof.cache").unwrap(),
2689 &plan,
2690 Some(plan.variants[0].variant_id.clone()),
2691 branch_merge_selection_decisions(),
2692 vec![b0_artifact, b1_artifact, merge_artifact],
2693 vec![b0_requirement, b1_requirement],
2694 vec![misaligned_cache, b1_cache],
2695 )
2696 .unwrap_err()
2697 .to_string();
2698 assert!(
2699 error.contains("does not match validation samples"),
2700 "unexpected fold-alignment error: {error}"
2701 );
2702 }
2703
2704 #[test]
2705 fn prediction_requirements_are_typed_and_validate_against_oof_edges() {
2706 let plan = branch_merge_plan();
2707 let meta_plan = plan
2708 .node_plans
2709 .get(&NodeId::new("merge:stack.pred_plus_original.meta:ridge").unwrap())
2710 .unwrap();
2711 let producer_node = NodeId::new("branch:b0.model:ridge").unwrap();
2712 let fold0 = FoldId::new("fold:0").unwrap();
2713 let fold1 = FoldId::new("fold:1").unwrap();
2714 let samples = [
2715 SampleId::new("sample:1").unwrap(),
2716 SampleId::new("sample:2").unwrap(),
2717 SampleId::new("sample:3").unwrap(),
2718 SampleId::new("sample:4").unwrap(),
2719 ];
2720 let requirement = BundlePredictionRequirement {
2721 producer_node: producer_node.clone(),
2722 source_port: "oof".to_string(),
2723 consumer_node: meta_plan.node_id.clone(),
2724 target_port: "b0_oof".to_string(),
2725 partition: PredictionPartition::Validation,
2726 prediction_level: PredictionLevel::Sample,
2727 fold_ids: vec![fold0.clone(), fold1.clone()],
2728 unit_ids: Vec::new(),
2729 sample_ids: samples.to_vec(),
2730 prediction_width: 1,
2731 target_names: vec!["y".to_string()],
2732 };
2733 let prediction_blocks = vec![
2734 PredictionBlock {
2735 prediction_id: Some("prediction:branch:b0.fold0".to_string()),
2736 producer_node: producer_node.clone(),
2737 partition: PredictionPartition::Validation,
2738 fold_id: Some(fold0),
2739 sample_ids: samples[0..2].to_vec(),
2740 values: vec![vec![0.1], vec![0.2]],
2741 target_names: vec!["y".to_string()],
2742 },
2743 PredictionBlock {
2744 prediction_id: Some("prediction:branch:b0.fold1".to_string()),
2745 producer_node: producer_node.clone(),
2746 partition: PredictionPartition::Validation,
2747 fold_id: Some(fold1),
2748 sample_ids: samples[2..4].to_vec(),
2749 values: vec![vec![0.3], vec![0.4]],
2750 target_names: vec!["y".to_string()],
2751 },
2752 ];
2753 let cache = build_prediction_cache_record(&requirement, &prediction_blocks).unwrap();
2754 let payload = build_prediction_cache_payload(&requirement, &prediction_blocks).unwrap();
2755 assert_eq!(cache.prediction_level, PredictionLevel::Sample);
2756 assert_eq!(payload.prediction_level, PredictionLevel::Sample);
2757 assert!(cache
2758 .blocks
2759 .iter()
2760 .all(|block| block.prediction_level == PredictionLevel::Sample));
2761 validate_prediction_cache_payload_matches_record(&payload, &cache).unwrap();
2762 let mut wrong_level_requirement = requirement.clone();
2763 wrong_level_requirement.prediction_level = PredictionLevel::Target;
2764 assert!(wrong_level_requirement.validate().is_err());
2765 let mut wrong_level_cache = cache.clone();
2766 wrong_level_cache.prediction_level = PredictionLevel::Target;
2767 assert!(wrong_level_cache.validate().is_err());
2768 let mut wrong_level_payload = payload.clone();
2769 wrong_level_payload.prediction_level = PredictionLevel::Target;
2770 assert!(wrong_level_payload.validate().is_err());
2771 let prediction_key = requirement.key();
2772 let artifact = RefitArtifactRecord {
2773 node_id: meta_plan.node_id.clone(),
2774 controller_id: meta_plan.controller_id.clone(),
2775 artifact: ArtifactRef {
2776 id: ArtifactId::new("artifact:merge:stack.pred_plus_original.meta:ridge:refit")
2777 .unwrap(),
2778 kind: "mock_model".to_string(),
2779 controller_id: meta_plan.controller_id.clone(),
2780 backend: None,
2781 uri: None,
2782 content_fingerprint: None,
2783 size_bytes: Some(128),
2784 plugin: None,
2785 plugin_version: None,
2786 },
2787 params_fingerprint: meta_plan.params_fingerprint.clone(),
2788 data_requirement_keys: vec![
2789 "merge:stack.pred_plus_original.meta:ridge.x_original".to_string()
2790 ],
2791 prediction_requirement_keys: vec![prediction_key],
2792 };
2793
2794 assert!(build_execution_bundle(
2795 BundleId::new("bundle:missing.prediction.requirement").unwrap(),
2796 &plan,
2797 Some(plan.variants[0].variant_id.clone()),
2798 BTreeMap::new(),
2799 vec![artifact.clone()],
2800 )
2801 .is_err());
2802
2803 assert!(build_execution_bundle_with_prediction_requirements(
2804 BundleId::new("bundle:typed.prediction.requirement.without.cache").unwrap(),
2805 &plan,
2806 Some(plan.variants[0].variant_id.clone()),
2807 BTreeMap::new(),
2808 vec![artifact.clone()],
2809 vec![requirement.clone()],
2810 )
2811 .is_err());
2812
2813 let bundle = build_execution_bundle_with_prediction_contracts(
2814 BundleId::new("bundle:typed.prediction.requirement").unwrap(),
2815 &plan,
2816 Some(plan.variants[0].variant_id.clone()),
2817 BTreeMap::new(),
2818 vec![artifact],
2819 vec![requirement],
2820 vec![cache],
2821 )
2822 .unwrap();
2823 bundle.validate_against_plan(&plan).unwrap();
2824 assert_eq!(bundle.prediction_requirements.len(), 1);
2825 assert_eq!(bundle.prediction_caches.len(), 1);
2826 assert_eq!(
2827 bundle.refit_artifacts[0].prediction_requirement_keys,
2828 vec!["branch:b0.model:ridge.oof->merge:stack.pred_plus_original.meta:ridge.b0_oof"]
2829 );
2830 let payload_set = BundlePredictionCachePayloadSet {
2831 bundle_id: bundle.bundle_id.clone(),
2832 schema_version: PREDICTION_CACHE_PAYLOAD_SCHEMA_VERSION,
2833 caches: vec![payload],
2834 };
2835 payload_set.validate_against_bundle(&bundle).unwrap();
2836 let refit_replay_request = ReplayPhaseRequest {
2837 bundle_id: bundle.bundle_id.clone(),
2838 phase: Phase::Refit,
2839 data_envelope_keys: bundle
2840 .data_requirements
2841 .iter()
2842 .map(BundleDataRequirement::key)
2843 .collect(),
2844 };
2845 refit_replay_request
2846 .validate_for_bundle_with_prediction_cache_payloads(&bundle, Some(&payload_set))
2847 .unwrap();
2848 let mut tampered_payload_set = payload_set.clone();
2849 tampered_payload_set.caches[0].blocks[0].values[0][0] = 99.0;
2850 assert!(tampered_payload_set
2851 .validate_against_bundle(&bundle)
2852 .is_err());
2853 let mut missing_payload_set = payload_set.clone();
2854 missing_payload_set.caches.clear();
2855 assert!(missing_payload_set
2856 .validate_against_bundle(&bundle)
2857 .is_err());
2858 assert!(refit_replay_request.validate_for_bundle(&bundle).is_err());
2859
2860 let mut wrong_data_owner = bundle.clone();
2861 wrong_data_owner.refit_artifacts[0].data_requirement_keys =
2862 vec!["branch:b0.model:ridge.x".to_string()];
2863 assert!(wrong_data_owner.validate().is_err());
2864
2865 let mut wrong_prediction_consumer = bundle;
2866 wrong_prediction_consumer.refit_artifacts[0].node_id =
2867 NodeId::new("branch:b0.model:ridge").unwrap();
2868 wrong_prediction_consumer.refit_artifacts[0]
2869 .data_requirement_keys
2870 .clear();
2871 assert!(wrong_prediction_consumer.validate().is_err());
2872 }
2873
2874 #[test]
2875 fn aggregated_prediction_cache_contracts_preserve_unit_ids() {
2876 let plan = branch_merge_plan();
2877 let producer_node = NodeId::new("branch:b0.model:ridge").unwrap();
2878 let consumer_node = NodeId::new("merge:stack.pred_plus_original.meta:ridge").unwrap();
2879 let fold0 = FoldId::new("fold:0").unwrap();
2880 let fold1 = FoldId::new("fold:1").unwrap();
2881 let target_a = PredictionUnitId::Target(TargetId::new("target:a").unwrap());
2882 let target_b = PredictionUnitId::Target(TargetId::new("target:b").unwrap());
2883 let requirement = BundlePredictionRequirement {
2884 producer_node: producer_node.clone(),
2885 source_port: "oof".to_string(),
2886 consumer_node: consumer_node.clone(),
2887 target_port: "b0_oof".to_string(),
2888 partition: PredictionPartition::Validation,
2889 prediction_level: PredictionLevel::Target,
2890 fold_ids: vec![fold0.clone(), fold1.clone()],
2891 unit_ids: vec![target_a.clone(), target_b.clone()],
2892 sample_ids: Vec::new(),
2893 prediction_width: 1,
2894 target_names: vec!["y".to_string()],
2895 };
2896 let aggregated_blocks = vec![
2897 AggregatedPredictionBlock {
2898 prediction_id: Some("prediction:branch:b0.target.fold0".to_string()),
2899 producer_node: producer_node.clone(),
2900 partition: PredictionPartition::Validation,
2901 fold_id: Some(fold0),
2902 level: PredictionLevel::Target,
2903 unit_ids: vec![target_a],
2904 values: vec![vec![0.15]],
2905 target_names: vec!["y".to_string()],
2906 },
2907 AggregatedPredictionBlock {
2908 prediction_id: Some("prediction:branch:b0.target.fold1".to_string()),
2909 producer_node,
2910 partition: PredictionPartition::Validation,
2911 fold_id: Some(fold1),
2912 level: PredictionLevel::Target,
2913 unit_ids: vec![target_b],
2914 values: vec![vec![0.35]],
2915 target_names: vec!["y".to_string()],
2916 },
2917 ];
2918
2919 let cache =
2920 build_aggregated_prediction_cache_record(&requirement, &aggregated_blocks).unwrap();
2921 let payload =
2922 build_aggregated_prediction_cache_payload(&requirement, &aggregated_blocks).unwrap();
2923 assert_eq!(cache.prediction_level, PredictionLevel::Target);
2924 assert_eq!(cache.unit_ids, requirement.unit_ids);
2925 assert!(cache.sample_ids.is_empty());
2926 assert!(payload.blocks.is_empty());
2927 assert_eq!(payload.aggregated_blocks.len(), 2);
2928 validate_prediction_cache_payload_matches_record(&payload, &cache).unwrap();
2929
2930 let artifact = refit_artifact(
2931 &plan,
2932 "merge:stack.pred_plus_original.meta:ridge",
2933 vec!["merge:stack.pred_plus_original.meta:ridge.x_original".to_string()],
2934 vec![requirement.key()],
2935 );
2936 let bundle = build_execution_bundle_with_prediction_contracts(
2937 BundleId::new("bundle:target.prediction.requirement").unwrap(),
2938 &plan,
2939 Some(plan.variants[0].variant_id.clone()),
2940 BTreeMap::new(),
2941 vec![artifact],
2942 vec![requirement],
2943 vec![cache],
2944 )
2945 .unwrap();
2946 bundle.validate_against_plan(&plan).unwrap();
2947
2948 let mut tampered_payload = payload;
2949 tampered_payload.aggregated_blocks[0].unit_ids =
2950 vec![PredictionUnitId::Target(TargetId::new("target:z").unwrap())];
2951 assert!(validate_prediction_cache_payload_matches_record(
2952 &tampered_payload,
2953 &bundle.prediction_caches[0]
2954 )
2955 .is_err());
2956 }
2957
2958 #[test]
2959 fn replay_envelopes_must_match_bundle_requirements() {
2960 let plan = plan();
2961 let bundle = build_execution_bundle(
2962 BundleId::new("bundle:demo").unwrap(),
2963 &plan,
2964 None,
2965 BTreeMap::new(),
2966 Vec::new(),
2967 )
2968 .unwrap();
2969 let envelope: ExternalDataPlanEnvelope = serde_json::from_str(include_str!(
2970 "../../../examples/fixtures/data/coordinator_data_plan_envelope_sample12.json"
2971 ))
2972 .unwrap();
2973
2974 bundle
2975 .validate_replay_envelopes(&BTreeMap::from([(
2976 "model:base.x".to_string(),
2977 envelope.clone(),
2978 )]))
2979 .unwrap();
2980
2981 let mut mismatched = envelope;
2982 mismatched.schema_fingerprint = "0".repeat(64);
2983 assert!(bundle
2984 .validate_replay_envelopes(&BTreeMap::from([("model:base.x".to_string(), mismatched,)]))
2985 .is_err());
2986 }
2987
2988 #[test]
2989 fn rejects_unsupported_bundle_schema_version() {
2990 let mut bundle = build_execution_bundle(
2991 BundleId::new("bundle:demo").unwrap(),
2992 &plan(),
2993 None,
2994 BTreeMap::new(),
2995 Vec::new(),
2996 )
2997 .unwrap();
2998 bundle.schema_version = EXECUTION_BUNDLE_SCHEMA_VERSION + 1;
2999
3000 assert!(bundle.validate().is_err());
3001
3002 bundle.schema_version = 0;
3003 assert!(bundle.validate().is_err());
3004 }
3005
3006 #[test]
3007 fn schema_migration_policy_is_explicit_and_refuses_implicit_migrations() {
3008 let bundle_policy = execution_bundle_schema_migration_policy();
3009 assert_eq!(
3010 bundle_policy.current_version,
3011 EXECUTION_BUNDLE_SCHEMA_VERSION
3012 );
3013 assert_eq!(
3014 bundle_policy.min_readable_version,
3015 MIN_READABLE_EXECUTION_BUNDLE_SCHEMA_VERSION
3016 );
3017 assert!(bundle_policy.automatic_migrations.is_empty());
3018 bundle_policy
3019 .validate_read_version(EXECUTION_BUNDLE_SCHEMA_VERSION, "bundle `current`")
3020 .unwrap();
3021 assert!(bundle_policy
3022 .validate_read_version(EXECUTION_BUNDLE_SCHEMA_VERSION + 1, "bundle `future`")
3023 .is_err());
3024 assert!(bundle_policy
3025 .validate_read_version(0, "bundle `zero`")
3026 .is_err());
3027
3028 let mut future_policy = SchemaMigrationPolicy {
3029 artifact: "execution_bundle".to_string(),
3030 current_version: 2,
3031 min_readable_version: 1,
3032 min_writable_version: 2,
3033 automatic_migrations: BTreeMap::new(),
3034 };
3035 assert!(future_policy
3036 .validate_read_version(1, "bundle `old-without-migration`")
3037 .is_err());
3038 future_policy.automatic_migrations.insert(1, 2);
3039 future_policy
3040 .validate_read_version(1, "bundle `old-with-migration`")
3041 .unwrap();
3042 }
3043
3044 #[test]
3045 fn prediction_cache_payload_schema_policy_rejects_unsupported_versions() {
3046 let policy = prediction_cache_payload_schema_migration_policy();
3047 assert_eq!(
3048 policy.current_version,
3049 PREDICTION_CACHE_PAYLOAD_SCHEMA_VERSION
3050 );
3051 assert!(policy.automatic_migrations.is_empty());
3052
3053 let mut payload_set = BundlePredictionCachePayloadSet {
3054 bundle_id: BundleId::new("bundle:payload.schema").unwrap(),
3055 schema_version: PREDICTION_CACHE_PAYLOAD_SCHEMA_VERSION,
3056 caches: Vec::new(),
3057 };
3058 payload_set.validate().unwrap();
3059
3060 payload_set.schema_version = PREDICTION_CACHE_PAYLOAD_SCHEMA_VERSION + 1;
3061 assert!(payload_set.validate().is_err());
3062
3063 payload_set.schema_version = 0;
3064 assert!(payload_set.validate().is_err());
3065 }
3066
3067 #[test]
3068 fn replay_request_requires_predict_explain_or_refit_phase() {
3069 let bundle = build_execution_bundle(
3070 BundleId::new("bundle:demo").unwrap(),
3071 &plan(),
3072 None,
3073 BTreeMap::new(),
3074 Vec::new(),
3075 )
3076 .unwrap();
3077
3078 ReplayPhaseRequest {
3079 bundle_id: bundle.bundle_id.clone(),
3080 phase: Phase::Predict,
3081 data_envelope_keys: vec!["model:base.x".to_string()],
3082 }
3083 .validate_for_bundle(&bundle)
3084 .unwrap();
3085 ReplayPhaseRequest {
3086 bundle_id: bundle.bundle_id.clone(),
3087 phase: Phase::Refit,
3088 data_envelope_keys: vec!["model:base.x".to_string()],
3089 }
3090 .validate_for_bundle(&bundle)
3091 .unwrap();
3092 assert!(ReplayPhaseRequest {
3093 bundle_id: bundle.bundle_id.clone(),
3094 phase: Phase::FitCv,
3095 data_envelope_keys: vec!["model:base.x".to_string()],
3096 }
3097 .validate_for_bundle(&bundle)
3098 .is_err());
3099 assert!(ReplayPhaseRequest {
3100 bundle_id: bundle.bundle_id.clone(),
3101 phase: Phase::Predict,
3102 data_envelope_keys: vec!["model:base.x".to_string(), "model:base.x".to_string()],
3103 }
3104 .validate_for_bundle(&bundle)
3105 .is_err());
3106 assert!(ReplayPhaseRequest {
3107 bundle_id: bundle.bundle_id.clone(),
3108 phase: Phase::Predict,
3109 data_envelope_keys: vec!["model:base.y".to_string()],
3110 }
3111 .validate_for_bundle(&bundle)
3112 .is_err());
3113 }
3114}