1use std::collections::{BTreeMap, BTreeSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::error::{DagMlError, Result};
6use crate::ids::{ControllerId, NodeId};
7use crate::relation::EntityUnitLevel;
8
9#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
10#[serde(rename_all = "snake_case")]
11pub enum SplitUnit {
12 PhysicalSample,
13 Observation,
14 Sample,
15 Target,
16 Group,
17}
18
19#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
20pub struct LeakageUnitPolicy {
21 #[serde(default = "default_split_unit")]
22 pub split_unit: SplitUnit,
23 #[serde(default = "default_true")]
24 pub forbid_origin_cross_fold: bool,
25 #[serde(default)]
26 pub allow_observation_split_with_shared_target: bool,
27 #[serde(default)]
28 pub require_group_ids: bool,
29 #[serde(default)]
30 pub unsafe_flags: BTreeSet<String>,
31}
32
33impl Default for LeakageUnitPolicy {
34 fn default() -> Self {
35 Self {
36 split_unit: SplitUnit::PhysicalSample,
37 forbid_origin_cross_fold: true,
38 allow_observation_split_with_shared_target: false,
39 require_group_ids: false,
40 unsafe_flags: BTreeSet::new(),
41 }
42 }
43}
44
45impl LeakageUnitPolicy {
46 pub fn validate(&self) -> Result<()> {
47 if self.split_unit == SplitUnit::Observation
48 && !self.allow_observation_split_with_shared_target
49 {
50 return Err(DagMlError::CampaignValidation(
51 "observation-level splitting is unsafe for repeated X / shared Y unless explicitly allowed".to_string(),
52 ));
53 }
54 if self.require_group_ids && self.split_unit != SplitUnit::Group {
55 return Err(DagMlError::CampaignValidation(
56 "require_group_ids=true requires split_unit=group".to_string(),
57 ));
58 }
59 Ok(())
60 }
61}
62
63fn default_split_unit() -> SplitUnit {
64 SplitUnit::PhysicalSample
65}
66
67#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
68#[serde(rename_all = "snake_case")]
69pub enum PredictionLevel {
70 Observation,
71 Sample,
72 Target,
73 Group,
74}
75
76#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
77#[serde(rename_all = "snake_case")]
78pub enum FitInfluencePolicy {
79 Auto,
80 #[default]
81 UniformRows,
82 EqualSampleInfluence,
83 ResampleEqualized,
84 BackendLossWeight,
85 ScorerOnly,
86 StrictWeightSupport,
87}
88
89#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
90#[serde(rename_all = "snake_case")]
91pub enum AggregationMethod {
92 None,
93 Mean,
94 WeightedMean,
95 Median,
96 Vote,
97 RobustMean,
98 ExcludeOutliers,
99 CustomController,
100}
101
102#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
103#[serde(rename_all = "snake_case")]
104pub enum AggregationWeights {
105 None,
106 Quality,
107 RepetitionCount,
108 ControllerEmitted,
109}
110
111#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
112pub struct AggregationControllerSpec {
113 pub controller_id: ControllerId,
114 #[serde(default = "default_json_object")]
115 pub params: serde_json::Value,
116}
117
118impl AggregationControllerSpec {
119 pub fn validate(&self) -> Result<()> {
120 if self.params.is_null() {
121 return Err(DagMlError::CampaignValidation(format!(
122 "custom aggregation controller `{}` params cannot be null",
123 self.controller_id
124 )));
125 }
126 Ok(())
127 }
128}
129
130#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
131#[serde(rename_all = "snake_case")]
132pub enum ReductionRole {
133 Score,
134 Persist,
135 FoldEnsemble,
136 MetaFeature,
137 FinalOutput,
138}
139
140#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
141#[serde(rename_all = "snake_case")]
142pub enum ReductionAxis {
143 Unit,
144 Fold,
145 Model,
146 Metric,
147}
148
149#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
150#[serde(rename_all = "snake_case")]
151pub enum ReductionMethod {
152 Mean,
153 WeightedMean,
154 Median,
155 Vote,
156 RobustMean,
157 ExcludeOutliers,
158 Custom,
159}
160
161impl From<AggregationMethod> for ReductionMethod {
162 fn from(method: AggregationMethod) -> Self {
163 match method {
164 AggregationMethod::None | AggregationMethod::Mean => Self::Mean,
165 AggregationMethod::WeightedMean => Self::WeightedMean,
166 AggregationMethod::Median => Self::Median,
167 AggregationMethod::Vote => Self::Vote,
168 AggregationMethod::RobustMean => Self::RobustMean,
169 AggregationMethod::ExcludeOutliers => Self::ExcludeOutliers,
170 AggregationMethod::CustomController => Self::Custom,
171 }
172 }
173}
174
175#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
176#[serde(rename_all = "snake_case")]
177pub enum ReductionTaskCompatibility {
178 Any,
179 Regression,
180 Classification,
181}
182
183#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
184pub struct ReductionPlan {
185 #[serde(default = "default_reduction_role")]
186 pub role: ReductionRole,
187 #[serde(default = "default_reduction_axis")]
188 pub axis: ReductionAxis,
189 #[serde(default = "default_reduction_input_unit_level")]
190 pub input_unit_level: EntityUnitLevel,
191 #[serde(default = "default_reduction_output_unit_level")]
192 pub output_unit_level: EntityUnitLevel,
193 #[serde(default = "default_reduction_method")]
194 pub method: ReductionMethod,
195 #[serde(default = "default_aggregation_weights")]
196 pub weight_source: AggregationWeights,
197 #[serde(default = "default_reduction_task_compatibility")]
198 pub task_compatibility: ReductionTaskCompatibility,
199 #[serde(default, skip_serializing_if = "Option::is_none")]
200 pub custom_controller: Option<AggregationControllerSpec>,
201 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
202 pub params: BTreeMap<String, serde_json::Value>,
203}
204
205impl Default for ReductionPlan {
206 fn default() -> Self {
207 Self {
208 role: default_reduction_role(),
209 axis: default_reduction_axis(),
210 input_unit_level: default_reduction_input_unit_level(),
211 output_unit_level: default_reduction_output_unit_level(),
212 method: default_reduction_method(),
213 weight_source: default_aggregation_weights(),
214 task_compatibility: default_reduction_task_compatibility(),
215 custom_controller: None,
216 params: BTreeMap::new(),
217 }
218 }
219}
220
221impl ReductionPlan {
222 pub fn validate(&self) -> Result<()> {
223 if self.method == ReductionMethod::WeightedMean
224 && self.weight_source == AggregationWeights::None
225 {
226 return Err(DagMlError::CampaignValidation(
227 "weighted_mean reduction requires an explicit weight_source".to_string(),
228 ));
229 }
230 if self.method != ReductionMethod::WeightedMean
231 && self.method != ReductionMethod::Custom
232 && self.weight_source != AggregationWeights::None
233 {
234 return Err(DagMlError::CampaignValidation(format!(
235 "reduction weight_source {:?} is only valid with weighted_mean or custom",
236 self.weight_source
237 )));
238 }
239 match (&self.method, &self.custom_controller) {
240 (ReductionMethod::Custom, Some(controller)) => controller.validate()?,
241 (ReductionMethod::Custom, None) => {
242 return Err(DagMlError::CampaignValidation(
243 "custom reduction requires a custom_controller spec".to_string(),
244 ));
245 }
246 (_, Some(controller)) => {
247 return Err(DagMlError::CampaignValidation(format!(
248 "reduction controller `{}` is only valid with custom method",
249 controller.controller_id
250 )));
251 }
252 (_, None) => {}
253 }
254 if self.method == ReductionMethod::Vote
255 && self.task_compatibility == ReductionTaskCompatibility::Regression
256 {
257 return Err(DagMlError::CampaignValidation(
258 "vote reduction is not compatible with regression tasks".to_string(),
259 ));
260 }
261 validate_trim_fraction(self.params.get("trim_fraction"))?;
262 validate_outlier_threshold(self.params.get("threshold"))?;
263 Ok(())
264 }
265}
266
267#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
268pub struct AggregationPolicy {
269 #[serde(default = "default_prediction_level")]
270 pub aggregation_level: PredictionLevel,
271 #[serde(default = "default_aggregation_method")]
272 pub method: AggregationMethod,
273 #[serde(default = "default_aggregation_weights")]
274 pub weights: AggregationWeights,
275 #[serde(default, skip_serializing_if = "Option::is_none")]
276 pub custom_controller: Option<AggregationControllerSpec>,
277 #[serde(default = "default_true")]
278 pub emit_parallel_metrics: bool,
279 #[serde(default = "default_prediction_level")]
280 pub selection_metric_level: PredictionLevel,
281 #[serde(default = "default_true")]
282 pub store_raw_predictions: bool,
283 #[serde(default = "default_true")]
284 pub store_aggregated_predictions: bool,
285}
286
287impl Default for AggregationPolicy {
288 fn default() -> Self {
289 Self {
290 aggregation_level: PredictionLevel::Sample,
291 method: AggregationMethod::Mean,
292 weights: AggregationWeights::None,
293 custom_controller: None,
294 emit_parallel_metrics: true,
295 selection_metric_level: PredictionLevel::Sample,
296 store_raw_predictions: true,
297 store_aggregated_predictions: true,
298 }
299 }
300}
301
302impl AggregationPolicy {
303 pub fn validate(&self) -> Result<()> {
304 if self.method == AggregationMethod::None
305 && self.aggregation_level != PredictionLevel::Observation
306 {
307 return Err(DagMlError::CampaignValidation(
308 "aggregation method none is only valid at observation level".to_string(),
309 ));
310 }
311 if self.method == AggregationMethod::WeightedMean
312 && self.weights == AggregationWeights::None
313 {
314 return Err(DagMlError::CampaignValidation(
315 "weighted_mean aggregation requires an explicit weights policy".to_string(),
316 ));
317 }
318 if self.method != AggregationMethod::WeightedMean
319 && self.method != AggregationMethod::CustomController
320 && self.weights != AggregationWeights::None
321 {
322 return Err(DagMlError::CampaignValidation(format!(
323 "aggregation weights {:?} are only valid with weighted_mean",
324 self.weights
325 )));
326 }
327 match (&self.method, &self.custom_controller) {
328 (AggregationMethod::CustomController, Some(controller)) => controller.validate()?,
329 (AggregationMethod::CustomController, None) => {
330 return Err(DagMlError::CampaignValidation(
331 "custom_controller aggregation requires a custom_controller spec".to_string(),
332 ));
333 }
334 (_, Some(controller)) => {
335 return Err(DagMlError::CampaignValidation(format!(
336 "aggregation controller `{}` is only valid with custom_controller method",
337 controller.controller_id
338 )));
339 }
340 (_, None) => {}
341 }
342 if !self.store_raw_predictions && !self.store_aggregated_predictions {
343 return Err(DagMlError::CampaignValidation(
344 "aggregation policy must store raw and/or aggregated predictions".to_string(),
345 ));
346 }
347 Ok(())
348 }
349}
350
351fn default_prediction_level() -> PredictionLevel {
352 PredictionLevel::Sample
353}
354
355fn default_aggregation_method() -> AggregationMethod {
356 AggregationMethod::Mean
357}
358
359fn default_aggregation_weights() -> AggregationWeights {
360 AggregationWeights::None
361}
362
363fn default_reduction_role() -> ReductionRole {
364 ReductionRole::FinalOutput
365}
366
367fn default_reduction_axis() -> ReductionAxis {
368 ReductionAxis::Unit
369}
370
371fn default_reduction_input_unit_level() -> EntityUnitLevel {
372 EntityUnitLevel::Observation
373}
374
375fn default_reduction_output_unit_level() -> EntityUnitLevel {
376 EntityUnitLevel::PhysicalSample
377}
378
379fn default_reduction_method() -> ReductionMethod {
380 ReductionMethod::Mean
381}
382
383fn default_reduction_task_compatibility() -> ReductionTaskCompatibility {
384 ReductionTaskCompatibility::Any
385}
386
387fn validate_trim_fraction(value: Option<&serde_json::Value>) -> Result<()> {
388 let Some(value) = value else {
389 return Ok(());
390 };
391 let Some(trim_fraction) = value.as_f64() else {
392 return Err(DagMlError::CampaignValidation(
393 "reduction trim_fraction must be numeric".to_string(),
394 ));
395 };
396 if trim_fraction.is_finite() && (0.0..0.5).contains(&trim_fraction) {
397 Ok(())
398 } else {
399 Err(DagMlError::CampaignValidation(
400 "reduction trim_fraction must be finite and in [0.0, 0.5)".to_string(),
401 ))
402 }
403}
404
405fn validate_outlier_threshold(value: Option<&serde_json::Value>) -> Result<()> {
406 let Some(value) = value else {
407 return Ok(());
408 };
409 let Some(threshold) = value.as_f64() else {
410 return Err(DagMlError::CampaignValidation(
411 "reduction threshold must be numeric".to_string(),
412 ));
413 };
414 if threshold.is_finite() && threshold > 0.0 && threshold < 1.0 {
415 Ok(())
416 } else {
417 Err(DagMlError::CampaignValidation(
418 "reduction threshold must be finite and in (0.0, 1.0)".to_string(),
419 ))
420 }
421}
422
423fn default_json_object() -> serde_json::Value {
424 serde_json::Value::Object(serde_json::Map::new())
425}
426
427fn default_true() -> bool {
428 true
429}
430
431#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
432#[serde(rename_all = "snake_case")]
433pub enum Granularity {
434 Observation,
435 Sample,
436 Target,
437 Group,
438}
439
440#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
441#[serde(rename_all = "snake_case")]
442pub enum FitBoundary {
443 FoldTrain,
444 FoldValidation,
445 FullTrain,
446 Predict,
447}
448
449#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
450#[serde(rename_all = "snake_case")]
451pub enum AugmentationScope {
452 None,
453 TrainOnly,
454 AllPartitions,
455}
456
457#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
458pub struct AugmentationPolicy {
459 #[serde(default = "default_augmentation_scope")]
460 pub sample_scope: AugmentationScope,
461 #[serde(default = "default_augmentation_scope")]
462 pub feature_scope: AugmentationScope,
463 #[serde(default = "default_true")]
464 pub require_origin_id: bool,
465 #[serde(default = "default_true")]
466 pub inherit_group: bool,
467 #[serde(default = "default_true")]
468 pub inherit_target: bool,
469 #[serde(default, skip_serializing_if = "BTreeSet::is_empty")]
470 pub unsafe_flags: BTreeSet<String>,
471}
472
473impl Default for AugmentationPolicy {
474 fn default() -> Self {
475 Self {
476 sample_scope: AugmentationScope::TrainOnly,
477 feature_scope: AugmentationScope::TrainOnly,
478 require_origin_id: true,
479 inherit_group: true,
480 inherit_target: true,
481 unsafe_flags: BTreeSet::new(),
482 }
483 }
484}
485
486impl AugmentationPolicy {
487 pub const ALLOW_SAMPLE_AUGMENTATION_ALL_PARTITIONS: &'static str =
488 "allow_sample_augmentation_all_partitions";
489 pub const ALLOW_SAMPLE_AUGMENTATION_WITHOUT_ORIGIN: &'static str =
490 "allow_sample_augmentation_without_origin";
491 pub const ALLOW_SAMPLE_AUGMENTATION_WITHOUT_GROUP_INHERITANCE: &'static str =
492 "allow_sample_augmentation_without_group_inheritance";
493 pub const ALLOW_SAMPLE_AUGMENTATION_WITHOUT_TARGET_INHERITANCE: &'static str =
494 "allow_sample_augmentation_without_target_inheritance";
495
496 pub fn validate(&self) -> Result<()> {
497 for unsafe_flag in &self.unsafe_flags {
498 if unsafe_flag.trim().is_empty() {
499 return Err(DagMlError::CampaignValidation(
500 "augmentation policy contains an empty unsafe flag".to_string(),
501 ));
502 }
503 }
504 if self.sample_scope == AugmentationScope::AllPartitions
505 && !self
506 .unsafe_flags
507 .contains(Self::ALLOW_SAMPLE_AUGMENTATION_ALL_PARTITIONS)
508 {
509 return Err(DagMlError::CampaignValidation(
510 "sample augmentation over all partitions can leak validation/test origins; add explicit unsafe flag allow_sample_augmentation_all_partitions".to_string(),
511 ));
512 }
513 if self.sample_scope != AugmentationScope::None {
514 if !self.require_origin_id
515 && !self
516 .unsafe_flags
517 .contains(Self::ALLOW_SAMPLE_AUGMENTATION_WITHOUT_ORIGIN)
518 {
519 return Err(DagMlError::CampaignValidation(
520 "sample augmentation requires origin ids unless explicit unsafe flag allow_sample_augmentation_without_origin is present".to_string(),
521 ));
522 }
523 if !self.inherit_group
524 && !self
525 .unsafe_flags
526 .contains(Self::ALLOW_SAMPLE_AUGMENTATION_WITHOUT_GROUP_INHERITANCE)
527 {
528 return Err(DagMlError::CampaignValidation(
529 "sample augmentation must inherit groups unless explicit unsafe flag allow_sample_augmentation_without_group_inheritance is present".to_string(),
530 ));
531 }
532 if !self.inherit_target
533 && !self
534 .unsafe_flags
535 .contains(Self::ALLOW_SAMPLE_AUGMENTATION_WITHOUT_TARGET_INHERITANCE)
536 {
537 return Err(DagMlError::CampaignValidation(
538 "sample augmentation must inherit targets unless explicit unsafe flag allow_sample_augmentation_without_target_inheritance is present".to_string(),
539 ));
540 }
541 }
542 Ok(())
543 }
544}
545
546fn default_augmentation_scope() -> AugmentationScope {
547 AugmentationScope::TrainOnly
548}
549
550#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
551#[serde(rename_all = "snake_case")]
552pub enum FeatureSelectionScope {
553 None,
554 Unsupervised,
555 SupervisedFoldTrain,
556}
557
558#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
559pub struct FeatureSelectionPolicy {
560 #[serde(default = "default_feature_selection_scope")]
561 pub scope: FeatureSelectionScope,
562 #[serde(default = "default_true")]
563 pub store_masks: bool,
564 #[serde(default)]
565 pub allow_schema_mismatch_on_join: bool,
566}
567
568impl Default for FeatureSelectionPolicy {
569 fn default() -> Self {
570 Self {
571 scope: FeatureSelectionScope::None,
572 store_masks: true,
573 allow_schema_mismatch_on_join: false,
574 }
575 }
576}
577
578impl FeatureSelectionPolicy {
579 pub fn validate(&self) -> Result<()> {
580 if self.scope == FeatureSelectionScope::SupervisedFoldTrain && !self.store_masks {
581 return Err(DagMlError::CampaignValidation(
582 "supervised feature selection must store fold/refit masks for replay and leakage audit".to_string(),
583 ));
584 }
585 Ok(())
586 }
587}
588
589fn default_feature_selection_scope() -> FeatureSelectionScope {
590 FeatureSelectionScope::None
591}
592
593#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
594pub struct DataModelShapePlan {
595 pub node_id: NodeId,
596 #[serde(default = "default_granularity")]
597 pub input_granularity: Granularity,
598 #[serde(default = "default_granularity")]
599 pub target_granularity: Granularity,
600 #[serde(default = "default_fit_boundary")]
601 pub fit_rows: FitBoundary,
602 #[serde(default = "default_predict_boundary")]
603 pub predict_rows: FitBoundary,
604 #[serde(default)]
605 pub feature_namespace: Option<String>,
606 #[serde(default)]
607 pub feature_schema_fingerprint: Option<String>,
608 #[serde(default = "default_target_space")]
609 pub target_space: String,
610 #[serde(default)]
611 pub aggregation_policy: AggregationPolicy,
612 #[serde(default)]
613 pub augmentation_policy: AugmentationPolicy,
614 #[serde(default)]
615 pub selection_policy: FeatureSelectionPolicy,
616}
617
618impl DataModelShapePlan {
619 pub fn validate(&self) -> Result<()> {
620 if self.target_space.trim().is_empty() {
621 return Err(DagMlError::CampaignValidation(format!(
622 "shape plan for `{}` has empty target_space",
623 self.node_id
624 )));
625 }
626 if self
627 .feature_namespace
628 .as_ref()
629 .is_some_and(|namespace| namespace.trim().is_empty())
630 {
631 return Err(DagMlError::CampaignValidation(format!(
632 "shape plan for `{}` has empty feature_namespace",
633 self.node_id
634 )));
635 }
636 if self
637 .feature_schema_fingerprint
638 .as_ref()
639 .is_some_and(|fingerprint| !is_hex_fingerprint(fingerprint))
640 {
641 return Err(DagMlError::CampaignValidation(format!(
642 "shape plan for `{}` has invalid feature_schema_fingerprint",
643 self.node_id
644 )));
645 }
646 self.aggregation_policy.validate()?;
647 self.augmentation_policy.validate()?;
648 self.selection_policy.validate()?;
649 if self.selection_policy.scope == FeatureSelectionScope::SupervisedFoldTrain
650 && self.fit_rows != FitBoundary::FoldTrain
651 {
652 return Err(DagMlError::CampaignValidation(format!(
653 "supervised feature selection for `{}` must fit on fold_train",
654 self.node_id
655 )));
656 }
657 Ok(())
658 }
659}
660
661fn is_hex_fingerprint(value: &str) -> bool {
662 value.len() == 64 && value.chars().all(|ch| ch.is_ascii_hexdigit())
663}
664
665fn default_granularity() -> Granularity {
666 Granularity::Sample
667}
668
669fn default_fit_boundary() -> FitBoundary {
670 FitBoundary::FoldTrain
671}
672
673fn default_predict_boundary() -> FitBoundary {
674 FitBoundary::FoldValidation
675}
676
677fn default_target_space() -> String {
678 "raw".to_string()
679}
680
681#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
682#[serde(rename_all = "snake_case")]
683pub enum ShapeDeltaKind {
684 Row,
685 Feature,
686 Target,
687 Prediction,
688}
689
690#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
691pub struct ShapeDelta {
692 pub node_id: NodeId,
693 pub kind: ShapeDeltaKind,
694 pub before_fingerprint: String,
695 pub after_fingerprint: String,
696 #[serde(default)]
697 pub metadata: BTreeMap<String, serde_json::Value>,
698}
699
700impl ShapeDelta {
701 pub fn validate(&self) -> Result<()> {
702 if self.before_fingerprint.trim().is_empty() || self.after_fingerprint.trim().is_empty() {
703 return Err(DagMlError::RuntimeValidation(format!(
704 "shape delta for `{}` has empty fingerprint",
705 self.node_id
706 )));
707 }
708 if self.before_fingerprint == self.after_fingerprint {
709 return Err(DagMlError::RuntimeValidation(format!(
710 "shape delta for `{}` does not change fingerprint",
711 self.node_id
712 )));
713 }
714 for key in self.metadata.keys() {
715 if key.trim().is_empty() {
716 return Err(DagMlError::RuntimeValidation(format!(
717 "shape delta for `{}` contains an empty metadata key",
718 self.node_id
719 )));
720 }
721 }
722 Ok(())
723 }
724}
725
726#[cfg(test)]
727mod tests {
728 use super::*;
729 use crate::ids::NodeId;
730
731 #[test]
732 fn repeated_measurements_default_to_sample_level_aggregation() {
733 let leakage = LeakageUnitPolicy::default();
734 let aggregation = AggregationPolicy::default();
735
736 assert_eq!(leakage.split_unit, SplitUnit::PhysicalSample);
737 assert_eq!(aggregation.aggregation_level, PredictionLevel::Sample);
738 assert!(aggregation.emit_parallel_metrics);
739 }
740
741 #[test]
742 fn observation_split_requires_explicit_unsafe_policy() {
743 let policy = LeakageUnitPolicy {
744 split_unit: SplitUnit::Observation,
745 ..LeakageUnitPolicy::default()
746 };
747
748 assert!(policy.validate().is_err());
749 }
750
751 #[test]
752 fn weighted_aggregation_requires_explicit_weight_policy() {
753 let missing_weights = AggregationPolicy {
754 method: AggregationMethod::WeightedMean,
755 weights: AggregationWeights::None,
756 ..AggregationPolicy::default()
757 };
758 assert!(missing_weights.validate().is_err());
759
760 let stray_weights = AggregationPolicy {
761 method: AggregationMethod::Mean,
762 weights: AggregationWeights::ControllerEmitted,
763 ..AggregationPolicy::default()
764 };
765 assert!(stray_weights.validate().is_err());
766
767 let valid = AggregationPolicy {
768 method: AggregationMethod::WeightedMean,
769 weights: AggregationWeights::ControllerEmitted,
770 ..AggregationPolicy::default()
771 };
772 valid.validate().unwrap();
773 }
774
775 #[test]
776 fn custom_aggregation_requires_controller_spec() {
777 let missing_controller = AggregationPolicy {
778 method: AggregationMethod::CustomController,
779 ..AggregationPolicy::default()
780 };
781 assert!(missing_controller.validate().is_err());
782
783 let stray_controller = AggregationPolicy {
784 custom_controller: Some(AggregationControllerSpec {
785 controller_id: ControllerId::new("controller:agg").unwrap(),
786 params: serde_json::json!({}),
787 }),
788 ..AggregationPolicy::default()
789 };
790 assert!(stray_controller.validate().is_err());
791
792 let valid = AggregationPolicy {
793 method: AggregationMethod::CustomController,
794 weights: AggregationWeights::ControllerEmitted,
795 custom_controller: Some(AggregationControllerSpec {
796 controller_id: ControllerId::new("controller:agg").unwrap(),
797 params: serde_json::json!({ "trim": 0.1 }),
798 }),
799 ..AggregationPolicy::default()
800 };
801 valid.validate().unwrap();
802 }
803
804 #[test]
805 fn reduction_plan_validates_weight_controller_and_task_contracts() {
806 let weighted = ReductionPlan {
807 method: ReductionMethod::WeightedMean,
808 weight_source: AggregationWeights::Quality,
809 ..ReductionPlan::default()
810 };
811 weighted.validate().unwrap();
812
813 let fold_ensemble = ReductionPlan {
814 role: ReductionRole::FoldEnsemble,
815 axis: ReductionAxis::Fold,
816 input_unit_level: EntityUnitLevel::PhysicalSample,
817 output_unit_level: EntityUnitLevel::PhysicalSample,
818 ..ReductionPlan::default()
819 };
820 fold_ensemble.validate().unwrap();
821
822 let model_meta_feature = ReductionPlan {
823 role: ReductionRole::MetaFeature,
824 axis: ReductionAxis::Model,
825 input_unit_level: EntityUnitLevel::PhysicalSample,
826 output_unit_level: EntityUnitLevel::PhysicalSample,
827 ..ReductionPlan::default()
828 };
829 model_meta_feature.validate().unwrap();
830
831 let missing_weight_source = ReductionPlan {
832 method: ReductionMethod::WeightedMean,
833 ..ReductionPlan::default()
834 };
835 assert!(missing_weight_source.validate().is_err());
836
837 let invalid_vote = ReductionPlan {
838 method: ReductionMethod::Vote,
839 task_compatibility: ReductionTaskCompatibility::Regression,
840 ..ReductionPlan::default()
841 };
842 assert!(invalid_vote.validate().is_err());
843
844 let custom = ReductionPlan {
845 method: ReductionMethod::Custom,
846 custom_controller: Some(AggregationControllerSpec {
847 controller_id: ControllerId::new("controller:agg.robust").unwrap(),
848 params: serde_json::json!({ "trim_fraction": 0.2 }),
849 }),
850 params: BTreeMap::from([("trim_fraction".to_string(), serde_json::json!(0.2))]),
851 ..ReductionPlan::default()
852 };
853 custom.validate().unwrap();
854
855 let invalid_trim = ReductionPlan {
856 method: ReductionMethod::RobustMean,
857 params: BTreeMap::from([("trim_fraction".to_string(), serde_json::json!(0.75))]),
858 ..ReductionPlan::default()
859 };
860 assert!(invalid_trim.validate().is_err());
861 }
862
863 #[test]
864 fn supervised_selection_must_fit_on_fold_train() {
865 let plan = DataModelShapePlan {
866 node_id: NodeId::new("model:pls").unwrap(),
867 fit_rows: FitBoundary::FullTrain,
868 selection_policy: FeatureSelectionPolicy {
869 scope: FeatureSelectionScope::SupervisedFoldTrain,
870 ..FeatureSelectionPolicy::default()
871 },
872 ..DataModelShapePlan {
873 node_id: NodeId::new("model:pls").unwrap(),
874 input_granularity: Granularity::Observation,
875 target_granularity: Granularity::Sample,
876 fit_rows: FitBoundary::FoldTrain,
877 predict_rows: FitBoundary::FoldValidation,
878 feature_namespace: None,
879 feature_schema_fingerprint: None,
880 target_space: "raw".to_string(),
881 aggregation_policy: AggregationPolicy::default(),
882 augmentation_policy: AugmentationPolicy::default(),
883 selection_policy: FeatureSelectionPolicy::default(),
884 }
885 };
886
887 assert!(plan.validate().is_err());
888 }
889
890 #[test]
891 fn augmentation_policy_requires_explicit_unsafe_flags_for_leaky_sample_augmentation() {
892 let policy = AugmentationPolicy {
893 sample_scope: AugmentationScope::AllPartitions,
894 ..AugmentationPolicy::default()
895 };
896 assert!(policy.validate().is_err());
897
898 let mut allowed = policy;
899 allowed.unsafe_flags = BTreeSet::from([
900 AugmentationPolicy::ALLOW_SAMPLE_AUGMENTATION_ALL_PARTITIONS.to_string(),
901 ]);
902 allowed.validate().unwrap();
903
904 let no_origin = AugmentationPolicy {
905 require_origin_id: false,
906 ..AugmentationPolicy::default()
907 };
908 assert!(no_origin.validate().is_err());
909 }
910
911 #[test]
912 fn shape_plan_validates_feature_and_selection_audit_contracts() {
913 let node_id = NodeId::new("model:pls").unwrap();
914 let base = DataModelShapePlan {
915 node_id: node_id.clone(),
916 input_granularity: Granularity::Sample,
917 target_granularity: Granularity::Sample,
918 fit_rows: FitBoundary::FoldTrain,
919 predict_rows: FitBoundary::FoldValidation,
920 feature_namespace: None,
921 feature_schema_fingerprint: None,
922 target_space: "raw".to_string(),
923 aggregation_policy: AggregationPolicy::default(),
924 augmentation_policy: AugmentationPolicy::default(),
925 selection_policy: FeatureSelectionPolicy::default(),
926 };
927
928 let mut empty_namespace = base.clone();
929 empty_namespace.feature_namespace = Some(" ".to_string());
930 assert!(empty_namespace.validate().is_err());
931
932 let mut bad_fingerprint = base.clone();
933 bad_fingerprint.feature_schema_fingerprint = Some("short".to_string());
934 assert!(bad_fingerprint.validate().is_err());
935
936 let mut supervised_without_masks = base;
937 supervised_without_masks.selection_policy = FeatureSelectionPolicy {
938 scope: FeatureSelectionScope::SupervisedFoldTrain,
939 store_masks: false,
940 allow_schema_mismatch_on_join: false,
941 };
942 assert!(supervised_without_masks.validate().is_err());
943 }
944
945 #[test]
946 fn shape_delta_requires_a_real_fingerprint_change() {
947 let delta = ShapeDelta {
948 node_id: NodeId::new("transform:select").unwrap(),
949 kind: ShapeDeltaKind::Feature,
950 before_fingerprint: "a".repeat(64),
951 after_fingerprint: "a".repeat(64),
952 metadata: BTreeMap::new(),
953 };
954 assert!(delta.validate().is_err());
955
956 let mut bad_metadata = delta;
957 bad_metadata.after_fingerprint = "b".repeat(64);
958 bad_metadata
959 .metadata
960 .insert(" ".to_string(), serde_json::Value::Bool(true));
961 assert!(bad_metadata.validate().is_err());
962 }
963}