1use std::cell::RefCell;
2use std::collections::{BTreeMap, BTreeSet};
3
4use serde::{Deserialize, Serialize};
5
6use crate::error::{DagMlError, Result};
7use crate::ids::{ControllerId, FoldId, NodeId, RunId, VariantId};
8use crate::phase::Phase;
9use crate::policy::FitInfluencePolicy;
10use crate::relation::{EntityUnitLevel, SampleRelationSet};
11use crate::runtime::{
12 DataMaterializationRequest, DataProviderViewSpec, DataViewRequest, HandleKind, HandleRef,
13 RuntimeDataProvider,
14};
15
16pub const EXTERNAL_DATA_PLAN_ENVELOPE_SCHEMA_VERSION: u32 = 1;
17pub const MODEL_INPUT_SPEC_SCHEMA_VERSION: u32 = 1;
18pub const MODEL_INPUT_SPEC_SCHEMA_ID: &str =
19 "https://github.com/GBeurier/dag-ml/schemas/model_input_spec.v1.schema.json";
20pub const DATA_PLAN_SCHEMA_VERSION: u32 = 1;
21pub const DATA_PLAN_SCHEMA_ID: &str =
22 "https://github.com/GBeurier/dag-ml/schemas/data_plan.v1.schema.json";
23
24fn default_external_data_plan_envelope_schema_version() -> u32 {
25 EXTERNAL_DATA_PLAN_ENVELOPE_SCHEMA_VERSION
26}
27
28fn default_model_input_spec_schema_version() -> u32 {
29 MODEL_INPUT_SPEC_SCHEMA_VERSION
30}
31
32fn default_data_plan_schema_version() -> u32 {
33 DATA_PLAN_SCHEMA_VERSION
34}
35
36#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
37#[serde(rename_all = "snake_case")]
38pub enum DataRequestPartition {
39 FoldTrain,
40 FoldValidation,
41 FullTrain,
42 Predict,
43}
44
45#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
46#[serde(rename_all = "snake_case")]
47pub enum ModelInputFusionMode {
48 SingleSource,
49 ConcatenateFeatures,
50 StackSamples,
51 DictBySource,
52 Custom,
53}
54
55#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
56#[serde(rename_all = "snake_case")]
57pub enum BranchViewMode {
58 Separation,
59 BySource,
60 ByMetadata,
61 ByTag,
62 ByFilter,
63}
64
65#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
66pub struct DataViewSelector {
67 #[serde(default, skip_serializing_if = "Vec::is_empty")]
68 pub source_ids: Vec<String>,
69 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
70 pub metadata: BTreeMap<String, serde_json::Value>,
71 #[serde(default, skip_serializing_if = "Vec::is_empty")]
72 pub tags: Vec<String>,
73 #[serde(default, skip_serializing_if = "Option::is_none")]
74 pub filter: Option<serde_json::Value>,
75}
76
77impl DataViewSelector {
78 pub fn validate(&self, label: &str) -> Result<()> {
79 if self.source_ids.is_empty()
80 && self.metadata.is_empty()
81 && self.tags.is_empty()
82 && self.filter.is_none()
83 {
84 return Err(DagMlError::CampaignValidation(format!(
85 "{label} selector must constrain source_ids, metadata, tags or filter"
86 )));
87 }
88 validate_string_list_entries(&format!("{label} selector source_ids"), &self.source_ids)?;
89 validate_unique_strings(&format!("{label} selector source_ids"), &self.source_ids)?;
90 validate_string_list_entries(&format!("{label} selector tags"), &self.tags)?;
91 validate_unique_strings(&format!("{label} selector tags"), &self.tags)?;
92 for key in self.metadata.keys() {
93 if key.trim().is_empty() {
94 return Err(DagMlError::CampaignValidation(format!(
95 "{label} selector contains an empty metadata key"
96 )));
97 }
98 }
99 if matches!(self.filter, Some(serde_json::Value::Null)) {
100 return Err(DagMlError::CampaignValidation(format!(
101 "{label} selector filter must not be null"
102 )));
103 }
104 Ok(())
105 }
106}
107
108#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
109pub struct BranchViewPlan {
110 pub view_id: String,
111 pub branch_id: String,
112 pub mode: BranchViewMode,
113 pub selector: DataViewSelector,
114 #[serde(default)]
115 pub allow_overlap: bool,
116 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
117 pub metadata: BTreeMap<String, serde_json::Value>,
118}
119
120impl BranchViewPlan {
121 pub fn validate(&self) -> Result<()> {
122 validate_non_empty("branch view plan view_id", &self.view_id)?;
123 validate_non_empty("branch view plan branch_id", &self.branch_id)?;
124 self.selector
125 .validate(&format!("branch view `{}`", self.view_id))?;
126 match self.mode {
127 BranchViewMode::BySource if self.selector.source_ids.is_empty() => {
128 return Err(DagMlError::CampaignValidation(format!(
129 "branch view `{}` mode=by_source requires source_ids",
130 self.view_id
131 )));
132 }
133 BranchViewMode::ByMetadata if self.selector.metadata.is_empty() => {
134 return Err(DagMlError::CampaignValidation(format!(
135 "branch view `{}` mode=by_metadata requires metadata",
136 self.view_id
137 )));
138 }
139 BranchViewMode::ByTag if self.selector.tags.is_empty() => {
140 return Err(DagMlError::CampaignValidation(format!(
141 "branch view `{}` mode=by_tag requires tags",
142 self.view_id
143 )));
144 }
145 BranchViewMode::ByFilter if self.selector.filter.is_none() => {
146 return Err(DagMlError::CampaignValidation(format!(
147 "branch view `{}` mode=by_filter requires filter",
148 self.view_id
149 )));
150 }
151 _ => {}
152 }
153 for key in self.metadata.keys() {
154 if key.trim().is_empty() {
155 return Err(DagMlError::CampaignValidation(format!(
156 "branch view `{}` metadata contains an empty key",
157 self.view_id
158 )));
159 }
160 }
161 Ok(())
162 }
163}
164
165#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
166#[serde(rename_all = "snake_case")]
167pub enum CombinationMode {
168 #[default]
169 Cartesian,
170 Zip,
171 MatchBy,
172 SampleK,
173 ReferenceBroadcast,
174}
175
176#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
177#[serde(rename_all = "snake_case")]
178pub enum RepresentationMissingSourcePolicy {
179 Strict,
180 Warn,
181 DropIncomplete,
182 ImputeDeclared,
183 Mask,
184 PartialModel,
185 Pad,
186}
187
188#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
189#[serde(rename_all = "snake_case")]
190pub enum RepresentationCardinality {
191 OneToOne,
192 OneToMany,
193 ManyToOne,
194 ManyToMany,
195 BoundedMany,
196}
197
198#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
199#[serde(deny_unknown_fields)]
200pub struct CombinationPlan {
201 pub mode: CombinationMode,
202 #[serde(default, skip_serializing_if = "Vec::is_empty")]
203 pub component_source_ids: Vec<String>,
204 #[serde(default, skip_serializing_if = "Vec::is_empty")]
205 pub component_unit_ids: Vec<String>,
206 #[serde(default, skip_serializing_if = "Option::is_none")]
207 pub match_key: Option<String>,
208 #[serde(default, skip_serializing_if = "Option::is_none")]
209 pub reference_source_id: Option<String>,
210 #[serde(default, skip_serializing_if = "Option::is_none")]
211 pub seed: Option<u64>,
212 #[serde(default, skip_serializing_if = "Option::is_none")]
213 pub cap: Option<usize>,
214 #[serde(default, skip_serializing_if = "Option::is_none")]
215 pub budget: Option<usize>,
216 #[serde(default, skip_serializing_if = "Option::is_none")]
217 pub missing_source_policy: Option<RepresentationMissingSourcePolicy>,
218 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
219 pub metadata: BTreeMap<String, serde_json::Value>,
220}
221
222impl CombinationPlan {
223 pub fn validate(&self) -> Result<()> {
224 validate_string_list_entries(
225 "combination plan component_source_ids",
226 &self.component_source_ids,
227 )?;
228 validate_unique_strings(
229 "combination plan component_source_ids",
230 &self.component_source_ids,
231 )?;
232 validate_string_list_entries(
233 "combination plan component_unit_ids",
234 &self.component_unit_ids,
235 )?;
236 validate_unique_strings(
237 "combination plan component_unit_ids",
238 &self.component_unit_ids,
239 )?;
240 validate_optional_non_empty("combination plan match_key", &self.match_key)?;
241 validate_optional_non_empty(
242 "combination plan reference_source_id",
243 &self.reference_source_id,
244 )?;
245 if self.cap == Some(0) {
246 return Err(DagMlError::CampaignValidation(
247 "combination plan cap must be positive when present".to_string(),
248 ));
249 }
250 if self.budget == Some(0) {
251 return Err(DagMlError::CampaignValidation(
252 "combination plan budget must be positive when present".to_string(),
253 ));
254 }
255 match self.mode {
256 CombinationMode::Cartesian => {
257 if self.component_source_ids.len() < 2 {
258 return Err(DagMlError::CampaignValidation(
259 "cartesian combination requires at least two component_source_ids"
260 .to_string(),
261 ));
262 }
263 }
264 CombinationMode::Zip => {
265 if self.component_source_ids.len() < 2 {
266 return Err(DagMlError::CampaignValidation(
267 "zip combination requires at least two component_source_ids".to_string(),
268 ));
269 }
270 }
271 CombinationMode::MatchBy => {
272 if self.match_key.is_none() {
273 return Err(DagMlError::CampaignValidation(
274 "match_by combination requires match_key".to_string(),
275 ));
276 }
277 }
278 CombinationMode::SampleK => {
279 if self.seed.is_none() {
280 return Err(DagMlError::CampaignValidation(
281 "sample_k combination requires seed".to_string(),
282 ));
283 }
284 if self.cap.is_none() {
285 return Err(DagMlError::CampaignValidation(
286 "sample_k combination requires cap".to_string(),
287 ));
288 }
289 }
290 CombinationMode::ReferenceBroadcast => {
291 let Some(reference) = &self.reference_source_id else {
292 return Err(DagMlError::CampaignValidation(
293 "reference_broadcast combination requires reference_source_id".to_string(),
294 ));
295 };
296 if !self.component_source_ids.is_empty()
297 && !self
298 .component_source_ids
299 .iter()
300 .any(|source| source == reference)
301 {
302 return Err(DagMlError::CampaignValidation(format!(
303 "reference_broadcast reference_source_id `{reference}` is not in component_source_ids"
304 )));
305 }
306 }
307 }
308 for key in self.metadata.keys() {
309 if key.trim().is_empty() {
310 return Err(DagMlError::CampaignValidation(
311 "combination plan metadata contains an empty key".to_string(),
312 ));
313 }
314 }
315 Ok(())
316 }
317}
318
319#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
320#[serde(tag = "kind", rename_all = "snake_case", deny_unknown_fields)]
321pub enum RepresentationPlan {
322 Aggregate(AggregateRepresentation),
323 CartesianProduct(CartesianProductRepresentation),
324 MonteCarloCartesian(MonteCarloCartesianRepresentation),
325 StackFixed(StackFixedRepresentation),
326 StackPaddedMasked(StackPaddedMaskedRepresentation),
327}
328
329impl RepresentationPlan {
330 pub fn validate(&self) -> Result<()> {
331 match self {
332 Self::Aggregate(plan) => plan.validate(),
333 Self::CartesianProduct(plan) => plan.validate(),
334 Self::MonteCarloCartesian(plan) => plan.validate(),
335 Self::StackFixed(plan) => plan.validate(),
336 Self::StackPaddedMasked(plan) => plan.validate(),
337 }
338 }
339
340 pub fn output_unit_level(&self) -> EntityUnitLevel {
341 match self {
342 Self::Aggregate(plan) => plan.output_unit_level,
343 Self::CartesianProduct(plan) => plan.output_unit_level,
344 Self::MonteCarloCartesian(plan) => plan.output_unit_level,
345 Self::StackFixed(plan) => plan.output_unit_level,
346 Self::StackPaddedMasked(plan) => plan.output_unit_level,
347 }
348 }
349}
350
351#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
352#[serde(deny_unknown_fields)]
353pub struct AggregateRepresentation {
354 pub input_unit_level: EntityUnitLevel,
355 pub output_unit_level: EntityUnitLevel,
356 #[serde(default, skip_serializing_if = "Option::is_none")]
357 pub reducer_id: Option<String>,
358 #[serde(default, skip_serializing_if = "Option::is_none")]
359 pub method: Option<String>,
360 pub cardinality: RepresentationCardinality,
361}
362
363impl AggregateRepresentation {
364 pub fn validate(&self) -> Result<()> {
365 validate_optional_non_empty("aggregate representation reducer_id", &self.reducer_id)?;
366 validate_optional_non_empty("aggregate representation method", &self.method)?;
367 if self.cardinality != RepresentationCardinality::ManyToOne {
368 return Err(DagMlError::CampaignValidation(
369 "aggregate representation cardinality must be many_to_one".to_string(),
370 ));
371 }
372 Ok(())
373 }
374}
375
376#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
377#[serde(deny_unknown_fields)]
378pub struct CartesianProductRepresentation {
379 pub combination_plan: CombinationPlan,
380 pub output_unit_level: EntityUnitLevel,
381 pub cardinality: RepresentationCardinality,
382 #[serde(default = "default_true")]
383 pub preserve_provenance: bool,
384}
385
386impl CartesianProductRepresentation {
387 pub fn validate(&self) -> Result<()> {
388 self.combination_plan.validate()?;
389 if self.combination_plan.mode != CombinationMode::Cartesian {
390 return Err(DagMlError::CampaignValidation(
391 "cartesian_product representation requires combination_plan.mode=cartesian"
392 .to_string(),
393 ));
394 }
395 validate_combo_like_output("cartesian_product", self.output_unit_level)?;
396 if self.cardinality != RepresentationCardinality::ManyToMany {
397 return Err(DagMlError::CampaignValidation(
398 "cartesian_product representation cardinality must be many_to_many".to_string(),
399 ));
400 }
401 Ok(())
402 }
403}
404
405#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
406#[serde(deny_unknown_fields)]
407pub struct MonteCarloCartesianRepresentation {
408 pub combination_plan: CombinationPlan,
409 pub output_unit_level: EntityUnitLevel,
410 pub cardinality: RepresentationCardinality,
411 #[serde(default = "default_true")]
412 pub preserve_provenance: bool,
413}
414
415impl MonteCarloCartesianRepresentation {
416 pub fn validate(&self) -> Result<()> {
417 self.combination_plan.validate()?;
418 if self.combination_plan.mode != CombinationMode::SampleK {
419 return Err(DagMlError::CampaignValidation(
420 "monte_carlo_cartesian representation requires combination_plan.mode=sample_k"
421 .to_string(),
422 ));
423 }
424 validate_combo_like_output("monte_carlo_cartesian", self.output_unit_level)?;
425 if self.cardinality != RepresentationCardinality::BoundedMany {
426 return Err(DagMlError::CampaignValidation(
427 "monte_carlo_cartesian representation cardinality must be bounded_many".to_string(),
428 ));
429 }
430 Ok(())
431 }
432}
433
434#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
435#[serde(deny_unknown_fields)]
436pub struct StackFixedRepresentation {
437 pub output_unit_level: EntityUnitLevel,
438 pub cardinality: RepresentationCardinality,
439 pub expected_cardinality: usize,
440 #[serde(default, skip_serializing_if = "Vec::is_empty")]
441 pub component_source_ids: Vec<String>,
442}
443
444impl StackFixedRepresentation {
445 pub fn validate(&self) -> Result<()> {
446 if self.expected_cardinality == 0 {
447 return Err(DagMlError::CampaignValidation(
448 "stack_fixed representation expected_cardinality must be positive".to_string(),
449 ));
450 }
451 if self.cardinality != RepresentationCardinality::OneToMany {
452 return Err(DagMlError::CampaignValidation(
453 "stack_fixed representation cardinality must be one_to_many".to_string(),
454 ));
455 }
456 validate_string_list_entries(
457 "stack_fixed representation component_source_ids",
458 &self.component_source_ids,
459 )?;
460 validate_unique_strings(
461 "stack_fixed representation component_source_ids",
462 &self.component_source_ids,
463 )
464 }
465}
466
467#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
468#[serde(deny_unknown_fields)]
469pub struct StackPaddedMaskedRepresentation {
470 pub output_unit_level: EntityUnitLevel,
471 pub cardinality: RepresentationCardinality,
472 pub expected_cardinality: usize,
473 pub missing_source_policy: RepresentationMissingSourcePolicy,
474 #[serde(default = "default_true")]
475 pub requires_missing_masks: bool,
476 #[serde(default, skip_serializing_if = "Vec::is_empty")]
477 pub component_source_ids: Vec<String>,
478}
479
480impl StackPaddedMaskedRepresentation {
481 pub fn validate(&self) -> Result<()> {
482 if self.expected_cardinality == 0 {
483 return Err(DagMlError::CampaignValidation(
484 "stack_padded_masked representation expected_cardinality must be positive"
485 .to_string(),
486 ));
487 }
488 if self.cardinality != RepresentationCardinality::BoundedMany {
489 return Err(DagMlError::CampaignValidation(
490 "stack_padded_masked representation cardinality must be bounded_many".to_string(),
491 ));
492 }
493 if !matches!(
494 self.missing_source_policy,
495 RepresentationMissingSourcePolicy::Mask | RepresentationMissingSourcePolicy::Pad
496 ) {
497 return Err(DagMlError::CampaignValidation(
498 "stack_padded_masked representation requires missing_source_policy=mask or pad"
499 .to_string(),
500 ));
501 }
502 if !self.requires_missing_masks {
503 return Err(DagMlError::CampaignValidation(
504 "stack_padded_masked representation requires missing-mask controller support"
505 .to_string(),
506 ));
507 }
508 validate_string_list_entries(
509 "stack_padded_masked representation component_source_ids",
510 &self.component_source_ids,
511 )?;
512 validate_unique_strings(
513 "stack_padded_masked representation component_source_ids",
514 &self.component_source_ids,
515 )
516 }
517}
518
519#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
520#[serde(deny_unknown_fields)]
521pub struct RepresentationSampleObservationMapping {
522 pub physical_sample_id: String,
523 pub source_id: String,
524 pub observation_ids: Vec<String>,
525}
526
527impl RepresentationSampleObservationMapping {
528 pub fn validate(&self) -> Result<()> {
529 validate_non_empty(
530 "representation sample observation mapping physical_sample_id",
531 &self.physical_sample_id,
532 )?;
533 validate_non_empty(
534 "representation sample observation mapping source_id",
535 &self.source_id,
536 )?;
537 validate_non_empty_list(
538 "representation sample observation mapping observation_ids",
539 &self.observation_ids,
540 )?;
541 validate_unique_strings(
542 "representation sample observation mapping observation_ids",
543 &self.observation_ids,
544 )
545 }
546}
547
548#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
549#[serde(deny_unknown_fields)]
550pub struct RepresentationComboSelectionRecord {
551 pub combo_unit_id: String,
552 pub physical_sample_id: String,
553 pub component_observation_ids: Vec<String>,
554 #[serde(default, skip_serializing_if = "Option::is_none")]
555 pub seed: Option<u64>,
556}
557
558impl RepresentationComboSelectionRecord {
559 pub fn validate(&self) -> Result<()> {
560 validate_non_empty(
561 "representation combo selection combo_unit_id",
562 &self.combo_unit_id,
563 )?;
564 validate_non_empty(
565 "representation combo selection physical_sample_id",
566 &self.physical_sample_id,
567 )?;
568 validate_non_empty_list(
569 "representation combo selection component_observation_ids",
570 &self.component_observation_ids,
571 )?;
572 validate_unique_strings(
573 "representation combo selection component_observation_ids",
574 &self.component_observation_ids,
575 )
576 }
577}
578
579#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
580#[serde(rename_all = "snake_case")]
581pub enum RepresentationCompatibilitySeverity {
582 Info,
583 Warning,
584 Error,
585}
586
587#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
588#[serde(rename_all = "snake_case")]
589pub enum RepresentationCompatibilityOutcome {
590 Compatible,
591 CompatibleWithFallback,
592 Incompatible,
593}
594
595#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
596#[serde(deny_unknown_fields)]
597pub struct RepresentationCompatibilityReport {
598 pub policy: RepresentationMissingSourcePolicy,
599 pub outcome: RepresentationCompatibilityOutcome,
600 #[serde(default, skip_serializing_if = "Option::is_none")]
601 pub fallback_used: Option<String>,
602 #[serde(default, skip_serializing_if = "Option::is_none")]
603 pub warning_severity: Option<RepresentationCompatibilitySeverity>,
604 #[serde(default)]
605 pub affected_source_count: u64,
606 #[serde(default)]
607 pub affected_repetition_count: u64,
608 #[serde(default)]
609 pub affected_sample_count: u64,
610 #[serde(default, skip_serializing_if = "Option::is_none")]
611 pub train_relation_fingerprint: Option<String>,
612 #[serde(default, skip_serializing_if = "Option::is_none")]
613 pub predict_relation_fingerprint: Option<String>,
614 #[serde(default, skip_serializing_if = "Option::is_none")]
615 pub train_unit_count: Option<u64>,
616 #[serde(default, skip_serializing_if = "Option::is_none")]
617 pub predict_unit_count: Option<u64>,
618 #[serde(default)]
619 pub fixed_width_required: bool,
620 #[serde(default)]
621 pub final_reducer_stabilizes_output: bool,
622 #[serde(default)]
623 pub cartesian_combo_count_changed: bool,
624 #[serde(default)]
625 pub late_fusion_branch_delta: bool,
626 #[serde(default, skip_serializing_if = "Vec::is_empty")]
627 pub messages: Vec<String>,
628 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
629 pub metadata: BTreeMap<String, serde_json::Value>,
630}
631
632impl RepresentationCompatibilityReport {
633 pub fn validate(&self) -> Result<()> {
634 validate_optional_non_empty(
635 "representation compatibility fallback_used",
636 &self.fallback_used,
637 )?;
638 if let Some(fingerprint) = &self.train_relation_fingerprint {
639 validate_fingerprint("representation compatibility train relation", fingerprint)?;
640 }
641 if let Some(fingerprint) = &self.predict_relation_fingerprint {
642 validate_fingerprint("representation compatibility predict relation", fingerprint)?;
643 }
644 validate_string_list_entries("representation compatibility messages", &self.messages)?;
645 for key in self.metadata.keys() {
646 if key.trim().is_empty() {
647 return Err(DagMlError::CampaignValidation(
648 "representation compatibility metadata contains an empty key".to_string(),
649 ));
650 }
651 }
652
653 let affected_total = self
654 .affected_source_count
655 .saturating_add(self.affected_repetition_count)
656 .saturating_add(self.affected_sample_count);
657 let relation_fingerprint_changed = matches!(
658 (
659 self.train_relation_fingerprint.as_deref(),
660 self.predict_relation_fingerprint.as_deref()
661 ),
662 (Some(train), Some(predict)) if train != predict
663 );
664 let unit_count_changed = matches!(
665 (self.train_unit_count, self.predict_unit_count),
666 (Some(train), Some(predict)) if train != predict
667 );
668 if affected_total == 0 {
669 if relation_fingerprint_changed {
670 return Err(DagMlError::CampaignValidation(
671 "representation compatibility relation fingerprint mismatch requires affected units"
672 .to_string(),
673 ));
674 }
675 if unit_count_changed {
676 return Err(DagMlError::CampaignValidation(
677 "representation compatibility unit count mismatch requires affected units"
678 .to_string(),
679 ));
680 }
681 if self.outcome == RepresentationCompatibilityOutcome::CompatibleWithFallback {
682 return Err(DagMlError::CampaignValidation(
683 "representation compatibility cannot use fallback when no units are affected"
684 .to_string(),
685 ));
686 }
687 if self.warning_severity.is_some() {
688 return Err(DagMlError::CampaignValidation(
689 "representation compatibility warning_severity requires affected units"
690 .to_string(),
691 ));
692 }
693 } else if self.policy == RepresentationMissingSourcePolicy::Strict {
694 if self.outcome != RepresentationCompatibilityOutcome::Incompatible {
695 return Err(DagMlError::CampaignValidation(
696 "strict representation compatibility with affected units must be incompatible"
697 .to_string(),
698 ));
699 }
700 if self.fallback_used.is_some() {
701 return Err(DagMlError::CampaignValidation(
702 "strict representation compatibility cannot declare fallback_used".to_string(),
703 ));
704 }
705 } else {
706 if self.warning_severity.is_none() {
707 return Err(DagMlError::CampaignValidation(
708 "non-strict representation compatibility with affected units requires warning_severity"
709 .to_string(),
710 ));
711 }
712 if self.outcome == RepresentationCompatibilityOutcome::Compatible {
713 return Err(DagMlError::CampaignValidation(
714 "representation compatibility with affected units cannot be compatible"
715 .to_string(),
716 ));
717 }
718 if self.outcome == RepresentationCompatibilityOutcome::CompatibleWithFallback
719 && self.fallback_used.is_none()
720 {
721 return Err(DagMlError::CampaignValidation(
722 "compatible_with_fallback representation compatibility requires fallback_used"
723 .to_string(),
724 ));
725 }
726 }
727
728 if self.outcome == RepresentationCompatibilityOutcome::Incompatible
729 && self.fallback_used.is_some()
730 {
731 return Err(DagMlError::CampaignValidation(
732 "incompatible representation compatibility cannot declare fallback_used"
733 .to_string(),
734 ));
735 }
736
737 if self.fixed_width_required && unit_count_changed && !self.allows_fixed_width_fallback() {
738 if self.outcome == RepresentationCompatibilityOutcome::Incompatible {
739 return Ok(());
740 }
741 return Err(DagMlError::CampaignValidation(
742 "fixed-width representation compatibility mismatch requires mask or pad policy/fallback"
743 .to_string(),
744 ));
745 }
746 if self.cartesian_combo_count_changed && !self.final_reducer_stabilizes_output {
747 if self.outcome == RepresentationCompatibilityOutcome::Incompatible {
748 return Ok(());
749 }
750 return Err(DagMlError::CampaignValidation(
751 "cartesian representation combo count may vary only when final reducer stabilizes output"
752 .to_string(),
753 ));
754 }
755 if self.late_fusion_branch_delta && !self.allows_late_fusion_delta() {
756 if self.outcome == RepresentationCompatibilityOutcome::Incompatible {
757 return Ok(());
758 }
759 return Err(DagMlError::CampaignValidation(
760 "late-fusion source deltas require an explicit drop/impute/mask/partial-model/pad policy or fallback"
761 .to_string(),
762 ));
763 }
764 Ok(())
765 }
766
767 fn allows_fixed_width_fallback(&self) -> bool {
768 matches!(
769 self.policy,
770 RepresentationMissingSourcePolicy::Mask | RepresentationMissingSourcePolicy::Pad
771 ) || self
772 .fallback_used
773 .as_deref()
774 .is_some_and(|fallback| matches!(fallback, "mask" | "pad"))
775 }
776
777 fn allows_late_fusion_delta(&self) -> bool {
778 matches!(
779 self.policy,
780 RepresentationMissingSourcePolicy::DropIncomplete
781 | RepresentationMissingSourcePolicy::ImputeDeclared
782 | RepresentationMissingSourcePolicy::Mask
783 | RepresentationMissingSourcePolicy::PartialModel
784 | RepresentationMissingSourcePolicy::Pad
785 ) || self.fallback_used.as_deref().is_some_and(|fallback| {
786 matches!(
787 fallback,
788 "drop_incomplete" | "impute_declared" | "mask" | "partial_model" | "pad"
789 )
790 })
791 }
792}
793
794#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
795#[serde(deny_unknown_fields)]
796pub struct RepresentationReplayManifest {
797 pub manifest_id: String,
798 pub representation_plan: RepresentationPlan,
799 #[serde(default, skip_serializing_if = "Option::is_none")]
800 pub combination_plan: Option<CombinationPlan>,
801 pub output_unit_level: EntityUnitLevel,
802 #[serde(default, skip_serializing_if = "Option::is_none")]
803 pub output_representation: Option<String>,
804 #[serde(default, skip_serializing_if = "Option::is_none")]
805 pub relation_fingerprint: Option<String>,
806 #[serde(default, skip_serializing_if = "Option::is_none")]
807 pub feature_schema_fingerprint: Option<String>,
808 #[serde(default, skip_serializing_if = "Option::is_none")]
809 pub final_reduction_id: Option<String>,
810 #[serde(default, skip_serializing_if = "Vec::is_empty")]
811 pub sample_observation_mapping: Vec<RepresentationSampleObservationMapping>,
812 #[serde(default, skip_serializing_if = "Vec::is_empty")]
813 pub combo_selection: Vec<RepresentationComboSelectionRecord>,
814 #[serde(default, skip_serializing_if = "Vec::is_empty")]
815 pub qc_policy_refs: Vec<String>,
816 #[serde(default, skip_serializing_if = "Vec::is_empty")]
817 pub outlier_policy_refs: Vec<String>,
818 #[serde(default, skip_serializing_if = "Option::is_none")]
819 pub missing_source_policy: Option<RepresentationMissingSourcePolicy>,
820 #[serde(default, skip_serializing_if = "Option::is_none")]
821 pub missing_repetition_policy: Option<RepresentationMissingSourcePolicy>,
822 #[serde(default, skip_serializing_if = "Option::is_none")]
823 pub prediction_representation: Option<String>,
824 #[serde(default, skip_serializing_if = "Option::is_none")]
825 pub final_output_unit_level: Option<EntityUnitLevel>,
826 #[serde(default, skip_serializing_if = "Option::is_none")]
827 pub train_compatibility: Option<RepresentationCompatibilityReport>,
828 #[serde(default, skip_serializing_if = "Option::is_none")]
829 pub predict_compatibility: Option<RepresentationCompatibilityReport>,
830 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
831 pub metadata: BTreeMap<String, serde_json::Value>,
832}
833
834impl RepresentationReplayManifest {
835 pub fn validate(&self) -> Result<()> {
836 validate_non_empty("representation replay manifest_id", &self.manifest_id)?;
837 self.representation_plan.validate()?;
838 if let Some(combination_plan) = &self.combination_plan {
839 combination_plan.validate()?;
840 }
841 if self.output_unit_level != self.representation_plan.output_unit_level() {
842 return Err(DagMlError::CampaignValidation(
843 "representation replay output_unit_level must match representation_plan"
844 .to_string(),
845 ));
846 }
847 validate_optional_non_empty(
848 "representation replay output_representation",
849 &self.output_representation,
850 )?;
851 validate_optional_non_empty(
852 "representation replay final_reduction_id",
853 &self.final_reduction_id,
854 )?;
855 validate_string_list_entries("representation replay qc_policy_refs", &self.qc_policy_refs)?;
856 validate_unique_strings("representation replay qc_policy_refs", &self.qc_policy_refs)?;
857 validate_string_list_entries(
858 "representation replay outlier_policy_refs",
859 &self.outlier_policy_refs,
860 )?;
861 validate_unique_strings(
862 "representation replay outlier_policy_refs",
863 &self.outlier_policy_refs,
864 )?;
865 validate_optional_non_empty(
866 "representation replay prediction_representation",
867 &self.prediction_representation,
868 )?;
869 let mut sample_source_pairs = BTreeSet::new();
870 for mapping in &self.sample_observation_mapping {
871 mapping.validate()?;
872 if !sample_source_pairs.insert((
873 mapping.physical_sample_id.as_str(),
874 mapping.source_id.as_str(),
875 )) {
876 return Err(DagMlError::CampaignValidation(format!(
877 "representation replay sample_observation_mapping contains duplicate physical_sample_id/source_id `{}`/`{}`",
878 mapping.physical_sample_id, mapping.source_id
879 )));
880 }
881 }
882 let mut combo_unit_ids = BTreeSet::new();
883 for record in &self.combo_selection {
884 record.validate()?;
885 if !combo_unit_ids.insert(record.combo_unit_id.as_str()) {
886 return Err(DagMlError::CampaignValidation(format!(
887 "representation replay combo_selection contains duplicate combo_unit_id `{}`",
888 record.combo_unit_id
889 )));
890 }
891 }
892 if let Some(report) = &self.train_compatibility {
893 report.validate()?;
894 }
895 if let Some(report) = &self.predict_compatibility {
896 report.validate()?;
897 }
898 if let Some(fingerprint) = &self.relation_fingerprint {
899 validate_fingerprint("representation replay relation", fingerprint)?;
900 }
901 if let Some(fingerprint) = &self.feature_schema_fingerprint {
902 validate_fingerprint("representation replay feature schema", fingerprint)?;
903 }
904 for key in self.metadata.keys() {
905 if key.trim().is_empty() {
906 return Err(DagMlError::CampaignValidation(
907 "representation replay metadata contains an empty key".to_string(),
908 ));
909 }
910 }
911 Ok(())
912 }
913}
914
915#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
916#[serde(deny_unknown_fields)]
917pub struct ModelInputFusionPolicy {
918 pub mode: ModelInputFusionMode,
919 #[serde(default)]
920 pub alignment: Option<String>,
921 #[serde(default)]
922 pub adapter_id: Option<String>,
923 #[serde(default, skip_serializing_if = "Option::is_none")]
924 pub representation_plan: Option<RepresentationPlan>,
925 #[serde(default)]
926 pub params: BTreeMap<String, serde_json::Value>,
927}
928
929impl ModelInputFusionPolicy {
930 pub fn validate(&self) -> Result<()> {
931 if self
932 .alignment
933 .as_ref()
934 .is_some_and(|alignment| alignment.trim().is_empty())
935 {
936 return Err(DagMlError::CampaignValidation(
937 "model input fusion policy has empty alignment".to_string(),
938 ));
939 }
940 if self
941 .adapter_id
942 .as_ref()
943 .is_some_and(|adapter_id| adapter_id.trim().is_empty())
944 {
945 return Err(DagMlError::CampaignValidation(
946 "model input fusion policy has empty adapter_id".to_string(),
947 ));
948 }
949 if self.mode == ModelInputFusionMode::Custom && self.adapter_id.is_none() {
950 return Err(DagMlError::CampaignValidation(
951 "custom model input fusion policy requires adapter_id".to_string(),
952 ));
953 }
954 if let Some(representation_plan) = &self.representation_plan {
955 representation_plan.validate()?;
956 }
957 for key in self.params.keys() {
958 if key.trim().is_empty() {
959 return Err(DagMlError::CampaignValidation(
960 "model input fusion policy contains an empty param key".to_string(),
961 ));
962 }
963 }
964 Ok(())
965 }
966}
967
968#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
969#[serde(deny_unknown_fields)]
970pub struct ModelInputPortSpec {
971 pub name: String,
972 pub accepted_representations: Vec<String>,
973 pub accepted_types: Vec<String>,
974 #[serde(default)]
975 pub rank: Option<u32>,
976 #[serde(default)]
977 pub multi_source: bool,
978 #[serde(default)]
979 pub optional: bool,
980 #[serde(default)]
981 pub metadata: BTreeMap<String, serde_json::Value>,
982}
983
984impl ModelInputPortSpec {
985 pub fn validate(&self) -> Result<()> {
986 validate_non_empty("model input port name", &self.name)?;
987 validate_non_empty_list(
988 "model input port accepted_representations",
989 &self.accepted_representations,
990 )?;
991 validate_non_empty_list("model input port accepted_types", &self.accepted_types)?;
992 validate_unique_strings(
993 "model input port accepted_representations",
994 &self.accepted_representations,
995 )?;
996 validate_unique_strings("model input port accepted_types", &self.accepted_types)?;
997 if self.rank.is_some_and(|rank| rank > 16) {
998 return Err(DagMlError::CampaignValidation(format!(
999 "model input port `{}` rank must be <= 16",
1000 self.name
1001 )));
1002 }
1003 for key in self.metadata.keys() {
1004 if key.trim().is_empty() {
1005 return Err(DagMlError::CampaignValidation(format!(
1006 "model input port `{}` contains an empty metadata key",
1007 self.name
1008 )));
1009 }
1010 }
1011 Ok(())
1012 }
1013}
1014
1015#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
1016#[serde(deny_unknown_fields)]
1017pub struct ModelInputSpec {
1018 #[serde(default = "default_model_input_spec_schema_version")]
1019 pub schema_version: u32,
1020 pub ports: Vec<ModelInputPortSpec>,
1021 #[serde(default)]
1022 pub default_fusion: Option<ModelInputFusionPolicy>,
1023 #[serde(default, skip_serializing_if = "Option::is_none")]
1024 pub fit_influence_policy: Option<FitInfluencePolicy>,
1025 #[serde(default)]
1026 pub metadata: BTreeMap<String, serde_json::Value>,
1027}
1028
1029impl ModelInputSpec {
1030 pub fn validate(&self) -> Result<()> {
1031 if self.schema_version != MODEL_INPUT_SPEC_SCHEMA_VERSION {
1032 return Err(DagMlError::CampaignValidation(format!(
1033 "model input spec uses unsupported schema_version {}, expected {}",
1034 self.schema_version, MODEL_INPUT_SPEC_SCHEMA_VERSION
1035 )));
1036 }
1037 if self.ports.is_empty() {
1038 return Err(DagMlError::CampaignValidation(
1039 "model input spec must declare at least one port".to_string(),
1040 ));
1041 }
1042 let mut names = BTreeSet::new();
1043 for port in &self.ports {
1044 port.validate()?;
1045 if !names.insert(port.name.as_str()) {
1046 return Err(DagMlError::CampaignValidation(format!(
1047 "model input spec contains duplicate port `{}`",
1048 port.name
1049 )));
1050 }
1051 }
1052 if let Some(default_fusion) = &self.default_fusion {
1053 default_fusion.validate()?;
1054 }
1055 for key in self.metadata.keys() {
1056 if key.trim().is_empty() {
1057 return Err(DagMlError::CampaignValidation(
1058 "model input spec contains an empty metadata key".to_string(),
1059 ));
1060 }
1061 }
1062 Ok(())
1063 }
1064}
1065
1066#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
1067#[serde(rename_all = "snake_case")]
1068pub enum DataPlanStepKind {
1069 Materialize,
1070 Adapt,
1071 Align,
1072 Join,
1073 Collate,
1074}
1075
1076#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
1077#[serde(deny_unknown_fields)]
1078pub struct DataPlanStep {
1079 pub kind: DataPlanStepKind,
1080 #[serde(default)]
1081 pub inputs: Vec<String>,
1082 pub output: String,
1083 #[serde(default)]
1084 pub adapter_id: Option<String>,
1085 #[serde(default)]
1086 pub params: BTreeMap<String, serde_json::Value>,
1087}
1088
1089impl DataPlanStep {
1090 pub fn validate(&self, previous_outputs: &BTreeSet<String>) -> Result<()> {
1091 validate_non_empty("data plan step output", &self.output)?;
1092 if self.kind != DataPlanStepKind::Materialize && self.inputs.is_empty() {
1093 return Err(DagMlError::CampaignValidation(format!(
1094 "data plan step `{}` requires at least one input",
1095 self.output
1096 )));
1097 }
1098 for (index, input) in self.inputs.iter().enumerate() {
1099 validate_non_empty("data plan step input", input)?;
1100 if self.kind != DataPlanStepKind::Materialize && !previous_outputs.contains(input) {
1101 return Err(DagMlError::CampaignValidation(format!(
1102 "data plan step `{}` input #{index} references `{input}` before it is produced",
1103 self.output
1104 )));
1105 }
1106 }
1107 if self
1108 .adapter_id
1109 .as_ref()
1110 .is_some_and(|adapter_id| adapter_id.trim().is_empty())
1111 {
1112 return Err(DagMlError::CampaignValidation(format!(
1113 "data plan step `{}` has empty adapter_id",
1114 self.output
1115 )));
1116 }
1117 for key in self.params.keys() {
1118 if key.trim().is_empty() {
1119 return Err(DagMlError::CampaignValidation(format!(
1120 "data plan step `{}` contains an empty param key",
1121 self.output
1122 )));
1123 }
1124 }
1125 Ok(())
1126 }
1127}
1128
1129#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
1130#[serde(deny_unknown_fields)]
1131pub struct DataPlan {
1132 #[serde(default = "default_data_plan_schema_version")]
1133 pub schema_version: u32,
1134 pub id: String,
1135 pub steps: Vec<DataPlanStep>,
1136 pub output_ports: BTreeMap<String, String>,
1137 #[serde(default)]
1138 pub warnings: Vec<String>,
1139 #[serde(default)]
1140 pub requires_user_choice: Vec<String>,
1141 #[serde(default)]
1142 pub metadata: BTreeMap<String, serde_json::Value>,
1143}
1144
1145impl DataPlan {
1146 pub fn validate(&self) -> Result<()> {
1147 if self.schema_version != DATA_PLAN_SCHEMA_VERSION {
1148 return Err(DagMlError::CampaignValidation(format!(
1149 "data plan uses unsupported schema_version {}, expected {}",
1150 self.schema_version, DATA_PLAN_SCHEMA_VERSION
1151 )));
1152 }
1153 validate_non_empty("data plan id", &self.id)?;
1154 if self.steps.is_empty() {
1155 return Err(DagMlError::CampaignValidation(format!(
1156 "data plan `{}` must contain at least one step",
1157 self.id
1158 )));
1159 }
1160 let mut outputs = BTreeSet::new();
1161 for step in &self.steps {
1162 step.validate(&outputs)?;
1163 if !outputs.insert(step.output.clone()) {
1164 return Err(DagMlError::CampaignValidation(format!(
1165 "data plan `{}` contains duplicate step output `{}`",
1166 self.id, step.output
1167 )));
1168 }
1169 }
1170 if self.output_ports.is_empty() {
1171 return Err(DagMlError::CampaignValidation(format!(
1172 "data plan `{}` must declare at least one output port",
1173 self.id
1174 )));
1175 }
1176 for (port_name, output) in &self.output_ports {
1177 validate_non_empty("data plan output port", port_name)?;
1178 validate_non_empty("data plan output reference", output)?;
1179 if !outputs.contains(output) {
1180 return Err(DagMlError::CampaignValidation(format!(
1181 "data plan `{}` output port `{port_name}` references unknown output `{output}`",
1182 self.id
1183 )));
1184 }
1185 }
1186 validate_string_list_entries("data plan warnings", &self.warnings)?;
1187 validate_string_list_entries("data plan requires_user_choice", &self.requires_user_choice)?;
1188 for key in self.metadata.keys() {
1189 if key.trim().is_empty() {
1190 return Err(DagMlError::CampaignValidation(format!(
1191 "data plan `{}` contains an empty metadata key",
1192 self.id
1193 )));
1194 }
1195 }
1196 Ok(())
1197 }
1198}
1199
1200#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
1201pub struct DataViewPolicy {
1202 #[serde(default = "default_fit_partition")]
1203 pub fit_partition: DataRequestPartition,
1204 #[serde(default = "default_predict_partition")]
1205 pub predict_partition: DataRequestPartition,
1206 #[serde(default)]
1207 pub include_augmented_train: bool,
1208 #[serde(default)]
1209 pub include_augmented_validation: bool,
1210 #[serde(default)]
1211 pub include_excluded: bool,
1212 #[serde(default = "default_true")]
1213 pub require_sample_ids: bool,
1214 #[serde(default, skip_serializing_if = "BTreeSet::is_empty")]
1215 pub unsafe_flags: BTreeSet<String>,
1216}
1217
1218impl Default for DataViewPolicy {
1219 fn default() -> Self {
1220 Self {
1221 fit_partition: DataRequestPartition::FoldTrain,
1222 predict_partition: DataRequestPartition::FoldValidation,
1223 include_augmented_train: true,
1224 include_augmented_validation: false,
1225 include_excluded: false,
1226 require_sample_ids: true,
1227 unsafe_flags: BTreeSet::new(),
1228 }
1229 }
1230}
1231
1232impl DataViewPolicy {
1233 pub const ALLOW_FIT_CV_FULL_TRAIN_VIEW: &'static str = "allow_fit_cv_full_train_view";
1234 pub const ALLOW_FIT_CV_VALIDATION_VIEW: &'static str = "allow_fit_cv_validation_view";
1235 pub const ALLOW_AUGMENTED_VALIDATION_VIEW: &'static str = "allow_augmented_validation_view";
1236 pub const ALLOW_EXCLUDED_ROWS: &'static str = "allow_excluded_rows";
1237
1238 pub fn validate(&self) -> Result<()> {
1239 for unsafe_flag in &self.unsafe_flags {
1240 if unsafe_flag.trim().is_empty() {
1241 return Err(DagMlError::CampaignValidation(
1242 "data view policy contains an empty unsafe flag".to_string(),
1243 ));
1244 }
1245 }
1246 match self.fit_partition {
1247 DataRequestPartition::FoldTrain => {}
1248 DataRequestPartition::FullTrain
1249 if self
1250 .unsafe_flags
1251 .contains(Self::ALLOW_FIT_CV_FULL_TRAIN_VIEW) => {}
1252 DataRequestPartition::FoldValidation
1253 if self
1254 .unsafe_flags
1255 .contains(Self::ALLOW_FIT_CV_VALIDATION_VIEW) => {}
1256 DataRequestPartition::FullTrain => {
1257 return Err(DagMlError::CampaignValidation(
1258 "data view policy fit_partition=full_train would leak validation rows during FIT_CV; add explicit unsafe flag allow_fit_cv_full_train_view".to_string(),
1259 ));
1260 }
1261 DataRequestPartition::FoldValidation => {
1262 return Err(DagMlError::CampaignValidation(
1263 "data view policy fit_partition=fold_validation would train on validation rows during FIT_CV; add explicit unsafe flag allow_fit_cv_validation_view".to_string(),
1264 ));
1265 }
1266 DataRequestPartition::Predict => {
1267 return Err(DagMlError::CampaignValidation(
1268 "data view policy fit_partition=predict is not valid for FIT_CV".to_string(),
1269 ));
1270 }
1271 }
1272 match self.predict_partition {
1273 DataRequestPartition::FoldValidation | DataRequestPartition::Predict => {}
1274 DataRequestPartition::FoldTrain | DataRequestPartition::FullTrain => {
1275 return Err(DagMlError::CampaignValidation(format!(
1276 "data view policy predict_partition={:?} is not valid for validation/predict views",
1277 self.predict_partition
1278 )));
1279 }
1280 }
1281 if self.include_augmented_validation
1282 && !self
1283 .unsafe_flags
1284 .contains(Self::ALLOW_AUGMENTED_VALIDATION_VIEW)
1285 {
1286 return Err(DagMlError::CampaignValidation(
1287 "data view policy include_augmented_validation=true can leak augmented validation/test rows; add explicit unsafe flag allow_augmented_validation_view".to_string(),
1288 ));
1289 }
1290 if self.include_excluded && !self.unsafe_flags.contains(Self::ALLOW_EXCLUDED_ROWS) {
1291 return Err(DagMlError::CampaignValidation(
1292 "data view policy include_excluded=true requires explicit unsafe flag allow_excluded_rows".to_string(),
1293 ));
1294 }
1295 Ok(())
1296 }
1297}
1298
1299fn default_fit_partition() -> DataRequestPartition {
1300 DataRequestPartition::FoldTrain
1301}
1302
1303fn default_predict_partition() -> DataRequestPartition {
1304 DataRequestPartition::FoldValidation
1305}
1306
1307fn default_true() -> bool {
1308 true
1309}
1310
1311#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
1312pub struct DataBinding {
1313 pub node_id: NodeId,
1314 pub input_name: String,
1315 pub request_id: String,
1316 pub schema_fingerprint: String,
1317 pub plan_fingerprint: String,
1318 #[serde(default)]
1319 pub relation_fingerprint: Option<String>,
1320 pub output_representation: String,
1321 #[serde(default)]
1322 pub feature_set_id: Option<String>,
1323 #[serde(default)]
1324 pub source_ids: Vec<String>,
1325 #[serde(default)]
1326 pub require_relations: bool,
1327 #[serde(default)]
1328 pub view_policy: DataViewPolicy,
1329 #[serde(default)]
1330 pub metadata: BTreeMap<String, serde_json::Value>,
1331}
1332
1333impl DataBinding {
1334 pub fn validate(&self) -> Result<()> {
1335 self.view_policy.validate()?;
1336 if self.input_name.trim().is_empty() {
1337 return Err(DagMlError::CampaignValidation(format!(
1338 "data binding for `{}` has empty input_name",
1339 self.node_id
1340 )));
1341 }
1342 if self.request_id.trim().is_empty() {
1343 return Err(DagMlError::CampaignValidation(format!(
1344 "data binding `{}` on `{}` has empty request_id",
1345 self.input_name, self.node_id
1346 )));
1347 }
1348 validate_fingerprint("schema", &self.schema_fingerprint)?;
1349 validate_fingerprint("plan", &self.plan_fingerprint)?;
1350 if let Some(relation_fingerprint) = &self.relation_fingerprint {
1351 validate_fingerprint("relation", relation_fingerprint)?;
1352 } else if self.require_relations {
1353 return Err(DagMlError::CampaignValidation(format!(
1354 "data binding `{}` on `{}` requires relations but has no relation_fingerprint",
1355 self.input_name, self.node_id
1356 )));
1357 }
1358 if self.output_representation.trim().is_empty() {
1359 return Err(DagMlError::CampaignValidation(format!(
1360 "data binding `{}` on `{}` has empty output_representation",
1361 self.input_name, self.node_id
1362 )));
1363 }
1364 if let Some(feature_set_id) = &self.feature_set_id {
1365 if feature_set_id.trim().is_empty() {
1366 return Err(DagMlError::CampaignValidation(format!(
1367 "data binding `{}` on `{}` has empty feature_set_id",
1368 self.input_name, self.node_id
1369 )));
1370 }
1371 }
1372 for source_id in &self.source_ids {
1373 if source_id.trim().is_empty() {
1374 return Err(DagMlError::CampaignValidation(format!(
1375 "data binding `{}` on `{}` has empty source id",
1376 self.input_name, self.node_id
1377 )));
1378 }
1379 }
1380 Ok(())
1381 }
1382
1383 pub fn feature_set_id(&self) -> &str {
1384 self.feature_set_id.as_deref().unwrap_or(&self.input_name)
1385 }
1386
1387 pub fn validate_envelope(&self, envelope: &ExternalDataPlanEnvelope) -> Result<()> {
1388 self.validate()?;
1389 envelope.validate()?;
1390 if self.schema_fingerprint != envelope.schema_fingerprint {
1391 return Err(DagMlError::CampaignValidation(format!(
1392 "data binding `{}` on `{}` schema fingerprint mismatch",
1393 self.input_name, self.node_id
1394 )));
1395 }
1396 if self.plan_fingerprint != envelope.plan_fingerprint {
1397 return Err(DagMlError::CampaignValidation(format!(
1398 "data binding `{}` on `{}` plan fingerprint mismatch",
1399 self.input_name, self.node_id
1400 )));
1401 }
1402 if self.relation_fingerprint != envelope.relation_fingerprint {
1403 return Err(DagMlError::CampaignValidation(format!(
1404 "data binding `{}` on `{}` relation fingerprint mismatch",
1405 self.input_name, self.node_id
1406 )));
1407 }
1408 if self.require_relations && envelope.coordinator_relations.is_none() {
1409 return Err(DagMlError::CampaignValidation(format!(
1410 "data binding `{}` on `{}` requires coordinator relations",
1411 self.input_name, self.node_id
1412 )));
1413 }
1414 Ok(())
1415 }
1416}
1417
1418#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
1419pub struct ExternalDataPlanEnvelope {
1420 #[serde(default = "default_external_data_plan_envelope_schema_version")]
1421 pub schema_version: u32,
1422 pub schema_fingerprint: String,
1423 pub plan_fingerprint: String,
1424 #[serde(default)]
1425 pub relation_fingerprint: Option<String>,
1426 #[serde(default)]
1427 pub coordinator_relations: Option<SampleRelationSet>,
1428}
1429
1430impl ExternalDataPlanEnvelope {
1431 pub fn validate(&self) -> Result<()> {
1432 if self.schema_version != EXTERNAL_DATA_PLAN_ENVELOPE_SCHEMA_VERSION {
1433 return Err(DagMlError::CampaignValidation(format!(
1434 "external data-plan envelope uses unsupported schema_version {}, expected {}",
1435 self.schema_version, EXTERNAL_DATA_PLAN_ENVELOPE_SCHEMA_VERSION
1436 )));
1437 }
1438 validate_fingerprint("schema", &self.schema_fingerprint)?;
1439 validate_fingerprint("plan", &self.plan_fingerprint)?;
1440 if let Some(relation_fingerprint) = &self.relation_fingerprint {
1441 validate_fingerprint("relation", relation_fingerprint)?;
1442 if self.coordinator_relations.is_none() {
1443 return Err(DagMlError::CampaignValidation(
1444 "relation_fingerprint requires coordinator_relations".to_string(),
1445 ));
1446 }
1447 }
1448 if let Some(relations) = &self.coordinator_relations {
1449 relations.validate()?;
1450 }
1451 Ok(())
1452 }
1453}
1454
1455pub fn validate_data_binding_envelope(
1456 binding: &DataBinding,
1457 envelope: &ExternalDataPlanEnvelope,
1458) -> Result<()> {
1459 binding.validate_envelope(envelope)
1460}
1461
1462#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
1463struct DataEnvelopeKey {
1464 schema_fingerprint: String,
1465 plan_fingerprint: String,
1466 relation_fingerprint: Option<String>,
1467}
1468
1469impl DataEnvelopeKey {
1470 fn from_binding(binding: &DataBinding) -> Self {
1471 Self {
1472 schema_fingerprint: binding.schema_fingerprint.clone(),
1473 plan_fingerprint: binding.plan_fingerprint.clone(),
1474 relation_fingerprint: binding.relation_fingerprint.clone(),
1475 }
1476 }
1477
1478 fn from_envelope(envelope: &ExternalDataPlanEnvelope) -> Self {
1479 Self {
1480 schema_fingerprint: envelope.schema_fingerprint.clone(),
1481 plan_fingerprint: envelope.plan_fingerprint.clone(),
1482 relation_fingerprint: envelope.relation_fingerprint.clone(),
1483 }
1484 }
1485}
1486
1487#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
1488pub struct DataHandleRecord {
1489 pub handle: HandleRef,
1490 pub run_id: RunId,
1491 pub node_id: NodeId,
1492 pub input_name: String,
1493 pub phase: Phase,
1494 pub variant_id: Option<VariantId>,
1495 pub fold_id: Option<FoldId>,
1496 pub request_id: String,
1497 pub schema_fingerprint: String,
1498 pub plan_fingerprint: String,
1499 pub relation_fingerprint: Option<String>,
1500 pub output_representation: String,
1501 #[serde(default)]
1502 pub feature_set_id: Option<String>,
1503 #[serde(default)]
1504 pub source_ids: Vec<String>,
1505 pub relation_record_count: Option<usize>,
1506}
1507
1508#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
1509pub struct DataViewHandleRecord {
1510 pub handle: HandleRef,
1511 pub parent_handle: HandleRef,
1512 pub run_id: RunId,
1513 pub node_id: NodeId,
1514 pub input_name: String,
1515 pub phase: Phase,
1516 pub variant_id: Option<VariantId>,
1517 pub fold_id: Option<FoldId>,
1518 pub request_id: String,
1519 pub feature_set_id: String,
1520 pub view: DataProviderViewSpec,
1521}
1522
1523#[derive(Debug)]
1524pub struct InMemoryDataProvider {
1525 owner_controller: ControllerId,
1526 envelopes: BTreeMap<DataEnvelopeKey, ExternalDataPlanEnvelope>,
1527 next_handle: RefCell<u64>,
1528 records: RefCell<BTreeMap<u64, DataHandleRecord>>,
1529 view_records: RefCell<BTreeMap<u64, DataViewHandleRecord>>,
1530}
1531
1532impl InMemoryDataProvider {
1533 pub fn new(owner_controller: ControllerId) -> Self {
1534 Self {
1535 owner_controller,
1536 envelopes: BTreeMap::new(),
1537 next_handle: RefCell::new(1),
1538 records: RefCell::new(BTreeMap::new()),
1539 view_records: RefCell::new(BTreeMap::new()),
1540 }
1541 }
1542
1543 pub fn with_envelope(
1544 owner_controller: ControllerId,
1545 envelope: ExternalDataPlanEnvelope,
1546 ) -> Result<Self> {
1547 let mut provider = Self::new(owner_controller);
1548 provider.register_envelope(envelope)?;
1549 Ok(provider)
1550 }
1551
1552 pub fn register_envelope(&mut self, envelope: ExternalDataPlanEnvelope) -> Result<()> {
1553 envelope.validate()?;
1554 let key = DataEnvelopeKey::from_envelope(&envelope);
1555 if let Some(existing) = self.envelopes.get(&key) {
1556 if existing == &envelope {
1557 return Ok(());
1558 }
1559 return Err(DagMlError::RuntimeValidation(
1560 "duplicate external data-plan envelope with different payload".to_string(),
1561 ));
1562 }
1563 self.envelopes.insert(key, envelope);
1564 Ok(())
1565 }
1566
1567 pub fn handle_record(&self, handle: u64) -> Option<DataHandleRecord> {
1568 self.records.borrow().get(&handle).cloned()
1569 }
1570
1571 pub fn handle_records(&self) -> Vec<DataHandleRecord> {
1572 self.records.borrow().values().cloned().collect()
1573 }
1574
1575 pub fn view_record(&self, handle: u64) -> Option<DataViewHandleRecord> {
1576 self.view_records.borrow().get(&handle).cloned()
1577 }
1578
1579 pub fn view_records(&self) -> Vec<DataViewHandleRecord> {
1580 self.view_records.borrow().values().cloned().collect()
1581 }
1582
1583 fn next_handle(&self) -> u64 {
1584 let mut next = self.next_handle.borrow_mut();
1585 let handle = *next;
1586 *next += 1;
1587 handle
1588 }
1589}
1590
1591impl RuntimeDataProvider for InMemoryDataProvider {
1592 fn materialize(&self, request: &DataMaterializationRequest) -> Result<HandleRef> {
1593 if request.node_id != request.binding.node_id {
1594 return Err(DagMlError::RuntimeValidation(format!(
1595 "data materialization request node `{}` does not match binding node `{}`",
1596 request.node_id, request.binding.node_id
1597 )));
1598 }
1599 if request.input_name != request.binding.input_name {
1600 return Err(DagMlError::RuntimeValidation(format!(
1601 "data materialization request input `{}` does not match binding input `{}`",
1602 request.input_name, request.binding.input_name
1603 )));
1604 }
1605 let envelope = self
1606 .envelopes
1607 .get(&DataEnvelopeKey::from_binding(&request.binding))
1608 .ok_or_else(|| {
1609 DagMlError::RuntimeValidation(format!(
1610 "no external data-plan envelope registered for binding `{}` on `{}`",
1611 request.binding.input_name, request.binding.node_id
1612 ))
1613 })?;
1614 request.binding.validate_envelope(envelope)?;
1615
1616 let handle = HandleRef {
1617 handle: self.next_handle(),
1618 kind: HandleKind::Data,
1619 owner_controller: self.owner_controller.clone(),
1620 };
1621 let record = DataHandleRecord {
1622 handle: handle.clone(),
1623 run_id: request.run_id.clone(),
1624 node_id: request.node_id.clone(),
1625 input_name: request.input_name.clone(),
1626 phase: request.phase,
1627 variant_id: request.variant_id.clone(),
1628 fold_id: request.fold_id.clone(),
1629 request_id: request.binding.request_id.clone(),
1630 schema_fingerprint: request.binding.schema_fingerprint.clone(),
1631 plan_fingerprint: request.binding.plan_fingerprint.clone(),
1632 relation_fingerprint: request.binding.relation_fingerprint.clone(),
1633 output_representation: request.binding.output_representation.clone(),
1634 feature_set_id: request.binding.feature_set_id.clone(),
1635 source_ids: request.binding.source_ids.clone(),
1636 relation_record_count: envelope
1637 .coordinator_relations
1638 .as_ref()
1639 .map(|relations| relations.records.len()),
1640 };
1641 self.records.borrow_mut().insert(handle.handle, record);
1642 Ok(handle)
1643 }
1644
1645 fn make_view(&self, request: &DataViewRequest) -> Result<HandleRef> {
1646 request.view.validate()?;
1647 if request.node_id != request.binding.node_id {
1648 return Err(DagMlError::RuntimeValidation(format!(
1649 "data view request node `{}` does not match binding node `{}`",
1650 request.node_id, request.binding.node_id
1651 )));
1652 }
1653 if request.input_name != request.binding.input_name {
1654 return Err(DagMlError::RuntimeValidation(format!(
1655 "data view request input `{}` does not match binding input `{}`",
1656 request.input_name, request.binding.input_name
1657 )));
1658 }
1659 if request.data_handle.kind != HandleKind::Data {
1660 return Err(DagMlError::RuntimeValidation(format!(
1661 "data view request for `{}` on `{}` received non-data parent handle",
1662 request.input_name, request.node_id
1663 )));
1664 }
1665 let parent = self
1666 .records
1667 .borrow()
1668 .get(&request.data_handle.handle)
1669 .cloned()
1670 .ok_or_else(|| {
1671 DagMlError::RuntimeValidation(format!(
1672 "unknown data handle `{}` for view request `{}` on `{}`",
1673 request.data_handle.handle, request.input_name, request.node_id
1674 ))
1675 })?;
1676 if parent.handle != request.data_handle {
1677 return Err(DagMlError::RuntimeValidation(format!(
1678 "data view request parent handle `{}` does not match provider record",
1679 request.data_handle.handle
1680 )));
1681 }
1682 request.binding.validate()?;
1683 let handle = HandleRef {
1684 handle: self.next_handle(),
1685 kind: HandleKind::DataView,
1686 owner_controller: self.owner_controller.clone(),
1687 };
1688 let record = DataViewHandleRecord {
1689 handle: handle.clone(),
1690 parent_handle: request.data_handle.clone(),
1691 run_id: request.run_id.clone(),
1692 node_id: request.node_id.clone(),
1693 input_name: request.input_name.clone(),
1694 phase: request.phase,
1695 variant_id: request.variant_id.clone(),
1696 fold_id: request.fold_id.clone(),
1697 request_id: request.binding.request_id.clone(),
1698 feature_set_id: request.binding.feature_set_id().to_string(),
1699 view: request.view.clone(),
1700 };
1701 self.view_records.borrow_mut().insert(handle.handle, record);
1702 Ok(handle)
1703 }
1704
1705 fn coordinator_relations(&self, binding: &DataBinding) -> Result<Option<SampleRelationSet>> {
1706 let envelope = self
1707 .envelopes
1708 .get(&DataEnvelopeKey::from_binding(binding))
1709 .ok_or_else(|| {
1710 DagMlError::RuntimeValidation(format!(
1711 "no external data-plan envelope registered for binding `{}` on `{}`",
1712 binding.input_name, binding.node_id
1713 ))
1714 })?;
1715 binding.validate_envelope(envelope)?;
1716 Ok(envelope.coordinator_relations.clone())
1717 }
1718}
1719
1720fn validate_fingerprint(label: &str, value: &str) -> Result<()> {
1721 if value.len() != 64 || !value.bytes().all(|byte| byte.is_ascii_hexdigit()) {
1722 return Err(DagMlError::CampaignValidation(format!(
1723 "{label} fingerprint must be a 64-character hex digest"
1724 )));
1725 }
1726 Ok(())
1727}
1728
1729fn validate_non_empty(label: &str, value: &str) -> Result<()> {
1730 if value.trim().is_empty() {
1731 return Err(DagMlError::CampaignValidation(format!(
1732 "{label} must be a non-empty string"
1733 )));
1734 }
1735 Ok(())
1736}
1737
1738fn validate_optional_non_empty(label: &str, value: &Option<String>) -> Result<()> {
1739 if let Some(value) = value {
1740 validate_non_empty(label, value)?;
1741 }
1742 Ok(())
1743}
1744
1745fn validate_combo_like_output(label: &str, unit_level: EntityUnitLevel) -> Result<()> {
1746 if matches!(
1747 unit_level,
1748 EntityUnitLevel::Combo | EntityUnitLevel::Observation
1749 ) {
1750 return Ok(());
1751 }
1752 Err(DagMlError::CampaignValidation(format!(
1753 "{label} representation output_unit_level must be combo or observation"
1754 )))
1755}
1756
1757fn validate_non_empty_list(label: &str, values: &[String]) -> Result<()> {
1758 if values.is_empty() {
1759 return Err(DagMlError::CampaignValidation(format!(
1760 "{label} must be a non-empty list"
1761 )));
1762 }
1763 validate_string_list_entries(label, values)
1764}
1765
1766fn validate_string_list_entries(label: &str, values: &[String]) -> Result<()> {
1767 for (index, value) in values.iter().enumerate() {
1768 if value.trim().is_empty() {
1769 return Err(DagMlError::CampaignValidation(format!(
1770 "{label}[{index}] must be a non-empty string"
1771 )));
1772 }
1773 }
1774 Ok(())
1775}
1776
1777fn validate_unique_strings(label: &str, values: &[String]) -> Result<()> {
1778 let mut seen = BTreeSet::new();
1779 for value in values {
1780 if !seen.insert(value.as_str()) {
1781 return Err(DagMlError::CampaignValidation(format!(
1782 "{label} contains duplicate value `{value}`"
1783 )));
1784 }
1785 }
1786 Ok(())
1787}
1788
1789#[cfg(test)]
1790mod tests {
1791 use super::*;
1792 use crate::ids::NodeId;
1793 use crate::runtime::DataMaterializationRequest;
1794
1795 fn binding() -> DataBinding {
1796 DataBinding {
1797 node_id: NodeId::new("model:base").unwrap(),
1798 input_name: "x".to_string(),
1799 request_id: "nir-to-tabular".to_string(),
1800 schema_fingerprint: "f97b37872fa22134b508f98fd8e207e5b776b52594fb8f6f5c3e15bee212246b"
1801 .to_string(),
1802 plan_fingerprint: "7c5431d85574b3f337022fa5d25971d5b5cf445b90331b49938f573ff6901e4d"
1803 .to_string(),
1804 relation_fingerprint: Some(
1805 "a3a7e329df35db9f2883a17b8611b7fae6dcaa031875e3ec2c9be1b9e29cbe10".to_string(),
1806 ),
1807 output_representation: "tabular_numeric".to_string(),
1808 feature_set_id: Some("x".to_string()),
1809 source_ids: vec!["nir".to_string()],
1810 require_relations: true,
1811 view_policy: DataViewPolicy::default(),
1812 metadata: BTreeMap::new(),
1813 }
1814 }
1815
1816 #[test]
1817 fn validates_data_binding_contract() {
1818 let binding = binding();
1819 binding.validate().unwrap();
1820 assert_eq!(binding.feature_set_id(), "x");
1821 }
1822
1823 #[test]
1824 fn published_model_input_and_data_plan_schemas_declare_current_contract() {
1825 let model_input_schema: serde_json::Value = serde_json::from_str(include_str!(
1826 "../../../docs/contracts/model_input_spec.schema.json"
1827 ))
1828 .unwrap();
1829 assert_eq!(model_input_schema["$id"], MODEL_INPUT_SPEC_SCHEMA_ID);
1830 assert_eq!(
1831 model_input_schema["properties"]["schema_version"]["const"].as_u64(),
1832 Some(MODEL_INPUT_SPEC_SCHEMA_VERSION as u64)
1833 );
1834 assert!(model_input_schema["$defs"]["input_port"]["required"]
1835 .as_array()
1836 .unwrap()
1837 .iter()
1838 .any(|field| field.as_str() == Some("accepted_representations")));
1839 assert!(model_input_schema["$defs"]["fusion_policy"]["properties"]
1840 .as_object()
1841 .unwrap()
1842 .contains_key("representation_plan"));
1843 assert!(model_input_schema["$defs"]
1844 .as_object()
1845 .unwrap()
1846 .contains_key("combination_plan"));
1847 assert!(model_input_schema["$defs"]
1848 .as_object()
1849 .unwrap()
1850 .contains_key("representation_plan"));
1851
1852 let data_plan_schema: serde_json::Value = serde_json::from_str(include_str!(
1853 "../../../docs/contracts/data_plan.schema.json"
1854 ))
1855 .unwrap();
1856 assert_eq!(data_plan_schema["$id"], DATA_PLAN_SCHEMA_ID);
1857 assert_eq!(
1858 data_plan_schema["properties"]["schema_version"]["const"].as_u64(),
1859 Some(DATA_PLAN_SCHEMA_VERSION as u64)
1860 );
1861 assert!(data_plan_schema["$defs"]["data_plan_step_kind"]["enum"]
1862 .as_array()
1863 .unwrap()
1864 .iter()
1865 .any(|kind| kind.as_str() == Some("collate")));
1866 }
1867
1868 #[test]
1869 fn validates_model_input_and_data_plan_fixtures() {
1870 let model_input: ModelInputSpec = serde_json::from_str(include_str!(
1871 "../../../examples/fixtures/data/model_input_spec_tabular_regressor.json"
1872 ))
1873 .unwrap();
1874 model_input.validate().unwrap();
1875 assert_eq!(model_input.ports[0].rank, Some(2));
1876 assert!(model_input.ports[0].multi_source);
1877
1878 let data_plan: DataPlan = serde_json::from_str(include_str!(
1879 "../../../examples/fixtures/data/data_plan_tabular_fusion.json"
1880 ))
1881 .unwrap();
1882 data_plan.validate().unwrap();
1883 assert_eq!(data_plan.output_ports.get("x").unwrap(), "x_collated");
1884 }
1885
1886 #[test]
1887 fn data_plan_rejects_forward_step_references() {
1888 let data_plan = DataPlan {
1889 schema_version: DATA_PLAN_SCHEMA_VERSION,
1890 id: "data-plan:bad".to_string(),
1891 steps: vec![DataPlanStep {
1892 kind: DataPlanStepKind::Adapt,
1893 inputs: vec!["missing".to_string()],
1894 output: "adapted".to_string(),
1895 adapter_id: Some("adapter:adapt".to_string()),
1896 params: BTreeMap::new(),
1897 }],
1898 output_ports: BTreeMap::from([("x".to_string(), "adapted".to_string())]),
1899 warnings: Vec::new(),
1900 requires_user_choice: Vec::new(),
1901 metadata: BTreeMap::new(),
1902 };
1903
1904 let error = data_plan.validate().unwrap_err().to_string();
1905 assert!(error.contains("before it is produced"));
1906 }
1907
1908 #[test]
1909 fn data_view_policy_rejects_unsafe_fit_and_validation_augmentation_by_default() {
1910 let mut full_train_binding = binding();
1911 full_train_binding.view_policy.fit_partition = DataRequestPartition::FullTrain;
1912 let full_train_error = full_train_binding.validate().unwrap_err().to_string();
1913 assert!(
1914 full_train_error.contains("fit_partition=full_train"),
1915 "unexpected full-train error: {full_train_error}"
1916 );
1917
1918 let mut augmented_validation_binding = binding();
1919 augmented_validation_binding
1920 .view_policy
1921 .include_augmented_validation = true;
1922 let augmented_error = augmented_validation_binding
1923 .validate()
1924 .unwrap_err()
1925 .to_string();
1926 assert!(
1927 augmented_error.contains("include_augmented_validation=true"),
1928 "unexpected augmented-validation error: {augmented_error}"
1929 );
1930
1931 let mut excluded_binding = binding();
1932 excluded_binding.view_policy.include_excluded = true;
1933 let excluded_error = excluded_binding.validate().unwrap_err().to_string();
1934 assert!(
1935 excluded_error.contains("include_excluded=true"),
1936 "unexpected excluded-row error: {excluded_error}"
1937 );
1938 }
1939
1940 #[test]
1941 fn data_view_policy_requires_explicit_unsafe_flags_for_debug_views() {
1942 let mut binding = binding();
1943 binding.view_policy.fit_partition = DataRequestPartition::FullTrain;
1944 binding.view_policy.include_augmented_validation = true;
1945 binding.view_policy.include_excluded = true;
1946 binding.view_policy.unsafe_flags = BTreeSet::from([
1947 DataViewPolicy::ALLOW_FIT_CV_FULL_TRAIN_VIEW.to_string(),
1948 DataViewPolicy::ALLOW_AUGMENTED_VALIDATION_VIEW.to_string(),
1949 DataViewPolicy::ALLOW_EXCLUDED_ROWS.to_string(),
1950 ]);
1951
1952 binding.validate().unwrap();
1953 }
1954
1955 #[test]
1956 fn validates_external_data_envelope_subset() {
1957 let envelope: ExternalDataPlanEnvelope = serde_json::from_str(include_str!(
1958 "../../../examples/fixtures/data/coordinator_data_plan_envelope_sample12.json"
1959 ))
1960 .unwrap();
1961
1962 assert_eq!(
1963 envelope.schema_version,
1964 EXTERNAL_DATA_PLAN_ENVELOPE_SCHEMA_VERSION
1965 );
1966 binding().validate_envelope(&envelope).unwrap();
1967 }
1968
1969 #[test]
1970 fn validates_multisource_repetition_envelope_fixture() {
1971 let envelope: ExternalDataPlanEnvelope = serde_json::from_str(include_str!(
1972 "../../../examples/fixtures/data/coordinator_data_plan_envelope_multisource_repetitions.json"
1973 ))
1974 .unwrap();
1975
1976 envelope.validate().unwrap();
1977 let relations = envelope.coordinator_relations.as_ref().unwrap();
1978 assert_eq!(relations.records.len(), 8);
1979 let source_counts = relations.records.iter().fold(
1980 BTreeMap::<String, usize>::new(),
1981 |mut counts, record| {
1982 if record.unit_level == EntityUnitLevel::Observation {
1983 *counts
1984 .entry(record.source_id.clone().expect("source_id"))
1985 .or_default() += 1;
1986 }
1987 counts
1988 },
1989 );
1990 assert_eq!(source_counts["A"], 2);
1991 assert_eq!(source_counts["B"], 3);
1992 assert_eq!(source_counts["C"], 2);
1993 let combo = relations
1994 .records
1995 .iter()
1996 .find(|record| record.unit_level == EntityUnitLevel::Combo)
1997 .expect("relation-backed combo row");
1998 assert_eq!(combo.sample_id.as_str(), "sample:1");
1999 assert_eq!(
2000 combo.origin_sample_id.as_ref().unwrap().as_str(),
2001 combo.sample_id.as_str()
2002 );
2003 assert_eq!(combo.component_observation_ids.len(), 3);
2004 for source_id in ["A", "B", "C"] {
2005 assert!(combo
2006 .component_observation_ids
2007 .iter()
2008 .any(|observation_id| observation_id.as_str().contains(source_id)));
2009 }
2010 assert_eq!(
2011 relations
2012 .sample_for_observation(
2013 &crate::ids::ObservationId::new("obs.s1.combo.A0.B0.C0").unwrap()
2014 )
2015 .unwrap()
2016 .as_str(),
2017 "sample:1"
2018 );
2019 }
2020
2021 #[test]
2022 fn published_external_data_envelope_schema_declares_current_version() {
2023 let schema: serde_json::Value = serde_json::from_str(include_str!(
2024 "../../../docs/contracts/coordinator_data_plan_envelope.schema.json"
2025 ))
2026 .unwrap();
2027
2028 assert_eq!(
2029 schema["properties"]["schema_version"]["const"].as_u64(),
2030 Some(EXTERNAL_DATA_PLAN_ENVELOPE_SCHEMA_VERSION as u64)
2031 );
2032 assert!(schema["required"]
2033 .as_array()
2034 .unwrap()
2035 .iter()
2036 .any(|field| field.as_str() == Some("schema_version")));
2037 }
2038
2039 #[test]
2040 fn refuses_unsupported_external_data_envelope_schema_version() {
2041 let mut envelope: ExternalDataPlanEnvelope = serde_json::from_str(include_str!(
2042 "../../../examples/fixtures/data/coordinator_data_plan_envelope_sample12.json"
2043 ))
2044 .unwrap();
2045 envelope.schema_version = EXTERNAL_DATA_PLAN_ENVELOPE_SCHEMA_VERSION + 1;
2046
2047 assert!(binding().validate_envelope(&envelope).is_err());
2048 }
2049
2050 #[test]
2051 fn refuses_envelope_fingerprint_mismatch() {
2052 let mut envelope: ExternalDataPlanEnvelope = serde_json::from_str(include_str!(
2053 "../../../examples/fixtures/data/coordinator_data_plan_envelope_sample12.json"
2054 ))
2055 .unwrap();
2056 envelope.plan_fingerprint = "0".repeat(64);
2057
2058 assert!(binding().validate_envelope(&envelope).is_err());
2059 }
2060
2061 #[test]
2062 fn in_memory_provider_materializes_validated_data_handles() {
2063 let envelope: ExternalDataPlanEnvelope = serde_json::from_str(include_str!(
2064 "../../../examples/fixtures/data/coordinator_data_plan_envelope_sample12.json"
2065 ))
2066 .unwrap();
2067 let provider = InMemoryDataProvider::with_envelope(
2068 ControllerId::new("controller:data.provider").unwrap(),
2069 envelope,
2070 )
2071 .unwrap();
2072
2073 let handle = provider
2074 .materialize(&DataMaterializationRequest {
2075 run_id: RunId::new("run:data").unwrap(),
2076 node_id: NodeId::new("model:base").unwrap(),
2077 input_name: "x".to_string(),
2078 phase: Phase::FitCv,
2079 variant_id: Some(VariantId::new("variant:base").unwrap()),
2080 fold_id: Some(FoldId::new("fold:0").unwrap()),
2081 binding: binding(),
2082 })
2083 .unwrap();
2084
2085 let record = provider.handle_record(handle.handle).unwrap();
2086 assert_eq!(record.input_name, "x");
2087 assert_eq!(record.relation_record_count, Some(4));
2088 assert_eq!(provider.handle_records().len(), 1);
2089 }
2090
2091 #[test]
2092 fn in_memory_provider_registration_is_idempotent_for_same_envelope() {
2093 let envelope: ExternalDataPlanEnvelope = serde_json::from_str(include_str!(
2094 "../../../examples/fixtures/data/coordinator_data_plan_envelope_sample12.json"
2095 ))
2096 .unwrap();
2097 let mut provider =
2098 InMemoryDataProvider::new(ControllerId::new("controller:data.provider").unwrap());
2099
2100 provider.register_envelope(envelope.clone()).unwrap();
2101 provider.register_envelope(envelope).unwrap();
2102 }
2103
2104 #[test]
2105 fn in_memory_provider_refuses_unknown_envelope() {
2106 let provider =
2107 InMemoryDataProvider::new(ControllerId::new("controller:data.provider").unwrap());
2108
2109 assert!(provider
2110 .materialize(&DataMaterializationRequest {
2111 run_id: RunId::new("run:data").unwrap(),
2112 node_id: NodeId::new("model:base").unwrap(),
2113 input_name: "x".to_string(),
2114 phase: Phase::FitCv,
2115 variant_id: None,
2116 fold_id: None,
2117 binding: binding(),
2118 })
2119 .is_err());
2120 }
2121
2122 fn cartesian_combination() -> CombinationPlan {
2123 CombinationPlan {
2124 mode: CombinationMode::Cartesian,
2125 component_source_ids: vec!["source:a".to_string(), "source:b".to_string()],
2126 component_unit_ids: Vec::new(),
2127 match_key: None,
2128 reference_source_id: None,
2129 seed: None,
2130 cap: None,
2131 budget: Some(32),
2132 missing_source_policy: Some(RepresentationMissingSourcePolicy::Strict),
2133 metadata: BTreeMap::new(),
2134 }
2135 }
2136
2137 fn compatibility_report() -> RepresentationCompatibilityReport {
2138 RepresentationCompatibilityReport {
2139 policy: RepresentationMissingSourcePolicy::Mask,
2140 outcome: RepresentationCompatibilityOutcome::CompatibleWithFallback,
2141 fallback_used: Some("mask".to_string()),
2142 warning_severity: Some(RepresentationCompatibilitySeverity::Warning),
2143 affected_source_count: 1,
2144 affected_repetition_count: 2,
2145 affected_sample_count: 3,
2146 train_relation_fingerprint: Some("c".repeat(64)),
2147 predict_relation_fingerprint: Some("d".repeat(64)),
2148 train_unit_count: Some(6),
2149 predict_unit_count: Some(4),
2150 fixed_width_required: true,
2151 final_reducer_stabilizes_output: true,
2152 cartesian_combo_count_changed: true,
2153 late_fusion_branch_delta: true,
2154 messages: vec!["mask fallback applied for missing source".to_string()],
2155 metadata: BTreeMap::new(),
2156 }
2157 }
2158
2159 #[derive(serde::Deserialize)]
2160 #[serde(deny_unknown_fields)]
2161 struct D9GoldenFixture {
2162 schema_version: u32,
2163 golden_scenarios: Vec<D9GoldenScenario>,
2164 }
2165
2166 #[derive(serde::Deserialize)]
2167 #[serde(deny_unknown_fields)]
2168 struct D9GoldenScenario {
2169 scenario_id: String,
2170 flow: Vec<String>,
2171 mock_phase_path: Vec<String>,
2172 representation_replay_manifest: RepresentationReplayManifest,
2173 assertions: Vec<String>,
2174 }
2175
2176 #[test]
2177 fn d9_golden_multisource_repetition_manifests_validate() {
2178 let fixture: D9GoldenFixture = serde_json::from_str(include_str!(
2179 "../../../examples/fixtures/runtime/d9_golden_multisource_scenarios.json"
2180 ))
2181 .unwrap();
2182 assert_eq!(fixture.schema_version, 1);
2183 assert_eq!(fixture.golden_scenarios.len(), 7);
2184
2185 let mut scenario_ids = BTreeSet::new();
2186 let mut has_same_repetition_replay = false;
2187 let mut has_changed_repetition_replay = false;
2188 let mut has_combo_meta_fit_influence = false;
2189 for scenario in &fixture.golden_scenarios {
2190 assert!(
2191 scenario_ids.insert(scenario.scenario_id.as_str()),
2192 "duplicate D9 scenario {}",
2193 scenario.scenario_id
2194 );
2195 assert!(!scenario.flow.is_empty());
2196 assert_eq!(scenario.mock_phase_path, ["fit_cv", "refit", "predict"]);
2197 assert!(!scenario.assertions.is_empty());
2198
2199 let manifest = &scenario.representation_replay_manifest;
2200 manifest.validate().unwrap();
2201 assert_eq!(
2202 manifest.final_output_unit_level,
2203 Some(EntityUnitLevel::PhysicalSample),
2204 "{} must publish sample-level outputs",
2205 scenario.scenario_id
2206 );
2207 if manifest.output_unit_level == EntityUnitLevel::Combo {
2208 assert!(
2209 manifest.final_reduction_id.is_some(),
2210 "{} must declare combo-to-sample reduction",
2211 scenario.scenario_id
2212 );
2213 assert!(
2214 !manifest.combo_selection.is_empty(),
2215 "{} must retain relation-backed combo identities",
2216 scenario.scenario_id
2217 );
2218 }
2219
2220 if let (Some(train), Some(predict)) = (
2221 &manifest.train_compatibility,
2222 &manifest.predict_compatibility,
2223 ) {
2224 has_same_repetition_replay |= train.train_unit_count == predict.predict_unit_count
2225 && train.train_relation_fingerprint == predict.predict_relation_fingerprint;
2226 has_changed_repetition_replay |= train.train_unit_count
2227 != predict.predict_unit_count
2228 || train.train_relation_fingerprint != predict.predict_relation_fingerprint;
2229 }
2230
2231 if scenario.scenario_id == "d9.combo_meta_post.relation_backed_adapters" {
2232 has_combo_meta_fit_influence = manifest
2233 .metadata
2234 .get("fit_influence_policy")
2235 .is_some_and(|value| value == "equal_sample_influence");
2236 }
2237 }
2238
2239 assert!(scenario_ids.contains("d9.per_source_aggregate.source_models.sample_reducer"));
2240 assert!(scenario_ids.contains("d9.late_fusion_by_source.prediction_join.meta_model"));
2241 assert!(scenario_ids.contains("d9.cartesian_full.model.combo_to_sample_reducer"));
2242 assert!(scenario_ids.contains("d9.cartesian_mc.deterministic_replay"));
2243 assert!(scenario_ids.contains("d9.stack_fixed.strict_cardinality"));
2244 assert!(scenario_ids.contains("d9.stack_padded_masked.missing_repetition"));
2245 assert!(scenario_ids.contains("d9.combo_meta_post.relation_backed_adapters"));
2246 assert!(has_same_repetition_replay);
2247 assert!(has_changed_repetition_replay);
2248 assert!(has_combo_meta_fit_influence);
2249 }
2250
2251 #[test]
2252 fn representation_plan_validates_cartesian_and_monte_carlo_contracts() {
2253 let cartesian = RepresentationPlan::CartesianProduct(CartesianProductRepresentation {
2254 combination_plan: cartesian_combination(),
2255 output_unit_level: EntityUnitLevel::Combo,
2256 cardinality: RepresentationCardinality::ManyToMany,
2257 preserve_provenance: true,
2258 });
2259 cartesian.validate().unwrap();
2260
2261 let monte_carlo =
2262 RepresentationPlan::MonteCarloCartesian(MonteCarloCartesianRepresentation {
2263 combination_plan: CombinationPlan {
2264 mode: CombinationMode::SampleK,
2265 component_source_ids: vec!["source:a".to_string(), "source:b".to_string()],
2266 component_unit_ids: Vec::new(),
2267 match_key: None,
2268 reference_source_id: None,
2269 seed: Some(42),
2270 cap: Some(8),
2271 budget: None,
2272 missing_source_policy: Some(RepresentationMissingSourcePolicy::Warn),
2273 metadata: BTreeMap::new(),
2274 },
2275 output_unit_level: EntityUnitLevel::Observation,
2276 cardinality: RepresentationCardinality::BoundedMany,
2277 preserve_provenance: true,
2278 });
2279 monte_carlo.validate().unwrap();
2280
2281 let mut bad = cartesian_combination();
2282 bad.mode = CombinationMode::SampleK;
2283 bad.seed = Some(7);
2284 bad.cap = Some(0);
2285 assert!(bad.validate().is_err());
2286 }
2287
2288 #[test]
2289 fn stack_representations_validate_cardinality_and_mask_policy() {
2290 let fixed = RepresentationPlan::StackFixed(StackFixedRepresentation {
2291 output_unit_level: EntityUnitLevel::SourceSample,
2292 cardinality: RepresentationCardinality::OneToMany,
2293 expected_cardinality: 3,
2294 component_source_ids: vec!["source:a".to_string(), "source:b".to_string()],
2295 });
2296 fixed.validate().unwrap();
2297
2298 let padded = RepresentationPlan::StackPaddedMasked(StackPaddedMaskedRepresentation {
2299 output_unit_level: EntityUnitLevel::SourceSample,
2300 cardinality: RepresentationCardinality::BoundedMany,
2301 expected_cardinality: 4,
2302 missing_source_policy: RepresentationMissingSourcePolicy::Mask,
2303 requires_missing_masks: true,
2304 component_source_ids: vec!["source:a".to_string()],
2305 });
2306 padded.validate().unwrap();
2307
2308 let bad = RepresentationPlan::StackPaddedMasked(StackPaddedMaskedRepresentation {
2309 output_unit_level: EntityUnitLevel::SourceSample,
2310 cardinality: RepresentationCardinality::BoundedMany,
2311 expected_cardinality: 4,
2312 missing_source_policy: RepresentationMissingSourcePolicy::ImputeDeclared,
2313 requires_missing_masks: false,
2314 component_source_ids: Vec::new(),
2315 });
2316 assert!(bad.validate().is_err());
2317 }
2318
2319 #[test]
2320 fn representation_compatibility_report_enforces_missingness_policy() {
2321 compatibility_report().validate().unwrap();
2322
2323 let strict = RepresentationCompatibilityReport {
2324 policy: RepresentationMissingSourcePolicy::Strict,
2325 outcome: RepresentationCompatibilityOutcome::Incompatible,
2326 fallback_used: None,
2327 warning_severity: None,
2328 affected_source_count: 1,
2329 affected_repetition_count: 0,
2330 affected_sample_count: 1,
2331 train_relation_fingerprint: None,
2332 predict_relation_fingerprint: None,
2333 train_unit_count: None,
2334 predict_unit_count: None,
2335 fixed_width_required: false,
2336 final_reducer_stabilizes_output: false,
2337 cartesian_combo_count_changed: false,
2338 late_fusion_branch_delta: false,
2339 messages: Vec::new(),
2340 metadata: BTreeMap::new(),
2341 };
2342 strict.validate().unwrap();
2343
2344 let mut bad_non_strict = compatibility_report();
2345 bad_non_strict.fallback_used = None;
2346 assert!(bad_non_strict.validate().is_err());
2347
2348 let mut bad_fixed_width = compatibility_report();
2349 bad_fixed_width.policy = RepresentationMissingSourcePolicy::ImputeDeclared;
2350 bad_fixed_width.fallback_used = Some("impute_declared".to_string());
2351 assert!(bad_fixed_width.validate().is_err());
2352
2353 let mut bad_cartesian = compatibility_report();
2354 bad_cartesian.final_reducer_stabilizes_output = false;
2355 assert!(bad_cartesian.validate().is_err());
2356
2357 let bad_relation_drift = RepresentationCompatibilityReport {
2358 policy: RepresentationMissingSourcePolicy::Strict,
2359 outcome: RepresentationCompatibilityOutcome::Compatible,
2360 fallback_used: None,
2361 warning_severity: None,
2362 affected_source_count: 0,
2363 affected_repetition_count: 0,
2364 affected_sample_count: 0,
2365 train_relation_fingerprint: Some("a".repeat(64)),
2366 predict_relation_fingerprint: Some("b".repeat(64)),
2367 train_unit_count: Some(3),
2368 predict_unit_count: Some(3),
2369 fixed_width_required: false,
2370 final_reducer_stabilizes_output: true,
2371 cartesian_combo_count_changed: false,
2372 late_fusion_branch_delta: false,
2373 messages: Vec::new(),
2374 metadata: BTreeMap::new(),
2375 };
2376 let error = bad_relation_drift.validate().unwrap_err().to_string();
2377 assert!(
2378 error.contains("relation fingerprint mismatch requires affected units"),
2379 "unexpected D9 relation drift error: {error}"
2380 );
2381
2382 let mut bad_unit_drift = bad_relation_drift;
2383 bad_unit_drift.predict_relation_fingerprint =
2384 bad_unit_drift.train_relation_fingerprint.clone();
2385 bad_unit_drift.predict_unit_count = Some(2);
2386 let error = bad_unit_drift.validate().unwrap_err().to_string();
2387 assert!(
2388 error.contains("unit count mismatch requires affected units"),
2389 "unexpected D9 unit-count drift error: {error}"
2390 );
2391 }
2392
2393 #[test]
2394 fn representation_replay_manifest_round_trips_and_validates() {
2395 let plan = RepresentationPlan::CartesianProduct(CartesianProductRepresentation {
2396 combination_plan: cartesian_combination(),
2397 output_unit_level: EntityUnitLevel::Combo,
2398 cardinality: RepresentationCardinality::ManyToMany,
2399 preserve_provenance: true,
2400 });
2401 let manifest = RepresentationReplayManifest {
2402 manifest_id: "repr:combo.ab".to_string(),
2403 representation_plan: plan,
2404 combination_plan: Some(cartesian_combination()),
2405 output_unit_level: EntityUnitLevel::Combo,
2406 output_representation: Some("combo_observation".to_string()),
2407 relation_fingerprint: Some("a".repeat(64)),
2408 feature_schema_fingerprint: Some("b".repeat(64)),
2409 final_reduction_id: Some("reduction:combo_to_sample".to_string()),
2410 sample_observation_mapping: vec![
2411 RepresentationSampleObservationMapping {
2412 physical_sample_id: "sample:1".to_string(),
2413 source_id: "source:a".to_string(),
2414 observation_ids: vec!["obs:a.1".to_string(), "obs:a.2".to_string()],
2415 },
2416 RepresentationSampleObservationMapping {
2417 physical_sample_id: "sample:1".to_string(),
2418 source_id: "source:b".to_string(),
2419 observation_ids: vec!["obs:b.1".to_string()],
2420 },
2421 ],
2422 combo_selection: vec![RepresentationComboSelectionRecord {
2423 combo_unit_id: "combo:sample1:a1:b1".to_string(),
2424 physical_sample_id: "sample:1".to_string(),
2425 component_observation_ids: vec!["obs:a.1".to_string(), "obs:b.1".to_string()],
2426 seed: Some(42),
2427 }],
2428 qc_policy_refs: vec!["qc:default".to_string()],
2429 outlier_policy_refs: vec!["outlier:none".to_string()],
2430 missing_source_policy: Some(RepresentationMissingSourcePolicy::Strict),
2431 missing_repetition_policy: Some(RepresentationMissingSourcePolicy::Warn),
2432 prediction_representation: Some("sample_prediction".to_string()),
2433 final_output_unit_level: Some(EntityUnitLevel::PhysicalSample),
2434 train_compatibility: Some(RepresentationCompatibilityReport {
2435 policy: RepresentationMissingSourcePolicy::Strict,
2436 outcome: RepresentationCompatibilityOutcome::Compatible,
2437 fallback_used: None,
2438 warning_severity: None,
2439 affected_source_count: 0,
2440 affected_repetition_count: 0,
2441 affected_sample_count: 0,
2442 train_relation_fingerprint: Some("a".repeat(64)),
2443 predict_relation_fingerprint: None,
2444 train_unit_count: Some(1),
2445 predict_unit_count: Some(1),
2446 fixed_width_required: false,
2447 final_reducer_stabilizes_output: true,
2448 cartesian_combo_count_changed: false,
2449 late_fusion_branch_delta: false,
2450 messages: Vec::new(),
2451 metadata: BTreeMap::new(),
2452 }),
2453 predict_compatibility: Some(compatibility_report()),
2454 metadata: BTreeMap::new(),
2455 };
2456
2457 manifest.validate().unwrap();
2458 let encoded = serde_json::to_string(&manifest).unwrap();
2459 let decoded: RepresentationReplayManifest = serde_json::from_str(&encoded).unwrap();
2460 assert_eq!(decoded, manifest);
2461 }
2462}