Skip to main content

dag_ml_core/
policy.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::error::{DagMlError, Result};
6use crate::ids::{ControllerId, NodeId};
7use crate::relation::EntityUnitLevel;
8
9#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
10#[serde(rename_all = "snake_case")]
11pub enum SplitUnit {
12    PhysicalSample,
13    Observation,
14    Sample,
15    Target,
16    Group,
17}
18
19#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
20pub struct LeakageUnitPolicy {
21    #[serde(default = "default_split_unit")]
22    pub split_unit: SplitUnit,
23    #[serde(default = "default_true")]
24    pub forbid_origin_cross_fold: bool,
25    #[serde(default)]
26    pub allow_observation_split_with_shared_target: bool,
27    #[serde(default)]
28    pub require_group_ids: bool,
29    #[serde(default)]
30    pub unsafe_flags: BTreeSet<String>,
31}
32
33impl Default for LeakageUnitPolicy {
34    fn default() -> Self {
35        Self {
36            split_unit: SplitUnit::PhysicalSample,
37            forbid_origin_cross_fold: true,
38            allow_observation_split_with_shared_target: false,
39            require_group_ids: false,
40            unsafe_flags: BTreeSet::new(),
41        }
42    }
43}
44
45impl LeakageUnitPolicy {
46    pub fn validate(&self) -> Result<()> {
47        if self.split_unit == SplitUnit::Observation
48            && !self.allow_observation_split_with_shared_target
49        {
50            return Err(DagMlError::CampaignValidation(
51                "observation-level splitting is unsafe for repeated X / shared Y unless explicitly allowed".to_string(),
52            ));
53        }
54        if self.require_group_ids && self.split_unit != SplitUnit::Group {
55            return Err(DagMlError::CampaignValidation(
56                "require_group_ids=true requires split_unit=group".to_string(),
57            ));
58        }
59        Ok(())
60    }
61}
62
63fn default_split_unit() -> SplitUnit {
64    SplitUnit::PhysicalSample
65}
66
67#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
68#[serde(rename_all = "snake_case")]
69pub enum PredictionLevel {
70    Observation,
71    Sample,
72    Target,
73    Group,
74}
75
76#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
77#[serde(rename_all = "snake_case")]
78pub enum FitInfluencePolicy {
79    Auto,
80    #[default]
81    UniformRows,
82    EqualSampleInfluence,
83    ResampleEqualized,
84    BackendLossWeight,
85    ScorerOnly,
86    StrictWeightSupport,
87}
88
89#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
90#[serde(rename_all = "snake_case")]
91pub enum AggregationMethod {
92    None,
93    Mean,
94    WeightedMean,
95    Median,
96    Vote,
97    RobustMean,
98    ExcludeOutliers,
99    CustomController,
100}
101
102#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
103#[serde(rename_all = "snake_case")]
104pub enum AggregationWeights {
105    None,
106    Quality,
107    RepetitionCount,
108    ControllerEmitted,
109}
110
111#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
112pub struct AggregationControllerSpec {
113    pub controller_id: ControllerId,
114    #[serde(default = "default_json_object")]
115    pub params: serde_json::Value,
116}
117
118impl AggregationControllerSpec {
119    pub fn validate(&self) -> Result<()> {
120        if self.params.is_null() {
121            return Err(DagMlError::CampaignValidation(format!(
122                "custom aggregation controller `{}` params cannot be null",
123                self.controller_id
124            )));
125        }
126        Ok(())
127    }
128}
129
130#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
131#[serde(rename_all = "snake_case")]
132pub enum ReductionRole {
133    Score,
134    Persist,
135    FoldEnsemble,
136    MetaFeature,
137    FinalOutput,
138}
139
140#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
141#[serde(rename_all = "snake_case")]
142pub enum ReductionAxis {
143    Unit,
144    Fold,
145    Model,
146    Metric,
147}
148
149#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
150#[serde(rename_all = "snake_case")]
151pub enum ReductionMethod {
152    Mean,
153    WeightedMean,
154    Median,
155    Vote,
156    RobustMean,
157    ExcludeOutliers,
158    Custom,
159}
160
161impl From<AggregationMethod> for ReductionMethod {
162    fn from(method: AggregationMethod) -> Self {
163        match method {
164            AggregationMethod::None | AggregationMethod::Mean => Self::Mean,
165            AggregationMethod::WeightedMean => Self::WeightedMean,
166            AggregationMethod::Median => Self::Median,
167            AggregationMethod::Vote => Self::Vote,
168            AggregationMethod::RobustMean => Self::RobustMean,
169            AggregationMethod::ExcludeOutliers => Self::ExcludeOutliers,
170            AggregationMethod::CustomController => Self::Custom,
171        }
172    }
173}
174
175#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
176#[serde(rename_all = "snake_case")]
177pub enum ReductionTaskCompatibility {
178    Any,
179    Regression,
180    Classification,
181}
182
183#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
184pub struct ReductionPlan {
185    #[serde(default = "default_reduction_role")]
186    pub role: ReductionRole,
187    #[serde(default = "default_reduction_axis")]
188    pub axis: ReductionAxis,
189    #[serde(default = "default_reduction_input_unit_level")]
190    pub input_unit_level: EntityUnitLevel,
191    #[serde(default = "default_reduction_output_unit_level")]
192    pub output_unit_level: EntityUnitLevel,
193    #[serde(default = "default_reduction_method")]
194    pub method: ReductionMethod,
195    #[serde(default = "default_aggregation_weights")]
196    pub weight_source: AggregationWeights,
197    #[serde(default = "default_reduction_task_compatibility")]
198    pub task_compatibility: ReductionTaskCompatibility,
199    #[serde(default, skip_serializing_if = "Option::is_none")]
200    pub custom_controller: Option<AggregationControllerSpec>,
201    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
202    pub params: BTreeMap<String, serde_json::Value>,
203}
204
205impl Default for ReductionPlan {
206    fn default() -> Self {
207        Self {
208            role: default_reduction_role(),
209            axis: default_reduction_axis(),
210            input_unit_level: default_reduction_input_unit_level(),
211            output_unit_level: default_reduction_output_unit_level(),
212            method: default_reduction_method(),
213            weight_source: default_aggregation_weights(),
214            task_compatibility: default_reduction_task_compatibility(),
215            custom_controller: None,
216            params: BTreeMap::new(),
217        }
218    }
219}
220
221impl ReductionPlan {
222    pub fn validate(&self) -> Result<()> {
223        if self.method == ReductionMethod::WeightedMean
224            && self.weight_source == AggregationWeights::None
225        {
226            return Err(DagMlError::CampaignValidation(
227                "weighted_mean reduction requires an explicit weight_source".to_string(),
228            ));
229        }
230        if self.method != ReductionMethod::WeightedMean
231            && self.method != ReductionMethod::Custom
232            && self.weight_source != AggregationWeights::None
233        {
234            return Err(DagMlError::CampaignValidation(format!(
235                "reduction weight_source {:?} is only valid with weighted_mean or custom",
236                self.weight_source
237            )));
238        }
239        match (&self.method, &self.custom_controller) {
240            (ReductionMethod::Custom, Some(controller)) => controller.validate()?,
241            (ReductionMethod::Custom, None) => {
242                return Err(DagMlError::CampaignValidation(
243                    "custom reduction requires a custom_controller spec".to_string(),
244                ));
245            }
246            (_, Some(controller)) => {
247                return Err(DagMlError::CampaignValidation(format!(
248                    "reduction controller `{}` is only valid with custom method",
249                    controller.controller_id
250                )));
251            }
252            (_, None) => {}
253        }
254        if self.method == ReductionMethod::Vote
255            && self.task_compatibility == ReductionTaskCompatibility::Regression
256        {
257            return Err(DagMlError::CampaignValidation(
258                "vote reduction is not compatible with regression tasks".to_string(),
259            ));
260        }
261        validate_trim_fraction(self.params.get("trim_fraction"))?;
262        validate_outlier_threshold(self.params.get("threshold"))?;
263        Ok(())
264    }
265}
266
267#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
268pub struct AggregationPolicy {
269    #[serde(default = "default_prediction_level")]
270    pub aggregation_level: PredictionLevel,
271    #[serde(default = "default_aggregation_method")]
272    pub method: AggregationMethod,
273    #[serde(default = "default_aggregation_weights")]
274    pub weights: AggregationWeights,
275    #[serde(default, skip_serializing_if = "Option::is_none")]
276    pub custom_controller: Option<AggregationControllerSpec>,
277    #[serde(default = "default_true")]
278    pub emit_parallel_metrics: bool,
279    #[serde(default = "default_prediction_level")]
280    pub selection_metric_level: PredictionLevel,
281    #[serde(default = "default_true")]
282    pub store_raw_predictions: bool,
283    #[serde(default = "default_true")]
284    pub store_aggregated_predictions: bool,
285}
286
287impl Default for AggregationPolicy {
288    fn default() -> Self {
289        Self {
290            aggregation_level: PredictionLevel::Sample,
291            method: AggregationMethod::Mean,
292            weights: AggregationWeights::None,
293            custom_controller: None,
294            emit_parallel_metrics: true,
295            selection_metric_level: PredictionLevel::Sample,
296            store_raw_predictions: true,
297            store_aggregated_predictions: true,
298        }
299    }
300}
301
302impl AggregationPolicy {
303    pub fn validate(&self) -> Result<()> {
304        if self.method == AggregationMethod::None
305            && self.aggregation_level != PredictionLevel::Observation
306        {
307            return Err(DagMlError::CampaignValidation(
308                "aggregation method none is only valid at observation level".to_string(),
309            ));
310        }
311        if self.method == AggregationMethod::WeightedMean
312            && self.weights == AggregationWeights::None
313        {
314            return Err(DagMlError::CampaignValidation(
315                "weighted_mean aggregation requires an explicit weights policy".to_string(),
316            ));
317        }
318        if self.method != AggregationMethod::WeightedMean
319            && self.method != AggregationMethod::CustomController
320            && self.weights != AggregationWeights::None
321        {
322            return Err(DagMlError::CampaignValidation(format!(
323                "aggregation weights {:?} are only valid with weighted_mean",
324                self.weights
325            )));
326        }
327        match (&self.method, &self.custom_controller) {
328            (AggregationMethod::CustomController, Some(controller)) => controller.validate()?,
329            (AggregationMethod::CustomController, None) => {
330                return Err(DagMlError::CampaignValidation(
331                    "custom_controller aggregation requires a custom_controller spec".to_string(),
332                ));
333            }
334            (_, Some(controller)) => {
335                return Err(DagMlError::CampaignValidation(format!(
336                    "aggregation controller `{}` is only valid with custom_controller method",
337                    controller.controller_id
338                )));
339            }
340            (_, None) => {}
341        }
342        if !self.store_raw_predictions && !self.store_aggregated_predictions {
343            return Err(DagMlError::CampaignValidation(
344                "aggregation policy must store raw and/or aggregated predictions".to_string(),
345            ));
346        }
347        Ok(())
348    }
349}
350
351fn default_prediction_level() -> PredictionLevel {
352    PredictionLevel::Sample
353}
354
355fn default_aggregation_method() -> AggregationMethod {
356    AggregationMethod::Mean
357}
358
359fn default_aggregation_weights() -> AggregationWeights {
360    AggregationWeights::None
361}
362
363fn default_reduction_role() -> ReductionRole {
364    ReductionRole::FinalOutput
365}
366
367fn default_reduction_axis() -> ReductionAxis {
368    ReductionAxis::Unit
369}
370
371fn default_reduction_input_unit_level() -> EntityUnitLevel {
372    EntityUnitLevel::Observation
373}
374
375fn default_reduction_output_unit_level() -> EntityUnitLevel {
376    EntityUnitLevel::PhysicalSample
377}
378
379fn default_reduction_method() -> ReductionMethod {
380    ReductionMethod::Mean
381}
382
383fn default_reduction_task_compatibility() -> ReductionTaskCompatibility {
384    ReductionTaskCompatibility::Any
385}
386
387fn validate_trim_fraction(value: Option<&serde_json::Value>) -> Result<()> {
388    let Some(value) = value else {
389        return Ok(());
390    };
391    let Some(trim_fraction) = value.as_f64() else {
392        return Err(DagMlError::CampaignValidation(
393            "reduction trim_fraction must be numeric".to_string(),
394        ));
395    };
396    if trim_fraction.is_finite() && (0.0..0.5).contains(&trim_fraction) {
397        Ok(())
398    } else {
399        Err(DagMlError::CampaignValidation(
400            "reduction trim_fraction must be finite and in [0.0, 0.5)".to_string(),
401        ))
402    }
403}
404
405fn validate_outlier_threshold(value: Option<&serde_json::Value>) -> Result<()> {
406    let Some(value) = value else {
407        return Ok(());
408    };
409    let Some(threshold) = value.as_f64() else {
410        return Err(DagMlError::CampaignValidation(
411            "reduction threshold must be numeric".to_string(),
412        ));
413    };
414    if threshold.is_finite() && threshold > 0.0 && threshold < 1.0 {
415        Ok(())
416    } else {
417        Err(DagMlError::CampaignValidation(
418            "reduction threshold must be finite and in (0.0, 1.0)".to_string(),
419        ))
420    }
421}
422
423fn default_json_object() -> serde_json::Value {
424    serde_json::Value::Object(serde_json::Map::new())
425}
426
427fn default_true() -> bool {
428    true
429}
430
431#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
432#[serde(rename_all = "snake_case")]
433pub enum Granularity {
434    Observation,
435    Sample,
436    Target,
437    Group,
438}
439
440#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
441#[serde(rename_all = "snake_case")]
442pub enum FitBoundary {
443    FoldTrain,
444    FoldValidation,
445    FullTrain,
446    Predict,
447}
448
449#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
450#[serde(rename_all = "snake_case")]
451pub enum AugmentationScope {
452    None,
453    TrainOnly,
454    AllPartitions,
455}
456
457#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
458pub struct AugmentationPolicy {
459    #[serde(default = "default_augmentation_scope")]
460    pub sample_scope: AugmentationScope,
461    #[serde(default = "default_augmentation_scope")]
462    pub feature_scope: AugmentationScope,
463    #[serde(default = "default_true")]
464    pub require_origin_id: bool,
465    #[serde(default = "default_true")]
466    pub inherit_group: bool,
467    #[serde(default = "default_true")]
468    pub inherit_target: bool,
469    #[serde(default, skip_serializing_if = "BTreeSet::is_empty")]
470    pub unsafe_flags: BTreeSet<String>,
471}
472
473impl Default for AugmentationPolicy {
474    fn default() -> Self {
475        Self {
476            sample_scope: AugmentationScope::TrainOnly,
477            feature_scope: AugmentationScope::TrainOnly,
478            require_origin_id: true,
479            inherit_group: true,
480            inherit_target: true,
481            unsafe_flags: BTreeSet::new(),
482        }
483    }
484}
485
486impl AugmentationPolicy {
487    pub const ALLOW_SAMPLE_AUGMENTATION_ALL_PARTITIONS: &'static str =
488        "allow_sample_augmentation_all_partitions";
489    pub const ALLOW_SAMPLE_AUGMENTATION_WITHOUT_ORIGIN: &'static str =
490        "allow_sample_augmentation_without_origin";
491    pub const ALLOW_SAMPLE_AUGMENTATION_WITHOUT_GROUP_INHERITANCE: &'static str =
492        "allow_sample_augmentation_without_group_inheritance";
493    pub const ALLOW_SAMPLE_AUGMENTATION_WITHOUT_TARGET_INHERITANCE: &'static str =
494        "allow_sample_augmentation_without_target_inheritance";
495
496    pub fn validate(&self) -> Result<()> {
497        for unsafe_flag in &self.unsafe_flags {
498            if unsafe_flag.trim().is_empty() {
499                return Err(DagMlError::CampaignValidation(
500                    "augmentation policy contains an empty unsafe flag".to_string(),
501                ));
502            }
503        }
504        if self.sample_scope == AugmentationScope::AllPartitions
505            && !self
506                .unsafe_flags
507                .contains(Self::ALLOW_SAMPLE_AUGMENTATION_ALL_PARTITIONS)
508        {
509            return Err(DagMlError::CampaignValidation(
510                "sample augmentation over all partitions can leak validation/test origins; add explicit unsafe flag allow_sample_augmentation_all_partitions".to_string(),
511            ));
512        }
513        if self.sample_scope != AugmentationScope::None {
514            if !self.require_origin_id
515                && !self
516                    .unsafe_flags
517                    .contains(Self::ALLOW_SAMPLE_AUGMENTATION_WITHOUT_ORIGIN)
518            {
519                return Err(DagMlError::CampaignValidation(
520                    "sample augmentation requires origin ids unless explicit unsafe flag allow_sample_augmentation_without_origin is present".to_string(),
521                ));
522            }
523            if !self.inherit_group
524                && !self
525                    .unsafe_flags
526                    .contains(Self::ALLOW_SAMPLE_AUGMENTATION_WITHOUT_GROUP_INHERITANCE)
527            {
528                return Err(DagMlError::CampaignValidation(
529                    "sample augmentation must inherit groups unless explicit unsafe flag allow_sample_augmentation_without_group_inheritance is present".to_string(),
530                ));
531            }
532            if !self.inherit_target
533                && !self
534                    .unsafe_flags
535                    .contains(Self::ALLOW_SAMPLE_AUGMENTATION_WITHOUT_TARGET_INHERITANCE)
536            {
537                return Err(DagMlError::CampaignValidation(
538                    "sample augmentation must inherit targets unless explicit unsafe flag allow_sample_augmentation_without_target_inheritance is present".to_string(),
539                ));
540            }
541        }
542        Ok(())
543    }
544}
545
546fn default_augmentation_scope() -> AugmentationScope {
547    AugmentationScope::TrainOnly
548}
549
550#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
551#[serde(rename_all = "snake_case")]
552pub enum FeatureSelectionScope {
553    None,
554    Unsupervised,
555    SupervisedFoldTrain,
556}
557
558#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
559pub struct FeatureSelectionPolicy {
560    #[serde(default = "default_feature_selection_scope")]
561    pub scope: FeatureSelectionScope,
562    #[serde(default = "default_true")]
563    pub store_masks: bool,
564    #[serde(default)]
565    pub allow_schema_mismatch_on_join: bool,
566}
567
568impl Default for FeatureSelectionPolicy {
569    fn default() -> Self {
570        Self {
571            scope: FeatureSelectionScope::None,
572            store_masks: true,
573            allow_schema_mismatch_on_join: false,
574        }
575    }
576}
577
578impl FeatureSelectionPolicy {
579    pub fn validate(&self) -> Result<()> {
580        if self.scope == FeatureSelectionScope::SupervisedFoldTrain && !self.store_masks {
581            return Err(DagMlError::CampaignValidation(
582                "supervised feature selection must store fold/refit masks for replay and leakage audit".to_string(),
583            ));
584        }
585        Ok(())
586    }
587}
588
589fn default_feature_selection_scope() -> FeatureSelectionScope {
590    FeatureSelectionScope::None
591}
592
593#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
594pub struct DataModelShapePlan {
595    pub node_id: NodeId,
596    #[serde(default = "default_granularity")]
597    pub input_granularity: Granularity,
598    #[serde(default = "default_granularity")]
599    pub target_granularity: Granularity,
600    #[serde(default = "default_fit_boundary")]
601    pub fit_rows: FitBoundary,
602    #[serde(default = "default_predict_boundary")]
603    pub predict_rows: FitBoundary,
604    #[serde(default)]
605    pub feature_namespace: Option<String>,
606    #[serde(default)]
607    pub feature_schema_fingerprint: Option<String>,
608    #[serde(default = "default_target_space")]
609    pub target_space: String,
610    #[serde(default)]
611    pub aggregation_policy: AggregationPolicy,
612    #[serde(default)]
613    pub augmentation_policy: AugmentationPolicy,
614    #[serde(default)]
615    pub selection_policy: FeatureSelectionPolicy,
616}
617
618impl DataModelShapePlan {
619    pub fn validate(&self) -> Result<()> {
620        if self.target_space.trim().is_empty() {
621            return Err(DagMlError::CampaignValidation(format!(
622                "shape plan for `{}` has empty target_space",
623                self.node_id
624            )));
625        }
626        if self
627            .feature_namespace
628            .as_ref()
629            .is_some_and(|namespace| namespace.trim().is_empty())
630        {
631            return Err(DagMlError::CampaignValidation(format!(
632                "shape plan for `{}` has empty feature_namespace",
633                self.node_id
634            )));
635        }
636        if self
637            .feature_schema_fingerprint
638            .as_ref()
639            .is_some_and(|fingerprint| !is_hex_fingerprint(fingerprint))
640        {
641            return Err(DagMlError::CampaignValidation(format!(
642                "shape plan for `{}` has invalid feature_schema_fingerprint",
643                self.node_id
644            )));
645        }
646        self.aggregation_policy.validate()?;
647        self.augmentation_policy.validate()?;
648        self.selection_policy.validate()?;
649        if self.selection_policy.scope == FeatureSelectionScope::SupervisedFoldTrain
650            && self.fit_rows != FitBoundary::FoldTrain
651        {
652            return Err(DagMlError::CampaignValidation(format!(
653                "supervised feature selection for `{}` must fit on fold_train",
654                self.node_id
655            )));
656        }
657        Ok(())
658    }
659}
660
661fn is_hex_fingerprint(value: &str) -> bool {
662    value.len() == 64 && value.chars().all(|ch| ch.is_ascii_hexdigit())
663}
664
665fn default_granularity() -> Granularity {
666    Granularity::Sample
667}
668
669fn default_fit_boundary() -> FitBoundary {
670    FitBoundary::FoldTrain
671}
672
673fn default_predict_boundary() -> FitBoundary {
674    FitBoundary::FoldValidation
675}
676
677fn default_target_space() -> String {
678    "raw".to_string()
679}
680
681#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
682#[serde(rename_all = "snake_case")]
683pub enum ShapeDeltaKind {
684    Row,
685    Feature,
686    Target,
687    Prediction,
688}
689
690#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
691pub struct ShapeDelta {
692    pub node_id: NodeId,
693    pub kind: ShapeDeltaKind,
694    pub before_fingerprint: String,
695    pub after_fingerprint: String,
696    #[serde(default)]
697    pub metadata: BTreeMap<String, serde_json::Value>,
698}
699
700impl ShapeDelta {
701    pub fn validate(&self) -> Result<()> {
702        if self.before_fingerprint.trim().is_empty() || self.after_fingerprint.trim().is_empty() {
703            return Err(DagMlError::RuntimeValidation(format!(
704                "shape delta for `{}` has empty fingerprint",
705                self.node_id
706            )));
707        }
708        if self.before_fingerprint == self.after_fingerprint {
709            return Err(DagMlError::RuntimeValidation(format!(
710                "shape delta for `{}` does not change fingerprint",
711                self.node_id
712            )));
713        }
714        for key in self.metadata.keys() {
715            if key.trim().is_empty() {
716                return Err(DagMlError::RuntimeValidation(format!(
717                    "shape delta for `{}` contains an empty metadata key",
718                    self.node_id
719                )));
720            }
721        }
722        Ok(())
723    }
724}
725
726#[cfg(test)]
727mod tests {
728    use super::*;
729    use crate::ids::NodeId;
730
731    #[test]
732    fn repeated_measurements_default_to_sample_level_aggregation() {
733        let leakage = LeakageUnitPolicy::default();
734        let aggregation = AggregationPolicy::default();
735
736        assert_eq!(leakage.split_unit, SplitUnit::PhysicalSample);
737        assert_eq!(aggregation.aggregation_level, PredictionLevel::Sample);
738        assert!(aggregation.emit_parallel_metrics);
739    }
740
741    #[test]
742    fn observation_split_requires_explicit_unsafe_policy() {
743        let policy = LeakageUnitPolicy {
744            split_unit: SplitUnit::Observation,
745            ..LeakageUnitPolicy::default()
746        };
747
748        assert!(policy.validate().is_err());
749    }
750
751    #[test]
752    fn weighted_aggregation_requires_explicit_weight_policy() {
753        let missing_weights = AggregationPolicy {
754            method: AggregationMethod::WeightedMean,
755            weights: AggregationWeights::None,
756            ..AggregationPolicy::default()
757        };
758        assert!(missing_weights.validate().is_err());
759
760        let stray_weights = AggregationPolicy {
761            method: AggregationMethod::Mean,
762            weights: AggregationWeights::ControllerEmitted,
763            ..AggregationPolicy::default()
764        };
765        assert!(stray_weights.validate().is_err());
766
767        let valid = AggregationPolicy {
768            method: AggregationMethod::WeightedMean,
769            weights: AggregationWeights::ControllerEmitted,
770            ..AggregationPolicy::default()
771        };
772        valid.validate().unwrap();
773    }
774
775    #[test]
776    fn custom_aggregation_requires_controller_spec() {
777        let missing_controller = AggregationPolicy {
778            method: AggregationMethod::CustomController,
779            ..AggregationPolicy::default()
780        };
781        assert!(missing_controller.validate().is_err());
782
783        let stray_controller = AggregationPolicy {
784            custom_controller: Some(AggregationControllerSpec {
785                controller_id: ControllerId::new("controller:agg").unwrap(),
786                params: serde_json::json!({}),
787            }),
788            ..AggregationPolicy::default()
789        };
790        assert!(stray_controller.validate().is_err());
791
792        let valid = AggregationPolicy {
793            method: AggregationMethod::CustomController,
794            weights: AggregationWeights::ControllerEmitted,
795            custom_controller: Some(AggregationControllerSpec {
796                controller_id: ControllerId::new("controller:agg").unwrap(),
797                params: serde_json::json!({ "trim": 0.1 }),
798            }),
799            ..AggregationPolicy::default()
800        };
801        valid.validate().unwrap();
802    }
803
804    #[test]
805    fn reduction_plan_validates_weight_controller_and_task_contracts() {
806        let weighted = ReductionPlan {
807            method: ReductionMethod::WeightedMean,
808            weight_source: AggregationWeights::Quality,
809            ..ReductionPlan::default()
810        };
811        weighted.validate().unwrap();
812
813        let fold_ensemble = ReductionPlan {
814            role: ReductionRole::FoldEnsemble,
815            axis: ReductionAxis::Fold,
816            input_unit_level: EntityUnitLevel::PhysicalSample,
817            output_unit_level: EntityUnitLevel::PhysicalSample,
818            ..ReductionPlan::default()
819        };
820        fold_ensemble.validate().unwrap();
821
822        let model_meta_feature = ReductionPlan {
823            role: ReductionRole::MetaFeature,
824            axis: ReductionAxis::Model,
825            input_unit_level: EntityUnitLevel::PhysicalSample,
826            output_unit_level: EntityUnitLevel::PhysicalSample,
827            ..ReductionPlan::default()
828        };
829        model_meta_feature.validate().unwrap();
830
831        let missing_weight_source = ReductionPlan {
832            method: ReductionMethod::WeightedMean,
833            ..ReductionPlan::default()
834        };
835        assert!(missing_weight_source.validate().is_err());
836
837        let invalid_vote = ReductionPlan {
838            method: ReductionMethod::Vote,
839            task_compatibility: ReductionTaskCompatibility::Regression,
840            ..ReductionPlan::default()
841        };
842        assert!(invalid_vote.validate().is_err());
843
844        let custom = ReductionPlan {
845            method: ReductionMethod::Custom,
846            custom_controller: Some(AggregationControllerSpec {
847                controller_id: ControllerId::new("controller:agg.robust").unwrap(),
848                params: serde_json::json!({ "trim_fraction": 0.2 }),
849            }),
850            params: BTreeMap::from([("trim_fraction".to_string(), serde_json::json!(0.2))]),
851            ..ReductionPlan::default()
852        };
853        custom.validate().unwrap();
854
855        let invalid_trim = ReductionPlan {
856            method: ReductionMethod::RobustMean,
857            params: BTreeMap::from([("trim_fraction".to_string(), serde_json::json!(0.75))]),
858            ..ReductionPlan::default()
859        };
860        assert!(invalid_trim.validate().is_err());
861    }
862
863    #[test]
864    fn supervised_selection_must_fit_on_fold_train() {
865        let plan = DataModelShapePlan {
866            node_id: NodeId::new("model:pls").unwrap(),
867            fit_rows: FitBoundary::FullTrain,
868            selection_policy: FeatureSelectionPolicy {
869                scope: FeatureSelectionScope::SupervisedFoldTrain,
870                ..FeatureSelectionPolicy::default()
871            },
872            ..DataModelShapePlan {
873                node_id: NodeId::new("model:pls").unwrap(),
874                input_granularity: Granularity::Observation,
875                target_granularity: Granularity::Sample,
876                fit_rows: FitBoundary::FoldTrain,
877                predict_rows: FitBoundary::FoldValidation,
878                feature_namespace: None,
879                feature_schema_fingerprint: None,
880                target_space: "raw".to_string(),
881                aggregation_policy: AggregationPolicy::default(),
882                augmentation_policy: AugmentationPolicy::default(),
883                selection_policy: FeatureSelectionPolicy::default(),
884            }
885        };
886
887        assert!(plan.validate().is_err());
888    }
889
890    #[test]
891    fn augmentation_policy_requires_explicit_unsafe_flags_for_leaky_sample_augmentation() {
892        let policy = AugmentationPolicy {
893            sample_scope: AugmentationScope::AllPartitions,
894            ..AugmentationPolicy::default()
895        };
896        assert!(policy.validate().is_err());
897
898        let mut allowed = policy;
899        allowed.unsafe_flags = BTreeSet::from([
900            AugmentationPolicy::ALLOW_SAMPLE_AUGMENTATION_ALL_PARTITIONS.to_string(),
901        ]);
902        allowed.validate().unwrap();
903
904        let no_origin = AugmentationPolicy {
905            require_origin_id: false,
906            ..AugmentationPolicy::default()
907        };
908        assert!(no_origin.validate().is_err());
909    }
910
911    #[test]
912    fn shape_plan_validates_feature_and_selection_audit_contracts() {
913        let node_id = NodeId::new("model:pls").unwrap();
914        let base = DataModelShapePlan {
915            node_id: node_id.clone(),
916            input_granularity: Granularity::Sample,
917            target_granularity: Granularity::Sample,
918            fit_rows: FitBoundary::FoldTrain,
919            predict_rows: FitBoundary::FoldValidation,
920            feature_namespace: None,
921            feature_schema_fingerprint: None,
922            target_space: "raw".to_string(),
923            aggregation_policy: AggregationPolicy::default(),
924            augmentation_policy: AugmentationPolicy::default(),
925            selection_policy: FeatureSelectionPolicy::default(),
926        };
927
928        let mut empty_namespace = base.clone();
929        empty_namespace.feature_namespace = Some(" ".to_string());
930        assert!(empty_namespace.validate().is_err());
931
932        let mut bad_fingerprint = base.clone();
933        bad_fingerprint.feature_schema_fingerprint = Some("short".to_string());
934        assert!(bad_fingerprint.validate().is_err());
935
936        let mut supervised_without_masks = base;
937        supervised_without_masks.selection_policy = FeatureSelectionPolicy {
938            scope: FeatureSelectionScope::SupervisedFoldTrain,
939            store_masks: false,
940            allow_schema_mismatch_on_join: false,
941        };
942        assert!(supervised_without_masks.validate().is_err());
943    }
944
945    #[test]
946    fn shape_delta_requires_a_real_fingerprint_change() {
947        let delta = ShapeDelta {
948            node_id: NodeId::new("transform:select").unwrap(),
949            kind: ShapeDeltaKind::Feature,
950            before_fingerprint: "a".repeat(64),
951            after_fingerprint: "a".repeat(64),
952            metadata: BTreeMap::new(),
953        };
954        assert!(delta.validate().is_err());
955
956        let mut bad_metadata = delta;
957        bad_metadata.after_fingerprint = "b".repeat(64);
958        bad_metadata
959            .metadata
960            .insert(" ".to_string(), serde_json::Value::Bool(true));
961        assert!(bad_metadata.validate().is_err());
962    }
963}