Skip to main content

dag_ml_core/
selection.rs

1use std::cmp::Ordering;
2use std::collections::{BTreeMap, BTreeSet};
3
4use serde::{Deserialize, Serialize};
5
6use crate::error::{DagMlError, Result};
7use crate::oof::PredictionPartition;
8use crate::policy::PredictionLevel;
9use crate::relation::EntityUnitLevel;
10
11pub const SELECTION_POLICY_SCHEMA_VERSION: u32 = 1;
12pub const SELECTION_POLICY_SCHEMA_ID: &str =
13    "https://github.com/GBeurier/dag-ml/schemas/selection_policy.v1.schema.json";
14pub const SELECTION_DECISION_SCHEMA_VERSION: u32 = 1;
15pub const SELECTION_DECISION_SCHEMA_ID: &str =
16    "https://github.com/GBeurier/dag-ml/schemas/selection_decision.v1.schema.json";
17
18#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
19#[serde(rename_all = "snake_case")]
20pub enum MetricObjective {
21    Minimize,
22    Maximize,
23}
24
25#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
26pub struct SelectionMetric {
27    pub name: String,
28    pub objective: MetricObjective,
29}
30
31impl SelectionMetric {
32    pub fn validate(&self) -> Result<()> {
33        if self.name.trim().is_empty() {
34            return Err(DagMlError::CampaignValidation(
35                "selection metric name is empty".to_string(),
36            ));
37        }
38        Ok(())
39    }
40}
41
42#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
43pub struct CandidateScore {
44    pub candidate_id: String,
45    #[serde(default)]
46    pub metrics: BTreeMap<String, f64>,
47    #[serde(default)]
48    pub metadata: BTreeMap<String, serde_json::Value>,
49}
50
51impl CandidateScore {
52    pub fn validate(&self) -> Result<()> {
53        if self.candidate_id.trim().is_empty() {
54            return Err(DagMlError::CampaignValidation(
55                "candidate id is empty".to_string(),
56            ));
57        }
58        for (name, value) in &self.metrics {
59            if name.trim().is_empty() {
60                return Err(DagMlError::CampaignValidation(format!(
61                    "candidate `{}` has an empty metric name",
62                    self.candidate_id
63                )));
64            }
65            if value.is_nan() {
66                return Err(DagMlError::CampaignValidation(format!(
67                    "candidate `{}` metric `{name}` is NaN",
68                    self.candidate_id
69                )));
70            }
71        }
72        Ok(())
73    }
74}
75
76#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
77#[serde(rename_all = "snake_case")]
78pub enum EvaluationScope {
79    Oof,
80    Holdout,
81    Final,
82    Train,
83    Refit,
84}
85
86#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
87pub struct EvaluationResult {
88    pub metric: SelectionMetric,
89    pub partition: PredictionPartition,
90    pub scope: EvaluationScope,
91    #[serde(default, skip_serializing_if = "Option::is_none")]
92    pub reduction_id: Option<String>,
93    #[serde(default, skip_serializing_if = "Option::is_none")]
94    pub unit_level: Option<EntityUnitLevel>,
95}
96
97impl EvaluationResult {
98    pub fn validate(&self) -> Result<()> {
99        self.metric.validate()?;
100        validate_optional_id("evaluation reduction_id", self.reduction_id.as_deref())
101    }
102}
103
104#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
105#[serde(rename_all = "snake_case")]
106pub enum RefitStrategy {
107    RefitOne,
108    RefitEnsemble,
109}
110
111#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
112pub struct RefitSlotPlan {
113    pub strategy: RefitStrategy,
114    pub selection_level: PredictionLevel,
115    pub member_count: usize,
116    pub selection_metric: SelectionMetric,
117    #[serde(default, skip_serializing_if = "Option::is_none")]
118    pub reduction_id: Option<String>,
119}
120
121impl RefitSlotPlan {
122    pub fn validate(&self) -> Result<()> {
123        self.selection_metric.validate()?;
124        if self.member_count == 0 {
125            return Err(DagMlError::CampaignValidation(
126                "refit slot member_count must be positive".to_string(),
127            ));
128        }
129        match self.strategy {
130            RefitStrategy::RefitOne if self.member_count != 1 => {
131                return Err(DagMlError::CampaignValidation(
132                    "refit_one slot requires member_count=1".to_string(),
133                ));
134            }
135            RefitStrategy::RefitEnsemble if self.member_count < 2 => {
136                return Err(DagMlError::CampaignValidation(
137                    "refit_ensemble slot requires member_count>=2".to_string(),
138                ));
139            }
140            _ => {}
141        }
142        validate_optional_id("refit slot reduction_id", self.reduction_id.as_deref())
143    }
144}
145
146#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
147#[serde(rename_all = "snake_case")]
148pub enum MetaRowDomain {
149    Sample,
150    Combo,
151}
152
153#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
154#[serde(rename_all = "snake_case")]
155pub enum MetaTrainingFeatures {
156    Oof,
157}
158
159#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
160#[serde(rename_all = "snake_case")]
161pub enum InferenceFeatures {
162    RefitBasePredictions,
163}
164
165#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
166#[serde(rename_all = "snake_case")]
167pub enum SelectionProtocol {
168    Nested,
169    Holdout,
170    ReuseOof,
171}
172
173#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
174pub struct StackingFitContract {
175    pub meta_training_features: MetaTrainingFeatures,
176    pub inference_features: InferenceFeatures,
177    pub selection_protocol: SelectionProtocol,
178    pub meta_row_domain: MetaRowDomain,
179    #[serde(default, skip_serializing_if = "Option::is_none")]
180    pub final_reduction_id: Option<String>,
181    #[serde(default)]
182    pub unsafe_allow_reuse_oof: bool,
183}
184
185impl StackingFitContract {
186    pub fn validate(&self) -> Result<()> {
187        if self.selection_protocol == SelectionProtocol::ReuseOof && !self.unsafe_allow_reuse_oof {
188            return Err(DagMlError::CampaignValidation(
189                "reuse_oof stacking selection requires unsafe_allow_reuse_oof=true".to_string(),
190            ));
191        }
192        if self.meta_row_domain == MetaRowDomain::Combo && self.final_reduction_id.is_none() {
193            return Err(DagMlError::CampaignValidation(
194                "combo meta_row_domain requires final_reduction_id".to_string(),
195            ));
196        }
197        validate_optional_id(
198            "stacking final_reduction_id",
199            self.final_reduction_id.as_deref(),
200        )
201    }
202}
203
204#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
205pub struct SelectionPolicy {
206    pub id: String,
207    pub metric: SelectionMetric,
208    #[serde(default, skip_serializing_if = "Option::is_none")]
209    pub required_metric_level: Option<PredictionLevel>,
210    #[serde(default = "default_true")]
211    pub require_finite: bool,
212    #[serde(default, skip_serializing_if = "Option::is_none")]
213    pub evaluation_scope: Option<EvaluationScope>,
214    #[serde(default, skip_serializing_if = "Option::is_none")]
215    pub refit_slot_plan: Option<RefitSlotPlan>,
216    #[serde(default, skip_serializing_if = "Option::is_none")]
217    pub stacking_fit_contract: Option<StackingFitContract>,
218    #[serde(default, skip_serializing_if = "Option::is_none")]
219    pub reduction_id: Option<String>,
220}
221
222impl SelectionPolicy {
223    pub fn validate(&self) -> Result<()> {
224        if self.id.trim().is_empty() {
225            return Err(DagMlError::CampaignValidation(
226                "selection policy id is empty".to_string(),
227            ));
228        }
229        self.metric.validate()?;
230        if let Some(refit_slot_plan) = &self.refit_slot_plan {
231            refit_slot_plan.validate()?;
232        }
233        if let Some(stacking_fit_contract) = &self.stacking_fit_contract {
234            stacking_fit_contract.validate()?;
235        }
236        validate_optional_id(
237            "selection policy reduction_id",
238            self.reduction_id.as_deref(),
239        )
240    }
241}
242
243fn default_true() -> bool {
244    true
245}
246
247#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
248pub struct RankedCandidate {
249    pub candidate_id: String,
250    pub score: f64,
251    pub rank: usize,
252}
253
254#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
255pub struct SelectionDecision {
256    pub policy_id: String,
257    pub selected_candidate_id: String,
258    pub metric_name: String,
259    pub objective: MetricObjective,
260    #[serde(default, skip_serializing_if = "Option::is_none")]
261    pub metric_level: Option<PredictionLevel>,
262    #[serde(default, skip_serializing_if = "Option::is_none")]
263    pub evaluation_scope: Option<EvaluationScope>,
264    #[serde(default, skip_serializing_if = "Option::is_none")]
265    pub refit_slot_plan: Option<RefitSlotPlan>,
266    #[serde(default, skip_serializing_if = "Option::is_none")]
267    pub reduction_id: Option<String>,
268    pub selected_score: f64,
269    #[serde(default)]
270    pub ranked_candidates: Vec<RankedCandidate>,
271}
272
273impl SelectionDecision {
274    pub fn validate(&self) -> Result<()> {
275        if self.policy_id.trim().is_empty() {
276            return Err(DagMlError::CampaignValidation(
277                "selection decision policy_id is empty".to_string(),
278            ));
279        }
280        if self.selected_candidate_id.trim().is_empty() {
281            return Err(DagMlError::CampaignValidation(
282                "selection decision selected_candidate_id is empty".to_string(),
283            ));
284        }
285        if self.metric_name.trim().is_empty() {
286            return Err(DagMlError::CampaignValidation(
287                "selection decision metric_name is empty".to_string(),
288            ));
289        }
290        if !self.selected_score.is_finite() {
291            return Err(DagMlError::CampaignValidation(format!(
292                "selection `{}` selected score is not finite",
293                self.policy_id
294            )));
295        }
296        if self.ranked_candidates.is_empty() {
297            return Err(DagMlError::CampaignValidation(format!(
298                "selection `{}` has no ranked candidates",
299                self.policy_id
300            )));
301        }
302        if self.ranked_candidates[0].candidate_id != self.selected_candidate_id {
303            return Err(DagMlError::CampaignValidation(format!(
304                "selection `{}` first ranked candidate does not match selected candidate",
305                self.policy_id
306            )));
307        }
308        if let Some(refit_slot_plan) = &self.refit_slot_plan {
309            refit_slot_plan.validate()?;
310        }
311        validate_optional_id(
312            "selection decision reduction_id",
313            self.reduction_id.as_deref(),
314        )?;
315        let mut seen = BTreeSet::new();
316        for (idx, candidate) in self.ranked_candidates.iter().enumerate() {
317            if candidate.rank != idx + 1 {
318                return Err(DagMlError::CampaignValidation(format!(
319                    "selection `{}` candidate `{}` has rank {}, expected {}",
320                    self.policy_id,
321                    candidate.candidate_id,
322                    candidate.rank,
323                    idx + 1
324                )));
325            }
326            if !seen.insert(candidate.candidate_id.as_str()) {
327                return Err(DagMlError::CampaignValidation(format!(
328                    "selection `{}` contains duplicate candidate `{}`",
329                    self.policy_id, candidate.candidate_id
330                )));
331            }
332        }
333        Ok(())
334    }
335}
336
337pub fn select_candidate(
338    policy: &SelectionPolicy,
339    candidates: &[CandidateScore],
340) -> Result<SelectionDecision> {
341    policy.validate()?;
342    if candidates.is_empty() {
343        return Err(DagMlError::CampaignValidation(format!(
344            "selection policy `{}` has no candidates",
345            policy.id
346        )));
347    }
348
349    let mut scored = Vec::with_capacity(candidates.len());
350    let mut seen = BTreeSet::new();
351    for candidate in candidates {
352        candidate.validate()?;
353        if !seen.insert(candidate.candidate_id.as_str()) {
354            return Err(DagMlError::CampaignValidation(format!(
355                "selection policy `{}` has duplicate candidate `{}`",
356                policy.id, candidate.candidate_id
357            )));
358        }
359        validate_candidate_metric_level(policy, candidate)?;
360        let score = candidate
361            .metrics
362            .get(&policy.metric.name)
363            .copied()
364            .ok_or_else(|| {
365                DagMlError::CampaignValidation(format!(
366                    "candidate `{}` is missing selection metric `{}`",
367                    candidate.candidate_id, policy.metric.name
368                ))
369            })?;
370        if policy.require_finite && !score.is_finite() {
371            return Err(DagMlError::CampaignValidation(format!(
372                "candidate `{}` metric `{}` is not finite",
373                candidate.candidate_id, policy.metric.name
374            )));
375        }
376        scored.push((candidate.candidate_id.clone(), score));
377    }
378
379    scored.sort_by(|left, right| compare_scores(policy.metric.objective, left, right));
380    let ranked_candidates = scored
381        .iter()
382        .enumerate()
383        .map(|(idx, (candidate_id, score))| RankedCandidate {
384            candidate_id: candidate_id.clone(),
385            score: *score,
386            rank: idx + 1,
387        })
388        .collect::<Vec<_>>();
389    let selected = ranked_candidates
390        .first()
391        .expect("candidates were checked as non-empty");
392    let decision = SelectionDecision {
393        policy_id: policy.id.clone(),
394        selected_candidate_id: selected.candidate_id.clone(),
395        metric_name: policy.metric.name.clone(),
396        objective: policy.metric.objective,
397        metric_level: policy.required_metric_level,
398        evaluation_scope: policy.evaluation_scope,
399        refit_slot_plan: policy.refit_slot_plan.clone(),
400        reduction_id: policy.reduction_id.clone(),
401        selected_score: selected.score,
402        ranked_candidates,
403    };
404    decision.validate()?;
405    Ok(decision)
406}
407
408pub fn select_candidate_groups(
409    policy: &SelectionPolicy,
410    candidates: &[CandidateScore],
411    groups: &BTreeMap<String, Vec<String>>,
412) -> Result<BTreeMap<String, SelectionDecision>> {
413    policy.validate()?;
414    let mut by_id = BTreeMap::new();
415    for candidate in candidates {
416        candidate.validate()?;
417        if by_id
418            .insert(candidate.candidate_id.as_str(), candidate)
419            .is_some()
420        {
421            return Err(DagMlError::CampaignValidation(format!(
422                "selection policy `{}` has duplicate candidate `{}`",
423                policy.id, candidate.candidate_id
424            )));
425        }
426    }
427    let mut decisions = BTreeMap::new();
428    for (group_id, candidate_ids) in groups {
429        if group_id.trim().is_empty() {
430            return Err(DagMlError::CampaignValidation(
431                "selection group id is empty".to_string(),
432            ));
433        }
434        if candidate_ids.is_empty() {
435            return Err(DagMlError::CampaignValidation(format!(
436                "selection group `{group_id}` has no candidates"
437            )));
438        }
439        let group_candidates = candidate_ids
440            .iter()
441            .map(|candidate_id| {
442                by_id
443                    .get(candidate_id.as_str())
444                    .cloned()
445                    .cloned()
446                    .ok_or_else(|| {
447                        DagMlError::CampaignValidation(format!(
448                        "selection group `{group_id}` references unknown candidate `{candidate_id}`"
449                    ))
450                    })
451            })
452            .collect::<Result<Vec<_>>>()?;
453        decisions.insert(
454            group_id.clone(),
455            select_candidate(policy, &group_candidates)?,
456        );
457    }
458    Ok(decisions)
459}
460
461fn compare_scores(
462    objective: MetricObjective,
463    left: &(String, f64),
464    right: &(String, f64),
465) -> Ordering {
466    let score_order = match objective {
467        MetricObjective::Minimize => left.1.total_cmp(&right.1),
468        MetricObjective::Maximize => right.1.total_cmp(&left.1),
469    };
470    score_order.then_with(|| left.0.cmp(&right.0))
471}
472
473fn validate_candidate_metric_level(
474    policy: &SelectionPolicy,
475    candidate: &CandidateScore,
476) -> Result<()> {
477    let Some(required_level) = policy.required_metric_level else {
478        return Ok(());
479    };
480    let Some(raw_level) = candidate.metadata.get("metric_level") else {
481        return Err(DagMlError::CampaignValidation(format!(
482            "candidate `{}` is missing required metric_level `{}`",
483            candidate.candidate_id,
484            prediction_level_name(required_level)
485        )));
486    };
487    let actual_level = match raw_level {
488        serde_json::Value::String(value) => parse_prediction_level(value).ok_or_else(|| {
489            DagMlError::CampaignValidation(format!(
490                "candidate `{}` has invalid metric_level `{value}`",
491                candidate.candidate_id
492            ))
493        })?,
494        _ => {
495            return Err(DagMlError::CampaignValidation(format!(
496                "candidate `{}` metric_level must be a string",
497                candidate.candidate_id
498            )));
499        }
500    };
501    if actual_level != required_level {
502        return Err(DagMlError::CampaignValidation(format!(
503            "candidate `{}` metric_level `{}` does not match required `{}`",
504            candidate.candidate_id,
505            prediction_level_name(actual_level),
506            prediction_level_name(required_level)
507        )));
508    }
509    Ok(())
510}
511
512fn parse_prediction_level(value: &str) -> Option<PredictionLevel> {
513    match value {
514        "observation" => Some(PredictionLevel::Observation),
515        "sample" => Some(PredictionLevel::Sample),
516        "target" => Some(PredictionLevel::Target),
517        "group" => Some(PredictionLevel::Group),
518        _ => None,
519    }
520}
521
522fn prediction_level_name(level: PredictionLevel) -> &'static str {
523    match level {
524        PredictionLevel::Observation => "observation",
525        PredictionLevel::Sample => "sample",
526        PredictionLevel::Target => "target",
527        PredictionLevel::Group => "group",
528    }
529}
530
531fn validate_optional_id(label: &str, value: Option<&str>) -> Result<()> {
532    if value.is_some_and(|value| value.trim().is_empty()) {
533        return Err(DagMlError::CampaignValidation(format!(
534            "{label} must not be empty"
535        )));
536    }
537    Ok(())
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543
544    fn rmse_policy() -> SelectionPolicy {
545        SelectionPolicy {
546            id: "select:rmse".to_string(),
547            metric: SelectionMetric {
548                name: "rmse".to_string(),
549                objective: MetricObjective::Minimize,
550            },
551            required_metric_level: None,
552            require_finite: true,
553            evaluation_scope: None,
554            refit_slot_plan: None,
555            stacking_fit_contract: None,
556            reduction_id: None,
557        }
558    }
559
560    fn candidate(id: &str, rmse: f64) -> CandidateScore {
561        CandidateScore {
562            candidate_id: id.to_string(),
563            metrics: BTreeMap::from([("rmse".to_string(), rmse)]),
564            metadata: BTreeMap::new(),
565        }
566    }
567
568    fn candidate_with_level(id: &str, rmse: f64, level: &str) -> CandidateScore {
569        CandidateScore {
570            candidate_id: id.to_string(),
571            metrics: BTreeMap::from([("rmse".to_string(), rmse)]),
572            metadata: BTreeMap::from([(
573                "metric_level".to_string(),
574                serde_json::Value::String(level.to_string()),
575            )]),
576        }
577    }
578
579    #[test]
580    fn selects_lowest_metric_with_deterministic_tie_break() {
581        let decision = select_candidate(
582            &rmse_policy(),
583            &[
584                candidate("model:b", 1.0),
585                candidate("model:a", 1.0),
586                candidate("model:c", 2.0),
587            ],
588        )
589        .unwrap();
590
591        assert_eq!(decision.selected_candidate_id, "model:a");
592        assert_eq!(decision.ranked_candidates[0].rank, 1);
593    }
594
595    #[test]
596    fn grouped_selection_rejects_duplicate_candidate_ids() {
597        assert!(select_candidate_groups(
598            &rmse_policy(),
599            &[candidate("model:a", 1.0), candidate("model:a", 2.0)],
600            &BTreeMap::from([("branch:b0".to_string(), vec!["model:a".to_string()])]),
601        )
602        .is_err());
603    }
604
605    #[test]
606    fn selection_policy_can_require_metric_level() {
607        let mut policy = rmse_policy();
608        policy.required_metric_level = Some(PredictionLevel::Sample);
609
610        let decision = select_candidate(
611            &policy,
612            &[
613                candidate_with_level("model:a", 1.0, "sample"),
614                candidate_with_level("model:b", 2.0, "sample"),
615            ],
616        )
617        .unwrap();
618        assert_eq!(decision.selected_candidate_id, "model:a");
619        assert_eq!(decision.metric_level, Some(PredictionLevel::Sample));
620
621        assert!(select_candidate(
622            &policy,
623            &[
624                candidate_with_level("model:a", 1.0, "sample"),
625                candidate_with_level("model:b", 2.0, "target"),
626            ],
627        )
628        .is_err());
629        assert!(select_candidate(&policy, &[candidate("model:a", 1.0)]).is_err());
630    }
631
632    #[test]
633    fn d9_negative_row_level_metric_cannot_drive_sample_refit() {
634        let mut policy = rmse_policy();
635        policy.required_metric_level = Some(PredictionLevel::Sample);
636
637        let error = select_candidate(
638            &policy,
639            &[candidate_with_level("model:row_metric", 0.1, "observation")],
640        )
641        .unwrap_err()
642        .to_string();
643
644        assert!(
645            error.contains("metric_level `observation` does not match required `sample`"),
646            "unexpected D9 row-vs-sample metric error: {error}"
647        );
648    }
649
650    #[test]
651    fn selection_policy_echoes_evaluation_and_refit_contracts() {
652        let mut policy = rmse_policy();
653        policy.evaluation_scope = Some(EvaluationScope::Oof);
654        policy.reduction_id = Some("reduction:obs_to_sample".to_string());
655        policy.refit_slot_plan = Some(RefitSlotPlan {
656            strategy: RefitStrategy::RefitOne,
657            selection_level: PredictionLevel::Sample,
658            member_count: 1,
659            selection_metric: policy.metric.clone(),
660            reduction_id: Some("reduction:obs_to_sample".to_string()),
661        });
662
663        let decision = select_candidate(
664            &policy,
665            &[candidate("model:a", 1.0), candidate("model:b", 2.0)],
666        )
667        .unwrap();
668
669        assert_eq!(decision.evaluation_scope, Some(EvaluationScope::Oof));
670        assert_eq!(
671            decision.refit_slot_plan.as_ref().unwrap().strategy,
672            RefitStrategy::RefitOne
673        );
674        assert_eq!(
675            decision.reduction_id.as_deref(),
676            Some("reduction:obs_to_sample")
677        );
678
679        let mut invalid_policy = policy;
680        invalid_policy.refit_slot_plan = Some(RefitSlotPlan {
681            strategy: RefitStrategy::RefitEnsemble,
682            selection_level: PredictionLevel::Sample,
683            member_count: 1,
684            selection_metric: invalid_policy.metric.clone(),
685            reduction_id: None,
686        });
687        assert!(select_candidate(&invalid_policy, &[candidate("model:a", 1.0)]).is_err());
688    }
689
690    #[test]
691    fn stacking_fit_contract_guards_oof_reuse_and_combo_reduction() {
692        let valid = StackingFitContract {
693            meta_training_features: MetaTrainingFeatures::Oof,
694            inference_features: InferenceFeatures::RefitBasePredictions,
695            selection_protocol: SelectionProtocol::Nested,
696            meta_row_domain: MetaRowDomain::Combo,
697            final_reduction_id: Some("reduction:combo_to_sample".to_string()),
698            unsafe_allow_reuse_oof: false,
699        };
700        valid.validate().unwrap();
701
702        let missing_reduction = StackingFitContract {
703            final_reduction_id: None,
704            ..valid.clone()
705        };
706        assert!(missing_reduction.validate().is_err());
707
708        let unsafe_reuse_required = StackingFitContract {
709            selection_protocol: SelectionProtocol::ReuseOof,
710            meta_row_domain: MetaRowDomain::Sample,
711            final_reduction_id: None,
712            unsafe_allow_reuse_oof: false,
713            ..valid
714        };
715        assert!(unsafe_reuse_required.validate().is_err());
716    }
717
718    #[test]
719    fn published_selection_schemas_declare_current_contracts() {
720        let policy_schema: serde_json::Value = serde_json::from_str(include_str!(
721            "../../../docs/contracts/selection_policy.schema.json"
722        ))
723        .unwrap();
724        assert_eq!(policy_schema["$id"], SELECTION_POLICY_SCHEMA_ID);
725        assert!(policy_schema["required"]
726            .as_array()
727            .unwrap()
728            .iter()
729            .any(|field| field.as_str() == Some("metric")));
730        assert!(policy_schema["properties"]
731            .get("evaluation_scope")
732            .is_some());
733        assert!(policy_schema["properties"].get("refit_slot_plan").is_some());
734        assert!(policy_schema["properties"]
735            .get("stacking_fit_contract")
736            .is_some());
737
738        let decision_schema: serde_json::Value = serde_json::from_str(include_str!(
739            "../../../docs/contracts/selection_decision.schema.json"
740        ))
741        .unwrap();
742        assert_eq!(decision_schema["$id"], SELECTION_DECISION_SCHEMA_ID);
743        assert!(decision_schema["$defs"]["prediction_level"]["enum"]
744            .as_array()
745            .unwrap()
746            .iter()
747            .any(|level| level.as_str() == Some("group")));
748        assert!(decision_schema["$defs"]["ranked_candidate"]["required"]
749            .as_array()
750            .unwrap()
751            .iter()
752            .any(|field| field.as_str() == Some("rank")));
753        assert!(decision_schema["properties"]
754            .get("evaluation_scope")
755            .is_some());
756        assert!(decision_schema["properties"]
757            .get("refit_slot_plan")
758            .is_some());
759    }
760
761    #[test]
762    fn selects_sklearn_demo_branch_and_merge_variants() {
763        let report: serde_json::Value = serde_json::from_str(include_str!(
764            "../../../examples/generated/sklearn_complex_report.json"
765        ))
766        .unwrap();
767        let branch_metrics = report["branch_variant_metrics"].as_object().unwrap();
768        let candidates = branch_metrics
769            .iter()
770            .map(|(candidate_id, metrics)| CandidateScore {
771                candidate_id: candidate_id.clone(),
772                metrics: metrics
773                    .as_object()
774                    .unwrap()
775                    .iter()
776                    .map(|(name, value)| (name.clone(), value.as_f64().unwrap()))
777                    .collect(),
778                metadata: BTreeMap::new(),
779            })
780            .collect::<Vec<_>>();
781        let groups = BTreeMap::from([
782            (
783                "branch:b0".to_string(),
784                vec![
785                    "branch:b0.variant:pca10_ridge_a03".to_string(),
786                    "branch:b0.variant:pca16_ridge_a12".to_string(),
787                ],
788            ),
789            (
790                "branch:b1".to_string(),
791                vec![
792                    "branch:b1.variant:rf_select_k28".to_string(),
793                    "branch:b1.variant:rf_select_k40".to_string(),
794                ],
795            ),
796            (
797                "branch:b2".to_string(),
798                vec![
799                    "branch:b2.variant:poly_extra_k45".to_string(),
800                    "branch:b2.variant:poly_extra_k80".to_string(),
801                ],
802            ),
803        ]);
804
805        let decisions = select_candidate_groups(&rmse_policy(), &candidates, &groups).unwrap();
806        assert_eq!(
807            decisions["branch:b1"].selected_candidate_id,
808            "branch:b1.variant:rf_select_k40"
809        );
810
811        let merge_metrics = report["merge_variant_metrics"].as_object().unwrap();
812        let merge_candidates = merge_metrics
813            .iter()
814            .map(|(candidate_id, metrics)| CandidateScore {
815                candidate_id: candidate_id.clone(),
816                metrics: metrics
817                    .as_object()
818                    .unwrap()
819                    .iter()
820                    .map(|(name, value)| (name.clone(), value.as_f64().unwrap()))
821                    .collect(),
822                metadata: BTreeMap::new(),
823            })
824            .collect::<Vec<_>>();
825        let merge_decision = select_candidate(&rmse_policy(), &merge_candidates).unwrap();
826        assert_eq!(
827            merge_decision.selected_candidate_id,
828            "merge:m1.pred_meta_original.meta:ridge"
829        );
830    }
831}