1use std::collections::{BTreeMap, BTreeSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::campaign::stable_json_fingerprint;
6use crate::controller::{
7 ArtifactPolicy, ControllerCapability, ControllerFitScope, ControllerManifest,
8 ControllerRegistry, RngPolicy,
9};
10use crate::data::{BranchViewPlan, DataBinding, ExternalDataPlanEnvelope};
11use crate::error::{DagMlError, Result};
12use crate::fold::{FoldSet, NestedCvSpec};
13use crate::generation::{
14 enumerate_variants, generation_spec_fingerprint, GenerationSpec, VariantPlan,
15};
16use crate::graph::{GraphSpec, NodeKind};
17use crate::ids::{ControllerId, FoldId, NodeId, VariantId};
18use crate::phase::Phase;
19use crate::policy::{AggregationPolicy, DataModelShapePlan, LeakageUnitPolicy};
20
21pub const CAMPAIGN_SPEC_SCHEMA_VERSION: u32 = 1;
22pub const CAMPAIGN_SPEC_SCHEMA_ID: &str =
23 "https://github.com/GBeurier/dag-ml/schemas/campaign_spec.v1.schema.json";
24pub const EXECUTION_PLAN_SCHEMA_VERSION: u32 = 1;
25pub const EXECUTION_PLAN_SCHEMA_ID: &str =
26 "https://github.com/GBeurier/dag-ml/schemas/execution_plan.v1.schema.json";
27
28#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
29pub struct SplitInvocation {
30 pub id: String,
31 #[serde(default)]
32 pub controller_id: Option<ControllerId>,
33 #[serde(default)]
34 pub leakage_policy: LeakageUnitPolicy,
35 #[serde(default)]
36 pub params: BTreeMap<String, serde_json::Value>,
37 #[serde(default)]
38 pub fold_set: Option<FoldSet>,
39}
40
41impl SplitInvocation {
42 pub fn validate(&self) -> Result<()> {
43 if self.id.trim().is_empty() {
44 return Err(DagMlError::CampaignValidation(
45 "split invocation id is empty".to_string(),
46 ));
47 }
48 self.leakage_policy.validate()?;
49 if let Some(fold_set) = &self.fold_set {
50 fold_set.validate()?;
51 }
52 Ok(())
53 }
54}
55
56#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
57pub struct CampaignSpec {
58 pub id: String,
59 pub root_seed: Option<u64>,
60 #[serde(default)]
61 pub leakage_policy: LeakageUnitPolicy,
62 #[serde(default)]
63 pub aggregation_policy: AggregationPolicy,
64 #[serde(default)]
65 pub split_invocation: Option<SplitInvocation>,
66 #[serde(default)]
67 pub generation: GenerationSpec,
68 #[serde(default)]
69 pub shape_plans: BTreeMap<NodeId, DataModelShapePlan>,
70 #[serde(default)]
71 pub data_bindings: BTreeMap<NodeId, Vec<DataBinding>>,
72 #[serde(default, skip_serializing_if = "Vec::is_empty")]
73 pub branch_view_plans: Vec<BranchViewPlan>,
74 #[serde(default, skip_serializing_if = "Option::is_none")]
77 pub inner_cv: Option<NestedCvSpec>,
78 #[serde(default)]
79 pub metadata: BTreeMap<String, serde_json::Value>,
80}
81
82impl CampaignSpec {
83 pub fn validate(&self) -> Result<()> {
84 if self.id.trim().is_empty() {
85 return Err(DagMlError::CampaignValidation(
86 "campaign id is empty".to_string(),
87 ));
88 }
89 self.leakage_policy.validate()?;
90 self.aggregation_policy.validate()?;
91 if let Some(inner_cv) = &self.inner_cv {
92 inner_cv.validate()?;
93 }
94 if let Some(split) = &self.split_invocation {
95 split.validate()?;
96 }
97 self.generation.validate()?;
98 for (node_id, shape_plan) in &self.shape_plans {
99 if node_id != &shape_plan.node_id {
100 return Err(DagMlError::CampaignValidation(format!(
101 "shape plan key `{node_id}` does not match node_id `{}`",
102 shape_plan.node_id
103 )));
104 }
105 shape_plan.validate()?;
106 }
107 for (node_id, bindings) in &self.data_bindings {
108 for binding in bindings {
109 if node_id != &binding.node_id {
110 return Err(DagMlError::CampaignValidation(format!(
111 "data binding key `{node_id}` does not match node_id `{}`",
112 binding.node_id
113 )));
114 }
115 binding.validate()?;
116 }
117 }
118 let mut branch_views = BTreeSet::new();
119 for plan in &self.branch_view_plans {
120 plan.validate()?;
121 if !branch_views.insert(plan.view_id.as_str()) {
122 return Err(DagMlError::CampaignValidation(format!(
123 "campaign `{}` contains duplicate branch view `{}`",
124 self.id, plan.view_id
125 )));
126 }
127 }
128 Ok(())
129 }
130
131 pub fn validate_data_envelope_relations(
132 &self,
133 envelope: &ExternalDataPlanEnvelope,
134 ) -> Result<()> {
135 envelope.validate()?;
136 let Some(relations) = &envelope.coordinator_relations else {
137 return Ok(());
138 };
139 let Some(split) = &self.split_invocation else {
140 return Ok(());
141 };
142 let Some(fold_set) = &split.fold_set else {
143 return Ok(());
144 };
145 relations.validate_against_fold_set(fold_set, &self.leakage_policy)?;
146 relations.validate_against_fold_set(fold_set, &split.leakage_policy)
147 }
148}
149
150#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
151pub struct GraphPlan {
152 pub graph: GraphSpec,
153 pub topological_order: Vec<NodeId>,
154 #[serde(default, skip_serializing_if = "Vec::is_empty")]
155 pub parallel_levels: Vec<Vec<NodeId>>,
156}
157
158impl GraphPlan {
159 pub fn from_graph(graph: GraphSpec) -> Result<Self> {
160 let topological_order = graph.topological_order()?;
161 let parallel_levels = graph.parallel_levels()?;
162 Ok(Self {
163 graph,
164 topological_order,
165 parallel_levels,
166 })
167 }
168
169 pub fn parallel_levels(&self) -> Result<Vec<Vec<NodeId>>> {
170 if self.parallel_levels.is_empty() {
171 return self.graph.parallel_levels();
172 }
173 Ok(self.parallel_levels.clone())
174 }
175}
176
177#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
178pub struct NodePlan {
179 pub node_id: NodeId,
180 pub kind: NodeKind,
181 pub controller_id: ControllerId,
182 pub controller_version: String,
183 pub supported_phases: BTreeSet<Phase>,
184 #[serde(default)]
185 pub controller_capabilities: BTreeSet<ControllerCapability>,
186 pub fit_scope: ControllerFitScope,
187 pub rng_policy: RngPolicy,
188 pub artifact_policy: ArtifactPolicy,
189 pub input_nodes: Vec<NodeId>,
190 pub output_nodes: Vec<NodeId>,
191 pub shape_plan: Option<DataModelShapePlan>,
192 #[serde(default)]
193 pub data_bindings: Vec<DataBinding>,
194 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
195 pub params: BTreeMap<String, serde_json::Value>,
196 #[serde(default, skip_serializing_if = "Option::is_none")]
199 pub inner_cv: Option<NestedCvSpec>,
200 pub params_fingerprint: String,
201}
202
203#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
204pub struct ExecutionPlan {
205 pub id: String,
206 pub graph_plan: GraphPlan,
207 pub campaign: CampaignSpec,
208 pub node_plans: BTreeMap<NodeId, NodePlan>,
209 pub controller_manifests: BTreeMap<ControllerId, ControllerManifest>,
210 pub variants: Vec<VariantPlan>,
211 pub fold_set: Option<FoldSet>,
212 pub graph_fingerprint: String,
213 pub campaign_fingerprint: String,
214 pub controller_fingerprint: String,
215}
216
217#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
218pub struct ExecutionScopePlan {
219 pub scope_id: String,
220 pub phase: Phase,
221 pub variant_id: Option<VariantId>,
222 pub variant_seed: Option<u64>,
223 pub fold_id: Option<FoldId>,
224 pub node_levels: Vec<Vec<NodeId>>,
225}
226
227#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
228pub struct PhaseExecutionSchedule {
229 pub plan_id: String,
230 pub phase: Phase,
231 pub scopes: Vec<ExecutionScopePlan>,
232}
233
234impl ExecutionPlan {
235 pub fn validate(&self) -> Result<()> {
236 self.graph_plan.graph.validate()?;
237 self.campaign.validate()?;
238 if !self.graph_plan.parallel_levels.is_empty()
239 && self.graph_plan.parallel_levels != self.graph_plan.graph.parallel_levels()?
240 {
241 return Err(DagMlError::Planning(
242 "graph plan parallel levels do not match graph".to_string(),
243 ));
244 }
245 if self.node_plans.len() != self.graph_plan.graph.nodes.len() {
246 return Err(DagMlError::Planning(
247 "execution plan node count does not match graph".to_string(),
248 ));
249 }
250 for node_id in &self.graph_plan.topological_order {
251 let plan = self.node_plans.get(node_id).ok_or_else(|| {
252 DagMlError::Planning(format!("missing node plan for `{node_id}`"))
253 })?;
254 let manifest = self
255 .controller_manifests
256 .get(&plan.controller_id)
257 .ok_or_else(|| {
258 DagMlError::Planning(format!(
259 "missing controller manifest `{}` for node `{node_id}`",
260 plan.controller_id
261 ))
262 })?;
263 if manifest.operator_kind != plan.kind {
264 return Err(DagMlError::Planning(format!(
265 "node `{node_id}` planned with incompatible controller `{}`",
266 manifest.controller_id
267 )));
268 }
269 if plan.controller_capabilities != manifest.capabilities {
270 return Err(DagMlError::Planning(format!(
271 "node `{node_id}` controller capabilities do not match manifest `{}`",
272 manifest.controller_id
273 )));
274 }
275 if plan.fit_scope != manifest.fit_scope
276 || plan.rng_policy != manifest.rng_policy
277 || plan.artifact_policy != manifest.artifact_policy
278 {
279 return Err(DagMlError::Planning(format!(
280 "node `{node_id}` controller policy fields do not match manifest `{}`",
281 manifest.controller_id
282 )));
283 }
284 for binding in &plan.data_bindings {
285 if binding.node_id != *node_id {
286 return Err(DagMlError::Planning(format!(
287 "node plan `{node_id}` contains data binding for `{}`",
288 binding.node_id
289 )));
290 }
291 binding.validate()?;
292 }
293 let actual_params_fingerprint = stable_json_fingerprint(&plan.params)?;
294 if actual_params_fingerprint != plan.params_fingerprint {
295 return Err(DagMlError::Planning(format!(
296 "node plan `{node_id}` params fingerprint does not match params"
297 )));
298 }
299 }
300 for (node_id, plan) in &self.node_plans {
306 if let Some(inner_cv) = &plan.inner_cv {
307 inner_cv.validate().map_err(|error| {
308 DagMlError::Planning(format!(
309 "node plan `{node_id}` has invalid inner_cv: {error}"
310 ))
311 })?;
312 }
313 }
314 self.validate_oof_controller_capabilities()?;
315 if let Some(fold_set) = &self.fold_set {
316 fold_set.validate()?;
317 }
318 if self.variants.is_empty() {
319 return Err(DagMlError::Planning(
320 "execution plan has no variants".to_string(),
321 ));
322 }
323 for variant in &self.variants {
324 variant.validate()?;
325 }
326 Ok(())
327 }
328
329 pub fn validate_parallel_controller_capabilities(
330 &self,
331 max_workers: usize,
332 phase: Phase,
333 ) -> Result<()> {
334 if max_workers <= 1 {
335 return Ok(());
336 }
337 let node_ids = self
338 .node_parallel_levels_for_phase(phase)?
339 .into_iter()
340 .flatten()
341 .collect::<Vec<_>>();
342 for node_id in node_ids {
343 let node_plan = self.node_plans.get(&node_id).ok_or_else(|| {
344 DagMlError::Planning(format!("missing node plan for `{node_id}`"))
345 })?;
346 let manifest = self
347 .controller_manifests
348 .get(&node_plan.controller_id)
349 .ok_or_else(|| {
350 DagMlError::Planning(format!(
351 "missing controller manifest `{}` for node `{}`",
352 node_plan.controller_id, node_plan.node_id
353 ))
354 })?;
355 if !manifest.supports_parallel_invocation() {
356 return Err(DagMlError::Planning(format!(
357 "parallel scheduler with {max_workers} workers requires controller `{}` for node `{}` to declare thread_safe or process_safe",
358 manifest.controller_id, node_plan.node_id
359 )));
360 }
361 }
362 Ok(())
363 }
364
365 fn validate_oof_controller_capabilities(&self) -> Result<()> {
366 for edge in &self.graph_plan.graph.edges {
367 if !edge.contract.requires_oof {
368 continue;
369 }
370 let source_plan = self.node_plans.get(&edge.source.node_id).ok_or_else(|| {
371 DagMlError::Planning(format!(
372 "OOF edge source node `{}` has no node plan",
373 edge.source.node_id
374 ))
375 })?;
376 if !source_plan
377 .controller_capabilities
378 .contains(&ControllerCapability::EmitsPredictions)
379 {
380 return Err(DagMlError::Planning(format!(
381 "OOF edge `{}.{}` -> `{}.{}` requires source controller `{}` to declare emits_predictions",
382 edge.source.node_id,
383 edge.source.port_name,
384 edge.target.node_id,
385 edge.target.port_name,
386 source_plan.controller_id
387 )));
388 }
389 let target_plan = self.node_plans.get(&edge.target.node_id).ok_or_else(|| {
390 DagMlError::Planning(format!(
391 "OOF edge target node `{}` has no node plan",
392 edge.target.node_id
393 ))
394 })?;
395 if !target_plan
396 .controller_capabilities
397 .contains(&ControllerCapability::ConsumesOofPredictions)
398 {
399 return Err(DagMlError::Planning(format!(
400 "OOF edge `{}.{}` -> `{}.{}` requires target controller `{}` to declare consumes_oof_predictions",
401 edge.source.node_id,
402 edge.source.port_name,
403 edge.target.node_id,
404 edge.target.port_name,
405 target_plan.controller_id
406 )));
407 }
408 }
409 Ok(())
410 }
411
412 pub fn node_parallel_levels_for_phase(&self, phase: Phase) -> Result<Vec<Vec<NodeId>>> {
413 let levels = self
414 .graph_plan
415 .parallel_levels()?
416 .into_iter()
417 .map(|level| {
418 level
419 .into_iter()
420 .filter(|node_id| {
421 self.node_plans
422 .get(node_id)
423 .is_some_and(|node_plan| node_plan.supported_phases.contains(&phase))
424 })
425 .collect::<Vec<_>>()
426 })
427 .filter(|level| !level.is_empty())
428 .collect::<Vec<_>>();
429 Ok(levels)
430 }
431
432 pub fn campaign_phase_schedule(&self, phase: Phase) -> Result<PhaseExecutionSchedule> {
433 self.validate()?;
434 let node_levels = self.node_parallel_levels_for_phase(phase)?;
435 let fold_ids = if phase == Phase::FitCv {
436 self.fold_set
437 .as_ref()
438 .map(|fold_set| {
439 fold_set
440 .folds
441 .iter()
442 .map(|fold| Some(fold.fold_id.clone()))
443 .collect::<Vec<_>>()
444 })
445 .unwrap_or_else(|| vec![None])
446 } else {
447 vec![None]
448 };
449 let mut scopes = Vec::new();
450 for variant in &self.variants {
451 for fold_id in &fold_ids {
452 scopes.push(ExecutionScopePlan {
453 scope_id: execution_scope_id(
454 phase,
455 Some(&variant.variant_id),
456 fold_id.as_ref(),
457 ),
458 phase,
459 variant_id: Some(variant.variant_id.clone()),
460 variant_seed: variant.seed,
461 fold_id: fold_id.clone(),
462 node_levels: node_levels.clone(),
463 });
464 }
465 }
466 Ok(PhaseExecutionSchedule {
467 plan_id: self.id.clone(),
468 phase,
469 scopes,
470 })
471 }
472
473 pub fn branch_view_for(&self, branch_id: &str) -> Option<&BranchViewPlan> {
477 branch_view_for_in(&self.campaign.branch_view_plans, branch_id)
478 }
479
480 pub fn branch_view_for_path(&self, branch_path: &[String]) -> Option<&BranchViewPlan> {
487 branch_view_for_path_in(&self.campaign.branch_view_plans, branch_path)
488 }
489}
490
491fn branch_view_for_in<'a>(
492 plans: &'a [BranchViewPlan],
493 branch_id: &str,
494) -> Option<&'a BranchViewPlan> {
495 plans.iter().find(|plan| plan.branch_id == branch_id)
496}
497
498fn branch_view_for_path_in<'a>(
499 plans: &'a [BranchViewPlan],
500 branch_path: &[String],
501) -> Option<&'a BranchViewPlan> {
502 for branch_id in branch_path.iter().rev() {
503 if let Some(plan) = branch_view_for_in(plans, branch_id) {
504 return Some(plan);
505 }
506 }
507 None
508}
509
510fn execution_scope_id(
511 phase: Phase,
512 variant_id: Option<&VariantId>,
513 fold_id: Option<&FoldId>,
514) -> String {
515 format!(
516 "scope:{}:{}:{}",
517 phase_scope_label(phase),
518 variant_id
519 .map(ToString::to_string)
520 .unwrap_or_else(|| "base".to_string()),
521 fold_id
522 .map(ToString::to_string)
523 .unwrap_or_else(|| "nofold".to_string())
524 )
525}
526
527fn phase_scope_label(phase: Phase) -> &'static str {
528 match phase {
529 Phase::Compile => "COMPILE",
530 Phase::Plan => "PLAN",
531 Phase::FitCv => "FIT_CV",
532 Phase::Select => "SELECT",
533 Phase::Refit => "REFIT",
534 Phase::Predict => "PREDICT",
535 Phase::Explain => "EXPLAIN",
536 }
537}
538
539pub fn build_execution_plan(
540 id: impl Into<String>,
541 graph: GraphSpec,
542 campaign: CampaignSpec,
543 registry: &ControllerRegistry,
544) -> Result<ExecutionPlan> {
545 let id = id.into();
546 if id.trim().is_empty() {
547 return Err(DagMlError::Planning(
548 "execution plan id is empty".to_string(),
549 ));
550 }
551 campaign.validate()?;
552 let graph_plan = GraphPlan::from_graph(graph)?;
553 validate_campaign_node_targets(&graph_plan.graph, &campaign)?;
554
555 let mut node_plans = BTreeMap::new();
556 let mut controller_manifests = BTreeMap::new();
557 for node_id in &graph_plan.topological_order {
558 let node = graph_plan
559 .graph
560 .nodes
561 .iter()
562 .find(|node| &node.id == node_id)
563 .expect("topological node exists");
564 let manifest = registry.resolve_for_node(node)?;
565 let params = node.params.clone();
566 let params_fingerprint = stable_json_fingerprint(¶ms)?;
567 let inner_cv = match node.metadata.get("dsl_inner_cv") {
571 Some(value) => {
572 let spec =
573 serde_json::from_value::<NestedCvSpec>(value.clone()).map_err(|error| {
574 DagMlError::Planning(format!(
575 "node `{}` has invalid dsl_inner_cv metadata: {error}",
576 node.id
577 ))
578 })?;
579 spec.validate().map_err(|error| {
582 DagMlError::Planning(format!(
583 "node `{}` has invalid dsl_inner_cv metadata: {error}",
584 node.id
585 ))
586 })?;
587 Some(spec)
588 }
589 None => None,
590 };
591 let shape_plan = campaign.shape_plans.get(&node.id).cloned();
592 let data_bindings = campaign
593 .data_bindings
594 .get(&node.id)
595 .cloned()
596 .unwrap_or_default();
597 node_plans.insert(
598 node.id.clone(),
599 NodePlan {
600 inner_cv,
601 node_id: node.id.clone(),
602 kind: node.kind.clone(),
603 controller_id: manifest.controller_id.clone(),
604 controller_version: manifest.controller_version.clone(),
605 supported_phases: manifest.supported_phases.clone(),
606 controller_capabilities: manifest.capabilities.clone(),
607 fit_scope: manifest.fit_scope,
608 rng_policy: manifest.rng_policy,
609 artifact_policy: manifest.artifact_policy,
610 input_nodes: graph_plan.graph.upstream_nodes(&node.id),
611 output_nodes: graph_plan.graph.downstream_nodes(&node.id),
612 shape_plan,
613 data_bindings,
614 params,
615 params_fingerprint,
616 },
617 );
618 controller_manifests.insert(manifest.controller_id.clone(), manifest);
619 }
620
621 let fold_set = campaign
622 .split_invocation
623 .as_ref()
624 .and_then(|split| split.fold_set.clone());
625 validate_search_space_fingerprint(&graph_plan.graph, &campaign)?;
626 let variants = enumerate_variants(&campaign.generation, campaign.root_seed)?;
627 validate_generation_override_targets(&graph_plan.graph, &variants)?;
628 let graph_fingerprint = stable_json_fingerprint(&graph_plan.graph)?;
629 let campaign_fingerprint = stable_json_fingerprint(&campaign)?;
630 let controller_fingerprint = stable_json_fingerprint(&controller_manifests)?;
631 let plan = ExecutionPlan {
632 id,
633 graph_plan,
634 campaign,
635 node_plans,
636 controller_manifests,
637 variants,
638 fold_set,
639 graph_fingerprint,
640 campaign_fingerprint,
641 controller_fingerprint,
642 };
643 plan.validate()?;
644 Ok(plan)
645}
646
647fn validate_search_space_fingerprint(graph: &GraphSpec, campaign: &CampaignSpec) -> Result<()> {
648 let Some(expected_fingerprint) = &graph.search_space_fingerprint else {
649 return Ok(());
650 };
651 if expected_fingerprint.trim().is_empty() {
652 return Err(DagMlError::Planning(format!(
653 "graph `{}` has empty search_space_fingerprint",
654 graph.id
655 )));
656 }
657 let actual_fingerprint = generation_spec_fingerprint(&campaign.generation)?;
658 if expected_fingerprint != &actual_fingerprint {
659 return Err(DagMlError::Planning(format!(
660 "graph `{}` search_space_fingerprint does not match campaign generation spec",
661 graph.id
662 )));
663 }
664 Ok(())
665}
666
667fn validate_generation_override_targets(graph: &GraphSpec, variants: &[VariantPlan]) -> Result<()> {
668 let node_ids = graph
669 .nodes
670 .iter()
671 .map(|node| node.id.clone())
672 .collect::<BTreeSet<_>>();
673 for variant in variants {
674 for node_id in variant.param_override_targets()? {
675 if !node_ids.contains(&node_id) {
676 return Err(DagMlError::Planning(format!(
677 "variant `{}` overrides params for unknown node `{node_id}`",
678 variant.variant_id
679 )));
680 }
681 }
682 }
683 Ok(())
684}
685
686fn validate_campaign_node_targets(graph: &GraphSpec, campaign: &CampaignSpec) -> Result<()> {
687 let node_ids = graph
688 .nodes
689 .iter()
690 .map(|node| &node.id)
691 .collect::<BTreeSet<_>>();
692 for node_id in campaign.shape_plans.keys() {
693 if !node_ids.contains(node_id) {
694 return Err(DagMlError::Planning(format!(
695 "shape plan references unknown node `{node_id}`"
696 )));
697 }
698 }
699 for node_id in campaign.data_bindings.keys() {
700 if !node_ids.contains(node_id) {
701 return Err(DagMlError::Planning(format!(
702 "data binding references unknown node `{node_id}`"
703 )));
704 }
705 }
706 Ok(())
707}
708
709#[cfg(test)]
710mod tests {
711 use std::collections::{BTreeMap, BTreeSet};
712 use std::time::{Duration, Instant};
713
714 use super::*;
715 use crate::controller::{
716 ArtifactPolicy, ControllerCapability, ControllerFitScope, ControllerManifest, RngPolicy,
717 };
718
719 #[test]
720 fn inner_cv_is_declarable_at_campaign_and_node_level() {
721 let campaign_json = r#"{"id":"c","root_seed":null,"inner_cv":{"kind":"kfold","n_splits":3,"shuffle":false,"seed":5}}"#;
723 let campaign: CampaignSpec = serde_json::from_str(campaign_json).unwrap();
724 campaign.validate().unwrap();
725 assert!(campaign.inner_cv.is_some());
726
727 let node_inner = crate::fold::NestedCvSpec::KFold(crate::fold::KFoldSpec {
729 n_splits: 4,
730 shuffle: false,
731 seed: Some(6),
732 });
733 let resolved = crate::fold::resolve_inner_cv(Some(&node_inner), campaign.inner_cv.as_ref());
734 assert_eq!(resolved, Some(&node_inner));
735
736 let bare = r#"{"id":"c","root_seed":null}"#;
738 let bare_campaign: CampaignSpec = serde_json::from_str(bare).unwrap();
739 assert!(bare_campaign.inner_cv.is_none());
740 let reserialized = serde_json::to_string(&bare_campaign).unwrap();
741 assert!(!reserialized.contains("inner_cv"));
742
743 let bad: CampaignSpec = serde_json::from_str(
746 r#"{"id":"c","root_seed":null,"inner_cv":{"kind":"kfold","n_splits":1,"shuffle":false,"seed":null}}"#,
747 )
748 .unwrap();
749 let error = bad.validate().unwrap_err();
750 assert!(error.to_string().contains("at least two splits"));
751 }
752
753 #[test]
754 fn execution_plan_validate_rejects_invalid_node_local_inner_cv() {
755 let campaign = CampaignSpec {
759 inner_cv: None,
760 id: "campaign:plan-validate".to_string(),
761 root_seed: Some(7),
762 leakage_policy: LeakageUnitPolicy::default(),
763 aggregation_policy: AggregationPolicy::default(),
764 split_invocation: None,
765 generation: Default::default(),
766 shape_plans: BTreeMap::new(),
767 data_bindings: BTreeMap::new(),
768 branch_view_plans: Vec::new(),
769 metadata: BTreeMap::new(),
770 };
771 let mut plan =
772 build_execution_plan("plan:validate", graph(), campaign, ®istry()).unwrap();
773 plan.validate().unwrap();
774 plan.node_plans
775 .get_mut(&NodeId::new("model:pls").unwrap())
776 .unwrap()
777 .inner_cv = Some(crate::fold::NestedCvSpec::KFold(crate::fold::KFoldSpec {
778 n_splits: 1,
779 shuffle: false,
780 seed: None,
781 }));
782 let error = plan.validate().unwrap_err();
783 assert!(matches!(error, DagMlError::Planning(_)));
784 assert!(error.to_string().contains("invalid inner_cv"));
785 assert!(error.to_string().contains("at least two splits"));
786 }
787
788 #[test]
789 fn build_execution_plan_lowers_dsl_inner_cv_metadata_into_node_plan() {
790 let mut graph = graph();
791 graph
792 .nodes
793 .iter_mut()
794 .find(|node| node.id.as_str() == "model:pls")
795 .unwrap()
796 .metadata
797 .insert(
798 "dsl_inner_cv".to_string(),
799 serde_json::json!({"kind": "kfold", "n_splits": 3, "shuffle": false, "seed": 9}),
800 );
801
802 let campaign = CampaignSpec {
803 inner_cv: None,
804 id: "campaign:inner-cv".to_string(),
805 root_seed: Some(7),
806 leakage_policy: LeakageUnitPolicy::default(),
807 aggregation_policy: AggregationPolicy::default(),
808 split_invocation: None,
809 generation: Default::default(),
810 shape_plans: BTreeMap::new(),
811 data_bindings: BTreeMap::new(),
812 branch_view_plans: Vec::new(),
813 metadata: BTreeMap::new(),
814 };
815
816 let plan = build_execution_plan("plan:inner-cv", graph, campaign, ®istry()).unwrap();
817 match &plan.node_plans[&NodeId::new("model:pls").unwrap()].inner_cv {
818 Some(crate::fold::NestedCvSpec::KFold(k)) => {
819 assert_eq!(k.n_splits, 3);
820 assert_eq!(k.seed, Some(9));
821 }
822 other => panic!("expected lowered KFold inner_cv, got {other:?}"),
823 }
824 assert!(plan.node_plans[&NodeId::new("transform:snv").unwrap()]
825 .inner_cv
826 .is_none());
827 }
828
829 #[test]
830 fn build_execution_plan_rejects_malformed_dsl_inner_cv_metadata() {
831 let mut graph = graph();
832 graph
833 .nodes
834 .iter_mut()
835 .find(|node| node.id.as_str() == "model:pls")
836 .unwrap()
837 .metadata
838 .insert(
839 "dsl_inner_cv".to_string(),
840 serde_json::json!({"kind": "not_a_real_kind"}),
841 );
842
843 let campaign = CampaignSpec {
844 inner_cv: None,
845 id: "campaign:inner-cv.bad".to_string(),
846 root_seed: Some(7),
847 leakage_policy: LeakageUnitPolicy::default(),
848 aggregation_policy: AggregationPolicy::default(),
849 split_invocation: None,
850 generation: Default::default(),
851 shape_plans: BTreeMap::new(),
852 data_bindings: BTreeMap::new(),
853 branch_view_plans: Vec::new(),
854 metadata: BTreeMap::new(),
855 };
856
857 let error =
858 build_execution_plan("plan:inner-cv.bad", graph, campaign, ®istry()).unwrap_err();
859 assert!(matches!(error, DagMlError::Planning(_)));
860 assert!(error.to_string().contains("invalid dsl_inner_cv metadata"));
861 }
862
863 #[test]
864 fn build_execution_plan_rejects_semantically_invalid_dsl_inner_cv() {
865 let mut graph = graph();
868 graph
869 .nodes
870 .iter_mut()
871 .find(|node| node.id.as_str() == "model:pls")
872 .unwrap()
873 .metadata
874 .insert(
875 "dsl_inner_cv".to_string(),
876 serde_json::json!({"kind": "kfold", "n_splits": 1, "shuffle": false, "seed": null}),
877 );
878
879 let campaign = CampaignSpec {
880 inner_cv: None,
881 id: "campaign:inner-cv.nsplits".to_string(),
882 root_seed: Some(7),
883 leakage_policy: LeakageUnitPolicy::default(),
884 aggregation_policy: AggregationPolicy::default(),
885 split_invocation: None,
886 generation: Default::default(),
887 shape_plans: BTreeMap::new(),
888 data_bindings: BTreeMap::new(),
889 branch_view_plans: Vec::new(),
890 metadata: BTreeMap::new(),
891 };
892
893 let error = build_execution_plan("plan:inner-cv.nsplits", graph, campaign, ®istry())
894 .unwrap_err();
895 assert!(matches!(error, DagMlError::Planning(_)));
896 assert!(error.to_string().contains("at least two splits"));
897 }
898 use crate::data::DataBinding;
899 use crate::generation::{
900 GenerationChoice, GenerationDimension, GenerationParamOverride, GenerationStrategy,
901 };
902 use crate::graph::{
903 EdgeContract, EdgeSpec, GraphInterface, NodeSpec, PortCardinality, PortKind, PortRef,
904 PortSchema, PortSpec,
905 };
906 use crate::ids::{ControllerId, FoldId, ObservationId, SampleId, TargetId};
907 use crate::phase::Phase;
908 use crate::policy::{DataModelShapePlan, Granularity};
909 use crate::relation::{SampleRelation, SampleRelationSet};
910
911 fn port(name: &str, kind: PortKind) -> PortSpec {
912 PortSpec {
913 name: name.to_string(),
914 kind,
915 representation: None,
916 cardinality: PortCardinality::One,
917 unit_level: None,
918 alignment_key: None,
919 target_level: None,
920 description: String::new(),
921 }
922 }
923
924 fn node(id: &str, kind: NodeKind, inputs: Vec<PortSpec>, outputs: Vec<PortSpec>) -> NodeSpec {
925 NodeSpec {
926 id: NodeId::new(id).unwrap(),
927 kind,
928 operator: None,
929 params: BTreeMap::new(),
930 ports: PortSchema { inputs, outputs },
931 metadata: BTreeMap::new(),
932 seed_label: None,
933 }
934 }
935
936 fn graph() -> GraphSpec {
937 GraphSpec {
938 id: "g".to_string(),
939 interface: GraphInterface::default(),
940 nodes: vec![
941 node(
942 "transform:snv",
943 NodeKind::Transform,
944 vec![],
945 vec![port("x", PortKind::Data)],
946 ),
947 node(
948 "model:pls",
949 NodeKind::Model,
950 vec![port("x", PortKind::Data)],
951 vec![port("pred", PortKind::Prediction)],
952 ),
953 ],
954 edges: vec![EdgeSpec {
955 source: PortRef {
956 node_id: NodeId::new("transform:snv").unwrap(),
957 port_name: "x".to_string(),
958 },
959 target: PortRef {
960 node_id: NodeId::new("model:pls").unwrap(),
961 port_name: "x".to_string(),
962 },
963 contract: EdgeContract {
964 requires_oof: false,
965 requires_fold_alignment: false,
966 ..EdgeContract::new(PortKind::Data, None)
967 },
968 }],
969 search_space_fingerprint: None,
970 metadata: BTreeMap::new(),
971 }
972 }
973
974 fn manifest(id: &str, kind: NodeKind) -> ControllerManifest {
975 let mut capabilities = BTreeSet::from([
976 ControllerCapability::Deterministic,
977 ControllerCapability::ThreadSafe,
978 ControllerCapability::ProcessSafe,
979 ]);
980 if kind == NodeKind::Model {
981 capabilities.insert(ControllerCapability::EmitsPredictions);
982 capabilities.insert(ControllerCapability::ConsumesOofPredictions);
983 }
984 ControllerManifest {
985 controller_id: ControllerId::new(id).unwrap(),
986 controller_version: "0.1.0".to_string(),
987 operator_kind: kind,
988 priority: 0,
989 supported_phases: BTreeSet::from([Phase::FitCv, Phase::Refit, Phase::Predict]),
990 input_ports: Vec::new(),
991 output_ports: Vec::new(),
992 data_requirements: None,
993 capabilities,
994 operator_selectors: Vec::new(),
995 fit_scope: ControllerFitScope::FoldTrain,
996 rng_policy: RngPolicy::UsesCoreSeed,
997 artifact_policy: ArtifactPolicy::Serializable,
998 }
999 }
1000
1001 fn registry() -> ControllerRegistry {
1002 let mut registry = ControllerRegistry::new();
1003 registry
1004 .register(manifest("controller:transform", NodeKind::Transform))
1005 .unwrap();
1006 registry
1007 .register(manifest("controller:model", NodeKind::Model))
1008 .unwrap();
1009 registry
1010 }
1011
1012 fn campaign(id: &str) -> CampaignSpec {
1013 CampaignSpec {
1014 id: id.to_string(),
1015 root_seed: Some(7),
1016 leakage_policy: LeakageUnitPolicy::default(),
1017 aggregation_policy: AggregationPolicy::default(),
1018 split_invocation: None,
1019 generation: Default::default(),
1020 shape_plans: BTreeMap::new(),
1021 data_bindings: BTreeMap::new(),
1022 branch_view_plans: Vec::new(),
1023 inner_cv: None,
1024 metadata: BTreeMap::new(),
1025 }
1026 }
1027
1028 fn large_linear_graph(transform_count: usize) -> GraphSpec {
1029 let mut nodes = Vec::new();
1030 let mut edges = Vec::new();
1031 for node_idx in 0..transform_count {
1032 let node_id = format!("transform:t{node_idx:04}");
1033 nodes.push(node(
1034 &node_id,
1035 NodeKind::Transform,
1036 vec![port("x", PortKind::Data)],
1037 vec![port("x", PortKind::Data)],
1038 ));
1039 if node_idx > 0 {
1040 edges.push(EdgeSpec {
1041 source: PortRef {
1042 node_id: NodeId::new(format!("transform:t{:04}", node_idx - 1)).unwrap(),
1043 port_name: "x".to_string(),
1044 },
1045 target: PortRef {
1046 node_id: NodeId::new(&node_id).unwrap(),
1047 port_name: "x".to_string(),
1048 },
1049 contract: EdgeContract::new(PortKind::Data, None),
1050 });
1051 }
1052 }
1053 nodes.push(node(
1054 "model:final",
1055 NodeKind::Model,
1056 vec![port("x", PortKind::Data)],
1057 vec![port("pred", PortKind::Prediction)],
1058 ));
1059 edges.push(EdgeSpec {
1060 source: PortRef {
1061 node_id: NodeId::new(format!("transform:t{:04}", transform_count - 1)).unwrap(),
1062 port_name: "x".to_string(),
1063 },
1064 target: PortRef {
1065 node_id: NodeId::new("model:final").unwrap(),
1066 port_name: "x".to_string(),
1067 },
1068 contract: EdgeContract::new(PortKind::Data, None),
1069 });
1070
1071 GraphSpec {
1072 id: "g:perf.linear".to_string(),
1073 interface: GraphInterface::default(),
1074 nodes,
1075 edges,
1076 search_space_fingerprint: None,
1077 metadata: BTreeMap::new(),
1078 }
1079 }
1080
1081 fn oof_graph() -> GraphSpec {
1082 GraphSpec {
1083 id: "g:oof.capabilities".to_string(),
1084 interface: GraphInterface::default(),
1085 nodes: vec![
1086 node(
1087 "model:base",
1088 NodeKind::Model,
1089 vec![],
1090 vec![port("pred", PortKind::Prediction)],
1091 ),
1092 node(
1093 "model:meta",
1094 NodeKind::Model,
1095 vec![port("pred", PortKind::Prediction)],
1096 vec![port("pred", PortKind::Prediction)],
1097 ),
1098 ],
1099 edges: vec![EdgeSpec {
1100 source: PortRef {
1101 node_id: NodeId::new("model:base").unwrap(),
1102 port_name: "pred".to_string(),
1103 },
1104 target: PortRef {
1105 node_id: NodeId::new("model:meta").unwrap(),
1106 port_name: "pred".to_string(),
1107 },
1108 contract: EdgeContract {
1109 requires_oof: true,
1110 requires_fold_alignment: true,
1111 ..EdgeContract::new(PortKind::Prediction, None)
1112 },
1113 }],
1114 search_space_fingerprint: None,
1115 metadata: BTreeMap::new(),
1116 }
1117 }
1118
1119 fn data_binding(node_id: &NodeId) -> DataBinding {
1120 DataBinding {
1121 node_id: node_id.clone(),
1122 input_name: "x".to_string(),
1123 request_id: "nir-to-tabular".to_string(),
1124 schema_fingerprint: "f97b37872fa22134b508f98fd8e207e5b776b52594fb8f6f5c3e15bee212246b"
1125 .to_string(),
1126 plan_fingerprint: "7c5431d85574b3f337022fa5d25971d5b5cf445b90331b49938f573ff6901e4d"
1127 .to_string(),
1128 relation_fingerprint: Some(
1129 "a3a7e329df35db9f2883a17b8611b7fae6dcaa031875e3ec2c9be1b9e29cbe10".to_string(),
1130 ),
1131 output_representation: "tabular_numeric".to_string(),
1132 feature_set_id: Some("x".to_string()),
1133 source_ids: vec!["nir".to_string()],
1134 require_relations: true,
1135 view_policy: Default::default(),
1136 metadata: BTreeMap::new(),
1137 }
1138 }
1139
1140 fn levels_as_strings(levels: &[Vec<NodeId>]) -> Vec<Vec<String>> {
1141 levels
1142 .iter()
1143 .map(|level| level.iter().map(ToString::to_string).collect())
1144 .collect()
1145 }
1146
1147 #[test]
1148 fn published_campaign_spec_schema_declares_current_contract() {
1149 let schema: serde_json::Value = serde_json::from_str(include_str!(
1150 "../../../docs/contracts/campaign_spec.schema.json"
1151 ))
1152 .unwrap();
1153
1154 assert_eq!(schema["$id"], CAMPAIGN_SPEC_SCHEMA_ID);
1155 assert!(schema["required"]
1156 .as_array()
1157 .unwrap()
1158 .iter()
1159 .any(|field| field.as_str() == Some("id")));
1160 assert!(schema["$defs"]["split_invocation"]["properties"]
1161 .as_object()
1162 .unwrap()
1163 .contains_key("fold_set"));
1164 assert!(schema["$defs"]["aggregation_policy"]["properties"]
1165 .as_object()
1166 .unwrap()
1167 .contains_key("selection_metric_level"));
1168 assert!(schema["$defs"]["aggregation_policy"]["properties"]
1169 .as_object()
1170 .unwrap()
1171 .contains_key("custom_controller"));
1172 assert!(schema["$defs"]["data_binding"]["properties"]
1173 .as_object()
1174 .unwrap()
1175 .contains_key("view_policy"));
1176 assert!(schema["properties"]
1177 .as_object()
1178 .unwrap()
1179 .contains_key("branch_view_plans"));
1180 assert!(schema["$defs"]["branch_view_plan"]["properties"]
1181 .as_object()
1182 .unwrap()
1183 .contains_key("selector"));
1184 }
1185
1186 #[test]
1187 fn published_execution_plan_schema_declares_current_contract() {
1188 let schema: serde_json::Value = serde_json::from_str(include_str!(
1189 "../../../docs/contracts/execution_plan.schema.json"
1190 ))
1191 .unwrap();
1192
1193 assert_eq!(schema["$id"], EXECUTION_PLAN_SCHEMA_ID);
1194 assert!(schema["required"]
1195 .as_array()
1196 .unwrap()
1197 .iter()
1198 .any(|field| field.as_str() == Some("node_plans")));
1199 assert!(schema["properties"]
1200 .as_object()
1201 .unwrap()
1202 .contains_key("controller_fingerprint"));
1203 assert!(schema["$defs"]["node_plan"]["properties"]
1204 .as_object()
1205 .unwrap()
1206 .contains_key("shape_plan"));
1207 assert!(schema["$defs"]["variant_plan"]["properties"]
1208 .as_object()
1209 .unwrap()
1210 .contains_key("choices"));
1211 }
1212
1213 #[test]
1214 fn published_execution_plan_fixture_validates_current_contract() {
1215 let plan: ExecutionPlan = serde_json::from_str(include_str!(
1216 "../../../examples/fixtures/runtime/execution_plan_branch_merge_executable.json"
1217 ))
1218 .unwrap();
1219
1220 plan.validate().unwrap();
1221 assert_eq!(plan.id, "plan:fixture.execution.branch_merge");
1222 assert_eq!(plan.variants.len(), 2);
1223 assert_eq!(plan.node_plans.len(), plan.graph_plan.graph.nodes.len());
1224 }
1225
1226 #[test]
1227 #[ignore = "perf sanity probe; run with --release --ignored --nocapture"]
1228 fn build_execution_plan_large_linear_graph_under_1500ms() {
1229 let started = Instant::now();
1230 let plan = build_execution_plan(
1231 "plan:perf.linear",
1232 large_linear_graph(400),
1233 campaign("campaign:perf.linear"),
1234 ®istry(),
1235 )
1236 .unwrap();
1237 let elapsed = started.elapsed();
1238
1239 assert_eq!(plan.graph_plan.topological_order.len(), 401);
1240 assert_eq!(plan.node_plans.len(), 401);
1241 assert!(
1242 elapsed <= Duration::from_millis(1_500),
1243 "large execution-plan build took {elapsed:?}"
1244 );
1245 }
1246
1247 #[test]
1248 fn builds_execution_plan_with_shape_and_fold_contracts() {
1249 let model_id = NodeId::new("model:pls").unwrap();
1250 let campaign = CampaignSpec {
1251 inner_cv: None,
1252 id: "campaign:oof".to_string(),
1253 root_seed: Some(7),
1254 leakage_policy: LeakageUnitPolicy::default(),
1255 aggregation_policy: AggregationPolicy::default(),
1256 split_invocation: Some(SplitInvocation {
1257 id: "split:outer".to_string(),
1258 controller_id: None,
1259 leakage_policy: LeakageUnitPolicy::default(),
1260 params: BTreeMap::new(),
1261 fold_set: Some(FoldSet {
1262 id: "outer".to_string(),
1263 sample_ids: vec![SampleId::new("s1").unwrap(), SampleId::new("s2").unwrap()],
1264 folds: vec![
1265 crate::fold::FoldAssignment {
1266 fold_id: FoldId::new("fold0").unwrap(),
1267 train_sample_ids: vec![SampleId::new("s2").unwrap()],
1268 validation_sample_ids: vec![SampleId::new("s1").unwrap()],
1269 metadata: BTreeMap::new(),
1270 },
1271 crate::fold::FoldAssignment {
1272 fold_id: FoldId::new("fold1").unwrap(),
1273 train_sample_ids: vec![SampleId::new("s1").unwrap()],
1274 validation_sample_ids: vec![SampleId::new("s2").unwrap()],
1275 metadata: BTreeMap::new(),
1276 },
1277 ],
1278 sample_groups: BTreeMap::new(),
1279 }),
1280 }),
1281 generation: Default::default(),
1282 shape_plans: BTreeMap::from([(
1283 model_id.clone(),
1284 DataModelShapePlan {
1285 node_id: model_id.clone(),
1286 input_granularity: Granularity::Observation,
1287 ..DataModelShapePlan {
1288 node_id: model_id.clone(),
1289 input_granularity: Granularity::Sample,
1290 target_granularity: Granularity::Sample,
1291 fit_rows: crate::policy::FitBoundary::FoldTrain,
1292 predict_rows: crate::policy::FitBoundary::FoldValidation,
1293 feature_namespace: None,
1294 feature_schema_fingerprint: None,
1295 target_space: "raw".to_string(),
1296 aggregation_policy: AggregationPolicy::default(),
1297 augmentation_policy: crate::policy::AugmentationPolicy::default(),
1298 selection_policy: crate::policy::FeatureSelectionPolicy::default(),
1299 }
1300 },
1301 )]),
1302 data_bindings: BTreeMap::from([(model_id.clone(), vec![data_binding(&model_id)])]),
1303 branch_view_plans: Vec::new(),
1304 metadata: BTreeMap::new(),
1305 };
1306
1307 let plan = build_execution_plan("plan:oof", graph(), campaign, ®istry()).unwrap();
1308
1309 assert_eq!(
1310 plan.graph_plan
1311 .topological_order
1312 .iter()
1313 .map(ToString::to_string)
1314 .collect::<Vec<_>>(),
1315 vec!["transform:snv", "model:pls"]
1316 );
1317 assert_eq!(
1318 levels_as_strings(&plan.graph_plan.parallel_levels),
1319 vec![vec!["transform:snv"], vec!["model:pls"]]
1320 );
1321 assert!(plan.node_plans[&model_id]
1322 .controller_capabilities
1323 .contains(&ControllerCapability::EmitsPredictions));
1324 assert!(plan.fold_set.is_some());
1325 let schedule = plan.campaign_phase_schedule(Phase::FitCv).unwrap();
1326 assert_eq!(schedule.scopes.len(), 2);
1327 assert!(schedule.scopes[0].scope_id.starts_with("scope:FIT_CV:"));
1328 assert!(schedule
1329 .scopes
1330 .iter()
1331 .all(|scope| levels_as_strings(&scope.node_levels)
1332 == vec![vec!["transform:snv"], vec!["model:pls"]]));
1333 assert_eq!(
1334 schedule
1335 .scopes
1336 .iter()
1337 .filter_map(|scope| scope.fold_id.as_ref().map(ToString::to_string))
1338 .collect::<Vec<_>>(),
1339 vec!["fold0", "fold1"]
1340 );
1341 assert_eq!(
1342 plan.node_plans
1343 .get(&model_id)
1344 .unwrap()
1345 .controller_id
1346 .as_str(),
1347 "controller:model"
1348 );
1349 assert_eq!(
1350 plan.node_plans.get(&model_id).unwrap().data_bindings.len(),
1351 1
1352 );
1353
1354 let mut bad_plan = plan.clone();
1355 bad_plan.graph_plan.parallel_levels =
1356 vec![vec![model_id], vec![NodeId::new("transform:snv").unwrap()]];
1357 assert!(bad_plan
1358 .validate()
1359 .unwrap_err()
1360 .to_string()
1361 .contains("parallel levels"));
1362
1363 let bad_envelope = ExternalDataPlanEnvelope {
1364 schema_version: crate::data::EXTERNAL_DATA_PLAN_ENVELOPE_SCHEMA_VERSION,
1365 schema_fingerprint: "f97b37872fa22134b508f98fd8e207e5b776b52594fb8f6f5c3e15bee212246b"
1366 .to_string(),
1367 plan_fingerprint: "7c5431d85574b3f337022fa5d25971d5b5cf445b90331b49938f573ff6901e4d"
1368 .to_string(),
1369 relation_fingerprint: None,
1370 coordinator_relations: Some(SampleRelationSet {
1371 records: vec![{
1372 let mut relation = SampleRelation::new(
1373 ObservationId::new("obs:outside").unwrap(),
1374 SampleId::new("sample:outside").unwrap(),
1375 );
1376 relation.target_id = Some(TargetId::new("target:outside").unwrap());
1377 relation.source_id = Some("nir".to_string());
1378 relation
1379 }],
1380 }),
1381 };
1382 assert!(plan
1383 .campaign
1384 .validate_data_envelope_relations(&bad_envelope)
1385 .unwrap_err()
1386 .to_string()
1387 .contains("outside fold set"));
1388 }
1389
1390 #[test]
1391 fn planning_refuses_shape_plan_for_unknown_node() {
1392 let campaign = CampaignSpec {
1393 inner_cv: None,
1394 id: "campaign:oof".to_string(),
1395 root_seed: Some(7),
1396 leakage_policy: LeakageUnitPolicy::default(),
1397 aggregation_policy: AggregationPolicy::default(),
1398 split_invocation: None,
1399 generation: Default::default(),
1400 shape_plans: BTreeMap::from([(
1401 NodeId::new("model:missing").unwrap(),
1402 DataModelShapePlan {
1403 node_id: NodeId::new("model:missing").unwrap(),
1404 input_granularity: Granularity::Sample,
1405 target_granularity: Granularity::Sample,
1406 fit_rows: crate::policy::FitBoundary::FoldTrain,
1407 predict_rows: crate::policy::FitBoundary::FoldValidation,
1408 feature_namespace: None,
1409 feature_schema_fingerprint: None,
1410 target_space: "raw".to_string(),
1411 aggregation_policy: AggregationPolicy::default(),
1412 augmentation_policy: crate::policy::AugmentationPolicy::default(),
1413 selection_policy: crate::policy::FeatureSelectionPolicy::default(),
1414 },
1415 )]),
1416 data_bindings: BTreeMap::new(),
1417 branch_view_plans: Vec::new(),
1418 metadata: BTreeMap::new(),
1419 };
1420
1421 assert!(build_execution_plan("plan:oof", graph(), campaign, ®istry()).is_err());
1422 }
1423
1424 #[test]
1425 fn planning_refuses_oof_edge_without_controller_capabilities() {
1426 let mut registry = ControllerRegistry::new();
1427 let mut model_manifest = manifest("controller:model", NodeKind::Model);
1428 model_manifest
1429 .capabilities
1430 .remove(&ControllerCapability::ConsumesOofPredictions);
1431 registry.register(model_manifest).unwrap();
1432
1433 let err = build_execution_plan(
1434 "plan:oof.capability",
1435 oof_graph(),
1436 CampaignSpec {
1437 inner_cv: None,
1438 id: "campaign:oof.capability".to_string(),
1439 root_seed: Some(11),
1440 leakage_policy: Default::default(),
1441 aggregation_policy: Default::default(),
1442 split_invocation: None,
1443 generation: Default::default(),
1444 shape_plans: BTreeMap::new(),
1445 data_bindings: BTreeMap::new(),
1446 branch_view_plans: Vec::new(),
1447 metadata: BTreeMap::new(),
1448 },
1449 ®istry,
1450 )
1451 .unwrap_err();
1452
1453 assert!(err.to_string().contains("consumes_oof_predictions"));
1454 }
1455
1456 #[test]
1457 fn parallel_controller_capability_validation_requires_safe_manifest() {
1458 let mut registry = ControllerRegistry::new();
1459 let mut transform_manifest = manifest("controller:transform", NodeKind::Transform);
1460 transform_manifest
1461 .capabilities
1462 .remove(&ControllerCapability::ThreadSafe);
1463 transform_manifest
1464 .capabilities
1465 .remove(&ControllerCapability::ProcessSafe);
1466 registry.register(transform_manifest).unwrap();
1467 registry
1468 .register(manifest("controller:model", NodeKind::Model))
1469 .unwrap();
1470 let plan = build_execution_plan(
1471 "plan:parallel.capability",
1472 graph(),
1473 CampaignSpec {
1474 inner_cv: None,
1475 id: "campaign:parallel.capability".to_string(),
1476 root_seed: Some(11),
1477 leakage_policy: Default::default(),
1478 aggregation_policy: Default::default(),
1479 split_invocation: None,
1480 generation: Default::default(),
1481 shape_plans: BTreeMap::new(),
1482 data_bindings: BTreeMap::new(),
1483 branch_view_plans: Vec::new(),
1484 metadata: BTreeMap::new(),
1485 },
1486 ®istry,
1487 )
1488 .unwrap();
1489
1490 assert!(plan
1491 .validate_parallel_controller_capabilities(1, Phase::FitCv)
1492 .is_ok());
1493 let err = plan
1494 .validate_parallel_controller_capabilities(2, Phase::FitCv)
1495 .unwrap_err();
1496 assert!(err.to_string().contains("thread_safe or process_safe"));
1497 }
1498
1499 #[test]
1500 fn planning_refuses_generation_override_for_unknown_node() {
1501 let campaign = CampaignSpec {
1502 inner_cv: None,
1503 id: "campaign:oof".to_string(),
1504 root_seed: Some(7),
1505 leakage_policy: LeakageUnitPolicy::default(),
1506 aggregation_policy: AggregationPolicy::default(),
1507 split_invocation: None,
1508 generation: GenerationSpec {
1509 strategy: GenerationStrategy::Cartesian,
1510 dimensions: vec![GenerationDimension {
1511 name: "model_family".to_string(),
1512 choices: vec![GenerationChoice {
1513 label: "pls".to_string(),
1514 value: serde_json::json!("pls"),
1515 param_overrides: vec![GenerationParamOverride {
1516 node_id: NodeId::new("model:missing").unwrap(),
1517 params: BTreeMap::from([(
1518 "n_components".to_string(),
1519 serde_json::json!(8),
1520 )]),
1521 }],
1522 }],
1523 }],
1524 max_variants: Some(1),
1525 },
1526 shape_plans: BTreeMap::new(),
1527 data_bindings: BTreeMap::new(),
1528 branch_view_plans: Vec::new(),
1529 metadata: BTreeMap::new(),
1530 };
1531
1532 let error = build_execution_plan("plan:oof", graph(), campaign, ®istry())
1533 .unwrap_err()
1534 .to_string();
1535
1536 assert!(error.contains("overrides params for unknown node"));
1537 }
1538
1539 #[test]
1540 fn planning_validates_declared_search_space_fingerprint() {
1541 let campaign = CampaignSpec {
1542 inner_cv: None,
1543 id: "campaign:search.fingerprint".to_string(),
1544 root_seed: Some(7),
1545 leakage_policy: LeakageUnitPolicy::default(),
1546 aggregation_policy: AggregationPolicy::default(),
1547 split_invocation: None,
1548 generation: GenerationSpec {
1549 strategy: GenerationStrategy::Cartesian,
1550 dimensions: vec![GenerationDimension {
1551 name: "model_family".to_string(),
1552 choices: vec![GenerationChoice {
1553 label: "pls".to_string(),
1554 value: serde_json::json!("pls"),
1555 param_overrides: vec![GenerationParamOverride {
1556 node_id: NodeId::new("model:pls").unwrap(),
1557 params: BTreeMap::from([(
1558 "n_components".to_string(),
1559 serde_json::json!(8),
1560 )]),
1561 }],
1562 }],
1563 }],
1564 max_variants: Some(1),
1565 },
1566 shape_plans: BTreeMap::new(),
1567 data_bindings: BTreeMap::new(),
1568 branch_view_plans: Vec::new(),
1569 metadata: BTreeMap::new(),
1570 };
1571 let mut graph = graph();
1572 graph.search_space_fingerprint =
1573 Some(generation_spec_fingerprint(&campaign.generation).unwrap());
1574
1575 let plan = build_execution_plan(
1576 "plan:search.fingerprint",
1577 graph.clone(),
1578 campaign.clone(),
1579 ®istry(),
1580 )
1581 .unwrap();
1582 assert_eq!(plan.variants.len(), 1);
1583
1584 graph.search_space_fingerprint = Some("sha256:not-the-generation-spec".to_string());
1585 let error = build_execution_plan("plan:search.fingerprint", graph, campaign, ®istry())
1586 .unwrap_err()
1587 .to_string();
1588 assert!(error.contains("search_space_fingerprint"));
1589 }
1590
1591 #[test]
1592 fn branch_view_lookup_helpers_match_by_branch_id_and_innermost_path() {
1593 use crate::data::{BranchViewMode, DataViewSelector};
1594
1595 let outer = BranchViewPlan {
1596 view_id: "branch_view:outer".to_string(),
1597 branch_id: "branch:outer".to_string(),
1598 mode: BranchViewMode::BySource,
1599 selector: DataViewSelector {
1600 source_ids: vec!["nir".to_string()],
1601 ..Default::default()
1602 },
1603 allow_overlap: false,
1604 metadata: BTreeMap::new(),
1605 };
1606 let inner = BranchViewPlan {
1607 view_id: "branch_view:inner".to_string(),
1608 branch_id: "branch:inner".to_string(),
1609 mode: BranchViewMode::Separation,
1610 selector: DataViewSelector {
1611 source_ids: vec!["chem".to_string()],
1612 ..Default::default()
1613 },
1614 allow_overlap: false,
1615 metadata: BTreeMap::new(),
1616 };
1617 let plans = vec![outer.clone(), inner.clone()];
1618
1619 assert_eq!(
1620 super::branch_view_for_in(&plans, "branch:outer"),
1621 Some(&outer)
1622 );
1623 assert_eq!(
1624 super::branch_view_for_in(&plans, "branch:inner"),
1625 Some(&inner)
1626 );
1627 assert_eq!(super::branch_view_for_in(&plans, "branch:missing"), None);
1628
1629 let path = vec!["branch:outer".to_string(), "branch:inner".to_string()];
1630 assert_eq!(super::branch_view_for_path_in(&plans, &path), Some(&inner));
1632
1633 let path_outer_only = vec!["branch:outer".to_string()];
1634 assert_eq!(
1635 super::branch_view_for_path_in(&plans, &path_outer_only),
1636 Some(&outer)
1637 );
1638
1639 let empty_path: Vec<String> = Vec::new();
1640 assert_eq!(super::branch_view_for_path_in(&plans, &empty_path), None);
1641
1642 let path_no_match = vec!["branch:other".to_string()];
1643 assert_eq!(super::branch_view_for_path_in(&plans, &path_no_match), None);
1644 }
1645}