Skip to main content

dag_ml_core/
plan.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::campaign::stable_json_fingerprint;
6use crate::controller::{
7    ArtifactPolicy, ControllerCapability, ControllerFitScope, ControllerManifest,
8    ControllerRegistry, RngPolicy,
9};
10use crate::data::{BranchViewPlan, DataBinding, ExternalDataPlanEnvelope};
11use crate::error::{DagMlError, Result};
12use crate::fold::{FoldSet, NestedCvSpec};
13use crate::generation::{
14    enumerate_variants, generation_spec_fingerprint, GenerationSpec, VariantPlan,
15};
16use crate::graph::{GraphSpec, NodeKind};
17use crate::ids::{ControllerId, FoldId, NodeId, VariantId};
18use crate::phase::Phase;
19use crate::policy::{AggregationPolicy, DataModelShapePlan, LeakageUnitPolicy};
20
21pub const CAMPAIGN_SPEC_SCHEMA_VERSION: u32 = 1;
22pub const CAMPAIGN_SPEC_SCHEMA_ID: &str =
23    "https://github.com/GBeurier/dag-ml/schemas/campaign_spec.v1.schema.json";
24pub const EXECUTION_PLAN_SCHEMA_VERSION: u32 = 1;
25pub const EXECUTION_PLAN_SCHEMA_ID: &str =
26    "https://github.com/GBeurier/dag-ml/schemas/execution_plan.v1.schema.json";
27
28#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
29pub struct SplitInvocation {
30    pub id: String,
31    #[serde(default)]
32    pub controller_id: Option<ControllerId>,
33    #[serde(default)]
34    pub leakage_policy: LeakageUnitPolicy,
35    #[serde(default)]
36    pub params: BTreeMap<String, serde_json::Value>,
37    #[serde(default)]
38    pub fold_set: Option<FoldSet>,
39}
40
41impl SplitInvocation {
42    pub fn validate(&self) -> Result<()> {
43        if self.id.trim().is_empty() {
44            return Err(DagMlError::CampaignValidation(
45                "split invocation id is empty".to_string(),
46            ));
47        }
48        self.leakage_policy.validate()?;
49        if let Some(fold_set) = &self.fold_set {
50            fold_set.validate()?;
51        }
52        Ok(())
53    }
54}
55
56#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
57pub struct CampaignSpec {
58    pub id: String,
59    pub root_seed: Option<u64>,
60    #[serde(default)]
61    pub leakage_policy: LeakageUnitPolicy,
62    #[serde(default)]
63    pub aggregation_policy: AggregationPolicy,
64    #[serde(default)]
65    pub split_invocation: Option<SplitInvocation>,
66    #[serde(default)]
67    pub generation: GenerationSpec,
68    #[serde(default)]
69    pub shape_plans: BTreeMap<NodeId, DataModelShapePlan>,
70    #[serde(default)]
71    pub data_bindings: BTreeMap<NodeId, Vec<DataBinding>>,
72    #[serde(default, skip_serializing_if = "Vec::is_empty")]
73    pub branch_view_plans: Vec<BranchViewPlan>,
74    /// Campaign-wide default nested (inner) CV policy. A node-level
75    /// `NodePlan.inner_cv` overrides it; see [`crate::fold::resolve_inner_cv`].
76    #[serde(default, skip_serializing_if = "Option::is_none")]
77    pub inner_cv: Option<NestedCvSpec>,
78    #[serde(default)]
79    pub metadata: BTreeMap<String, serde_json::Value>,
80}
81
82impl CampaignSpec {
83    pub fn validate(&self) -> Result<()> {
84        if self.id.trim().is_empty() {
85            return Err(DagMlError::CampaignValidation(
86                "campaign id is empty".to_string(),
87            ));
88        }
89        self.leakage_policy.validate()?;
90        self.aggregation_policy.validate()?;
91        if let Some(inner_cv) = &self.inner_cv {
92            inner_cv.validate()?;
93        }
94        if let Some(split) = &self.split_invocation {
95            split.validate()?;
96        }
97        self.generation.validate()?;
98        for (node_id, shape_plan) in &self.shape_plans {
99            if node_id != &shape_plan.node_id {
100                return Err(DagMlError::CampaignValidation(format!(
101                    "shape plan key `{node_id}` does not match node_id `{}`",
102                    shape_plan.node_id
103                )));
104            }
105            shape_plan.validate()?;
106        }
107        for (node_id, bindings) in &self.data_bindings {
108            for binding in bindings {
109                if node_id != &binding.node_id {
110                    return Err(DagMlError::CampaignValidation(format!(
111                        "data binding key `{node_id}` does not match node_id `{}`",
112                        binding.node_id
113                    )));
114                }
115                binding.validate()?;
116            }
117        }
118        let mut branch_views = BTreeSet::new();
119        for plan in &self.branch_view_plans {
120            plan.validate()?;
121            if !branch_views.insert(plan.view_id.as_str()) {
122                return Err(DagMlError::CampaignValidation(format!(
123                    "campaign `{}` contains duplicate branch view `{}`",
124                    self.id, plan.view_id
125                )));
126            }
127        }
128        Ok(())
129    }
130
131    pub fn validate_data_envelope_relations(
132        &self,
133        envelope: &ExternalDataPlanEnvelope,
134    ) -> Result<()> {
135        envelope.validate()?;
136        let Some(relations) = &envelope.coordinator_relations else {
137            return Ok(());
138        };
139        let Some(split) = &self.split_invocation else {
140            return Ok(());
141        };
142        let Some(fold_set) = &split.fold_set else {
143            return Ok(());
144        };
145        relations.validate_against_fold_set(fold_set, &self.leakage_policy)?;
146        relations.validate_against_fold_set(fold_set, &split.leakage_policy)
147    }
148}
149
150#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
151pub struct GraphPlan {
152    pub graph: GraphSpec,
153    pub topological_order: Vec<NodeId>,
154    #[serde(default, skip_serializing_if = "Vec::is_empty")]
155    pub parallel_levels: Vec<Vec<NodeId>>,
156}
157
158impl GraphPlan {
159    pub fn from_graph(graph: GraphSpec) -> Result<Self> {
160        let topological_order = graph.topological_order()?;
161        let parallel_levels = graph.parallel_levels()?;
162        Ok(Self {
163            graph,
164            topological_order,
165            parallel_levels,
166        })
167    }
168
169    pub fn parallel_levels(&self) -> Result<Vec<Vec<NodeId>>> {
170        if self.parallel_levels.is_empty() {
171            return self.graph.parallel_levels();
172        }
173        Ok(self.parallel_levels.clone())
174    }
175}
176
177#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
178pub struct NodePlan {
179    pub node_id: NodeId,
180    pub kind: NodeKind,
181    pub controller_id: ControllerId,
182    pub controller_version: String,
183    pub supported_phases: BTreeSet<Phase>,
184    #[serde(default)]
185    pub controller_capabilities: BTreeSet<ControllerCapability>,
186    pub fit_scope: ControllerFitScope,
187    pub rng_policy: RngPolicy,
188    pub artifact_policy: ArtifactPolicy,
189    pub input_nodes: Vec<NodeId>,
190    pub output_nodes: Vec<NodeId>,
191    pub shape_plan: Option<DataModelShapePlan>,
192    #[serde(default)]
193    pub data_bindings: Vec<DataBinding>,
194    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
195    pub params: BTreeMap<String, serde_json::Value>,
196    /// Node-local nested (inner) CV policy (e.g. for a finetune/tuner or branch
197    /// node); overrides the campaign-wide default.
198    #[serde(default, skip_serializing_if = "Option::is_none")]
199    pub inner_cv: Option<NestedCvSpec>,
200    pub params_fingerprint: String,
201}
202
203#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
204pub struct ExecutionPlan {
205    pub id: String,
206    pub graph_plan: GraphPlan,
207    pub campaign: CampaignSpec,
208    pub node_plans: BTreeMap<NodeId, NodePlan>,
209    pub controller_manifests: BTreeMap<ControllerId, ControllerManifest>,
210    pub variants: Vec<VariantPlan>,
211    pub fold_set: Option<FoldSet>,
212    pub graph_fingerprint: String,
213    pub campaign_fingerprint: String,
214    pub controller_fingerprint: String,
215}
216
217#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
218pub struct ExecutionScopePlan {
219    pub scope_id: String,
220    pub phase: Phase,
221    pub variant_id: Option<VariantId>,
222    pub variant_seed: Option<u64>,
223    pub fold_id: Option<FoldId>,
224    pub node_levels: Vec<Vec<NodeId>>,
225}
226
227#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
228pub struct PhaseExecutionSchedule {
229    pub plan_id: String,
230    pub phase: Phase,
231    pub scopes: Vec<ExecutionScopePlan>,
232}
233
234impl ExecutionPlan {
235    pub fn validate(&self) -> Result<()> {
236        self.graph_plan.graph.validate()?;
237        self.campaign.validate()?;
238        if !self.graph_plan.parallel_levels.is_empty()
239            && self.graph_plan.parallel_levels != self.graph_plan.graph.parallel_levels()?
240        {
241            return Err(DagMlError::Planning(
242                "graph plan parallel levels do not match graph".to_string(),
243            ));
244        }
245        if self.node_plans.len() != self.graph_plan.graph.nodes.len() {
246            return Err(DagMlError::Planning(
247                "execution plan node count does not match graph".to_string(),
248            ));
249        }
250        for node_id in &self.graph_plan.topological_order {
251            let plan = self.node_plans.get(node_id).ok_or_else(|| {
252                DagMlError::Planning(format!("missing node plan for `{node_id}`"))
253            })?;
254            let manifest = self
255                .controller_manifests
256                .get(&plan.controller_id)
257                .ok_or_else(|| {
258                    DagMlError::Planning(format!(
259                        "missing controller manifest `{}` for node `{node_id}`",
260                        plan.controller_id
261                    ))
262                })?;
263            if manifest.operator_kind != plan.kind {
264                return Err(DagMlError::Planning(format!(
265                    "node `{node_id}` planned with incompatible controller `{}`",
266                    manifest.controller_id
267                )));
268            }
269            if plan.controller_capabilities != manifest.capabilities {
270                return Err(DagMlError::Planning(format!(
271                    "node `{node_id}` controller capabilities do not match manifest `{}`",
272                    manifest.controller_id
273                )));
274            }
275            if plan.fit_scope != manifest.fit_scope
276                || plan.rng_policy != manifest.rng_policy
277                || plan.artifact_policy != manifest.artifact_policy
278            {
279                return Err(DagMlError::Planning(format!(
280                    "node `{node_id}` controller policy fields do not match manifest `{}`",
281                    manifest.controller_id
282                )));
283            }
284            for binding in &plan.data_bindings {
285                if binding.node_id != *node_id {
286                    return Err(DagMlError::Planning(format!(
287                        "node plan `{node_id}` contains data binding for `{}`",
288                        binding.node_id
289                    )));
290                }
291                binding.validate()?;
292            }
293            let actual_params_fingerprint = stable_json_fingerprint(&plan.params)?;
294            if actual_params_fingerprint != plan.params_fingerprint {
295                return Err(DagMlError::Planning(format!(
296                    "node plan `{node_id}` params fingerprint does not match params"
297                )));
298            }
299        }
300        // Validate every node-local inner_cv over ALL node plans (not just the
301        // cached topological order): a hand-loaded ExecutionPlan JSON with a
302        // stale/tampered order could omit a FIT_CV node from that order while
303        // still scheduling it via parallel levels, so a malformed inner_cv must
304        // be refused here rather than deferred to FIT_CV fold building.
305        for (node_id, plan) in &self.node_plans {
306            if let Some(inner_cv) = &plan.inner_cv {
307                inner_cv.validate().map_err(|error| {
308                    DagMlError::Planning(format!(
309                        "node plan `{node_id}` has invalid inner_cv: {error}"
310                    ))
311                })?;
312            }
313        }
314        self.validate_oof_controller_capabilities()?;
315        if let Some(fold_set) = &self.fold_set {
316            fold_set.validate()?;
317        }
318        if self.variants.is_empty() {
319            return Err(DagMlError::Planning(
320                "execution plan has no variants".to_string(),
321            ));
322        }
323        for variant in &self.variants {
324            variant.validate()?;
325        }
326        Ok(())
327    }
328
329    pub fn validate_parallel_controller_capabilities(
330        &self,
331        max_workers: usize,
332        phase: Phase,
333    ) -> Result<()> {
334        if max_workers <= 1 {
335            return Ok(());
336        }
337        let node_ids = self
338            .node_parallel_levels_for_phase(phase)?
339            .into_iter()
340            .flatten()
341            .collect::<Vec<_>>();
342        for node_id in node_ids {
343            let node_plan = self.node_plans.get(&node_id).ok_or_else(|| {
344                DagMlError::Planning(format!("missing node plan for `{node_id}`"))
345            })?;
346            let manifest = self
347                .controller_manifests
348                .get(&node_plan.controller_id)
349                .ok_or_else(|| {
350                    DagMlError::Planning(format!(
351                        "missing controller manifest `{}` for node `{}`",
352                        node_plan.controller_id, node_plan.node_id
353                    ))
354                })?;
355            if !manifest.supports_parallel_invocation() {
356                return Err(DagMlError::Planning(format!(
357                    "parallel scheduler with {max_workers} workers requires controller `{}` for node `{}` to declare thread_safe or process_safe",
358                    manifest.controller_id, node_plan.node_id
359                )));
360            }
361        }
362        Ok(())
363    }
364
365    fn validate_oof_controller_capabilities(&self) -> Result<()> {
366        for edge in &self.graph_plan.graph.edges {
367            if !edge.contract.requires_oof {
368                continue;
369            }
370            let source_plan = self.node_plans.get(&edge.source.node_id).ok_or_else(|| {
371                DagMlError::Planning(format!(
372                    "OOF edge source node `{}` has no node plan",
373                    edge.source.node_id
374                ))
375            })?;
376            if !source_plan
377                .controller_capabilities
378                .contains(&ControllerCapability::EmitsPredictions)
379            {
380                return Err(DagMlError::Planning(format!(
381                    "OOF edge `{}.{}` -> `{}.{}` requires source controller `{}` to declare emits_predictions",
382                    edge.source.node_id,
383                    edge.source.port_name,
384                    edge.target.node_id,
385                    edge.target.port_name,
386                    source_plan.controller_id
387                )));
388            }
389            let target_plan = self.node_plans.get(&edge.target.node_id).ok_or_else(|| {
390                DagMlError::Planning(format!(
391                    "OOF edge target node `{}` has no node plan",
392                    edge.target.node_id
393                ))
394            })?;
395            if !target_plan
396                .controller_capabilities
397                .contains(&ControllerCapability::ConsumesOofPredictions)
398            {
399                return Err(DagMlError::Planning(format!(
400                    "OOF edge `{}.{}` -> `{}.{}` requires target controller `{}` to declare consumes_oof_predictions",
401                    edge.source.node_id,
402                    edge.source.port_name,
403                    edge.target.node_id,
404                    edge.target.port_name,
405                    target_plan.controller_id
406                )));
407            }
408        }
409        Ok(())
410    }
411
412    pub fn node_parallel_levels_for_phase(&self, phase: Phase) -> Result<Vec<Vec<NodeId>>> {
413        let levels = self
414            .graph_plan
415            .parallel_levels()?
416            .into_iter()
417            .map(|level| {
418                level
419                    .into_iter()
420                    .filter(|node_id| {
421                        self.node_plans
422                            .get(node_id)
423                            .is_some_and(|node_plan| node_plan.supported_phases.contains(&phase))
424                    })
425                    .collect::<Vec<_>>()
426            })
427            .filter(|level| !level.is_empty())
428            .collect::<Vec<_>>();
429        Ok(levels)
430    }
431
432    pub fn campaign_phase_schedule(&self, phase: Phase) -> Result<PhaseExecutionSchedule> {
433        self.validate()?;
434        let node_levels = self.node_parallel_levels_for_phase(phase)?;
435        let fold_ids = if phase == Phase::FitCv {
436            self.fold_set
437                .as_ref()
438                .map(|fold_set| {
439                    fold_set
440                        .folds
441                        .iter()
442                        .map(|fold| Some(fold.fold_id.clone()))
443                        .collect::<Vec<_>>()
444                })
445                .unwrap_or_else(|| vec![None])
446        } else {
447            vec![None]
448        };
449        let mut scopes = Vec::new();
450        for variant in &self.variants {
451            for fold_id in &fold_ids {
452                scopes.push(ExecutionScopePlan {
453                    scope_id: execution_scope_id(
454                        phase,
455                        Some(&variant.variant_id),
456                        fold_id.as_ref(),
457                    ),
458                    phase,
459                    variant_id: Some(variant.variant_id.clone()),
460                    variant_seed: variant.seed,
461                    fold_id: fold_id.clone(),
462                    node_levels: node_levels.clone(),
463                });
464            }
465        }
466        Ok(PhaseExecutionSchedule {
467            plan_id: self.id.clone(),
468            phase,
469            scopes,
470        })
471    }
472
473    /// Returns the `BranchViewPlan` whose `branch_id` matches `branch_id`,
474    /// if any. The match is exact; callers that need fuzzy or prefix matching
475    /// must iterate `self.campaign.branch_view_plans` themselves.
476    pub fn branch_view_for(&self, branch_id: &str) -> Option<&BranchViewPlan> {
477        branch_view_for_in(&self.campaign.branch_view_plans, branch_id)
478    }
479
480    /// Returns the `BranchViewPlan` for the deepest branch in `branch_path`
481    /// that has a matching plan, if any. The path is walked tip-first so the
482    /// closest enclosing branch wins; an empty path returns `None`. The
483    /// returned reference borrows the plan from the campaign; the caller can
484    /// `.clone()` it into a `DataProviderViewSpec.branch_view` field when
485    /// constructing a provider view for an in-branch node.
486    pub fn branch_view_for_path(&self, branch_path: &[String]) -> Option<&BranchViewPlan> {
487        branch_view_for_path_in(&self.campaign.branch_view_plans, branch_path)
488    }
489}
490
491fn branch_view_for_in<'a>(
492    plans: &'a [BranchViewPlan],
493    branch_id: &str,
494) -> Option<&'a BranchViewPlan> {
495    plans.iter().find(|plan| plan.branch_id == branch_id)
496}
497
498fn branch_view_for_path_in<'a>(
499    plans: &'a [BranchViewPlan],
500    branch_path: &[String],
501) -> Option<&'a BranchViewPlan> {
502    for branch_id in branch_path.iter().rev() {
503        if let Some(plan) = branch_view_for_in(plans, branch_id) {
504            return Some(plan);
505        }
506    }
507    None
508}
509
510fn execution_scope_id(
511    phase: Phase,
512    variant_id: Option<&VariantId>,
513    fold_id: Option<&FoldId>,
514) -> String {
515    format!(
516        "scope:{}:{}:{}",
517        phase_scope_label(phase),
518        variant_id
519            .map(ToString::to_string)
520            .unwrap_or_else(|| "base".to_string()),
521        fold_id
522            .map(ToString::to_string)
523            .unwrap_or_else(|| "nofold".to_string())
524    )
525}
526
527fn phase_scope_label(phase: Phase) -> &'static str {
528    match phase {
529        Phase::Compile => "COMPILE",
530        Phase::Plan => "PLAN",
531        Phase::FitCv => "FIT_CV",
532        Phase::Select => "SELECT",
533        Phase::Refit => "REFIT",
534        Phase::Predict => "PREDICT",
535        Phase::Explain => "EXPLAIN",
536    }
537}
538
539pub fn build_execution_plan(
540    id: impl Into<String>,
541    graph: GraphSpec,
542    campaign: CampaignSpec,
543    registry: &ControllerRegistry,
544) -> Result<ExecutionPlan> {
545    let id = id.into();
546    if id.trim().is_empty() {
547        return Err(DagMlError::Planning(
548            "execution plan id is empty".to_string(),
549        ));
550    }
551    campaign.validate()?;
552    let graph_plan = GraphPlan::from_graph(graph)?;
553    validate_campaign_node_targets(&graph_plan.graph, &campaign)?;
554
555    let mut node_plans = BTreeMap::new();
556    let mut controller_manifests = BTreeMap::new();
557    for node_id in &graph_plan.topological_order {
558        let node = graph_plan
559            .graph
560            .nodes
561            .iter()
562            .find(|node| &node.id == node_id)
563            .expect("topological node exists");
564        let manifest = registry.resolve_for_node(node)?;
565        let params = node.params.clone();
566        let params_fingerprint = stable_json_fingerprint(&params)?;
567        // Lower a node-local nested-CV policy carried by the DSL compiler in the
568        // graph node metadata into the typed NodePlan field. Malformed metadata
569        // fails the plan rather than silently dropping nested CV.
570        let inner_cv = match node.metadata.get("dsl_inner_cv") {
571            Some(value) => {
572                let spec =
573                    serde_json::from_value::<NestedCvSpec>(value.clone()).map_err(|error| {
574                        DagMlError::Planning(format!(
575                            "node `{}` has invalid dsl_inner_cv metadata: {error}",
576                            node.id
577                        ))
578                    })?;
579                // Reject semantically malformed specs (e.g. n_splits < 2) here, at
580                // the plan boundary, rather than deferring to FIT_CV fold building.
581                spec.validate().map_err(|error| {
582                    DagMlError::Planning(format!(
583                        "node `{}` has invalid dsl_inner_cv metadata: {error}",
584                        node.id
585                    ))
586                })?;
587                Some(spec)
588            }
589            None => None,
590        };
591        let shape_plan = campaign.shape_plans.get(&node.id).cloned();
592        let data_bindings = campaign
593            .data_bindings
594            .get(&node.id)
595            .cloned()
596            .unwrap_or_default();
597        node_plans.insert(
598            node.id.clone(),
599            NodePlan {
600                inner_cv,
601                node_id: node.id.clone(),
602                kind: node.kind.clone(),
603                controller_id: manifest.controller_id.clone(),
604                controller_version: manifest.controller_version.clone(),
605                supported_phases: manifest.supported_phases.clone(),
606                controller_capabilities: manifest.capabilities.clone(),
607                fit_scope: manifest.fit_scope,
608                rng_policy: manifest.rng_policy,
609                artifact_policy: manifest.artifact_policy,
610                input_nodes: graph_plan.graph.upstream_nodes(&node.id),
611                output_nodes: graph_plan.graph.downstream_nodes(&node.id),
612                shape_plan,
613                data_bindings,
614                params,
615                params_fingerprint,
616            },
617        );
618        controller_manifests.insert(manifest.controller_id.clone(), manifest);
619    }
620
621    let fold_set = campaign
622        .split_invocation
623        .as_ref()
624        .and_then(|split| split.fold_set.clone());
625    validate_search_space_fingerprint(&graph_plan.graph, &campaign)?;
626    let variants = enumerate_variants(&campaign.generation, campaign.root_seed)?;
627    validate_generation_override_targets(&graph_plan.graph, &variants)?;
628    let graph_fingerprint = stable_json_fingerprint(&graph_plan.graph)?;
629    let campaign_fingerprint = stable_json_fingerprint(&campaign)?;
630    let controller_fingerprint = stable_json_fingerprint(&controller_manifests)?;
631    let plan = ExecutionPlan {
632        id,
633        graph_plan,
634        campaign,
635        node_plans,
636        controller_manifests,
637        variants,
638        fold_set,
639        graph_fingerprint,
640        campaign_fingerprint,
641        controller_fingerprint,
642    };
643    plan.validate()?;
644    Ok(plan)
645}
646
647fn validate_search_space_fingerprint(graph: &GraphSpec, campaign: &CampaignSpec) -> Result<()> {
648    let Some(expected_fingerprint) = &graph.search_space_fingerprint else {
649        return Ok(());
650    };
651    if expected_fingerprint.trim().is_empty() {
652        return Err(DagMlError::Planning(format!(
653            "graph `{}` has empty search_space_fingerprint",
654            graph.id
655        )));
656    }
657    let actual_fingerprint = generation_spec_fingerprint(&campaign.generation)?;
658    if expected_fingerprint != &actual_fingerprint {
659        return Err(DagMlError::Planning(format!(
660            "graph `{}` search_space_fingerprint does not match campaign generation spec",
661            graph.id
662        )));
663    }
664    Ok(())
665}
666
667fn validate_generation_override_targets(graph: &GraphSpec, variants: &[VariantPlan]) -> Result<()> {
668    let node_ids = graph
669        .nodes
670        .iter()
671        .map(|node| node.id.clone())
672        .collect::<BTreeSet<_>>();
673    for variant in variants {
674        for node_id in variant.param_override_targets()? {
675            if !node_ids.contains(&node_id) {
676                return Err(DagMlError::Planning(format!(
677                    "variant `{}` overrides params for unknown node `{node_id}`",
678                    variant.variant_id
679                )));
680            }
681        }
682    }
683    Ok(())
684}
685
686fn validate_campaign_node_targets(graph: &GraphSpec, campaign: &CampaignSpec) -> Result<()> {
687    let node_ids = graph
688        .nodes
689        .iter()
690        .map(|node| &node.id)
691        .collect::<BTreeSet<_>>();
692    for node_id in campaign.shape_plans.keys() {
693        if !node_ids.contains(node_id) {
694            return Err(DagMlError::Planning(format!(
695                "shape plan references unknown node `{node_id}`"
696            )));
697        }
698    }
699    for node_id in campaign.data_bindings.keys() {
700        if !node_ids.contains(node_id) {
701            return Err(DagMlError::Planning(format!(
702                "data binding references unknown node `{node_id}`"
703            )));
704        }
705    }
706    Ok(())
707}
708
709#[cfg(test)]
710mod tests {
711    use std::collections::{BTreeMap, BTreeSet};
712    use std::time::{Duration, Instant};
713
714    use super::*;
715    use crate::controller::{
716        ArtifactPolicy, ControllerCapability, ControllerFitScope, ControllerManifest, RngPolicy,
717    };
718
719    #[test]
720    fn inner_cv_is_declarable_at_campaign_and_node_level() {
721        // Campaign-level (global) declaration round-trips through JSON.
722        let campaign_json = r#"{"id":"c","root_seed":null,"inner_cv":{"kind":"kfold","n_splits":3,"shuffle":false,"seed":5}}"#;
723        let campaign: CampaignSpec = serde_json::from_str(campaign_json).unwrap();
724        campaign.validate().unwrap();
725        assert!(campaign.inner_cv.is_some());
726
727        // A node-local declaration overrides the campaign default.
728        let node_inner = crate::fold::NestedCvSpec::KFold(crate::fold::KFoldSpec {
729            n_splits: 4,
730            shuffle: false,
731            seed: Some(6),
732        });
733        let resolved = crate::fold::resolve_inner_cv(Some(&node_inner), campaign.inner_cv.as_ref());
734        assert_eq!(resolved, Some(&node_inner));
735
736        // Absent on both campaign and node serializes away (skip_serializing_if).
737        let bare = r#"{"id":"c","root_seed":null}"#;
738        let bare_campaign: CampaignSpec = serde_json::from_str(bare).unwrap();
739        assert!(bare_campaign.inner_cv.is_none());
740        let reserialized = serde_json::to_string(&bare_campaign).unwrap();
741        assert!(!reserialized.contains("inner_cv"));
742
743        // A semantically-malformed campaign-global inner_cv (n_splits < 2) is
744        // rejected by CampaignSpec::validate (the plan boundary), not deferred.
745        let bad: CampaignSpec = serde_json::from_str(
746            r#"{"id":"c","root_seed":null,"inner_cv":{"kind":"kfold","n_splits":1,"shuffle":false,"seed":null}}"#,
747        )
748        .unwrap();
749        let error = bad.validate().unwrap_err();
750        assert!(error.to_string().contains("at least two splits"));
751    }
752
753    #[test]
754    fn execution_plan_validate_rejects_invalid_node_local_inner_cv() {
755        // A canonical ExecutionPlan loaded from JSON (bypassing DSL lowering) can
756        // carry a malformed node-local inner_cv; ExecutionPlan::validate must
757        // refuse it rather than deferring to FIT_CV fold building.
758        let campaign = CampaignSpec {
759            inner_cv: None,
760            id: "campaign:plan-validate".to_string(),
761            root_seed: Some(7),
762            leakage_policy: LeakageUnitPolicy::default(),
763            aggregation_policy: AggregationPolicy::default(),
764            split_invocation: None,
765            generation: Default::default(),
766            shape_plans: BTreeMap::new(),
767            data_bindings: BTreeMap::new(),
768            branch_view_plans: Vec::new(),
769            metadata: BTreeMap::new(),
770        };
771        let mut plan =
772            build_execution_plan("plan:validate", graph(), campaign, &registry()).unwrap();
773        plan.validate().unwrap();
774        plan.node_plans
775            .get_mut(&NodeId::new("model:pls").unwrap())
776            .unwrap()
777            .inner_cv = Some(crate::fold::NestedCvSpec::KFold(crate::fold::KFoldSpec {
778            n_splits: 1,
779            shuffle: false,
780            seed: None,
781        }));
782        let error = plan.validate().unwrap_err();
783        assert!(matches!(error, DagMlError::Planning(_)));
784        assert!(error.to_string().contains("invalid inner_cv"));
785        assert!(error.to_string().contains("at least two splits"));
786    }
787
788    #[test]
789    fn build_execution_plan_lowers_dsl_inner_cv_metadata_into_node_plan() {
790        let mut graph = graph();
791        graph
792            .nodes
793            .iter_mut()
794            .find(|node| node.id.as_str() == "model:pls")
795            .unwrap()
796            .metadata
797            .insert(
798                "dsl_inner_cv".to_string(),
799                serde_json::json!({"kind": "kfold", "n_splits": 3, "shuffle": false, "seed": 9}),
800            );
801
802        let campaign = CampaignSpec {
803            inner_cv: None,
804            id: "campaign:inner-cv".to_string(),
805            root_seed: Some(7),
806            leakage_policy: LeakageUnitPolicy::default(),
807            aggregation_policy: AggregationPolicy::default(),
808            split_invocation: None,
809            generation: Default::default(),
810            shape_plans: BTreeMap::new(),
811            data_bindings: BTreeMap::new(),
812            branch_view_plans: Vec::new(),
813            metadata: BTreeMap::new(),
814        };
815
816        let plan = build_execution_plan("plan:inner-cv", graph, campaign, &registry()).unwrap();
817        match &plan.node_plans[&NodeId::new("model:pls").unwrap()].inner_cv {
818            Some(crate::fold::NestedCvSpec::KFold(k)) => {
819                assert_eq!(k.n_splits, 3);
820                assert_eq!(k.seed, Some(9));
821            }
822            other => panic!("expected lowered KFold inner_cv, got {other:?}"),
823        }
824        assert!(plan.node_plans[&NodeId::new("transform:snv").unwrap()]
825            .inner_cv
826            .is_none());
827    }
828
829    #[test]
830    fn build_execution_plan_rejects_malformed_dsl_inner_cv_metadata() {
831        let mut graph = graph();
832        graph
833            .nodes
834            .iter_mut()
835            .find(|node| node.id.as_str() == "model:pls")
836            .unwrap()
837            .metadata
838            .insert(
839                "dsl_inner_cv".to_string(),
840                serde_json::json!({"kind": "not_a_real_kind"}),
841            );
842
843        let campaign = CampaignSpec {
844            inner_cv: None,
845            id: "campaign:inner-cv.bad".to_string(),
846            root_seed: Some(7),
847            leakage_policy: LeakageUnitPolicy::default(),
848            aggregation_policy: AggregationPolicy::default(),
849            split_invocation: None,
850            generation: Default::default(),
851            shape_plans: BTreeMap::new(),
852            data_bindings: BTreeMap::new(),
853            branch_view_plans: Vec::new(),
854            metadata: BTreeMap::new(),
855        };
856
857        let error =
858            build_execution_plan("plan:inner-cv.bad", graph, campaign, &registry()).unwrap_err();
859        assert!(matches!(error, DagMlError::Planning(_)));
860        assert!(error.to_string().contains("invalid dsl_inner_cv metadata"));
861    }
862
863    #[test]
864    fn build_execution_plan_rejects_semantically_invalid_dsl_inner_cv() {
865        // Right discriminator, invalid value: a single split is rejected at the
866        // plan boundary rather than deferred to FIT_CV fold building.
867        let mut graph = graph();
868        graph
869            .nodes
870            .iter_mut()
871            .find(|node| node.id.as_str() == "model:pls")
872            .unwrap()
873            .metadata
874            .insert(
875                "dsl_inner_cv".to_string(),
876                serde_json::json!({"kind": "kfold", "n_splits": 1, "shuffle": false, "seed": null}),
877            );
878
879        let campaign = CampaignSpec {
880            inner_cv: None,
881            id: "campaign:inner-cv.nsplits".to_string(),
882            root_seed: Some(7),
883            leakage_policy: LeakageUnitPolicy::default(),
884            aggregation_policy: AggregationPolicy::default(),
885            split_invocation: None,
886            generation: Default::default(),
887            shape_plans: BTreeMap::new(),
888            data_bindings: BTreeMap::new(),
889            branch_view_plans: Vec::new(),
890            metadata: BTreeMap::new(),
891        };
892
893        let error = build_execution_plan("plan:inner-cv.nsplits", graph, campaign, &registry())
894            .unwrap_err();
895        assert!(matches!(error, DagMlError::Planning(_)));
896        assert!(error.to_string().contains("at least two splits"));
897    }
898    use crate::data::DataBinding;
899    use crate::generation::{
900        GenerationChoice, GenerationDimension, GenerationParamOverride, GenerationStrategy,
901    };
902    use crate::graph::{
903        EdgeContract, EdgeSpec, GraphInterface, NodeSpec, PortCardinality, PortKind, PortRef,
904        PortSchema, PortSpec,
905    };
906    use crate::ids::{ControllerId, FoldId, ObservationId, SampleId, TargetId};
907    use crate::phase::Phase;
908    use crate::policy::{DataModelShapePlan, Granularity};
909    use crate::relation::{SampleRelation, SampleRelationSet};
910
911    fn port(name: &str, kind: PortKind) -> PortSpec {
912        PortSpec {
913            name: name.to_string(),
914            kind,
915            representation: None,
916            cardinality: PortCardinality::One,
917            unit_level: None,
918            alignment_key: None,
919            target_level: None,
920            description: String::new(),
921        }
922    }
923
924    fn node(id: &str, kind: NodeKind, inputs: Vec<PortSpec>, outputs: Vec<PortSpec>) -> NodeSpec {
925        NodeSpec {
926            id: NodeId::new(id).unwrap(),
927            kind,
928            operator: None,
929            params: BTreeMap::new(),
930            ports: PortSchema { inputs, outputs },
931            metadata: BTreeMap::new(),
932            seed_label: None,
933        }
934    }
935
936    fn graph() -> GraphSpec {
937        GraphSpec {
938            id: "g".to_string(),
939            interface: GraphInterface::default(),
940            nodes: vec![
941                node(
942                    "transform:snv",
943                    NodeKind::Transform,
944                    vec![],
945                    vec![port("x", PortKind::Data)],
946                ),
947                node(
948                    "model:pls",
949                    NodeKind::Model,
950                    vec![port("x", PortKind::Data)],
951                    vec![port("pred", PortKind::Prediction)],
952                ),
953            ],
954            edges: vec![EdgeSpec {
955                source: PortRef {
956                    node_id: NodeId::new("transform:snv").unwrap(),
957                    port_name: "x".to_string(),
958                },
959                target: PortRef {
960                    node_id: NodeId::new("model:pls").unwrap(),
961                    port_name: "x".to_string(),
962                },
963                contract: EdgeContract {
964                    requires_oof: false,
965                    requires_fold_alignment: false,
966                    ..EdgeContract::new(PortKind::Data, None)
967                },
968            }],
969            search_space_fingerprint: None,
970            metadata: BTreeMap::new(),
971        }
972    }
973
974    fn manifest(id: &str, kind: NodeKind) -> ControllerManifest {
975        let mut capabilities = BTreeSet::from([
976            ControllerCapability::Deterministic,
977            ControllerCapability::ThreadSafe,
978            ControllerCapability::ProcessSafe,
979        ]);
980        if kind == NodeKind::Model {
981            capabilities.insert(ControllerCapability::EmitsPredictions);
982            capabilities.insert(ControllerCapability::ConsumesOofPredictions);
983        }
984        ControllerManifest {
985            controller_id: ControllerId::new(id).unwrap(),
986            controller_version: "0.1.0".to_string(),
987            operator_kind: kind,
988            priority: 0,
989            supported_phases: BTreeSet::from([Phase::FitCv, Phase::Refit, Phase::Predict]),
990            input_ports: Vec::new(),
991            output_ports: Vec::new(),
992            data_requirements: None,
993            capabilities,
994            operator_selectors: Vec::new(),
995            fit_scope: ControllerFitScope::FoldTrain,
996            rng_policy: RngPolicy::UsesCoreSeed,
997            artifact_policy: ArtifactPolicy::Serializable,
998        }
999    }
1000
1001    fn registry() -> ControllerRegistry {
1002        let mut registry = ControllerRegistry::new();
1003        registry
1004            .register(manifest("controller:transform", NodeKind::Transform))
1005            .unwrap();
1006        registry
1007            .register(manifest("controller:model", NodeKind::Model))
1008            .unwrap();
1009        registry
1010    }
1011
1012    fn campaign(id: &str) -> CampaignSpec {
1013        CampaignSpec {
1014            id: id.to_string(),
1015            root_seed: Some(7),
1016            leakage_policy: LeakageUnitPolicy::default(),
1017            aggregation_policy: AggregationPolicy::default(),
1018            split_invocation: None,
1019            generation: Default::default(),
1020            shape_plans: BTreeMap::new(),
1021            data_bindings: BTreeMap::new(),
1022            branch_view_plans: Vec::new(),
1023            inner_cv: None,
1024            metadata: BTreeMap::new(),
1025        }
1026    }
1027
1028    fn large_linear_graph(transform_count: usize) -> GraphSpec {
1029        let mut nodes = Vec::new();
1030        let mut edges = Vec::new();
1031        for node_idx in 0..transform_count {
1032            let node_id = format!("transform:t{node_idx:04}");
1033            nodes.push(node(
1034                &node_id,
1035                NodeKind::Transform,
1036                vec![port("x", PortKind::Data)],
1037                vec![port("x", PortKind::Data)],
1038            ));
1039            if node_idx > 0 {
1040                edges.push(EdgeSpec {
1041                    source: PortRef {
1042                        node_id: NodeId::new(format!("transform:t{:04}", node_idx - 1)).unwrap(),
1043                        port_name: "x".to_string(),
1044                    },
1045                    target: PortRef {
1046                        node_id: NodeId::new(&node_id).unwrap(),
1047                        port_name: "x".to_string(),
1048                    },
1049                    contract: EdgeContract::new(PortKind::Data, None),
1050                });
1051            }
1052        }
1053        nodes.push(node(
1054            "model:final",
1055            NodeKind::Model,
1056            vec![port("x", PortKind::Data)],
1057            vec![port("pred", PortKind::Prediction)],
1058        ));
1059        edges.push(EdgeSpec {
1060            source: PortRef {
1061                node_id: NodeId::new(format!("transform:t{:04}", transform_count - 1)).unwrap(),
1062                port_name: "x".to_string(),
1063            },
1064            target: PortRef {
1065                node_id: NodeId::new("model:final").unwrap(),
1066                port_name: "x".to_string(),
1067            },
1068            contract: EdgeContract::new(PortKind::Data, None),
1069        });
1070
1071        GraphSpec {
1072            id: "g:perf.linear".to_string(),
1073            interface: GraphInterface::default(),
1074            nodes,
1075            edges,
1076            search_space_fingerprint: None,
1077            metadata: BTreeMap::new(),
1078        }
1079    }
1080
1081    fn oof_graph() -> GraphSpec {
1082        GraphSpec {
1083            id: "g:oof.capabilities".to_string(),
1084            interface: GraphInterface::default(),
1085            nodes: vec![
1086                node(
1087                    "model:base",
1088                    NodeKind::Model,
1089                    vec![],
1090                    vec![port("pred", PortKind::Prediction)],
1091                ),
1092                node(
1093                    "model:meta",
1094                    NodeKind::Model,
1095                    vec![port("pred", PortKind::Prediction)],
1096                    vec![port("pred", PortKind::Prediction)],
1097                ),
1098            ],
1099            edges: vec![EdgeSpec {
1100                source: PortRef {
1101                    node_id: NodeId::new("model:base").unwrap(),
1102                    port_name: "pred".to_string(),
1103                },
1104                target: PortRef {
1105                    node_id: NodeId::new("model:meta").unwrap(),
1106                    port_name: "pred".to_string(),
1107                },
1108                contract: EdgeContract {
1109                    requires_oof: true,
1110                    requires_fold_alignment: true,
1111                    ..EdgeContract::new(PortKind::Prediction, None)
1112                },
1113            }],
1114            search_space_fingerprint: None,
1115            metadata: BTreeMap::new(),
1116        }
1117    }
1118
1119    fn data_binding(node_id: &NodeId) -> DataBinding {
1120        DataBinding {
1121            node_id: node_id.clone(),
1122            input_name: "x".to_string(),
1123            request_id: "nir-to-tabular".to_string(),
1124            schema_fingerprint: "f97b37872fa22134b508f98fd8e207e5b776b52594fb8f6f5c3e15bee212246b"
1125                .to_string(),
1126            plan_fingerprint: "7c5431d85574b3f337022fa5d25971d5b5cf445b90331b49938f573ff6901e4d"
1127                .to_string(),
1128            relation_fingerprint: Some(
1129                "a3a7e329df35db9f2883a17b8611b7fae6dcaa031875e3ec2c9be1b9e29cbe10".to_string(),
1130            ),
1131            output_representation: "tabular_numeric".to_string(),
1132            feature_set_id: Some("x".to_string()),
1133            source_ids: vec!["nir".to_string()],
1134            require_relations: true,
1135            view_policy: Default::default(),
1136            metadata: BTreeMap::new(),
1137        }
1138    }
1139
1140    fn levels_as_strings(levels: &[Vec<NodeId>]) -> Vec<Vec<String>> {
1141        levels
1142            .iter()
1143            .map(|level| level.iter().map(ToString::to_string).collect())
1144            .collect()
1145    }
1146
1147    #[test]
1148    fn published_campaign_spec_schema_declares_current_contract() {
1149        let schema: serde_json::Value = serde_json::from_str(include_str!(
1150            "../../../docs/contracts/campaign_spec.schema.json"
1151        ))
1152        .unwrap();
1153
1154        assert_eq!(schema["$id"], CAMPAIGN_SPEC_SCHEMA_ID);
1155        assert!(schema["required"]
1156            .as_array()
1157            .unwrap()
1158            .iter()
1159            .any(|field| field.as_str() == Some("id")));
1160        assert!(schema["$defs"]["split_invocation"]["properties"]
1161            .as_object()
1162            .unwrap()
1163            .contains_key("fold_set"));
1164        assert!(schema["$defs"]["aggregation_policy"]["properties"]
1165            .as_object()
1166            .unwrap()
1167            .contains_key("selection_metric_level"));
1168        assert!(schema["$defs"]["aggregation_policy"]["properties"]
1169            .as_object()
1170            .unwrap()
1171            .contains_key("custom_controller"));
1172        assert!(schema["$defs"]["data_binding"]["properties"]
1173            .as_object()
1174            .unwrap()
1175            .contains_key("view_policy"));
1176        assert!(schema["properties"]
1177            .as_object()
1178            .unwrap()
1179            .contains_key("branch_view_plans"));
1180        assert!(schema["$defs"]["branch_view_plan"]["properties"]
1181            .as_object()
1182            .unwrap()
1183            .contains_key("selector"));
1184    }
1185
1186    #[test]
1187    fn published_execution_plan_schema_declares_current_contract() {
1188        let schema: serde_json::Value = serde_json::from_str(include_str!(
1189            "../../../docs/contracts/execution_plan.schema.json"
1190        ))
1191        .unwrap();
1192
1193        assert_eq!(schema["$id"], EXECUTION_PLAN_SCHEMA_ID);
1194        assert!(schema["required"]
1195            .as_array()
1196            .unwrap()
1197            .iter()
1198            .any(|field| field.as_str() == Some("node_plans")));
1199        assert!(schema["properties"]
1200            .as_object()
1201            .unwrap()
1202            .contains_key("controller_fingerprint"));
1203        assert!(schema["$defs"]["node_plan"]["properties"]
1204            .as_object()
1205            .unwrap()
1206            .contains_key("shape_plan"));
1207        assert!(schema["$defs"]["variant_plan"]["properties"]
1208            .as_object()
1209            .unwrap()
1210            .contains_key("choices"));
1211    }
1212
1213    #[test]
1214    fn published_execution_plan_fixture_validates_current_contract() {
1215        let plan: ExecutionPlan = serde_json::from_str(include_str!(
1216            "../../../examples/fixtures/runtime/execution_plan_branch_merge_executable.json"
1217        ))
1218        .unwrap();
1219
1220        plan.validate().unwrap();
1221        assert_eq!(plan.id, "plan:fixture.execution.branch_merge");
1222        assert_eq!(plan.variants.len(), 2);
1223        assert_eq!(plan.node_plans.len(), plan.graph_plan.graph.nodes.len());
1224    }
1225
1226    #[test]
1227    #[ignore = "perf sanity probe; run with --release --ignored --nocapture"]
1228    fn build_execution_plan_large_linear_graph_under_1500ms() {
1229        let started = Instant::now();
1230        let plan = build_execution_plan(
1231            "plan:perf.linear",
1232            large_linear_graph(400),
1233            campaign("campaign:perf.linear"),
1234            &registry(),
1235        )
1236        .unwrap();
1237        let elapsed = started.elapsed();
1238
1239        assert_eq!(plan.graph_plan.topological_order.len(), 401);
1240        assert_eq!(plan.node_plans.len(), 401);
1241        assert!(
1242            elapsed <= Duration::from_millis(1_500),
1243            "large execution-plan build took {elapsed:?}"
1244        );
1245    }
1246
1247    #[test]
1248    fn builds_execution_plan_with_shape_and_fold_contracts() {
1249        let model_id = NodeId::new("model:pls").unwrap();
1250        let campaign = CampaignSpec {
1251            inner_cv: None,
1252            id: "campaign:oof".to_string(),
1253            root_seed: Some(7),
1254            leakage_policy: LeakageUnitPolicy::default(),
1255            aggregation_policy: AggregationPolicy::default(),
1256            split_invocation: Some(SplitInvocation {
1257                id: "split:outer".to_string(),
1258                controller_id: None,
1259                leakage_policy: LeakageUnitPolicy::default(),
1260                params: BTreeMap::new(),
1261                fold_set: Some(FoldSet {
1262                    id: "outer".to_string(),
1263                    sample_ids: vec![SampleId::new("s1").unwrap(), SampleId::new("s2").unwrap()],
1264                    folds: vec![
1265                        crate::fold::FoldAssignment {
1266                            fold_id: FoldId::new("fold0").unwrap(),
1267                            train_sample_ids: vec![SampleId::new("s2").unwrap()],
1268                            validation_sample_ids: vec![SampleId::new("s1").unwrap()],
1269                            metadata: BTreeMap::new(),
1270                        },
1271                        crate::fold::FoldAssignment {
1272                            fold_id: FoldId::new("fold1").unwrap(),
1273                            train_sample_ids: vec![SampleId::new("s1").unwrap()],
1274                            validation_sample_ids: vec![SampleId::new("s2").unwrap()],
1275                            metadata: BTreeMap::new(),
1276                        },
1277                    ],
1278                    sample_groups: BTreeMap::new(),
1279                }),
1280            }),
1281            generation: Default::default(),
1282            shape_plans: BTreeMap::from([(
1283                model_id.clone(),
1284                DataModelShapePlan {
1285                    node_id: model_id.clone(),
1286                    input_granularity: Granularity::Observation,
1287                    ..DataModelShapePlan {
1288                        node_id: model_id.clone(),
1289                        input_granularity: Granularity::Sample,
1290                        target_granularity: Granularity::Sample,
1291                        fit_rows: crate::policy::FitBoundary::FoldTrain,
1292                        predict_rows: crate::policy::FitBoundary::FoldValidation,
1293                        feature_namespace: None,
1294                        feature_schema_fingerprint: None,
1295                        target_space: "raw".to_string(),
1296                        aggregation_policy: AggregationPolicy::default(),
1297                        augmentation_policy: crate::policy::AugmentationPolicy::default(),
1298                        selection_policy: crate::policy::FeatureSelectionPolicy::default(),
1299                    }
1300                },
1301            )]),
1302            data_bindings: BTreeMap::from([(model_id.clone(), vec![data_binding(&model_id)])]),
1303            branch_view_plans: Vec::new(),
1304            metadata: BTreeMap::new(),
1305        };
1306
1307        let plan = build_execution_plan("plan:oof", graph(), campaign, &registry()).unwrap();
1308
1309        assert_eq!(
1310            plan.graph_plan
1311                .topological_order
1312                .iter()
1313                .map(ToString::to_string)
1314                .collect::<Vec<_>>(),
1315            vec!["transform:snv", "model:pls"]
1316        );
1317        assert_eq!(
1318            levels_as_strings(&plan.graph_plan.parallel_levels),
1319            vec![vec!["transform:snv"], vec!["model:pls"]]
1320        );
1321        assert!(plan.node_plans[&model_id]
1322            .controller_capabilities
1323            .contains(&ControllerCapability::EmitsPredictions));
1324        assert!(plan.fold_set.is_some());
1325        let schedule = plan.campaign_phase_schedule(Phase::FitCv).unwrap();
1326        assert_eq!(schedule.scopes.len(), 2);
1327        assert!(schedule.scopes[0].scope_id.starts_with("scope:FIT_CV:"));
1328        assert!(schedule
1329            .scopes
1330            .iter()
1331            .all(|scope| levels_as_strings(&scope.node_levels)
1332                == vec![vec!["transform:snv"], vec!["model:pls"]]));
1333        assert_eq!(
1334            schedule
1335                .scopes
1336                .iter()
1337                .filter_map(|scope| scope.fold_id.as_ref().map(ToString::to_string))
1338                .collect::<Vec<_>>(),
1339            vec!["fold0", "fold1"]
1340        );
1341        assert_eq!(
1342            plan.node_plans
1343                .get(&model_id)
1344                .unwrap()
1345                .controller_id
1346                .as_str(),
1347            "controller:model"
1348        );
1349        assert_eq!(
1350            plan.node_plans.get(&model_id).unwrap().data_bindings.len(),
1351            1
1352        );
1353
1354        let mut bad_plan = plan.clone();
1355        bad_plan.graph_plan.parallel_levels =
1356            vec![vec![model_id], vec![NodeId::new("transform:snv").unwrap()]];
1357        assert!(bad_plan
1358            .validate()
1359            .unwrap_err()
1360            .to_string()
1361            .contains("parallel levels"));
1362
1363        let bad_envelope = ExternalDataPlanEnvelope {
1364            schema_version: crate::data::EXTERNAL_DATA_PLAN_ENVELOPE_SCHEMA_VERSION,
1365            schema_fingerprint: "f97b37872fa22134b508f98fd8e207e5b776b52594fb8f6f5c3e15bee212246b"
1366                .to_string(),
1367            plan_fingerprint: "7c5431d85574b3f337022fa5d25971d5b5cf445b90331b49938f573ff6901e4d"
1368                .to_string(),
1369            relation_fingerprint: None,
1370            coordinator_relations: Some(SampleRelationSet {
1371                records: vec![{
1372                    let mut relation = SampleRelation::new(
1373                        ObservationId::new("obs:outside").unwrap(),
1374                        SampleId::new("sample:outside").unwrap(),
1375                    );
1376                    relation.target_id = Some(TargetId::new("target:outside").unwrap());
1377                    relation.source_id = Some("nir".to_string());
1378                    relation
1379                }],
1380            }),
1381        };
1382        assert!(plan
1383            .campaign
1384            .validate_data_envelope_relations(&bad_envelope)
1385            .unwrap_err()
1386            .to_string()
1387            .contains("outside fold set"));
1388    }
1389
1390    #[test]
1391    fn planning_refuses_shape_plan_for_unknown_node() {
1392        let campaign = CampaignSpec {
1393            inner_cv: None,
1394            id: "campaign:oof".to_string(),
1395            root_seed: Some(7),
1396            leakage_policy: LeakageUnitPolicy::default(),
1397            aggregation_policy: AggregationPolicy::default(),
1398            split_invocation: None,
1399            generation: Default::default(),
1400            shape_plans: BTreeMap::from([(
1401                NodeId::new("model:missing").unwrap(),
1402                DataModelShapePlan {
1403                    node_id: NodeId::new("model:missing").unwrap(),
1404                    input_granularity: Granularity::Sample,
1405                    target_granularity: Granularity::Sample,
1406                    fit_rows: crate::policy::FitBoundary::FoldTrain,
1407                    predict_rows: crate::policy::FitBoundary::FoldValidation,
1408                    feature_namespace: None,
1409                    feature_schema_fingerprint: None,
1410                    target_space: "raw".to_string(),
1411                    aggregation_policy: AggregationPolicy::default(),
1412                    augmentation_policy: crate::policy::AugmentationPolicy::default(),
1413                    selection_policy: crate::policy::FeatureSelectionPolicy::default(),
1414                },
1415            )]),
1416            data_bindings: BTreeMap::new(),
1417            branch_view_plans: Vec::new(),
1418            metadata: BTreeMap::new(),
1419        };
1420
1421        assert!(build_execution_plan("plan:oof", graph(), campaign, &registry()).is_err());
1422    }
1423
1424    #[test]
1425    fn planning_refuses_oof_edge_without_controller_capabilities() {
1426        let mut registry = ControllerRegistry::new();
1427        let mut model_manifest = manifest("controller:model", NodeKind::Model);
1428        model_manifest
1429            .capabilities
1430            .remove(&ControllerCapability::ConsumesOofPredictions);
1431        registry.register(model_manifest).unwrap();
1432
1433        let err = build_execution_plan(
1434            "plan:oof.capability",
1435            oof_graph(),
1436            CampaignSpec {
1437                inner_cv: None,
1438                id: "campaign:oof.capability".to_string(),
1439                root_seed: Some(11),
1440                leakage_policy: Default::default(),
1441                aggregation_policy: Default::default(),
1442                split_invocation: None,
1443                generation: Default::default(),
1444                shape_plans: BTreeMap::new(),
1445                data_bindings: BTreeMap::new(),
1446                branch_view_plans: Vec::new(),
1447                metadata: BTreeMap::new(),
1448            },
1449            &registry,
1450        )
1451        .unwrap_err();
1452
1453        assert!(err.to_string().contains("consumes_oof_predictions"));
1454    }
1455
1456    #[test]
1457    fn parallel_controller_capability_validation_requires_safe_manifest() {
1458        let mut registry = ControllerRegistry::new();
1459        let mut transform_manifest = manifest("controller:transform", NodeKind::Transform);
1460        transform_manifest
1461            .capabilities
1462            .remove(&ControllerCapability::ThreadSafe);
1463        transform_manifest
1464            .capabilities
1465            .remove(&ControllerCapability::ProcessSafe);
1466        registry.register(transform_manifest).unwrap();
1467        registry
1468            .register(manifest("controller:model", NodeKind::Model))
1469            .unwrap();
1470        let plan = build_execution_plan(
1471            "plan:parallel.capability",
1472            graph(),
1473            CampaignSpec {
1474                inner_cv: None,
1475                id: "campaign:parallel.capability".to_string(),
1476                root_seed: Some(11),
1477                leakage_policy: Default::default(),
1478                aggregation_policy: Default::default(),
1479                split_invocation: None,
1480                generation: Default::default(),
1481                shape_plans: BTreeMap::new(),
1482                data_bindings: BTreeMap::new(),
1483                branch_view_plans: Vec::new(),
1484                metadata: BTreeMap::new(),
1485            },
1486            &registry,
1487        )
1488        .unwrap();
1489
1490        assert!(plan
1491            .validate_parallel_controller_capabilities(1, Phase::FitCv)
1492            .is_ok());
1493        let err = plan
1494            .validate_parallel_controller_capabilities(2, Phase::FitCv)
1495            .unwrap_err();
1496        assert!(err.to_string().contains("thread_safe or process_safe"));
1497    }
1498
1499    #[test]
1500    fn planning_refuses_generation_override_for_unknown_node() {
1501        let campaign = CampaignSpec {
1502            inner_cv: None,
1503            id: "campaign:oof".to_string(),
1504            root_seed: Some(7),
1505            leakage_policy: LeakageUnitPolicy::default(),
1506            aggregation_policy: AggregationPolicy::default(),
1507            split_invocation: None,
1508            generation: GenerationSpec {
1509                strategy: GenerationStrategy::Cartesian,
1510                dimensions: vec![GenerationDimension {
1511                    name: "model_family".to_string(),
1512                    choices: vec![GenerationChoice {
1513                        label: "pls".to_string(),
1514                        value: serde_json::json!("pls"),
1515                        param_overrides: vec![GenerationParamOverride {
1516                            node_id: NodeId::new("model:missing").unwrap(),
1517                            params: BTreeMap::from([(
1518                                "n_components".to_string(),
1519                                serde_json::json!(8),
1520                            )]),
1521                        }],
1522                    }],
1523                }],
1524                max_variants: Some(1),
1525            },
1526            shape_plans: BTreeMap::new(),
1527            data_bindings: BTreeMap::new(),
1528            branch_view_plans: Vec::new(),
1529            metadata: BTreeMap::new(),
1530        };
1531
1532        let error = build_execution_plan("plan:oof", graph(), campaign, &registry())
1533            .unwrap_err()
1534            .to_string();
1535
1536        assert!(error.contains("overrides params for unknown node"));
1537    }
1538
1539    #[test]
1540    fn planning_validates_declared_search_space_fingerprint() {
1541        let campaign = CampaignSpec {
1542            inner_cv: None,
1543            id: "campaign:search.fingerprint".to_string(),
1544            root_seed: Some(7),
1545            leakage_policy: LeakageUnitPolicy::default(),
1546            aggregation_policy: AggregationPolicy::default(),
1547            split_invocation: None,
1548            generation: GenerationSpec {
1549                strategy: GenerationStrategy::Cartesian,
1550                dimensions: vec![GenerationDimension {
1551                    name: "model_family".to_string(),
1552                    choices: vec![GenerationChoice {
1553                        label: "pls".to_string(),
1554                        value: serde_json::json!("pls"),
1555                        param_overrides: vec![GenerationParamOverride {
1556                            node_id: NodeId::new("model:pls").unwrap(),
1557                            params: BTreeMap::from([(
1558                                "n_components".to_string(),
1559                                serde_json::json!(8),
1560                            )]),
1561                        }],
1562                    }],
1563                }],
1564                max_variants: Some(1),
1565            },
1566            shape_plans: BTreeMap::new(),
1567            data_bindings: BTreeMap::new(),
1568            branch_view_plans: Vec::new(),
1569            metadata: BTreeMap::new(),
1570        };
1571        let mut graph = graph();
1572        graph.search_space_fingerprint =
1573            Some(generation_spec_fingerprint(&campaign.generation).unwrap());
1574
1575        let plan = build_execution_plan(
1576            "plan:search.fingerprint",
1577            graph.clone(),
1578            campaign.clone(),
1579            &registry(),
1580        )
1581        .unwrap();
1582        assert_eq!(plan.variants.len(), 1);
1583
1584        graph.search_space_fingerprint = Some("sha256:not-the-generation-spec".to_string());
1585        let error = build_execution_plan("plan:search.fingerprint", graph, campaign, &registry())
1586            .unwrap_err()
1587            .to_string();
1588        assert!(error.contains("search_space_fingerprint"));
1589    }
1590
1591    #[test]
1592    fn branch_view_lookup_helpers_match_by_branch_id_and_innermost_path() {
1593        use crate::data::{BranchViewMode, DataViewSelector};
1594
1595        let outer = BranchViewPlan {
1596            view_id: "branch_view:outer".to_string(),
1597            branch_id: "branch:outer".to_string(),
1598            mode: BranchViewMode::BySource,
1599            selector: DataViewSelector {
1600                source_ids: vec!["nir".to_string()],
1601                ..Default::default()
1602            },
1603            allow_overlap: false,
1604            metadata: BTreeMap::new(),
1605        };
1606        let inner = BranchViewPlan {
1607            view_id: "branch_view:inner".to_string(),
1608            branch_id: "branch:inner".to_string(),
1609            mode: BranchViewMode::Separation,
1610            selector: DataViewSelector {
1611                source_ids: vec!["chem".to_string()],
1612                ..Default::default()
1613            },
1614            allow_overlap: false,
1615            metadata: BTreeMap::new(),
1616        };
1617        let plans = vec![outer.clone(), inner.clone()];
1618
1619        assert_eq!(
1620            super::branch_view_for_in(&plans, "branch:outer"),
1621            Some(&outer)
1622        );
1623        assert_eq!(
1624            super::branch_view_for_in(&plans, "branch:inner"),
1625            Some(&inner)
1626        );
1627        assert_eq!(super::branch_view_for_in(&plans, "branch:missing"), None);
1628
1629        let path = vec!["branch:outer".to_string(), "branch:inner".to_string()];
1630        // tip-first: innermost matching branch wins
1631        assert_eq!(super::branch_view_for_path_in(&plans, &path), Some(&inner));
1632
1633        let path_outer_only = vec!["branch:outer".to_string()];
1634        assert_eq!(
1635            super::branch_view_for_path_in(&plans, &path_outer_only),
1636            Some(&outer)
1637        );
1638
1639        let empty_path: Vec<String> = Vec::new();
1640        assert_eq!(super::branch_view_for_path_in(&plans, &empty_path), None);
1641
1642        let path_no_match = vec!["branch:other".to_string()];
1643        assert_eq!(super::branch_view_for_path_in(&plans, &path_no_match), None);
1644    }
1645}