1use std::cmp::Ordering;
2use std::collections::{BTreeMap, BTreeSet};
3
4use serde::{Deserialize, Serialize};
5
6use crate::error::{DagMlError, Result};
7use crate::oof::PredictionPartition;
8use crate::policy::PredictionLevel;
9use crate::relation::EntityUnitLevel;
10
11pub const SELECTION_POLICY_SCHEMA_VERSION: u32 = 1;
12pub const SELECTION_POLICY_SCHEMA_ID: &str =
13 "https://github.com/GBeurier/dag-ml/schemas/selection_policy.v1.schema.json";
14pub const SELECTION_DECISION_SCHEMA_VERSION: u32 = 1;
15pub const SELECTION_DECISION_SCHEMA_ID: &str =
16 "https://github.com/GBeurier/dag-ml/schemas/selection_decision.v1.schema.json";
17
18#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
19#[serde(rename_all = "snake_case")]
20pub enum MetricObjective {
21 Minimize,
22 Maximize,
23}
24
25#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
26pub struct SelectionMetric {
27 pub name: String,
28 pub objective: MetricObjective,
29}
30
31impl SelectionMetric {
32 pub fn validate(&self) -> Result<()> {
33 if self.name.trim().is_empty() {
34 return Err(DagMlError::CampaignValidation(
35 "selection metric name is empty".to_string(),
36 ));
37 }
38 Ok(())
39 }
40}
41
42#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
43pub struct CandidateScore {
44 pub candidate_id: String,
45 #[serde(default)]
46 pub metrics: BTreeMap<String, f64>,
47 #[serde(default)]
48 pub metadata: BTreeMap<String, serde_json::Value>,
49}
50
51impl CandidateScore {
52 pub fn validate(&self) -> Result<()> {
53 if self.candidate_id.trim().is_empty() {
54 return Err(DagMlError::CampaignValidation(
55 "candidate id is empty".to_string(),
56 ));
57 }
58 for (name, value) in &self.metrics {
59 if name.trim().is_empty() {
60 return Err(DagMlError::CampaignValidation(format!(
61 "candidate `{}` has an empty metric name",
62 self.candidate_id
63 )));
64 }
65 if value.is_nan() {
66 return Err(DagMlError::CampaignValidation(format!(
67 "candidate `{}` metric `{name}` is NaN",
68 self.candidate_id
69 )));
70 }
71 }
72 Ok(())
73 }
74}
75
76#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
77#[serde(rename_all = "snake_case")]
78pub enum EvaluationScope {
79 Oof,
80 Holdout,
81 Final,
82 Train,
83 Refit,
84}
85
86#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
87pub struct EvaluationResult {
88 pub metric: SelectionMetric,
89 pub partition: PredictionPartition,
90 pub scope: EvaluationScope,
91 #[serde(default, skip_serializing_if = "Option::is_none")]
92 pub reduction_id: Option<String>,
93 #[serde(default, skip_serializing_if = "Option::is_none")]
94 pub unit_level: Option<EntityUnitLevel>,
95}
96
97impl EvaluationResult {
98 pub fn validate(&self) -> Result<()> {
99 self.metric.validate()?;
100 validate_optional_id("evaluation reduction_id", self.reduction_id.as_deref())
101 }
102}
103
104#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
105#[serde(rename_all = "snake_case")]
106pub enum RefitStrategy {
107 RefitOne,
108 RefitEnsemble,
109}
110
111#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
112pub struct RefitSlotPlan {
113 pub strategy: RefitStrategy,
114 pub selection_level: PredictionLevel,
115 pub member_count: usize,
116 pub selection_metric: SelectionMetric,
117 #[serde(default, skip_serializing_if = "Option::is_none")]
118 pub reduction_id: Option<String>,
119}
120
121impl RefitSlotPlan {
122 pub fn validate(&self) -> Result<()> {
123 self.selection_metric.validate()?;
124 if self.member_count == 0 {
125 return Err(DagMlError::CampaignValidation(
126 "refit slot member_count must be positive".to_string(),
127 ));
128 }
129 match self.strategy {
130 RefitStrategy::RefitOne if self.member_count != 1 => {
131 return Err(DagMlError::CampaignValidation(
132 "refit_one slot requires member_count=1".to_string(),
133 ));
134 }
135 RefitStrategy::RefitEnsemble if self.member_count < 2 => {
136 return Err(DagMlError::CampaignValidation(
137 "refit_ensemble slot requires member_count>=2".to_string(),
138 ));
139 }
140 _ => {}
141 }
142 validate_optional_id("refit slot reduction_id", self.reduction_id.as_deref())
143 }
144}
145
146#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
147#[serde(rename_all = "snake_case")]
148pub enum MetaRowDomain {
149 Sample,
150 Combo,
151}
152
153#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
154#[serde(rename_all = "snake_case")]
155pub enum MetaTrainingFeatures {
156 Oof,
157}
158
159#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
160#[serde(rename_all = "snake_case")]
161pub enum InferenceFeatures {
162 RefitBasePredictions,
163}
164
165#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
166#[serde(rename_all = "snake_case")]
167pub enum SelectionProtocol {
168 Nested,
169 Holdout,
170 ReuseOof,
171}
172
173#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
174pub struct StackingFitContract {
175 pub meta_training_features: MetaTrainingFeatures,
176 pub inference_features: InferenceFeatures,
177 pub selection_protocol: SelectionProtocol,
178 pub meta_row_domain: MetaRowDomain,
179 #[serde(default, skip_serializing_if = "Option::is_none")]
180 pub final_reduction_id: Option<String>,
181 #[serde(default)]
182 pub unsafe_allow_reuse_oof: bool,
183}
184
185impl StackingFitContract {
186 pub fn validate(&self) -> Result<()> {
187 if self.selection_protocol == SelectionProtocol::ReuseOof && !self.unsafe_allow_reuse_oof {
188 return Err(DagMlError::CampaignValidation(
189 "reuse_oof stacking selection requires unsafe_allow_reuse_oof=true".to_string(),
190 ));
191 }
192 if self.meta_row_domain == MetaRowDomain::Combo && self.final_reduction_id.is_none() {
193 return Err(DagMlError::CampaignValidation(
194 "combo meta_row_domain requires final_reduction_id".to_string(),
195 ));
196 }
197 validate_optional_id(
198 "stacking final_reduction_id",
199 self.final_reduction_id.as_deref(),
200 )
201 }
202}
203
204#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
205pub struct SelectionPolicy {
206 pub id: String,
207 pub metric: SelectionMetric,
208 #[serde(default, skip_serializing_if = "Option::is_none")]
209 pub required_metric_level: Option<PredictionLevel>,
210 #[serde(default = "default_true")]
211 pub require_finite: bool,
212 #[serde(default, skip_serializing_if = "Option::is_none")]
213 pub evaluation_scope: Option<EvaluationScope>,
214 #[serde(default, skip_serializing_if = "Option::is_none")]
215 pub refit_slot_plan: Option<RefitSlotPlan>,
216 #[serde(default, skip_serializing_if = "Option::is_none")]
217 pub stacking_fit_contract: Option<StackingFitContract>,
218 #[serde(default, skip_serializing_if = "Option::is_none")]
219 pub reduction_id: Option<String>,
220}
221
222impl SelectionPolicy {
223 pub fn validate(&self) -> Result<()> {
224 if self.id.trim().is_empty() {
225 return Err(DagMlError::CampaignValidation(
226 "selection policy id is empty".to_string(),
227 ));
228 }
229 self.metric.validate()?;
230 if let Some(refit_slot_plan) = &self.refit_slot_plan {
231 refit_slot_plan.validate()?;
232 }
233 if let Some(stacking_fit_contract) = &self.stacking_fit_contract {
234 stacking_fit_contract.validate()?;
235 }
236 validate_optional_id(
237 "selection policy reduction_id",
238 self.reduction_id.as_deref(),
239 )
240 }
241}
242
243fn default_true() -> bool {
244 true
245}
246
247#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
248pub struct RankedCandidate {
249 pub candidate_id: String,
250 pub score: f64,
251 pub rank: usize,
252}
253
254#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
255pub struct SelectionDecision {
256 pub policy_id: String,
257 pub selected_candidate_id: String,
258 pub metric_name: String,
259 pub objective: MetricObjective,
260 #[serde(default, skip_serializing_if = "Option::is_none")]
261 pub metric_level: Option<PredictionLevel>,
262 #[serde(default, skip_serializing_if = "Option::is_none")]
263 pub evaluation_scope: Option<EvaluationScope>,
264 #[serde(default, skip_serializing_if = "Option::is_none")]
265 pub refit_slot_plan: Option<RefitSlotPlan>,
266 #[serde(default, skip_serializing_if = "Option::is_none")]
267 pub reduction_id: Option<String>,
268 pub selected_score: f64,
269 #[serde(default)]
270 pub ranked_candidates: Vec<RankedCandidate>,
271}
272
273impl SelectionDecision {
274 pub fn validate(&self) -> Result<()> {
275 if self.policy_id.trim().is_empty() {
276 return Err(DagMlError::CampaignValidation(
277 "selection decision policy_id is empty".to_string(),
278 ));
279 }
280 if self.selected_candidate_id.trim().is_empty() {
281 return Err(DagMlError::CampaignValidation(
282 "selection decision selected_candidate_id is empty".to_string(),
283 ));
284 }
285 if self.metric_name.trim().is_empty() {
286 return Err(DagMlError::CampaignValidation(
287 "selection decision metric_name is empty".to_string(),
288 ));
289 }
290 if !self.selected_score.is_finite() {
291 return Err(DagMlError::CampaignValidation(format!(
292 "selection `{}` selected score is not finite",
293 self.policy_id
294 )));
295 }
296 if self.ranked_candidates.is_empty() {
297 return Err(DagMlError::CampaignValidation(format!(
298 "selection `{}` has no ranked candidates",
299 self.policy_id
300 )));
301 }
302 if self.ranked_candidates[0].candidate_id != self.selected_candidate_id {
303 return Err(DagMlError::CampaignValidation(format!(
304 "selection `{}` first ranked candidate does not match selected candidate",
305 self.policy_id
306 )));
307 }
308 if let Some(refit_slot_plan) = &self.refit_slot_plan {
309 refit_slot_plan.validate()?;
310 }
311 validate_optional_id(
312 "selection decision reduction_id",
313 self.reduction_id.as_deref(),
314 )?;
315 let mut seen = BTreeSet::new();
316 for (idx, candidate) in self.ranked_candidates.iter().enumerate() {
317 if candidate.rank != idx + 1 {
318 return Err(DagMlError::CampaignValidation(format!(
319 "selection `{}` candidate `{}` has rank {}, expected {}",
320 self.policy_id,
321 candidate.candidate_id,
322 candidate.rank,
323 idx + 1
324 )));
325 }
326 if !seen.insert(candidate.candidate_id.as_str()) {
327 return Err(DagMlError::CampaignValidation(format!(
328 "selection `{}` contains duplicate candidate `{}`",
329 self.policy_id, candidate.candidate_id
330 )));
331 }
332 }
333 Ok(())
334 }
335}
336
337pub fn select_candidate(
338 policy: &SelectionPolicy,
339 candidates: &[CandidateScore],
340) -> Result<SelectionDecision> {
341 policy.validate()?;
342 if candidates.is_empty() {
343 return Err(DagMlError::CampaignValidation(format!(
344 "selection policy `{}` has no candidates",
345 policy.id
346 )));
347 }
348
349 let mut scored = Vec::with_capacity(candidates.len());
350 let mut seen = BTreeSet::new();
351 for candidate in candidates {
352 candidate.validate()?;
353 if !seen.insert(candidate.candidate_id.as_str()) {
354 return Err(DagMlError::CampaignValidation(format!(
355 "selection policy `{}` has duplicate candidate `{}`",
356 policy.id, candidate.candidate_id
357 )));
358 }
359 validate_candidate_metric_level(policy, candidate)?;
360 let score = candidate
361 .metrics
362 .get(&policy.metric.name)
363 .copied()
364 .ok_or_else(|| {
365 DagMlError::CampaignValidation(format!(
366 "candidate `{}` is missing selection metric `{}`",
367 candidate.candidate_id, policy.metric.name
368 ))
369 })?;
370 if policy.require_finite && !score.is_finite() {
371 return Err(DagMlError::CampaignValidation(format!(
372 "candidate `{}` metric `{}` is not finite",
373 candidate.candidate_id, policy.metric.name
374 )));
375 }
376 scored.push((candidate.candidate_id.clone(), score));
377 }
378
379 scored.sort_by(|left, right| compare_scores(policy.metric.objective, left, right));
380 let ranked_candidates = scored
381 .iter()
382 .enumerate()
383 .map(|(idx, (candidate_id, score))| RankedCandidate {
384 candidate_id: candidate_id.clone(),
385 score: *score,
386 rank: idx + 1,
387 })
388 .collect::<Vec<_>>();
389 let selected = ranked_candidates
390 .first()
391 .expect("candidates were checked as non-empty");
392 let decision = SelectionDecision {
393 policy_id: policy.id.clone(),
394 selected_candidate_id: selected.candidate_id.clone(),
395 metric_name: policy.metric.name.clone(),
396 objective: policy.metric.objective,
397 metric_level: policy.required_metric_level,
398 evaluation_scope: policy.evaluation_scope,
399 refit_slot_plan: policy.refit_slot_plan.clone(),
400 reduction_id: policy.reduction_id.clone(),
401 selected_score: selected.score,
402 ranked_candidates,
403 };
404 decision.validate()?;
405 Ok(decision)
406}
407
408pub fn select_candidate_groups(
409 policy: &SelectionPolicy,
410 candidates: &[CandidateScore],
411 groups: &BTreeMap<String, Vec<String>>,
412) -> Result<BTreeMap<String, SelectionDecision>> {
413 policy.validate()?;
414 let mut by_id = BTreeMap::new();
415 for candidate in candidates {
416 candidate.validate()?;
417 if by_id
418 .insert(candidate.candidate_id.as_str(), candidate)
419 .is_some()
420 {
421 return Err(DagMlError::CampaignValidation(format!(
422 "selection policy `{}` has duplicate candidate `{}`",
423 policy.id, candidate.candidate_id
424 )));
425 }
426 }
427 let mut decisions = BTreeMap::new();
428 for (group_id, candidate_ids) in groups {
429 if group_id.trim().is_empty() {
430 return Err(DagMlError::CampaignValidation(
431 "selection group id is empty".to_string(),
432 ));
433 }
434 if candidate_ids.is_empty() {
435 return Err(DagMlError::CampaignValidation(format!(
436 "selection group `{group_id}` has no candidates"
437 )));
438 }
439 let group_candidates = candidate_ids
440 .iter()
441 .map(|candidate_id| {
442 by_id
443 .get(candidate_id.as_str())
444 .cloned()
445 .cloned()
446 .ok_or_else(|| {
447 DagMlError::CampaignValidation(format!(
448 "selection group `{group_id}` references unknown candidate `{candidate_id}`"
449 ))
450 })
451 })
452 .collect::<Result<Vec<_>>>()?;
453 decisions.insert(
454 group_id.clone(),
455 select_candidate(policy, &group_candidates)?,
456 );
457 }
458 Ok(decisions)
459}
460
461fn compare_scores(
462 objective: MetricObjective,
463 left: &(String, f64),
464 right: &(String, f64),
465) -> Ordering {
466 let score_order = match objective {
467 MetricObjective::Minimize => left.1.total_cmp(&right.1),
468 MetricObjective::Maximize => right.1.total_cmp(&left.1),
469 };
470 score_order.then_with(|| left.0.cmp(&right.0))
471}
472
473fn validate_candidate_metric_level(
474 policy: &SelectionPolicy,
475 candidate: &CandidateScore,
476) -> Result<()> {
477 let Some(required_level) = policy.required_metric_level else {
478 return Ok(());
479 };
480 let Some(raw_level) = candidate.metadata.get("metric_level") else {
481 return Err(DagMlError::CampaignValidation(format!(
482 "candidate `{}` is missing required metric_level `{}`",
483 candidate.candidate_id,
484 prediction_level_name(required_level)
485 )));
486 };
487 let actual_level = match raw_level {
488 serde_json::Value::String(value) => parse_prediction_level(value).ok_or_else(|| {
489 DagMlError::CampaignValidation(format!(
490 "candidate `{}` has invalid metric_level `{value}`",
491 candidate.candidate_id
492 ))
493 })?,
494 _ => {
495 return Err(DagMlError::CampaignValidation(format!(
496 "candidate `{}` metric_level must be a string",
497 candidate.candidate_id
498 )));
499 }
500 };
501 if actual_level != required_level {
502 return Err(DagMlError::CampaignValidation(format!(
503 "candidate `{}` metric_level `{}` does not match required `{}`",
504 candidate.candidate_id,
505 prediction_level_name(actual_level),
506 prediction_level_name(required_level)
507 )));
508 }
509 Ok(())
510}
511
512fn parse_prediction_level(value: &str) -> Option<PredictionLevel> {
513 match value {
514 "observation" => Some(PredictionLevel::Observation),
515 "sample" => Some(PredictionLevel::Sample),
516 "target" => Some(PredictionLevel::Target),
517 "group" => Some(PredictionLevel::Group),
518 _ => None,
519 }
520}
521
522fn prediction_level_name(level: PredictionLevel) -> &'static str {
523 match level {
524 PredictionLevel::Observation => "observation",
525 PredictionLevel::Sample => "sample",
526 PredictionLevel::Target => "target",
527 PredictionLevel::Group => "group",
528 }
529}
530
531fn validate_optional_id(label: &str, value: Option<&str>) -> Result<()> {
532 if value.is_some_and(|value| value.trim().is_empty()) {
533 return Err(DagMlError::CampaignValidation(format!(
534 "{label} must not be empty"
535 )));
536 }
537 Ok(())
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543
544 fn rmse_policy() -> SelectionPolicy {
545 SelectionPolicy {
546 id: "select:rmse".to_string(),
547 metric: SelectionMetric {
548 name: "rmse".to_string(),
549 objective: MetricObjective::Minimize,
550 },
551 required_metric_level: None,
552 require_finite: true,
553 evaluation_scope: None,
554 refit_slot_plan: None,
555 stacking_fit_contract: None,
556 reduction_id: None,
557 }
558 }
559
560 fn candidate(id: &str, rmse: f64) -> CandidateScore {
561 CandidateScore {
562 candidate_id: id.to_string(),
563 metrics: BTreeMap::from([("rmse".to_string(), rmse)]),
564 metadata: BTreeMap::new(),
565 }
566 }
567
568 fn candidate_with_level(id: &str, rmse: f64, level: &str) -> CandidateScore {
569 CandidateScore {
570 candidate_id: id.to_string(),
571 metrics: BTreeMap::from([("rmse".to_string(), rmse)]),
572 metadata: BTreeMap::from([(
573 "metric_level".to_string(),
574 serde_json::Value::String(level.to_string()),
575 )]),
576 }
577 }
578
579 #[test]
580 fn selects_lowest_metric_with_deterministic_tie_break() {
581 let decision = select_candidate(
582 &rmse_policy(),
583 &[
584 candidate("model:b", 1.0),
585 candidate("model:a", 1.0),
586 candidate("model:c", 2.0),
587 ],
588 )
589 .unwrap();
590
591 assert_eq!(decision.selected_candidate_id, "model:a");
592 assert_eq!(decision.ranked_candidates[0].rank, 1);
593 }
594
595 #[test]
596 fn grouped_selection_rejects_duplicate_candidate_ids() {
597 assert!(select_candidate_groups(
598 &rmse_policy(),
599 &[candidate("model:a", 1.0), candidate("model:a", 2.0)],
600 &BTreeMap::from([("branch:b0".to_string(), vec!["model:a".to_string()])]),
601 )
602 .is_err());
603 }
604
605 #[test]
606 fn selection_policy_can_require_metric_level() {
607 let mut policy = rmse_policy();
608 policy.required_metric_level = Some(PredictionLevel::Sample);
609
610 let decision = select_candidate(
611 &policy,
612 &[
613 candidate_with_level("model:a", 1.0, "sample"),
614 candidate_with_level("model:b", 2.0, "sample"),
615 ],
616 )
617 .unwrap();
618 assert_eq!(decision.selected_candidate_id, "model:a");
619 assert_eq!(decision.metric_level, Some(PredictionLevel::Sample));
620
621 assert!(select_candidate(
622 &policy,
623 &[
624 candidate_with_level("model:a", 1.0, "sample"),
625 candidate_with_level("model:b", 2.0, "target"),
626 ],
627 )
628 .is_err());
629 assert!(select_candidate(&policy, &[candidate("model:a", 1.0)]).is_err());
630 }
631
632 #[test]
633 fn d9_negative_row_level_metric_cannot_drive_sample_refit() {
634 let mut policy = rmse_policy();
635 policy.required_metric_level = Some(PredictionLevel::Sample);
636
637 let error = select_candidate(
638 &policy,
639 &[candidate_with_level("model:row_metric", 0.1, "observation")],
640 )
641 .unwrap_err()
642 .to_string();
643
644 assert!(
645 error.contains("metric_level `observation` does not match required `sample`"),
646 "unexpected D9 row-vs-sample metric error: {error}"
647 );
648 }
649
650 #[test]
651 fn selection_policy_echoes_evaluation_and_refit_contracts() {
652 let mut policy = rmse_policy();
653 policy.evaluation_scope = Some(EvaluationScope::Oof);
654 policy.reduction_id = Some("reduction:obs_to_sample".to_string());
655 policy.refit_slot_plan = Some(RefitSlotPlan {
656 strategy: RefitStrategy::RefitOne,
657 selection_level: PredictionLevel::Sample,
658 member_count: 1,
659 selection_metric: policy.metric.clone(),
660 reduction_id: Some("reduction:obs_to_sample".to_string()),
661 });
662
663 let decision = select_candidate(
664 &policy,
665 &[candidate("model:a", 1.0), candidate("model:b", 2.0)],
666 )
667 .unwrap();
668
669 assert_eq!(decision.evaluation_scope, Some(EvaluationScope::Oof));
670 assert_eq!(
671 decision.refit_slot_plan.as_ref().unwrap().strategy,
672 RefitStrategy::RefitOne
673 );
674 assert_eq!(
675 decision.reduction_id.as_deref(),
676 Some("reduction:obs_to_sample")
677 );
678
679 let mut invalid_policy = policy;
680 invalid_policy.refit_slot_plan = Some(RefitSlotPlan {
681 strategy: RefitStrategy::RefitEnsemble,
682 selection_level: PredictionLevel::Sample,
683 member_count: 1,
684 selection_metric: invalid_policy.metric.clone(),
685 reduction_id: None,
686 });
687 assert!(select_candidate(&invalid_policy, &[candidate("model:a", 1.0)]).is_err());
688 }
689
690 #[test]
691 fn stacking_fit_contract_guards_oof_reuse_and_combo_reduction() {
692 let valid = StackingFitContract {
693 meta_training_features: MetaTrainingFeatures::Oof,
694 inference_features: InferenceFeatures::RefitBasePredictions,
695 selection_protocol: SelectionProtocol::Nested,
696 meta_row_domain: MetaRowDomain::Combo,
697 final_reduction_id: Some("reduction:combo_to_sample".to_string()),
698 unsafe_allow_reuse_oof: false,
699 };
700 valid.validate().unwrap();
701
702 let missing_reduction = StackingFitContract {
703 final_reduction_id: None,
704 ..valid.clone()
705 };
706 assert!(missing_reduction.validate().is_err());
707
708 let unsafe_reuse_required = StackingFitContract {
709 selection_protocol: SelectionProtocol::ReuseOof,
710 meta_row_domain: MetaRowDomain::Sample,
711 final_reduction_id: None,
712 unsafe_allow_reuse_oof: false,
713 ..valid
714 };
715 assert!(unsafe_reuse_required.validate().is_err());
716 }
717
718 #[test]
719 fn published_selection_schemas_declare_current_contracts() {
720 let policy_schema: serde_json::Value = serde_json::from_str(include_str!(
721 "../../../docs/contracts/selection_policy.schema.json"
722 ))
723 .unwrap();
724 assert_eq!(policy_schema["$id"], SELECTION_POLICY_SCHEMA_ID);
725 assert!(policy_schema["required"]
726 .as_array()
727 .unwrap()
728 .iter()
729 .any(|field| field.as_str() == Some("metric")));
730 assert!(policy_schema["properties"]
731 .get("evaluation_scope")
732 .is_some());
733 assert!(policy_schema["properties"].get("refit_slot_plan").is_some());
734 assert!(policy_schema["properties"]
735 .get("stacking_fit_contract")
736 .is_some());
737
738 let decision_schema: serde_json::Value = serde_json::from_str(include_str!(
739 "../../../docs/contracts/selection_decision.schema.json"
740 ))
741 .unwrap();
742 assert_eq!(decision_schema["$id"], SELECTION_DECISION_SCHEMA_ID);
743 assert!(decision_schema["$defs"]["prediction_level"]["enum"]
744 .as_array()
745 .unwrap()
746 .iter()
747 .any(|level| level.as_str() == Some("group")));
748 assert!(decision_schema["$defs"]["ranked_candidate"]["required"]
749 .as_array()
750 .unwrap()
751 .iter()
752 .any(|field| field.as_str() == Some("rank")));
753 assert!(decision_schema["properties"]
754 .get("evaluation_scope")
755 .is_some());
756 assert!(decision_schema["properties"]
757 .get("refit_slot_plan")
758 .is_some());
759 }
760
761 #[test]
762 fn selects_sklearn_demo_branch_and_merge_variants() {
763 let report: serde_json::Value = serde_json::from_str(include_str!(
764 "../../../examples/generated/sklearn_complex_report.json"
765 ))
766 .unwrap();
767 let branch_metrics = report["branch_variant_metrics"].as_object().unwrap();
768 let candidates = branch_metrics
769 .iter()
770 .map(|(candidate_id, metrics)| CandidateScore {
771 candidate_id: candidate_id.clone(),
772 metrics: metrics
773 .as_object()
774 .unwrap()
775 .iter()
776 .map(|(name, value)| (name.clone(), value.as_f64().unwrap()))
777 .collect(),
778 metadata: BTreeMap::new(),
779 })
780 .collect::<Vec<_>>();
781 let groups = BTreeMap::from([
782 (
783 "branch:b0".to_string(),
784 vec![
785 "branch:b0.variant:pca10_ridge_a03".to_string(),
786 "branch:b0.variant:pca16_ridge_a12".to_string(),
787 ],
788 ),
789 (
790 "branch:b1".to_string(),
791 vec![
792 "branch:b1.variant:rf_select_k28".to_string(),
793 "branch:b1.variant:rf_select_k40".to_string(),
794 ],
795 ),
796 (
797 "branch:b2".to_string(),
798 vec![
799 "branch:b2.variant:poly_extra_k45".to_string(),
800 "branch:b2.variant:poly_extra_k80".to_string(),
801 ],
802 ),
803 ]);
804
805 let decisions = select_candidate_groups(&rmse_policy(), &candidates, &groups).unwrap();
806 assert_eq!(
807 decisions["branch:b1"].selected_candidate_id,
808 "branch:b1.variant:rf_select_k40"
809 );
810
811 let merge_metrics = report["merge_variant_metrics"].as_object().unwrap();
812 let merge_candidates = merge_metrics
813 .iter()
814 .map(|(candidate_id, metrics)| CandidateScore {
815 candidate_id: candidate_id.clone(),
816 metrics: metrics
817 .as_object()
818 .unwrap()
819 .iter()
820 .map(|(name, value)| (name.clone(), value.as_f64().unwrap()))
821 .collect(),
822 metadata: BTreeMap::new(),
823 })
824 .collect::<Vec<_>>();
825 let merge_decision = select_candidate(&rmse_policy(), &merge_candidates).unwrap();
826 assert_eq!(
827 merge_decision.selected_candidate_id,
828 "merge:m1.pred_meta_original.meta:ridge"
829 );
830 }
831}