Skip to main content

dag_ml_core/
dsl.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3use serde::{de::DeserializeOwned, Deserialize, Serialize};
4
5use crate::controller::ControllerRegistry;
6use crate::data::{BranchViewMode, BranchViewPlan, DataBinding, DataViewSelector};
7use crate::error::{DagMlError, Result};
8use crate::fold::NestedCvSpec;
9use crate::generation::{
10    generation_spec_fingerprint, GenerationChoice, GenerationDimension, GenerationParamOverride,
11    GenerationSpec, GenerationStrategy,
12};
13use crate::graph::{
14    EdgeContract, EdgeSpec, GraphInterface, GraphSpec, NodeKind, NodeSpec, PortCardinality,
15    PortKind, PortRef, PortSchema, PortSpec,
16};
17use crate::ids::NodeId;
18use crate::plan::{CampaignSpec, SplitInvocation};
19use crate::policy::{
20    AggregationPolicy, AugmentationPolicy, DataModelShapePlan, FeatureSelectionPolicy, FitBoundary,
21    Granularity, LeakageUnitPolicy,
22};
23use crate::relation::EntityUnitLevel;
24
25pub const PIPELINE_DSL_SCHEMA_VERSION: u32 = 1;
26pub const PIPELINE_DSL_SCHEMA_ID: &str =
27    "https://github.com/GBeurier/dag-ml/schemas/pipeline_dsl.v1.schema.json";
28const DSL_MINIMAL_OPERATOR_ALIAS: &str = "dsl_minimal_operator_alias";
29const DSL_REGISTRY_INFERRED_KIND: &str = "dsl_registry_inferred_kind";
30const DSL_COMPAT_ORIGINAL_KEYWORD: &str = "dsl_compat_original_keyword";
31
32#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
33pub struct PipelineDslSpec {
34    pub id: String,
35    #[serde(default)]
36    pub input: PipelineDslDataPort,
37    #[serde(default)]
38    pub output: PipelineDslPredictionPort,
39    #[serde(default)]
40    pub generation_strategy: Option<GenerationStrategy>,
41    #[serde(default)]
42    pub max_variants: Option<usize>,
43    #[serde(default, skip_serializing_if = "Vec::is_empty")]
44    pub generation_dimensions: Vec<PipelineDslGenerationDimension>,
45    #[serde(default, skip_serializing_if = "Option::is_none")]
46    pub campaign_id: Option<String>,
47    #[serde(default)]
48    pub root_seed: Option<u64>,
49    #[serde(default, skip_serializing_if = "Option::is_none")]
50    pub leakage_policy: Option<LeakageUnitPolicy>,
51    #[serde(default, skip_serializing_if = "Option::is_none")]
52    pub aggregation_policy: Option<AggregationPolicy>,
53    #[serde(default, skip_serializing_if = "Option::is_none")]
54    pub split_invocation: Option<SplitInvocation>,
55    /// Campaign-wide default nested (inner) CV policy; a per-step `inner_cv`
56    /// overrides it (compiled to `CampaignSpec.inner_cv`).
57    #[serde(default, skip_serializing_if = "Option::is_none")]
58    pub inner_cv: Option<NestedCvSpec>,
59    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
60    pub campaign_metadata: BTreeMap<String, serde_json::Value>,
61    #[serde(default, skip_serializing_if = "Vec::is_empty")]
62    pub data_bindings: Vec<DataBinding>,
63    #[serde(default)]
64    pub steps: Vec<PipelineDslStep>,
65    #[serde(default)]
66    pub metadata: BTreeMap<String, serde_json::Value>,
67}
68
69#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
70pub struct PipelineDslDataPort {
71    #[serde(default = "default_input_name")]
72    pub name: String,
73    #[serde(default = "default_data_representation")]
74    pub representation: String,
75    #[serde(default, skip_serializing_if = "Option::is_none")]
76    pub unit_level: Option<EntityUnitLevel>,
77    #[serde(default, skip_serializing_if = "Option::is_none")]
78    pub alignment_key: Option<String>,
79    #[serde(default, skip_serializing_if = "Option::is_none")]
80    pub target_level: Option<EntityUnitLevel>,
81    #[serde(default)]
82    pub description: String,
83}
84
85impl Default for PipelineDslDataPort {
86    fn default() -> Self {
87        Self {
88            name: default_input_name(),
89            representation: default_data_representation(),
90            unit_level: None,
91            alignment_key: None,
92            target_level: None,
93            description: String::new(),
94        }
95    }
96}
97
98#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
99pub struct PipelineDslPredictionPort {
100    #[serde(default = "default_output_name")]
101    pub name: String,
102    #[serde(default, skip_serializing_if = "Option::is_none")]
103    pub representation: Option<String>,
104    #[serde(default, skip_serializing_if = "Option::is_none")]
105    pub unit_level: Option<EntityUnitLevel>,
106    #[serde(default, skip_serializing_if = "Option::is_none")]
107    pub alignment_key: Option<String>,
108    #[serde(default, skip_serializing_if = "Option::is_none")]
109    pub target_level: Option<EntityUnitLevel>,
110    #[serde(default)]
111    pub description: String,
112}
113
114impl Default for PipelineDslPredictionPort {
115    fn default() -> Self {
116        Self {
117            name: default_output_name(),
118            representation: None,
119            unit_level: None,
120            alignment_key: None,
121            target_level: None,
122            description: String::new(),
123        }
124    }
125}
126
127#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
128#[serde(tag = "kind", rename_all = "snake_case")]
129pub enum PipelineDslStep {
130    Transform(PipelineDslOperatorStep),
131    YTransform(PipelineDslOperatorStep),
132    Tag(PipelineDslOperatorStep),
133    Exclude(PipelineDslOperatorStep),
134    Filter(PipelineDslOperatorStep),
135    SampleFilter(PipelineDslOperatorStep),
136    Augmentation(PipelineDslOperatorStep),
137    FeatureAugmentation(PipelineDslOperatorStep),
138    SampleAugmentation(PipelineDslOperatorStep),
139    #[serde(alias = "generation")]
140    DataGeneration(PipelineDslOperatorStep),
141    ConcatTransform(PipelineDslConcatTransformStep),
142    Model(PipelineDslOperatorStep),
143    #[serde(alias = "finetune")]
144    Tuner(PipelineDslOperatorStep),
145    Branch(PipelineDslBranchStep),
146    Generator(PipelineDslGeneratorStep),
147    Sequential(PipelineDslSequenceStep),
148    Merge(PipelineDslMergeStep),
149    MergeModel(PipelineDslMergeModelStep),
150    Chart(PipelineDslOperatorStep),
151}
152
153#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
154pub struct PipelineDslOperatorStep {
155    pub id: NodeId,
156    pub operator: serde_json::Value,
157    #[serde(default)]
158    pub params: BTreeMap<String, serde_json::Value>,
159    #[serde(default)]
160    pub metadata: BTreeMap<String, serde_json::Value>,
161    #[serde(default)]
162    pub seed_label: Option<String>,
163    #[serde(default)]
164    pub representation: Option<String>,
165    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
166    pub train_params: BTreeMap<String, serde_json::Value>,
167    #[serde(
168        default,
169        alias = "finetune_params",
170        skip_serializing_if = "Option::is_none"
171    )]
172    pub tuning: Option<PipelineDslTuningSpec>,
173    #[serde(default, skip_serializing_if = "Vec::is_empty")]
174    pub variants: Vec<PipelineDslVariantChoice>,
175    #[serde(default, alias = "generators", skip_serializing_if = "Vec::is_empty")]
176    pub param_generators: Vec<PipelineDslParamGenerator>,
177    #[serde(default, skip_serializing_if = "Option::is_none")]
178    pub shape: Option<PipelineDslShapePlan>,
179    /// Node-local nested (inner) CV policy (e.g. for a finetune/tuner step);
180    /// overrides the campaign-wide default. Compiled to `NodePlan.inner_cv` via
181    /// the node's `dsl_inner_cv` metadata.
182    #[serde(default, skip_serializing_if = "Option::is_none")]
183    pub inner_cv: Option<NestedCvSpec>,
184}
185
186#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
187pub struct PipelineDslTuningSpec {
188    #[serde(default, skip_serializing_if = "Option::is_none")]
189    pub n_trials: Option<usize>,
190    #[serde(default, skip_serializing_if = "Option::is_none")]
191    pub approach: Option<String>,
192    #[serde(default, skip_serializing_if = "Option::is_none")]
193    pub eval_mode: Option<String>,
194    #[serde(default, skip_serializing_if = "Option::is_none")]
195    pub sampler: Option<String>,
196    #[serde(default, skip_serializing_if = "Option::is_none")]
197    pub metric: Option<String>,
198    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
199    pub model_params: BTreeMap<String, serde_json::Value>,
200    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
201    pub train_params: BTreeMap<String, serde_json::Value>,
202    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
203    pub metadata: BTreeMap<String, serde_json::Value>,
204}
205
206#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
207pub struct PipelineDslVariantChoice {
208    pub label: String,
209    #[serde(default)]
210    pub params: BTreeMap<String, serde_json::Value>,
211    #[serde(default)]
212    pub value: Option<serde_json::Value>,
213}
214
215#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
216#[serde(tag = "kind", rename_all = "snake_case")]
217pub enum PipelineDslParamGenerator {
218    Or {
219        #[serde(default, skip_serializing_if = "Option::is_none")]
220        name: Option<String>,
221        param: String,
222        values: Vec<PipelineDslGeneratorValue>,
223        #[serde(default, skip_serializing_if = "Option::is_none")]
224        count: Option<usize>,
225    },
226    Range {
227        #[serde(default, skip_serializing_if = "Option::is_none")]
228        name: Option<String>,
229        param: String,
230        start: f64,
231        stop: f64,
232        step: f64,
233        #[serde(default = "default_true")]
234        inclusive: bool,
235        #[serde(default, skip_serializing_if = "Option::is_none")]
236        count: Option<usize>,
237    },
238    LogRange {
239        #[serde(default, skip_serializing_if = "Option::is_none")]
240        name: Option<String>,
241        param: String,
242        start: f64,
243        stop: f64,
244        count: usize,
245        #[serde(default = "default_log_base")]
246        base: f64,
247    },
248    Grid {
249        #[serde(default, skip_serializing_if = "Option::is_none")]
250        name: Option<String>,
251        params: BTreeMap<String, Vec<PipelineDslGeneratorValue>>,
252        #[serde(default, skip_serializing_if = "Option::is_none")]
253        count: Option<usize>,
254    },
255    Pick {
256        #[serde(default, skip_serializing_if = "Option::is_none")]
257        name: Option<String>,
258        param: String,
259        values: Vec<PipelineDslGeneratorValue>,
260        sizes: Vec<usize>,
261        #[serde(default, skip_serializing_if = "Option::is_none")]
262        count: Option<usize>,
263    },
264    Arrange {
265        #[serde(default, skip_serializing_if = "Option::is_none")]
266        name: Option<String>,
267        param: String,
268        values: Vec<PipelineDslGeneratorValue>,
269        sizes: Vec<usize>,
270        #[serde(default, skip_serializing_if = "Option::is_none")]
271        count: Option<usize>,
272    },
273}
274
275#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
276#[serde(untagged)]
277pub enum PipelineDslGeneratorValue {
278    Labeled {
279        label: String,
280        value: serde_json::Value,
281    },
282    Value(serde_json::Value),
283}
284
285#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
286pub struct PipelineDslGenerationDimension {
287    pub name: String,
288    #[serde(default)]
289    pub choices: Vec<PipelineDslGenerationChoice>,
290}
291
292#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
293pub struct PipelineDslGenerationChoice {
294    pub label: String,
295    #[serde(default)]
296    pub value: Option<serde_json::Value>,
297    #[serde(default)]
298    pub param_overrides: Vec<PipelineDslGenerationParamOverride>,
299}
300
301#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
302pub struct PipelineDslGenerationParamOverride {
303    pub node_id: NodeId,
304    #[serde(default)]
305    pub params: BTreeMap<String, serde_json::Value>,
306}
307
308#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
309pub struct PipelineDslBranchStep {
310    #[serde(default)]
311    pub mode: PipelineDslBranchMode,
312    #[serde(default, skip_serializing_if = "Option::is_none")]
313    pub selector: Option<serde_json::Value>,
314    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
315    pub metadata: BTreeMap<String, serde_json::Value>,
316    pub branches: Vec<PipelineDslBranch>,
317}
318
319#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
320#[serde(rename_all = "snake_case")]
321pub enum PipelineDslBranchMode {
322    #[default]
323    Duplication,
324    Separation,
325    BySource,
326    ByMetadata,
327    ByTag,
328    ByFilter,
329}
330
331#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
332pub struct PipelineDslBranch {
333    pub id: String,
334    #[serde(default, skip_serializing_if = "Option::is_none")]
335    pub selector: Option<serde_json::Value>,
336    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
337    pub metadata: BTreeMap<String, serde_json::Value>,
338    #[serde(default)]
339    pub steps: Vec<PipelineDslStep>,
340}
341
342#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
343pub struct PipelineDslSequenceStep {
344    #[serde(default, skip_serializing_if = "Option::is_none")]
345    pub id: Option<NodeId>,
346    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
347    pub metadata: BTreeMap<String, serde_json::Value>,
348    #[serde(default)]
349    pub steps: Vec<PipelineDslStep>,
350}
351
352#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
353pub struct PipelineDslGeneratorStep {
354    pub id: NodeId,
355    #[serde(default)]
356    pub mode: PipelineDslGeneratorMode,
357    #[serde(default, skip_serializing_if = "Vec::is_empty")]
358    pub branches: Vec<PipelineDslBranch>,
359    #[serde(default, skip_serializing_if = "Vec::is_empty")]
360    pub stages: Vec<PipelineDslGeneratorStage>,
361    #[serde(default, skip_serializing_if = "Option::is_none")]
362    pub pick: Option<PipelineDslSelectionSpec>,
363    #[serde(default, skip_serializing_if = "Option::is_none")]
364    pub arrange: Option<PipelineDslSelectionSpec>,
365    #[serde(default, skip_serializing_if = "Option::is_none")]
366    pub then_pick: Option<PipelineDslSelectionSpec>,
367    #[serde(default, skip_serializing_if = "Option::is_none")]
368    pub then_arrange: Option<PipelineDslSelectionSpec>,
369    #[serde(default, skip_serializing_if = "Option::is_none")]
370    pub count: Option<usize>,
371    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
372    pub metadata: BTreeMap<String, serde_json::Value>,
373}
374
375#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
376#[serde(rename_all = "snake_case")]
377pub enum PipelineDslGeneratorMode {
378    #[default]
379    Or,
380    Cartesian,
381}
382
383#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
384pub struct PipelineDslGeneratorStage {
385    pub id: String,
386    #[serde(default, skip_serializing_if = "Option::is_none")]
387    pub selector: Option<serde_json::Value>,
388    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
389    pub metadata: BTreeMap<String, serde_json::Value>,
390    #[serde(default)]
391    pub branches: Vec<PipelineDslBranch>,
392}
393
394#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
395#[serde(untagged)]
396pub enum PipelineDslSelectionSpec {
397    Single(usize),
398    Range([usize; 2]),
399}
400
401#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
402pub struct PipelineDslConcatTransformStep {
403    pub id: NodeId,
404    #[serde(default)]
405    pub branches: Vec<PipelineDslConcatBranch>,
406    #[serde(default)]
407    pub metadata: BTreeMap<String, serde_json::Value>,
408    #[serde(default)]
409    pub seed_label: Option<String>,
410    #[serde(default)]
411    pub representation: Option<String>,
412    #[serde(default, skip_serializing_if = "Vec::is_empty")]
413    pub variants: Vec<PipelineDslVariantChoice>,
414    #[serde(default, alias = "generators", skip_serializing_if = "Vec::is_empty")]
415    pub param_generators: Vec<PipelineDslParamGenerator>,
416    #[serde(default, skip_serializing_if = "Option::is_none")]
417    pub shape: Option<PipelineDslShapePlan>,
418}
419
420#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
421pub struct PipelineDslConcatBranch {
422    pub id: String,
423    #[serde(default)]
424    pub steps: Vec<PipelineDslOperatorStep>,
425}
426
427#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
428pub struct PipelineDslMergeStep {
429    pub id: NodeId,
430    #[serde(default = "default_merge_mode")]
431    pub merge_mode: String,
432    #[serde(default)]
433    pub output_as: PipelineDslMergeOutput,
434    #[serde(default = "default_true")]
435    pub include_original_data: bool,
436    #[serde(default, skip_serializing_if = "Option::is_none")]
437    pub on_missing: Option<String>,
438    #[serde(default, skip_serializing_if = "Vec::is_empty")]
439    pub selectors: Vec<PipelineDslMergeSelector>,
440    #[serde(default)]
441    pub metadata: BTreeMap<String, serde_json::Value>,
442    #[serde(default)]
443    pub seed_label: Option<String>,
444    #[serde(default)]
445    pub representation: Option<String>,
446    #[serde(default, skip_serializing_if = "Vec::is_empty")]
447    pub variants: Vec<PipelineDslVariantChoice>,
448    #[serde(default, alias = "generators", skip_serializing_if = "Vec::is_empty")]
449    pub param_generators: Vec<PipelineDslParamGenerator>,
450    #[serde(default, skip_serializing_if = "Option::is_none")]
451    pub shape: Option<PipelineDslShapePlan>,
452}
453
454#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
455#[serde(rename_all = "snake_case")]
456pub enum PipelineDslMergeOutput {
457    #[default]
458    Features,
459    Predictions,
460    Sources,
461}
462
463#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
464pub struct PipelineDslMergeSelector {
465    #[serde(default, skip_serializing_if = "Option::is_none")]
466    pub input_name: Option<String>,
467    #[serde(default, skip_serializing_if = "Option::is_none")]
468    pub branch: Option<String>,
469    #[serde(default, skip_serializing_if = "Option::is_none")]
470    pub model: Option<NodeId>,
471    #[serde(default, skip_serializing_if = "Option::is_none")]
472    pub select: Option<serde_json::Value>,
473    #[serde(default, skip_serializing_if = "Option::is_none")]
474    pub metric: Option<String>,
475    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
476    pub metadata: BTreeMap<String, serde_json::Value>,
477}
478
479#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
480pub struct PipelineDslMergeModelStep {
481    pub id: NodeId,
482    pub operator: serde_json::Value,
483    #[serde(default)]
484    pub params: BTreeMap<String, serde_json::Value>,
485    #[serde(default)]
486    pub metadata: BTreeMap<String, serde_json::Value>,
487    #[serde(default)]
488    pub seed_label: Option<String>,
489    #[serde(default = "default_true")]
490    pub include_original_data: bool,
491    #[serde(default = "default_merge_mode")]
492    pub merge_mode: String,
493    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
494    pub train_params: BTreeMap<String, serde_json::Value>,
495    #[serde(
496        default,
497        alias = "finetune_params",
498        skip_serializing_if = "Option::is_none"
499    )]
500    pub tuning: Option<PipelineDslTuningSpec>,
501    #[serde(default, skip_serializing_if = "Vec::is_empty")]
502    pub variants: Vec<PipelineDslVariantChoice>,
503    #[serde(default, alias = "generators", skip_serializing_if = "Vec::is_empty")]
504    pub param_generators: Vec<PipelineDslParamGenerator>,
505    #[serde(default, skip_serializing_if = "Option::is_none")]
506    pub shape: Option<PipelineDslShapePlan>,
507    /// Node-local nested (inner) CV policy for this meta-model (the meta-stacker's
508    /// inner CV is nested inside the outer CV); overrides the campaign default.
509    #[serde(default, skip_serializing_if = "Option::is_none")]
510    pub inner_cv: Option<NestedCvSpec>,
511}
512
513#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
514pub struct PipelineDslShapePlan {
515    #[serde(default)]
516    pub input_granularity: Option<Granularity>,
517    #[serde(default)]
518    pub target_granularity: Option<Granularity>,
519    #[serde(default)]
520    pub fit_rows: Option<FitBoundary>,
521    #[serde(default)]
522    pub predict_rows: Option<FitBoundary>,
523    #[serde(default)]
524    pub feature_namespace: Option<String>,
525    #[serde(default)]
526    pub feature_schema_fingerprint: Option<String>,
527    #[serde(default)]
528    pub target_space: Option<String>,
529    #[serde(default)]
530    pub aggregation_policy: Option<AggregationPolicy>,
531    #[serde(default)]
532    pub augmentation_policy: Option<AugmentationPolicy>,
533    #[serde(default)]
534    pub selection_policy: Option<FeatureSelectionPolicy>,
535}
536
537#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
538pub struct CompiledPipelineDsl {
539    pub graph: GraphSpec,
540    pub generation: GenerationSpec,
541    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
542    pub shape_plans: BTreeMap<NodeId, DataModelShapePlan>,
543    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
544    pub data_bindings: BTreeMap<NodeId, Vec<DataBinding>>,
545    #[serde(default, skip_serializing_if = "Vec::is_empty")]
546    pub branch_view_plans: Vec<BranchViewPlan>,
547    pub campaign_template: CampaignSpec,
548    #[serde(default, skip_serializing_if = "Option::is_none")]
549    pub generation_fingerprint: Option<String>,
550}
551
552pub fn compile_pipeline_dsl(spec: &PipelineDslSpec) -> Result<GraphSpec> {
553    Ok(compile_pipeline_dsl_with_generation(spec)?.graph)
554}
555
556pub fn compile_pipeline_dsl_with_controller_registry(
557    spec: &PipelineDslSpec,
558    registry: &ControllerRegistry,
559) -> Result<GraphSpec> {
560    Ok(compile_pipeline_dsl_with_generation_and_controller_registry(spec, registry)?.graph)
561}
562
563pub fn parse_pipeline_dsl_json(data: &[u8]) -> Result<PipelineDslSpec> {
564    match serde_json::from_slice::<PipelineDslSpec>(data) {
565        Ok(spec) if validate_pipeline_dsl(&spec).is_ok() => Ok(spec),
566        Ok(spec) => {
567            let strict_error = validate_pipeline_dsl(&spec)
568                .err()
569                .map(|error| error.to_string())
570                .unwrap_or_else(|| "unknown validation error".to_string());
571            let value = serde_json::from_slice::<serde_json::Value>(data).map_err(|error| {
572                DagMlError::GraphValidation(format!("failed to parse pipeline DSL JSON: {error}"))
573            })?;
574            lower_nirs4all_compat_pipeline_dsl(&value).map_err(|compat_error| {
575                DagMlError::GraphValidation(format!(
576                    "failed to parse pipeline DSL as valid canonical PipelineDslSpec ({strict_error}) or nirs4all-compatible JSON ({compat_error})"
577                ))
578            })
579        }
580        Err(strict_error) => {
581            let value = serde_json::from_slice::<serde_json::Value>(data).map_err(|error| {
582                DagMlError::GraphValidation(format!("failed to parse pipeline DSL JSON: {error}"))
583            })?;
584            lower_nirs4all_compat_pipeline_dsl(&value).map_err(|compat_error| {
585                DagMlError::GraphValidation(format!(
586                    "failed to parse pipeline DSL as canonical PipelineDslSpec ({strict_error}) or nirs4all-compatible JSON ({compat_error})"
587                ))
588            })
589        }
590    }
591}
592
593pub fn lower_nirs4all_compat_pipeline_dsl(value: &serde_json::Value) -> Result<PipelineDslSpec> {
594    CompatDslLowerer::default().lower_root(value)
595}
596
597pub fn resolve_pipeline_dsl_minimal_aliases(
598    spec: &PipelineDslSpec,
599    registry: &ControllerRegistry,
600) -> Result<PipelineDslSpec> {
601    let mut resolved = spec.clone();
602    for step in &mut resolved.steps {
603        resolve_step_minimal_aliases(step, registry)?;
604    }
605    validate_pipeline_dsl(&resolved)?;
606    Ok(resolved)
607}
608
609pub fn compile_pipeline_dsl_with_generation_and_controller_registry(
610    spec: &PipelineDslSpec,
611    registry: &ControllerRegistry,
612) -> Result<CompiledPipelineDsl> {
613    let resolved = resolve_pipeline_dsl_minimal_aliases(spec, registry)?;
614    compile_pipeline_dsl_with_generation(&resolved)
615}
616
617pub fn compile_pipeline_dsl_with_generation(spec: &PipelineDslSpec) -> Result<CompiledPipelineDsl> {
618    validate_pipeline_dsl(spec)?;
619    let input_representation = Some(spec.input.representation.clone());
620    let external_data = DataSource {
621        node_id: None,
622        port_name: spec.input.name.clone(),
623        representation: input_representation.clone(),
624    };
625    let mut compiler = PipelineCompiler {
626        graph_id: spec.id.clone(),
627        input_representation: input_representation.clone(),
628        nodes: Vec::new(),
629        edges: Vec::new(),
630        generation_dimensions: Vec::new(),
631        shape_plans: BTreeMap::new(),
632        branch_view_plans: Vec::new(),
633    };
634    let mut sequence_state = SequenceCompileState::new(external_data.clone());
635
636    for step in &spec.steps {
637        compiler.compile_top_level_step(step, &external_data, &mut sequence_state)?;
638    }
639
640    let mut generation_dimensions =
641        compile_explicit_generation_dimensions(&spec.generation_dimensions, &compiler.nodes)?;
642    generation_dimensions.extend(compiler.generation_dimensions);
643    let generation = build_generation_spec(
644        spec.generation_strategy,
645        spec.max_variants,
646        generation_dimensions,
647    )?;
648    let generation_fingerprint = if generation.strategy == GenerationStrategy::None {
649        None
650    } else {
651        Some(generation_spec_fingerprint(&generation)?)
652    };
653    let mut interface_input = data_port(
654        &spec.input.name,
655        input_representation.clone(),
656        &spec.input.description,
657    );
658    apply_data_unit_contract(&mut interface_input, &spec.input);
659    let mut interface_output = prediction_port(&spec.output.name, &spec.output.description);
660    apply_prediction_unit_contract(&mut interface_output, &spec.output);
661
662    let graph = GraphSpec {
663        id: spec.id.clone(),
664        interface: GraphInterface {
665            inputs: vec![interface_input],
666            outputs: vec![interface_output],
667        },
668        nodes: compiler.nodes,
669        edges: compiler.edges,
670        search_space_fingerprint: generation_fingerprint.clone(),
671        metadata: spec.metadata.clone(),
672    };
673    graph.validate()?;
674    validate_shape_plan_targets(&compiler.shape_plans, &graph)?;
675    let data_bindings = compile_data_bindings(&spec.data_bindings, &graph)?;
676    let campaign_template = build_campaign_template(
677        spec,
678        &generation,
679        &compiler.shape_plans,
680        &data_bindings,
681        &compiler.branch_view_plans,
682    )?;
683    Ok(CompiledPipelineDsl {
684        graph,
685        generation,
686        shape_plans: compiler.shape_plans,
687        data_bindings,
688        branch_view_plans: compiler.branch_view_plans,
689        campaign_template,
690        generation_fingerprint,
691    })
692}
693
694fn resolve_step_minimal_aliases(
695    step: &mut PipelineDslStep,
696    registry: &ControllerRegistry,
697) -> Result<()> {
698    if let Some(resolved) = resolve_operator_step_minimal_alias(step, registry)? {
699        *step = resolved;
700    }
701    match step {
702        PipelineDslStep::Branch(branch) => {
703            for branch in &mut branch.branches {
704                for child in &mut branch.steps {
705                    resolve_step_minimal_aliases(child, registry)?;
706                }
707            }
708        }
709        PipelineDslStep::Generator(generator) => {
710            for branch in &mut generator.branches {
711                for child in &mut branch.steps {
712                    resolve_step_minimal_aliases(child, registry)?;
713                }
714            }
715            for stage in &mut generator.stages {
716                for branch in &mut stage.branches {
717                    for child in &mut branch.steps {
718                        resolve_step_minimal_aliases(child, registry)?;
719                    }
720                }
721            }
722        }
723        PipelineDslStep::Sequential(sequence) => {
724            for child in &mut sequence.steps {
725                resolve_step_minimal_aliases(child, registry)?;
726            }
727        }
728        _ => {}
729    }
730    Ok(())
731}
732
733fn resolve_operator_step_minimal_alias(
734    step: &PipelineDslStep,
735    registry: &ControllerRegistry,
736) -> Result<Option<PipelineDslStep>> {
737    let Some((current_kind, operator_step)) = operator_step_node_kind(step) else {
738        return Ok(None);
739    };
740    if !is_minimal_operator_alias(operator_step) {
741        return Ok(None);
742    }
743    let Some(inferred_kind) = registry.infer_operator_kind(&operator_step.operator)? else {
744        return Ok(None);
745    };
746    if inferred_kind == current_kind {
747        return Ok(None);
748    }
749    let mut resolved = operator_step.clone();
750    annotate_registry_inferred_operator_step(&mut resolved, &inferred_kind)?;
751    Ok(Some(operator_pipeline_step_for_node_kind(
752        inferred_kind,
753        resolved,
754    )?))
755}
756
757fn operator_step_node_kind(step: &PipelineDslStep) -> Option<(NodeKind, &PipelineDslOperatorStep)> {
758    match step {
759        PipelineDslStep::Transform(step) => Some((NodeKind::Transform, step)),
760        PipelineDslStep::YTransform(step) => Some((NodeKind::YTransform, step)),
761        PipelineDslStep::Tag(step) => Some((NodeKind::Tag, step)),
762        PipelineDslStep::Exclude(step) => Some((NodeKind::Exclude, step)),
763        PipelineDslStep::Filter(step) | PipelineDslStep::SampleFilter(step) => {
764            Some((NodeKind::Exclude, step))
765        }
766        PipelineDslStep::Augmentation(step)
767        | PipelineDslStep::FeatureAugmentation(step)
768        | PipelineDslStep::SampleAugmentation(step) => Some((NodeKind::Augmentation, step)),
769        PipelineDslStep::DataGeneration(step) => Some((NodeKind::Generator, step)),
770        PipelineDslStep::Model(step) => Some((NodeKind::Model, step)),
771        PipelineDslStep::Tuner(step) => Some((NodeKind::Tuner, step)),
772        PipelineDslStep::Chart(step) => Some((NodeKind::Chart, step)),
773        _ => None,
774    }
775}
776
777fn is_minimal_operator_alias(step: &PipelineDslOperatorStep) -> bool {
778    step.metadata
779        .get(DSL_MINIMAL_OPERATOR_ALIAS)
780        .and_then(serde_json::Value::as_bool)
781        .unwrap_or(false)
782}
783
784fn annotate_registry_inferred_operator_step(
785    step: &mut PipelineDslOperatorStep,
786    inferred_kind: &NodeKind,
787) -> Result<()> {
788    if let Some(keyword) = step.metadata.get("dsl_compat_keyword").cloned() {
789        step.metadata
790            .entry(DSL_COMPAT_ORIGINAL_KEYWORD.to_string())
791            .or_insert(keyword);
792    }
793    step.metadata.insert(
794        "dsl_compat_keyword".to_string(),
795        serde_json::Value::String(compat_keyword_for_node_kind(inferred_kind)?.to_string()),
796    );
797    step.metadata.insert(
798        DSL_REGISTRY_INFERRED_KIND.to_string(),
799        serde_json::to_value(inferred_kind).map_err(|error| {
800            DagMlError::GraphValidation(format!(
801                "failed to serialize registry-inferred operator kind: {error}"
802            ))
803        })?,
804    );
805    Ok(())
806}
807
808fn operator_pipeline_step_for_node_kind(
809    kind: NodeKind,
810    step: PipelineDslOperatorStep,
811) -> Result<PipelineDslStep> {
812    match kind {
813        NodeKind::Transform => Ok(PipelineDslStep::Transform(step)),
814        NodeKind::YTransform => Ok(PipelineDslStep::YTransform(step)),
815        NodeKind::Tag => Ok(PipelineDslStep::Tag(step)),
816        NodeKind::Exclude => Ok(PipelineDslStep::Exclude(step)),
817        NodeKind::Augmentation => Ok(PipelineDslStep::Augmentation(step)),
818        NodeKind::Generator => Ok(PipelineDslStep::DataGeneration(step)),
819        NodeKind::Model => Ok(PipelineDslStep::Model(step)),
820        NodeKind::Tuner => Ok(PipelineDslStep::Tuner(step)),
821        NodeKind::Chart => Ok(PipelineDslStep::Chart(step)),
822        unsupported => Err(DagMlError::GraphValidation(format!(
823            "minimal operator alias matched unsupported node kind {:?}; use explicit DSL syntax",
824            unsupported
825        ))),
826    }
827}
828
829fn compat_keyword_for_node_kind(kind: &NodeKind) -> Result<&'static str> {
830    match kind {
831        NodeKind::Transform => Ok("preprocessing"),
832        NodeKind::YTransform => Ok("y_processing"),
833        NodeKind::Tag => Ok("tag"),
834        NodeKind::Exclude => Ok("exclude"),
835        NodeKind::Augmentation => Ok("augmentation"),
836        NodeKind::Generator => Ok("data_generation"),
837        NodeKind::Model => Ok("model"),
838        NodeKind::Tuner => Ok("tuner"),
839        NodeKind::Chart => Ok("chart"),
840        unsupported => Err(DagMlError::GraphValidation(format!(
841            "minimal operator alias matched unsupported node kind {:?}; use explicit DSL syntax",
842            unsupported
843        ))),
844    }
845}
846
847fn validate_pipeline_dsl(spec: &PipelineDslSpec) -> Result<()> {
848    if spec.id.trim().is_empty() {
849        return Err(DagMlError::GraphValidation(
850            "pipeline DSL graph id must not be empty".to_string(),
851        ));
852    }
853    if spec.input.name.trim().is_empty() {
854        return Err(DagMlError::GraphValidation(
855            "pipeline DSL input name must not be empty".to_string(),
856        ));
857    }
858    if spec.input.representation.trim().is_empty() {
859        return Err(DagMlError::GraphValidation(
860            "pipeline DSL input representation must not be empty".to_string(),
861        ));
862    }
863    if spec.output.name.trim().is_empty() {
864        return Err(DagMlError::GraphValidation(
865            "pipeline DSL output name must not be empty".to_string(),
866        ));
867    }
868    if spec.steps.is_empty() {
869        return Err(DagMlError::GraphValidation(
870            "pipeline DSL must contain at least one step".to_string(),
871        ));
872    }
873    Ok(())
874}
875
876struct PipelineCompiler {
877    graph_id: String,
878    input_representation: Option<String>,
879    nodes: Vec<NodeSpec>,
880    edges: Vec<EdgeSpec>,
881    generation_dimensions: Vec<GenerationDimension>,
882    shape_plans: BTreeMap<NodeId, DataModelShapePlan>,
883    branch_view_plans: Vec<BranchViewPlan>,
884}
885
886#[derive(Clone, Debug)]
887struct DataSource {
888    node_id: Option<NodeId>,
889    port_name: String,
890    representation: Option<String>,
891}
892
893#[derive(Clone, Debug)]
894struct PredictionSource {
895    node_id: NodeId,
896    port_name: String,
897    input_name: String,
898    branch_id: Option<String>,
899}
900
901#[derive(Clone, Debug)]
902struct BranchDataSource {
903    source: DataSource,
904    input_name: String,
905    branch_id: Option<String>,
906}
907
908#[derive(Clone, Debug, Default)]
909struct BranchCompileOutput {
910    predictions: Vec<PredictionSource>,
911    data_sources: Vec<BranchDataSource>,
912}
913
914#[derive(Clone, Debug)]
915struct SequenceCompileState {
916    current_data: DataSource,
917    pending_predictions: Vec<PredictionSource>,
918    pending_branch_data: Vec<BranchDataSource>,
919}
920
921impl SequenceCompileState {
922    fn new(current_data: DataSource) -> Self {
923        Self {
924            current_data,
925            pending_predictions: Vec::new(),
926            pending_branch_data: Vec::new(),
927        }
928    }
929
930    fn clear_pending(&mut self) {
931        self.pending_predictions.clear();
932        self.pending_branch_data.clear();
933    }
934}
935
936#[derive(Clone, Debug)]
937enum MergeOutputSource {
938    Data(DataSource),
939    Prediction(PredictionSource),
940}
941
942#[derive(Clone, Debug)]
943struct GeneratedSequence {
944    id: String,
945    labels: Vec<String>,
946    steps: Vec<PipelineDslStep>,
947    metadata: BTreeMap<String, serde_json::Value>,
948}
949
950#[derive(Clone, Debug, Default)]
951struct CompatGenerationAttachment {
952    variants: Vec<PipelineDslVariantChoice>,
953    param_generators: Vec<PipelineDslParamGenerator>,
954}
955
956#[derive(Default)]
957struct CompatDslLowerer {
958    node_counter: usize,
959    generator_counter: usize,
960    split_invocation: Option<SplitInvocation>,
961    metadata: BTreeMap<String, serde_json::Value>,
962}
963
964#[derive(Clone, Copy, Debug, Eq, PartialEq)]
965enum CompatPlainOperatorKind {
966    Transform,
967    Model,
968    Tuner,
969    Split,
970    Chart,
971}
972
973impl CompatDslLowerer {
974    fn lower_root(mut self, value: &serde_json::Value) -> Result<PipelineDslSpec> {
975        let root = value.as_object();
976        let pipeline = match value {
977            serde_json::Value::Array(_) => value,
978            serde_json::Value::Object(object) => object
979                .get("pipeline")
980                .or_else(|| object.get("steps"))
981                .ok_or_else(|| {
982                    DagMlError::GraphValidation(
983                        "nirs4all-compatible pipeline DSL must be a JSON array or an object with `pipeline`/`steps`".to_string(),
984                    )
985                })?,
986            _ => {
987                return Err(DagMlError::GraphValidation(
988                    "nirs4all-compatible pipeline DSL must be a JSON array or object".to_string(),
989                ));
990            }
991        };
992        let pipeline = pipeline.as_array().ok_or_else(|| {
993            DagMlError::GraphValidation(
994                "nirs4all-compatible pipeline field must be an array".to_string(),
995            )
996        })?;
997        let steps = self.lower_steps(pipeline, "pipeline")?;
998        let id = root
999            .and_then(|object| object.get("id"))
1000            .and_then(serde_json::Value::as_str)
1001            .unwrap_or("dsl-nirs4all-compat")
1002            .to_string();
1003        let mut metadata: BTreeMap<String, serde_json::Value> =
1004            optional_root_field(root, "metadata")?.unwrap_or_default();
1005        metadata.extend(std::mem::take(&mut self.metadata));
1006        metadata.insert(
1007            "dsl_compat_profile".to_string(),
1008            serde_json::Value::String("nirs4all_json_v1".to_string()),
1009        );
1010        let root_split = optional_root_field(root, "split_invocation")?;
1011        let split_invocation = match (root_split, self.split_invocation) {
1012            (Some(_), Some(_)) => {
1013                return Err(DagMlError::GraphValidation(
1014                    "nirs4all-compatible pipeline declares split_invocation and a pipeline split step".to_string(),
1015                ));
1016            }
1017            (Some(split), None) | (None, Some(split)) => Some(split),
1018            (None, None) => None,
1019        };
1020        Ok(PipelineDslSpec {
1021            inner_cv: optional_root_field(root, "inner_cv")?,
1022            id,
1023            input: optional_root_field(root, "input")?.unwrap_or_default(),
1024            output: optional_root_field(root, "output")?.unwrap_or_default(),
1025            generation_strategy: optional_root_field(root, "generation_strategy")?,
1026            max_variants: optional_root_field(root, "max_variants")?,
1027            generation_dimensions: optional_root_field(root, "generation_dimensions")?
1028                .unwrap_or_default(),
1029            campaign_id: optional_root_field(root, "campaign_id")?,
1030            root_seed: optional_root_field(root, "root_seed")?,
1031            leakage_policy: optional_root_field(root, "leakage_policy")?,
1032            aggregation_policy: optional_root_field(root, "aggregation_policy")?,
1033            split_invocation,
1034            campaign_metadata: optional_root_field(root, "campaign_metadata")?.unwrap_or_default(),
1035            data_bindings: optional_root_field(root, "data_bindings")?.unwrap_or_default(),
1036            steps,
1037            metadata,
1038        })
1039    }
1040
1041    fn lower_steps(
1042        &mut self,
1043        values: &[serde_json::Value],
1044        path: &str,
1045    ) -> Result<Vec<PipelineDslStep>> {
1046        let mut lowered = Vec::new();
1047        let mut index = 0usize;
1048        while index < values.len() {
1049            let current_path = format!("{path}[{index}]");
1050            if self.consume_side_effect_step(&values[index], &current_path)? {
1051                index += 1;
1052                continue;
1053            }
1054            if let Some(attachment) =
1055                self.parse_attached_generation(&values[index], &current_path)?
1056            {
1057                if value_can_receive_generation_attachment(&values[index]) {
1058                    let mut attached = self.lower_value_with_attachment(
1059                        &values[index],
1060                        &current_path,
1061                        attachment,
1062                    )?;
1063                    lowered.append(&mut attached);
1064                    index += 1;
1065                    continue;
1066                }
1067                let next = values.get(index + 1).ok_or_else(|| {
1068                    DagMlError::GraphValidation(format!(
1069                        "{current_path} declares a parameter generator but has no following operator/model step"
1070                    ))
1071                })?;
1072                let mut attached = self.lower_value_with_attachment(
1073                    next,
1074                    &format!("{path}[{}]", index + 1),
1075                    attachment,
1076                )?;
1077                lowered.append(&mut attached);
1078                index += 2;
1079                continue;
1080            }
1081            if let Some(merge_model) =
1082                self.lower_merge_followed_by_model(values, index, &current_path)?
1083            {
1084                lowered.push(PipelineDslStep::MergeModel(merge_model));
1085                index += 2;
1086                continue;
1087            }
1088
1089            let steps = self.lower_value_as_steps(&values[index], &current_path)?;
1090            if let [PipelineDslStep::Generator(generator)] = steps.as_slice() {
1091                if !generator_step_has_prediction(generator) {
1092                    if let Some((combined, consumed)) = self.combine_data_generator_with_following(
1093                        generator.clone(),
1094                        &values[index + 1..],
1095                        path,
1096                        index + 1,
1097                    )? {
1098                        lowered.push(PipelineDslStep::Generator(combined));
1099                        index += consumed + 1;
1100                        continue;
1101                    }
1102                }
1103            }
1104            lowered.extend(steps);
1105            index += 1;
1106        }
1107        Ok(lowered)
1108    }
1109
1110    fn consume_side_effect_step(&mut self, value: &serde_json::Value, path: &str) -> Result<bool> {
1111        if compat_plain_operator_kind(value) == CompatPlainOperatorKind::Split {
1112            self.set_split_invocation(self.lower_plain_split_invocation(value, path)?, path)?;
1113            return Ok(true);
1114        }
1115        let Some(object) = value.as_object() else {
1116            return Ok(false);
1117        };
1118        if is_comment_only_object(object) {
1119            return Ok(true);
1120        }
1121        if let Some(split) = object.get("split") {
1122            self.set_split_invocation(self.lower_split_invocation(split, object, path)?, path)?;
1123            return Ok(true);
1124        }
1125        if let Some(sources) = object.get("sources") {
1126            self.metadata
1127                .insert("compat_sources".to_string(), sources.clone());
1128            return Ok(true);
1129        }
1130        Ok(false)
1131    }
1132
1133    fn lower_value_as_steps(
1134        &mut self,
1135        value: &serde_json::Value,
1136        path: &str,
1137    ) -> Result<Vec<PipelineDslStep>> {
1138        match value {
1139            serde_json::Value::Null => Ok(Vec::new()),
1140            serde_json::Value::Array(children) => {
1141                Ok(vec![PipelineDslStep::Sequential(PipelineDslSequenceStep {
1142                    id: None,
1143                    metadata: BTreeMap::new(),
1144                    steps: self.lower_steps(children, path)?,
1145                })])
1146            }
1147            serde_json::Value::String(_) => {
1148                let step = match compat_plain_operator_kind(value) {
1149                    CompatPlainOperatorKind::Transform => PipelineDslStep::Transform(
1150                        self.compat_operator_step(None, "preprocessing", value, None, None)?,
1151                    ),
1152                    CompatPlainOperatorKind::Model => PipelineDslStep::Model(
1153                        self.compat_operator_step(None, "model", value, None, None)?,
1154                    ),
1155                    CompatPlainOperatorKind::Tuner => PipelineDslStep::Tuner(
1156                        self.compat_operator_step(None, "tuner", value, None, None)?,
1157                    ),
1158                    CompatPlainOperatorKind::Chart => PipelineDslStep::Chart(
1159                        self.compat_operator_step(None, "chart", value, None, None)?,
1160                    ),
1161                    CompatPlainOperatorKind::Split => {
1162                        return Err(DagMlError::GraphValidation(format!(
1163                            "{path} splitter alias was not consumed as a campaign split"
1164                        )));
1165                    }
1166                };
1167                Ok(vec![step])
1168            }
1169            serde_json::Value::Object(object) => {
1170                if object.contains_key("kind") {
1171                    let step = serde_json::from_value::<PipelineDslStep>(value.clone()).map_err(
1172                        |error| {
1173                            DagMlError::GraphValidation(format!(
1174                                "failed to parse canonical DSL step at {path}: {error}"
1175                            ))
1176                        },
1177                    )?;
1178                    return Ok(vec![step]);
1179                }
1180                if self.consume_side_effect_step(value, path)? {
1181                    return Ok(Vec::new());
1182                }
1183                if let Some(operator) =
1184                    first_object_value(object, &["preprocessing", "processing", "transform"])
1185                {
1186                    return Ok(vec![PipelineDslStep::Transform(
1187                        self.compat_operator_step(
1188                            Some(object),
1189                            "preprocessing",
1190                            operator,
1191                            None,
1192                            None,
1193                        )?,
1194                    )]);
1195                }
1196                if let Some(operator) = first_object_value(object, &["y_processing", "y_transform"])
1197                {
1198                    return Ok(vec![PipelineDslStep::YTransform(
1199                        self.compat_operator_step(
1200                            Some(object),
1201                            "y_processing",
1202                            operator,
1203                            None,
1204                            None,
1205                        )?,
1206                    )]);
1207                }
1208                if let Some(operator) = object.get("tag") {
1209                    return Ok(vec![PipelineDslStep::Tag(self.compat_operator_step(
1210                        Some(object),
1211                        "tag",
1212                        operator,
1213                        None,
1214                        None,
1215                    )?)]);
1216                }
1217                if let Some(operator) = object.get("exclude") {
1218                    return Ok(vec![PipelineDslStep::Exclude(self.compat_operator_step(
1219                        Some(object),
1220                        "exclude",
1221                        operator,
1222                        None,
1223                        None,
1224                    )?)]);
1225                }
1226                if let Some(operator) = object.get("filter") {
1227                    return Ok(vec![PipelineDslStep::Filter(self.compat_operator_step(
1228                        Some(object),
1229                        "filter",
1230                        operator,
1231                        None,
1232                        None,
1233                    )?)]);
1234                }
1235                if let Some(operator) = object.get("sample_filter") {
1236                    return Ok(vec![PipelineDslStep::SampleFilter(
1237                        self.compat_operator_step(
1238                            Some(object),
1239                            "sample_filter",
1240                            operator,
1241                            None,
1242                            None,
1243                        )?,
1244                    )]);
1245                }
1246                if let Some(operator) = object.get("sample_augmentation") {
1247                    return Ok(vec![PipelineDslStep::SampleAugmentation(
1248                        self.compat_operator_step(
1249                            Some(object),
1250                            "sample_augmentation",
1251                            operator,
1252                            None,
1253                            Some(compat_augmentation_shape("sample", object)?),
1254                        )?,
1255                    )]);
1256                }
1257                if let Some(operator) = object.get("feature_augmentation") {
1258                    return Ok(vec![PipelineDslStep::FeatureAugmentation(
1259                        self.compat_operator_step(
1260                            Some(object),
1261                            "feature_augmentation",
1262                            operator,
1263                            None,
1264                            Some(compat_augmentation_shape("feature", object)?),
1265                        )?,
1266                    )]);
1267                }
1268                if let Some(operator) = object.get("augmentation") {
1269                    return Ok(vec![PipelineDslStep::Augmentation(
1270                        self.compat_operator_step(
1271                            Some(object),
1272                            "augmentation",
1273                            operator,
1274                            None,
1275                            Some(compat_augmentation_shape("both", object)?),
1276                        )?,
1277                    )]);
1278                }
1279                if let Some(operator) =
1280                    first_object_value(object, &["data_generation", "generation"])
1281                {
1282                    return Ok(vec![PipelineDslStep::DataGeneration(
1283                        self.compat_operator_step(
1284                            Some(object),
1285                            "data_generation",
1286                            operator,
1287                            None,
1288                            None,
1289                        )?,
1290                    )]);
1291                }
1292                if let Some(operator) = object.get("model") {
1293                    return Ok(vec![PipelineDslStep::Model(self.compat_operator_step(
1294                        Some(object),
1295                        "model",
1296                        operator,
1297                        None,
1298                        None,
1299                    )?)]);
1300                }
1301                if let Some(operator) = first_object_value(object, &["tuner", "finetune"]) {
1302                    return Ok(vec![PipelineDslStep::Tuner(self.compat_operator_step(
1303                        Some(object),
1304                        "tuner",
1305                        operator,
1306                        None,
1307                        None,
1308                    )?)]);
1309                }
1310                if let Some(operator) = object.get("chart") {
1311                    return Ok(vec![PipelineDslStep::Chart(self.compat_operator_step(
1312                        Some(object),
1313                        "chart",
1314                        operator,
1315                        None,
1316                        None,
1317                    )?)]);
1318                }
1319                if object.contains_key("branch") {
1320                    return Ok(vec![PipelineDslStep::Branch(
1321                        self.lower_branch_step(object, path)?,
1322                    )]);
1323                }
1324                if object.contains_key("concat_transform") {
1325                    return Ok(vec![PipelineDslStep::ConcatTransform(
1326                        self.lower_concat_transform_step(object, path)?,
1327                    )]);
1328                }
1329                if object.contains_key("merge") {
1330                    return Ok(vec![PipelineDslStep::Merge(
1331                        self.lower_merge_step(object, path)?,
1332                    )]);
1333                }
1334                if let Some(step_value) = object.get("step") {
1335                    let mut steps =
1336                        self.lower_pipeline_fragment(step_value, &format!("{path}.step"))?;
1337                    if let Some(name) = object.get("name").and_then(serde_json::Value::as_str) {
1338                        annotate_named_steps(&mut steps, name);
1339                    }
1340                    return Ok(steps);
1341                }
1342                if object.contains_key("_or_") {
1343                    return Ok(vec![PipelineDslStep::Generator(
1344                        self.lower_or_generator(object, "_or_", path)?,
1345                    )]);
1346                }
1347                if object.contains_key("_chain_") {
1348                    return Ok(vec![PipelineDslStep::Generator(
1349                        self.lower_or_generator(object, "_chain_", path)?,
1350                    )]);
1351                }
1352                if object.contains_key("_cartesian_") {
1353                    return Ok(vec![PipelineDslStep::Generator(
1354                        self.lower_cartesian_generator(object, path)?,
1355                    )]);
1356                }
1357                if object.contains_key("_grid_") {
1358                    return Ok(vec![PipelineDslStep::Generator(
1359                        self.lower_grid_generator(object, path)?,
1360                    )]);
1361                }
1362                if object.contains_key("_sample_") {
1363                    return Ok(vec![PipelineDslStep::Generator(
1364                        self.lower_sample_generator(object, path)?,
1365                    )]);
1366                }
1367                if compat_plain_operator_ref(value).is_some() {
1368                    let operator = compat_plain_operator_value(value)?;
1369                    return match compat_plain_operator_kind(value) {
1370                        CompatPlainOperatorKind::Transform => Ok(vec![PipelineDslStep::Transform(
1371                            self.compat_operator_step(
1372                                Some(object),
1373                                "preprocessing",
1374                                &operator,
1375                                None,
1376                                None,
1377                            )?,
1378                        )]),
1379                        CompatPlainOperatorKind::Model => {
1380                            Ok(vec![PipelineDslStep::Model(self.compat_operator_step(
1381                                Some(object),
1382                                "model",
1383                                &operator,
1384                                None,
1385                                None,
1386                            )?)])
1387                        }
1388                        CompatPlainOperatorKind::Tuner => {
1389                            Ok(vec![PipelineDslStep::Tuner(self.compat_operator_step(
1390                                Some(object),
1391                                "tuner",
1392                                &operator,
1393                                None,
1394                                None,
1395                            )?)])
1396                        }
1397                        CompatPlainOperatorKind::Chart => {
1398                            Ok(vec![PipelineDslStep::Chart(self.compat_operator_step(
1399                                Some(object),
1400                                "chart",
1401                                &operator,
1402                                None,
1403                                None,
1404                            )?)])
1405                        }
1406                        CompatPlainOperatorKind::Split => Err(DagMlError::GraphValidation(
1407                            format!("{path} splitter object was not consumed as a campaign split"),
1408                        )),
1409                    };
1410                }
1411                if object.contains_key("type") || object.contains_key("ref") {
1412                    return Ok(vec![PipelineDslStep::Transform(
1413                        self.compat_operator_step(None, "preprocessing", value, None, None)?,
1414                    )]);
1415                }
1416                Err(DagMlError::GraphValidation(format!(
1417                    "unsupported nirs4all-compatible DSL object at {path}"
1418                )))
1419            }
1420            _ => Err(DagMlError::GraphValidation(format!(
1421                "unsupported nirs4all-compatible DSL value at {path}"
1422            ))),
1423        }
1424    }
1425
1426    fn lower_value_with_attachment(
1427        &mut self,
1428        value: &serde_json::Value,
1429        path: &str,
1430        attachment: CompatGenerationAttachment,
1431    ) -> Result<Vec<PipelineDslStep>> {
1432        match value {
1433            serde_json::Value::String(_) => match compat_plain_operator_kind(value) {
1434                CompatPlainOperatorKind::Transform => Ok(vec![PipelineDslStep::Transform(
1435                    self.compat_operator_step(
1436                        None,
1437                        "preprocessing",
1438                        value,
1439                        Some(attachment),
1440                        None,
1441                    )?,
1442                )]),
1443                CompatPlainOperatorKind::Model => Ok(vec![PipelineDslStep::Model(
1444                    self.compat_operator_step(None, "model", value, Some(attachment), None)?,
1445                )]),
1446                CompatPlainOperatorKind::Tuner => Ok(vec![PipelineDslStep::Tuner(
1447                    self.compat_operator_step(None, "tuner", value, Some(attachment), None)?,
1448                )]),
1449                CompatPlainOperatorKind::Chart => Ok(vec![PipelineDslStep::Chart(
1450                    self.compat_operator_step(None, "chart", value, Some(attachment), None)?,
1451                )]),
1452                CompatPlainOperatorKind::Split => Err(DagMlError::GraphValidation(format!(
1453                    "{path} splitter alias cannot receive a parameter generator"
1454                ))),
1455            },
1456            serde_json::Value::Object(object) => {
1457                if let Some(operator) = object.get("model") {
1458                    return Ok(vec![PipelineDslStep::Model(self.compat_operator_step(
1459                        Some(object),
1460                        "model",
1461                        operator,
1462                        Some(attachment),
1463                        None,
1464                    )?)]);
1465                }
1466                if let Some(operator) = first_object_value(object, &["tuner", "finetune"]) {
1467                    return Ok(vec![PipelineDslStep::Tuner(self.compat_operator_step(
1468                        Some(object),
1469                        "tuner",
1470                        operator,
1471                        Some(attachment),
1472                        None,
1473                    )?)]);
1474                }
1475                if let Some(operator) =
1476                    first_object_value(object, &["preprocessing", "processing", "transform"])
1477                {
1478                    return Ok(vec![PipelineDslStep::Transform(self.compat_operator_step(
1479                        Some(object),
1480                        "preprocessing",
1481                        operator,
1482                        Some(attachment),
1483                        None,
1484                    )?)]);
1485                }
1486                if compat_plain_operator_ref(value).is_some() {
1487                    let operator = compat_plain_operator_value(value)?;
1488                    return match compat_plain_operator_kind(value) {
1489                        CompatPlainOperatorKind::Transform => Ok(vec![PipelineDslStep::Transform(
1490                            self.compat_operator_step(
1491                                Some(object),
1492                                "preprocessing",
1493                                &operator,
1494                                Some(attachment),
1495                                None,
1496                            )?,
1497                        )]),
1498                        CompatPlainOperatorKind::Model => Ok(vec![PipelineDslStep::Model(
1499                            self.compat_operator_step(
1500                                Some(object),
1501                                "model",
1502                                &operator,
1503                                Some(attachment),
1504                                None,
1505                            )?,
1506                        )]),
1507                        CompatPlainOperatorKind::Tuner => Ok(vec![PipelineDslStep::Tuner(
1508                            self.compat_operator_step(
1509                                Some(object),
1510                                "tuner",
1511                                &operator,
1512                                Some(attachment),
1513                                None,
1514                            )?,
1515                        )]),
1516                        CompatPlainOperatorKind::Chart => Ok(vec![PipelineDslStep::Chart(
1517                            self.compat_operator_step(
1518                                Some(object),
1519                                "chart",
1520                                &operator,
1521                                Some(attachment),
1522                                None,
1523                            )?,
1524                        )]),
1525                        CompatPlainOperatorKind::Split => Err(DagMlError::GraphValidation(
1526                            format!("{path} splitter object cannot receive a parameter generator"),
1527                        )),
1528                    };
1529                }
1530                Err(DagMlError::GraphValidation(format!(
1531                    "{path} cannot receive a preceding nirs4all parameter generator; expected model, tuner or preprocessing"
1532                )))
1533            }
1534            _ => Err(DagMlError::GraphValidation(format!(
1535                "{path} cannot receive a preceding nirs4all parameter generator; expected model, tuner or preprocessing"
1536            ))),
1537        }
1538    }
1539
1540    fn lower_merge_followed_by_model(
1541        &mut self,
1542        values: &[serde_json::Value],
1543        index: usize,
1544        _path: &str,
1545    ) -> Result<Option<PipelineDslMergeModelStep>> {
1546        let Some(merge_object) = values[index].as_object() else {
1547            return Ok(None);
1548        };
1549        if !merge_object.contains_key("merge") {
1550            return Ok(None);
1551        }
1552        let Some(next) = values.get(index + 1).and_then(serde_json::Value::as_object) else {
1553            return Ok(None);
1554        };
1555        let Some(operator) = next.get("model") else {
1556            return Ok(None);
1557        };
1558        let (merge_mode, include_original_data, _) = compat_merge_modes(merge_object)?;
1559        let operator_step = self.compat_operator_step(Some(next), "model", operator, None, None)?;
1560        Ok(Some(PipelineDslMergeModelStep {
1561            inner_cv: operator_step.inner_cv,
1562            id: operator_step.id,
1563            operator: operator_step.operator,
1564            params: operator_step.params,
1565            metadata: operator_step.metadata,
1566            seed_label: operator_step.seed_label,
1567            include_original_data,
1568            merge_mode,
1569            train_params: operator_step.train_params,
1570            tuning: operator_step.tuning,
1571            variants: operator_step.variants,
1572            param_generators: operator_step.param_generators,
1573            shape: operator_step.shape,
1574        }))
1575    }
1576
1577    fn combine_data_generator_with_following(
1578        &mut self,
1579        generator: PipelineDslGeneratorStep,
1580        remaining: &[serde_json::Value],
1581        path: &str,
1582        absolute_start: usize,
1583    ) -> Result<Option<(PipelineDslGeneratorStep, usize)>> {
1584        let fused_id = generator.id.clone();
1585        let mut stages = generator_to_cartesian_stages(generator)?;
1586        let mut prefix_steps = Vec::new();
1587        let mut consumed = 0usize;
1588        while consumed < remaining.len() {
1589            let current_path = format!("{path}[{}]", absolute_start + consumed);
1590            if self.consume_side_effect_step(&remaining[consumed], &current_path)? {
1591                consumed += 1;
1592                continue;
1593            }
1594            let steps = if let Some(attachment) =
1595                self.parse_attached_generation(&remaining[consumed], &current_path)?
1596            {
1597                let next = remaining.get(consumed + 1).ok_or_else(|| {
1598                    DagMlError::GraphValidation(format!(
1599                        "{current_path} declares a parameter generator but has no following operator/model step"
1600                    ))
1601                })?;
1602                consumed += 1;
1603                self.lower_value_with_attachment(
1604                    next,
1605                    &format!("{path}[{}]", absolute_start + consumed),
1606                    attachment,
1607                )?
1608            } else if let Some(merge_model) =
1609                self.lower_merge_followed_by_model(remaining, consumed, &current_path)?
1610            {
1611                consumed += 1;
1612                vec![PipelineDslStep::MergeModel(merge_model)]
1613            } else {
1614                self.lower_value_as_steps(&remaining[consumed], &current_path)?
1615            };
1616            consumed += 1;
1617            if steps.is_empty() {
1618                continue;
1619            }
1620            if let [PipelineDslStep::Generator(next_generator)] = steps.as_slice() {
1621                if !prefix_steps.is_empty() {
1622                    stages.push(single_stage(
1623                        format!("stage{}", stages.len()),
1624                        "prefix",
1625                        std::mem::take(&mut prefix_steps),
1626                    ));
1627                }
1628                let next_has_prediction = generator_step_has_prediction(next_generator);
1629                stages.extend(generator_to_cartesian_stages(next_generator.clone())?);
1630                if next_has_prediction {
1631                    return Ok(Some((
1632                        combined_cartesian_generator(fused_id.clone(), stages),
1633                        consumed,
1634                    )));
1635                }
1636                continue;
1637            }
1638            let has_prediction = steps.iter().any(step_has_prediction);
1639            prefix_steps.extend(steps);
1640            if has_prediction {
1641                stages.push(single_stage(
1642                    format!("stage{}", stages.len()),
1643                    "then",
1644                    std::mem::take(&mut prefix_steps),
1645                ));
1646                return Ok(Some((
1647                    combined_cartesian_generator(fused_id.clone(), stages),
1648                    consumed,
1649                )));
1650            }
1651        }
1652        Ok(None)
1653    }
1654
1655    fn lower_branch_step(
1656        &mut self,
1657        object: &serde_json::Map<String, serde_json::Value>,
1658        path: &str,
1659    ) -> Result<PipelineDslBranchStep> {
1660        let branch_value = object.get("branch").expect("checked by caller");
1661        let mode = optional_object_field(object, "mode")?.unwrap_or_default();
1662        let selector = object.get("selector").cloned();
1663        let metadata = optional_object_field(object, "metadata")?.unwrap_or_default();
1664        let branches = match branch_value {
1665            serde_json::Value::Array(values) => values
1666                .iter()
1667                .enumerate()
1668                .map(|(index, value)| {
1669                    let id = compat_branch_id(value, index);
1670                    Ok(PipelineDslBranch {
1671                        id,
1672                        selector: None,
1673                        metadata: BTreeMap::new(),
1674                        steps: self
1675                            .lower_pipeline_fragment(value, &format!("{path}.branch[{index}]"))?,
1676                    })
1677                })
1678                .collect::<Result<Vec<_>>>()?,
1679            serde_json::Value::Object(branch_object) => {
1680                if let Some(values) = branch_object
1681                    .get("branches")
1682                    .and_then(serde_json::Value::as_array)
1683                {
1684                    values
1685                        .iter()
1686                        .enumerate()
1687                        .map(|(index, value)| {
1688                            self.lower_named_branch(
1689                                value,
1690                                index,
1691                                &format!("{path}.branch.branches[{index}]"),
1692                            )
1693                        })
1694                        .collect::<Result<Vec<_>>>()?
1695                } else {
1696                    branch_object
1697                        .iter()
1698                        .filter(|(key, _)| {
1699                            !matches!(key.as_str(), "mode" | "selector" | "metadata")
1700                        })
1701                        .enumerate()
1702                        .map(|(index, (key, value))| {
1703                            Ok(PipelineDslBranch {
1704                                id: sanitize_branch_id(key, index),
1705                                selector: None,
1706                                metadata: BTreeMap::new(),
1707                                steps: self.lower_pipeline_fragment(
1708                                    value,
1709                                    &format!("{path}.branch.{key}"),
1710                                )?,
1711                            })
1712                        })
1713                        .collect::<Result<Vec<_>>>()?
1714                }
1715            }
1716            _ => {
1717                return Err(DagMlError::GraphValidation(format!(
1718                    "{path}.branch must be an array or object"
1719                )));
1720            }
1721        };
1722        Ok(PipelineDslBranchStep {
1723            mode,
1724            selector,
1725            metadata,
1726            branches,
1727        })
1728    }
1729
1730    fn lower_named_branch(
1731        &mut self,
1732        value: &serde_json::Value,
1733        index: usize,
1734        path: &str,
1735    ) -> Result<PipelineDslBranch> {
1736        if let Some(object) = value.as_object() {
1737            if object.contains_key("steps") || object.contains_key("pipeline") {
1738                let id = object
1739                    .get("id")
1740                    .and_then(serde_json::Value::as_str)
1741                    .map(|id| sanitize_branch_id(id, index))
1742                    .unwrap_or_else(|| format!("branch{index}"));
1743                let selector = object.get("selector").cloned();
1744                let metadata = optional_object_field(object, "metadata")?.unwrap_or_default();
1745                let steps_value = object
1746                    .get("steps")
1747                    .or_else(|| object.get("pipeline"))
1748                    .ok_or_else(|| {
1749                        DagMlError::GraphValidation(format!(
1750                            "{path} branch object must contain steps or pipeline"
1751                        ))
1752                    })?;
1753                return Ok(PipelineDslBranch {
1754                    id,
1755                    selector,
1756                    metadata,
1757                    steps: self.lower_pipeline_fragment(steps_value, path)?,
1758                });
1759            }
1760        }
1761        Ok(PipelineDslBranch {
1762            id: compat_branch_id(value, index),
1763            selector: None,
1764            metadata: BTreeMap::new(),
1765            steps: self.lower_pipeline_fragment(value, path)?,
1766        })
1767    }
1768
1769    fn lower_concat_transform_step(
1770        &mut self,
1771        object: &serde_json::Map<String, serde_json::Value>,
1772        path: &str,
1773    ) -> Result<PipelineDslConcatTransformStep> {
1774        let value = object.get("concat_transform").expect("checked by caller");
1775        let branches = match value {
1776            serde_json::Value::Array(values) => values
1777                .iter()
1778                .enumerate()
1779                .map(|(index, value)| {
1780                    Ok(PipelineDslConcatBranch {
1781                        id: compat_branch_id(value, index),
1782                        steps: self.lower_concat_operator_steps(
1783                            value,
1784                            &format!("{path}.concat_transform[{index}]"),
1785                        )?,
1786                    })
1787                })
1788                .collect::<Result<Vec<_>>>()?,
1789            serde_json::Value::Object(map) => map
1790                .iter()
1791                .enumerate()
1792                .map(|(index, (key, value))| {
1793                    Ok(PipelineDslConcatBranch {
1794                        id: sanitize_branch_id(key, index),
1795                        steps: self.lower_concat_operator_steps(
1796                            value,
1797                            &format!("{path}.concat_transform.{key}"),
1798                        )?,
1799                    })
1800                })
1801                .collect::<Result<Vec<_>>>()?,
1802            _ => {
1803                return Err(DagMlError::GraphValidation(format!(
1804                    "{path}.concat_transform must be an array or object"
1805                )));
1806            }
1807        };
1808        Ok(PipelineDslConcatTransformStep {
1809            id: explicit_or_generated_node_id(object, "id", || self.next_node_id("join"))?,
1810            branches,
1811            metadata: optional_object_field(object, "metadata")?.unwrap_or_default(),
1812            seed_label: optional_object_field(object, "seed_label")?,
1813            representation: optional_object_field(object, "representation")?,
1814            variants: Vec::new(),
1815            param_generators: Vec::new(),
1816            shape: optional_object_field(object, "shape")?,
1817        })
1818    }
1819
1820    fn lower_concat_operator_steps(
1821        &mut self,
1822        value: &serde_json::Value,
1823        path: &str,
1824    ) -> Result<Vec<PipelineDslOperatorStep>> {
1825        let steps = self.lower_pipeline_fragment(value, path)?;
1826        steps
1827            .into_iter()
1828            .map(|step| match step {
1829                PipelineDslStep::Transform(step) => Ok(step),
1830                _ => Err(DagMlError::GraphValidation(format!(
1831                    "{path} concat_transform branches currently accept only preprocessing/transform steps"
1832                ))),
1833            })
1834            .collect()
1835    }
1836
1837    fn lower_merge_step(
1838        &mut self,
1839        object: &serde_json::Map<String, serde_json::Value>,
1840        _path: &str,
1841    ) -> Result<PipelineDslMergeStep> {
1842        let (merge_mode, include_original_data, output_as) = compat_merge_modes(object)?;
1843        let mut metadata: BTreeMap<String, serde_json::Value> =
1844            optional_object_field(object, "metadata")?.unwrap_or_default();
1845        if let Some(merge) = object.get("merge").filter(|merge| merge.is_object()) {
1846            metadata.insert("dsl_compat_merge".to_string(), merge.clone());
1847        }
1848        Ok(PipelineDslMergeStep {
1849            id: explicit_or_generated_node_id(object, "id", || self.next_node_id("merge"))?,
1850            merge_mode,
1851            output_as,
1852            include_original_data,
1853            on_missing: compat_merge_field(object, "on_missing")?,
1854            selectors: compat_merge_field(object, "selectors")?.unwrap_or_default(),
1855            metadata,
1856            seed_label: optional_object_field(object, "seed_label")?,
1857            representation: optional_object_field(object, "representation")?,
1858            variants: Vec::new(),
1859            param_generators: Vec::new(),
1860            shape: optional_object_field(object, "shape")?,
1861        })
1862    }
1863
1864    fn lower_or_generator(
1865        &mut self,
1866        object: &serde_json::Map<String, serde_json::Value>,
1867        key: &str,
1868        path: &str,
1869    ) -> Result<PipelineDslGeneratorStep> {
1870        let values = object
1871            .get(key)
1872            .and_then(serde_json::Value::as_array)
1873            .ok_or_else(|| DagMlError::GraphValidation(format!("{path}.{key} must be an array")))?;
1874        let branches = values
1875            .iter()
1876            .enumerate()
1877            .map(|(index, value)| {
1878                Ok(PipelineDslBranch {
1879                    id: compat_branch_id(value, index),
1880                    selector: None,
1881                    metadata: BTreeMap::new(),
1882                    steps: self
1883                        .lower_pipeline_fragment(value, &format!("{path}.{key}[{index}]"))?,
1884                })
1885            })
1886            .collect::<Result<Vec<_>>>()?;
1887        Ok(PipelineDslGeneratorStep {
1888            id: explicit_or_generated_node_id(object, "id", || self.next_generator_id())?,
1889            mode: PipelineDslGeneratorMode::Or,
1890            branches,
1891            stages: Vec::new(),
1892            pick: optional_object_field(object, "pick")?,
1893            arrange: optional_object_field(object, "arrange")?,
1894            then_pick: optional_object_field(object, "then_pick")?,
1895            then_arrange: optional_object_field(object, "then_arrange")?,
1896            count: optional_object_field(object, "count")?,
1897            metadata: compat_generator_metadata(object, key)?,
1898        })
1899    }
1900
1901    fn lower_cartesian_generator(
1902        &mut self,
1903        object: &serde_json::Map<String, serde_json::Value>,
1904        path: &str,
1905    ) -> Result<PipelineDslGeneratorStep> {
1906        let values = object
1907            .get("_cartesian_")
1908            .and_then(serde_json::Value::as_array)
1909            .ok_or_else(|| {
1910                DagMlError::GraphValidation(format!("{path}._cartesian_ must be an array"))
1911            })?;
1912        let stages = values
1913            .iter()
1914            .enumerate()
1915            .map(|(index, value)| {
1916                self.lower_cartesian_stage(value, index, &format!("{path}._cartesian_[{index}]"))
1917            })
1918            .collect::<Result<Vec<_>>>()?;
1919        Ok(PipelineDslGeneratorStep {
1920            id: explicit_or_generated_node_id(object, "id", || self.next_generator_id())?,
1921            mode: PipelineDslGeneratorMode::Cartesian,
1922            branches: Vec::new(),
1923            stages,
1924            pick: None,
1925            arrange: None,
1926            then_pick: None,
1927            then_arrange: None,
1928            count: optional_object_field(object, "count")?,
1929            metadata: compat_generator_metadata(object, "_cartesian_")?,
1930        })
1931    }
1932
1933    fn lower_cartesian_stage(
1934        &mut self,
1935        value: &serde_json::Value,
1936        index: usize,
1937        path: &str,
1938    ) -> Result<PipelineDslGeneratorStage> {
1939        if let Some(object) = value.as_object() {
1940            if object.contains_key("_or_") {
1941                let generator = self.lower_or_generator(object, "_or_", path)?;
1942                return Ok(PipelineDslGeneratorStage {
1943                    id: format!("stage{index}"),
1944                    selector: None,
1945                    metadata: BTreeMap::new(),
1946                    branches: generator.branches,
1947                });
1948            }
1949            if object.contains_key("_chain_") {
1950                let generator = self.lower_or_generator(object, "_chain_", path)?;
1951                return Ok(PipelineDslGeneratorStage {
1952                    id: format!("stage{index}"),
1953                    selector: None,
1954                    metadata: BTreeMap::new(),
1955                    branches: generator.branches,
1956                });
1957            }
1958            if object.contains_key("_grid_") {
1959                return Ok(PipelineDslGeneratorStage {
1960                    id: format!("stage{index}"),
1961                    selector: None,
1962                    metadata: BTreeMap::new(),
1963                    branches: self.lower_grid_branches(object.get("_grid_").unwrap(), path)?,
1964                });
1965            }
1966            if object.contains_key("_sample_") {
1967                let generator = self.lower_sample_generator(object, path)?;
1968                return Ok(PipelineDslGeneratorStage {
1969                    id: format!("stage{index}"),
1970                    selector: None,
1971                    metadata: BTreeMap::new(),
1972                    branches: generator.branches,
1973                });
1974            }
1975        }
1976        Ok(PipelineDslGeneratorStage {
1977            id: format!("stage{index}"),
1978            selector: None,
1979            metadata: BTreeMap::new(),
1980            branches: vec![PipelineDslBranch {
1981                id: "option0".to_string(),
1982                selector: None,
1983                metadata: BTreeMap::new(),
1984                steps: self.lower_pipeline_fragment(value, path)?,
1985            }],
1986        })
1987    }
1988
1989    fn lower_grid_generator(
1990        &mut self,
1991        object: &serde_json::Map<String, serde_json::Value>,
1992        path: &str,
1993    ) -> Result<PipelineDslGeneratorStep> {
1994        Ok(PipelineDslGeneratorStep {
1995            id: explicit_or_generated_node_id(object, "id", || self.next_generator_id())?,
1996            mode: PipelineDslGeneratorMode::Or,
1997            branches: self.lower_grid_branches(object.get("_grid_").unwrap(), path)?,
1998            stages: Vec::new(),
1999            pick: None,
2000            arrange: None,
2001            then_pick: None,
2002            then_arrange: None,
2003            count: optional_object_field(object, "count")?,
2004            metadata: compat_generator_metadata(object, "_grid_")?,
2005        })
2006    }
2007
2008    fn lower_sample_generator(
2009        &mut self,
2010        object: &serde_json::Map<String, serde_json::Value>,
2011        path: &str,
2012    ) -> Result<PipelineDslGeneratorStep> {
2013        Ok(PipelineDslGeneratorStep {
2014            id: explicit_or_generated_node_id(object, "id", || self.next_generator_id())?,
2015            mode: PipelineDslGeneratorMode::Or,
2016            branches: self.lower_sample_branches(object.get("_sample_").unwrap(), path)?,
2017            stages: Vec::new(),
2018            pick: None,
2019            arrange: None,
2020            then_pick: None,
2021            then_arrange: None,
2022            count: optional_object_field(object, "count")?,
2023            metadata: compat_generator_metadata(object, "_sample_")?,
2024        })
2025    }
2026
2027    fn lower_sample_branches(
2028        &mut self,
2029        value: &serde_json::Value,
2030        path: &str,
2031    ) -> Result<Vec<PipelineDslBranch>> {
2032        let sample = value.as_object().ok_or_else(|| {
2033            DagMlError::GraphValidation(format!("{path}._sample_ must be an object"))
2034        })?;
2035        let rows = compat_sample_rows(sample, path)?;
2036        let operator = sample
2037            .get("model")
2038            .or_else(|| sample.get("tuner"))
2039            .or_else(|| sample.get("finetune"))
2040            .or_else(|| sample.get("preprocessing"))
2041            .or_else(|| sample.get("transform"))
2042            .ok_or_else(|| {
2043                DagMlError::GraphValidation(format!(
2044                    "{path}._sample_ structural lowering requires `model`, `tuner`, `preprocessing` or `transform`"
2045                ))
2046            })?
2047            .clone();
2048        let keyword = if sample.contains_key("model") {
2049            "model"
2050        } else if sample.contains_key("tuner") || sample.contains_key("finetune") {
2051            "tuner"
2052        } else {
2053            "preprocessing"
2054        };
2055        let fixed_params = sample
2056            .iter()
2057            .filter(|(key, _)| {
2058                !matches!(
2059                    key.as_str(),
2060                    "model"
2061                        | "tuner"
2062                        | "finetune"
2063                        | "preprocessing"
2064                        | "transform"
2065                        | "distribution"
2066                        | "from"
2067                        | "to"
2068                        | "num"
2069                        | "count"
2070                        | "param"
2071                        | "tune"
2072                )
2073            })
2074            .map(|(key, value)| (key.clone(), value.clone()))
2075            .collect::<BTreeMap<_, _>>();
2076        rows.into_iter()
2077            .enumerate()
2078            .map(|(index, mut row)| {
2079                row.extend(fixed_params.clone());
2080                let step = self.compat_operator_step_from_parts(
2081                    keyword,
2082                    operator.clone(),
2083                    row,
2084                    None,
2085                    None,
2086                )?;
2087                Ok(PipelineDslBranch {
2088                    id: format!("sample{index}"),
2089                    selector: None,
2090                    metadata: BTreeMap::new(),
2091                    steps: vec![if keyword == "model" {
2092                        PipelineDslStep::Model(step)
2093                    } else if keyword == "tuner" {
2094                        PipelineDslStep::Tuner(step)
2095                    } else {
2096                        PipelineDslStep::Transform(step)
2097                    }],
2098                })
2099            })
2100            .collect()
2101    }
2102
2103    fn lower_grid_branches(
2104        &mut self,
2105        value: &serde_json::Value,
2106        path: &str,
2107    ) -> Result<Vec<PipelineDslBranch>> {
2108        let rows = compat_grid_rows(value, path)?;
2109        rows.into_iter()
2110            .enumerate()
2111            .map(|(index, row)| {
2112                let metadata = BTreeMap::from([(
2113                    "compat_grid_row".to_string(),
2114                    serde_json::to_value(&row).map_err(|error| {
2115                        DagMlError::GraphValidation(format!(
2116                            "failed to serialize grid row at {path}: {error}"
2117                        ))
2118                    })?,
2119                )]);
2120                Ok(PipelineDslBranch {
2121                    id: format!("grid{index}"),
2122                    selector: None,
2123                    metadata,
2124                    steps: self.lower_grid_row(row, path)?,
2125                })
2126            })
2127            .collect()
2128    }
2129
2130    fn lower_grid_row(
2131        &mut self,
2132        mut row: BTreeMap<String, serde_json::Value>,
2133        path: &str,
2134    ) -> Result<Vec<PipelineDslStep>> {
2135        if let Some(operator) = row.remove("model") {
2136            return Ok(vec![PipelineDslStep::Model(
2137                self.compat_operator_step_from_parts("model", operator, row, None, None)?,
2138            )]);
2139        }
2140        if let Some(operator) = row.remove("tuner").or_else(|| row.remove("finetune")) {
2141            return Ok(vec![PipelineDslStep::Tuner(
2142                self.compat_operator_step_from_parts("tuner", operator, row, None, None)?,
2143            )]);
2144        }
2145        if let Some(operator) = row
2146            .remove("preprocessing")
2147            .or_else(|| row.remove("processing"))
2148            .or_else(|| row.remove("transform"))
2149        {
2150            return Ok(vec![PipelineDslStep::Transform(
2151                self.compat_operator_step_from_parts("preprocessing", operator, row, None, None)?,
2152            )]);
2153        }
2154        Err(DagMlError::GraphValidation(format!(
2155            "{path}._grid_ rows must contain `model`, `tuner`, `preprocessing` or `transform` for structural lowering"
2156        )))
2157    }
2158
2159    fn lower_pipeline_fragment(
2160        &mut self,
2161        value: &serde_json::Value,
2162        path: &str,
2163    ) -> Result<Vec<PipelineDslStep>> {
2164        match value {
2165            serde_json::Value::Null => Ok(Vec::new()),
2166            serde_json::Value::Array(values) => self.lower_steps(values, path),
2167            _ => self.lower_value_as_steps(value, path),
2168        }
2169    }
2170
2171    fn parse_attached_generation(
2172        &mut self,
2173        value: &serde_json::Value,
2174        path: &str,
2175    ) -> Result<Option<CompatGenerationAttachment>> {
2176        let Some(object) = value.as_object() else {
2177            return Ok(None);
2178        };
2179        if let Some(range) = object.get("_range_") {
2180            return Ok(Some(CompatGenerationAttachment {
2181                variants: Vec::new(),
2182                param_generators: vec![compat_range_generator(range, object, path)?],
2183            }));
2184        }
2185        if let Some(range) = object.get("_log_range_") {
2186            return Ok(Some(CompatGenerationAttachment {
2187                variants: Vec::new(),
2188                param_generators: vec![compat_log_range_generator(range, object, path)?],
2189            }));
2190        }
2191        if let Some(grid) = object.get("_grid_") {
2192            if grid.as_object().is_some_and(|grid| {
2193                !grid.contains_key("model")
2194                    && !grid.contains_key("preprocessing")
2195                    && !grid.contains_key("transform")
2196            }) {
2197                return Ok(Some(CompatGenerationAttachment {
2198                    variants: Vec::new(),
2199                    param_generators: vec![compat_grid_param_generator(grid, object, path)?],
2200                }));
2201            }
2202        }
2203        if let Some(zip) = object.get("_zip_") {
2204            return Ok(Some(CompatGenerationAttachment {
2205                variants: compat_zip_variants(zip, path)?,
2206                param_generators: Vec::new(),
2207            }));
2208        }
2209        if let Some(sample) = object.get("_sample_") {
2210            if sample.as_object().is_some_and(|sample| {
2211                sample.contains_key("model")
2212                    || sample.contains_key("tuner")
2213                    || sample.contains_key("finetune")
2214                    || sample.contains_key("preprocessing")
2215                    || sample.contains_key("transform")
2216            }) {
2217                return Ok(None);
2218            }
2219            return Ok(Some(CompatGenerationAttachment {
2220                variants: compat_sample_variants(sample, path)?,
2221                param_generators: Vec::new(),
2222            }));
2223        }
2224        Ok(None)
2225    }
2226
2227    fn compat_operator_step(
2228        &mut self,
2229        object: Option<&serde_json::Map<String, serde_json::Value>>,
2230        keyword: &str,
2231        operator: &serde_json::Value,
2232        attachment: Option<CompatGenerationAttachment>,
2233        fallback_shape: Option<PipelineDslShapePlan>,
2234    ) -> Result<PipelineDslOperatorStep> {
2235        let id_prefix = compat_node_prefix(keyword);
2236        let mut params = object
2237            .and_then(|object| object_value_as_map(object.get("params")))
2238            .unwrap_or_default();
2239        if let Some(object) = object {
2240            for alias in compat_param_aliases(keyword) {
2241                if let Some(alias_params) = object_value_as_map(object.get(*alias)) {
2242                    params.extend(alias_params);
2243                }
2244            }
2245            for wrapper_key in compat_wrapper_param_keys(keyword) {
2246                if let Some(value) = object.get(*wrapper_key) {
2247                    params.insert((*wrapper_key).to_string(), value.clone());
2248                }
2249            }
2250        }
2251        let shape = match object.and_then(|object| object.get("shape")) {
2252            Some(shape) => Some(deserialize_value(
2253                shape.clone(),
2254                "pipeline DSL compat shape",
2255            )?),
2256            None => fallback_shape,
2257        };
2258        let mut step = PipelineDslOperatorStep {
2259            inner_cv: optional_object_field_from_option(object, "inner_cv")?,
2260            id: match object {
2261                Some(object) => {
2262                    explicit_or_generated_node_id(object, "id", || self.next_node_id(id_prefix))?
2263                }
2264                None => self.next_node_id(id_prefix)?,
2265            },
2266            operator: operator.clone(),
2267            params,
2268            metadata: optional_object_field_from_option(object, "metadata")?.unwrap_or_default(),
2269            seed_label: optional_object_field_from_option(object, "seed_label")?,
2270            representation: optional_object_field_from_option(object, "representation")?,
2271            train_params: optional_object_field_from_option(object, "train_params")?
2272                .unwrap_or_default(),
2273            tuning: optional_object_field_from_option(object, "tuning")?.or(
2274                optional_object_field_from_option(object, "finetune_params")?,
2275            ),
2276            variants: optional_object_field_from_option(object, "variants")?.unwrap_or_default(),
2277            param_generators: optional_object_field_from_option(object, "generators")?
2278                .unwrap_or_default(),
2279            shape,
2280        };
2281        step.metadata.insert(
2282            "dsl_compat_keyword".to_string(),
2283            serde_json::Value::String(keyword.to_string()),
2284        );
2285        if is_minimal_compat_operator_alias(object, operator) {
2286            step.metadata.insert(
2287                DSL_MINIMAL_OPERATOR_ALIAS.to_string(),
2288                serde_json::Value::Bool(true),
2289            );
2290        }
2291        if let Some(policy) = object.and_then(|object| object.get("policy")) {
2292            step.metadata
2293                .insert("dsl_compat_policy".to_string(), policy.clone());
2294        }
2295        if let Some(name) = object
2296            .and_then(|object| object.get("name"))
2297            .and_then(serde_json::Value::as_str)
2298        {
2299            step.metadata.insert(
2300                "dsl_name".to_string(),
2301                serde_json::Value::String(name.to_string()),
2302            );
2303        }
2304        if let Some(attachment) = attachment {
2305            step.variants.extend(attachment.variants);
2306            step.param_generators.extend(attachment.param_generators);
2307        }
2308        Ok(step)
2309    }
2310
2311    fn compat_operator_step_from_parts(
2312        &mut self,
2313        keyword: &str,
2314        operator: serde_json::Value,
2315        params: BTreeMap<String, serde_json::Value>,
2316        attachment: Option<CompatGenerationAttachment>,
2317        shape: Option<PipelineDslShapePlan>,
2318    ) -> Result<PipelineDslOperatorStep> {
2319        let mut step = PipelineDslOperatorStep {
2320            inner_cv: None,
2321            id: self.next_node_id(compat_node_prefix(keyword))?,
2322            operator,
2323            params,
2324            metadata: BTreeMap::from([(
2325                "dsl_compat_keyword".to_string(),
2326                serde_json::Value::String(keyword.to_string()),
2327            )]),
2328            seed_label: None,
2329            representation: None,
2330            train_params: BTreeMap::new(),
2331            tuning: None,
2332            variants: Vec::new(),
2333            param_generators: Vec::new(),
2334            shape,
2335        };
2336        if let Some(attachment) = attachment {
2337            step.variants.extend(attachment.variants);
2338            step.param_generators.extend(attachment.param_generators);
2339        }
2340        Ok(step)
2341    }
2342
2343    fn lower_split_invocation(
2344        &self,
2345        split: &serde_json::Value,
2346        object: &serde_json::Map<String, serde_json::Value>,
2347        path: &str,
2348    ) -> Result<SplitInvocation> {
2349        let mut params = BTreeMap::new();
2350        let mut id = object
2351            .get("id")
2352            .and_then(serde_json::Value::as_str)
2353            .unwrap_or("split:compat")
2354            .to_string();
2355        let mut controller_id = optional_object_field(object, "controller_id")?;
2356        let mut leakage_policy =
2357            optional_object_field(object, "leakage_policy")?.unwrap_or_default();
2358        let fold_set = optional_object_field(object, "fold_set")?;
2359        match split {
2360            serde_json::Value::String(kind) => {
2361                params.insert("kind".to_string(), serde_json::Value::String(kind.clone()));
2362                id = format!("split:{}", sanitize_generation_label(kind));
2363            }
2364            serde_json::Value::Object(split_object) => {
2365                if let Some(split_id) = split_object.get("id").and_then(serde_json::Value::as_str) {
2366                    id = split_id.to_string();
2367                }
2368                if controller_id.is_none() {
2369                    controller_id = optional_object_field(split_object, "controller_id")?;
2370                }
2371                if let Some(policy) = optional_object_field(split_object, "leakage_policy")? {
2372                    leakage_policy = policy;
2373                }
2374                if let Some(explicit_params) = object_value_as_map(split_object.get("params")) {
2375                    params.extend(explicit_params);
2376                }
2377                for (key, value) in split_object {
2378                    if !matches!(
2379                        key.as_str(),
2380                        "id" | "controller_id" | "leakage_policy" | "fold_set" | "params"
2381                    ) {
2382                        params.insert(key.clone(), value.clone());
2383                    }
2384                }
2385            }
2386            _ => {
2387                return Err(DagMlError::GraphValidation(format!(
2388                    "{path}.split must be a string or object"
2389                )));
2390            }
2391        }
2392        for (key, value) in object {
2393            if !matches!(
2394                key.as_str(),
2395                "split" | "id" | "controller_id" | "leakage_policy" | "fold_set" | "params"
2396            ) {
2397                params.entry(key.clone()).or_insert_with(|| value.clone());
2398            }
2399        }
2400        Ok(SplitInvocation {
2401            id,
2402            controller_id,
2403            leakage_policy,
2404            params,
2405            fold_set,
2406        })
2407    }
2408
2409    fn lower_plain_split_invocation(
2410        &self,
2411        value: &serde_json::Value,
2412        path: &str,
2413    ) -> Result<SplitInvocation> {
2414        let mut params = BTreeMap::new();
2415        let id;
2416        let mut controller_id = None;
2417        let mut leakage_policy = LeakageUnitPolicy::default();
2418        let mut fold_set = None;
2419        if let Some(object) = value.as_object() {
2420            id = object
2421                .get("id")
2422                .and_then(serde_json::Value::as_str)
2423                .map(str::to_string)
2424                .unwrap_or_else(|| {
2425                    compat_plain_operator_ref(value)
2426                        .map(|reference| format!("split:{}", sanitize_generation_label(reference)))
2427                        .unwrap_or_else(|| "split:compat".to_string())
2428                });
2429            controller_id = optional_object_field(object, "controller_id")?;
2430            leakage_policy = optional_object_field(object, "leakage_policy")?.unwrap_or_default();
2431            fold_set = optional_object_field(object, "fold_set")?;
2432            if let Some(explicit_params) = object_value_as_map(object.get("params")) {
2433                params.extend(explicit_params);
2434            }
2435            for (key, item) in object {
2436                if !matches!(
2437                    key.as_str(),
2438                    "id" | "controller_id" | "leakage_policy" | "fold_set" | "params" | "name"
2439                ) {
2440                    params.insert(key.clone(), item.clone());
2441                }
2442            }
2443        } else if let Some(reference) = compat_plain_operator_ref(value) {
2444            id = format!("split:{}", sanitize_generation_label(reference));
2445            params.insert(
2446                "class".to_string(),
2447                serde_json::Value::String(reference.to_string()),
2448            );
2449        } else {
2450            return Err(DagMlError::GraphValidation(format!(
2451                "{path} is not a nirs4all-compatible splitter alias"
2452            )));
2453        }
2454        if let Some(reference) = compat_plain_operator_ref(value) {
2455            params
2456                .entry("class".to_string())
2457                .or_insert_with(|| serde_json::Value::String(reference.to_string()));
2458        }
2459        Ok(SplitInvocation {
2460            id,
2461            controller_id,
2462            leakage_policy,
2463            params,
2464            fold_set,
2465        })
2466    }
2467
2468    fn set_split_invocation(&mut self, split: SplitInvocation, path: &str) -> Result<()> {
2469        let Some(existing) = self.split_invocation.as_mut() else {
2470            self.split_invocation = Some(split);
2471            return Ok(());
2472        };
2473        if existing.fold_set.is_some() && split.fold_set.is_some() {
2474            return Err(DagMlError::GraphValidation(format!(
2475                "{path} declares a second split with a fold_set; only one explicit fold_set can drive campaign OOF validation"
2476            )));
2477        }
2478        if existing.fold_set.is_none() {
2479            existing.fold_set = split.fold_set.clone();
2480        }
2481        let default_policy = LeakageUnitPolicy::default();
2482        if existing.leakage_policy == default_policy {
2483            existing.leakage_policy = split.leakage_policy.clone();
2484        } else if split.leakage_policy != default_policy
2485            && existing.leakage_policy != split.leakage_policy
2486        {
2487            return Err(DagMlError::GraphValidation(format!(
2488                "{path} declares split leakage_policy incompatible with the existing campaign split policy"
2489            )));
2490        }
2491        let first = split_invocation_chain_entry(existing)?;
2492        let second = split_invocation_chain_entry(&split)?;
2493        let mut chain = existing
2494            .params
2495            .remove("compat_split_chain")
2496            .and_then(|value| value.as_array().cloned())
2497            .unwrap_or_else(|| vec![first]);
2498        chain.push(second);
2499        existing.id = "split:compat.chain".to_string();
2500        existing.controller_id = None;
2501        existing.params.clear();
2502        existing.params.insert(
2503            "kind".to_string(),
2504            serde_json::Value::String("compat_split_chain".to_string()),
2505        );
2506        existing.params.insert(
2507            "compat_split_chain".to_string(),
2508            serde_json::Value::Array(chain),
2509        );
2510        Ok(())
2511    }
2512
2513    fn next_node_id(&mut self, prefix: &str) -> Result<NodeId> {
2514        let id = NodeId::new(format!("{prefix}:compat.{}", self.node_counter))?;
2515        self.node_counter += 1;
2516        Ok(id)
2517    }
2518
2519    fn next_generator_id(&mut self) -> Result<NodeId> {
2520        let id = NodeId::new(format!("generator:compat.{}", self.generator_counter))?;
2521        self.generator_counter += 1;
2522        Ok(id)
2523    }
2524}
2525
2526fn optional_root_field<T>(
2527    root: Option<&serde_json::Map<String, serde_json::Value>>,
2528    key: &str,
2529) -> Result<Option<T>>
2530where
2531    T: DeserializeOwned,
2532{
2533    match root.and_then(|object| object.get(key)) {
2534        Some(value) => Ok(Some(deserialize_value(value.clone(), key)?)),
2535        None => Ok(None),
2536    }
2537}
2538
2539fn optional_object_field<T>(
2540    object: &serde_json::Map<String, serde_json::Value>,
2541    key: &str,
2542) -> Result<Option<T>>
2543where
2544    T: DeserializeOwned,
2545{
2546    match object.get(key) {
2547        Some(value) => Ok(Some(deserialize_value(value.clone(), key)?)),
2548        None => Ok(None),
2549    }
2550}
2551
2552fn optional_object_field_from_option<T>(
2553    object: Option<&serde_json::Map<String, serde_json::Value>>,
2554    key: &str,
2555) -> Result<Option<T>>
2556where
2557    T: DeserializeOwned,
2558{
2559    match object.and_then(|object| object.get(key)) {
2560        Some(value) => Ok(Some(deserialize_value(value.clone(), key)?)),
2561        None => Ok(None),
2562    }
2563}
2564
2565fn compat_merge_field<T>(
2566    object: &serde_json::Map<String, serde_json::Value>,
2567    key: &str,
2568) -> Result<Option<T>>
2569where
2570    T: DeserializeOwned,
2571{
2572    let value = object.get(key).or_else(|| {
2573        object
2574            .get("merge")
2575            .and_then(serde_json::Value::as_object)
2576            .and_then(|merge| merge.get(key))
2577    });
2578    match value {
2579        Some(value) => Ok(Some(deserialize_value(value.clone(), key)?)),
2580        None => Ok(None),
2581    }
2582}
2583
2584fn deserialize_value<T>(value: serde_json::Value, label: &str) -> Result<T>
2585where
2586    T: DeserializeOwned,
2587{
2588    serde_json::from_value(value)
2589        .map_err(|error| DagMlError::GraphValidation(format!("failed to parse {label}: {error}")))
2590}
2591
2592fn explicit_or_generated_node_id<F>(
2593    object: &serde_json::Map<String, serde_json::Value>,
2594    key: &str,
2595    generated: F,
2596) -> Result<NodeId>
2597where
2598    F: FnOnce() -> Result<NodeId>,
2599{
2600    match object.get(key).and_then(serde_json::Value::as_str) {
2601        Some(id) => NodeId::new(id),
2602        None => generated(),
2603    }
2604}
2605
2606fn first_object_value<'a>(
2607    object: &'a serde_json::Map<String, serde_json::Value>,
2608    keys: &[&str],
2609) -> Option<&'a serde_json::Value> {
2610    keys.iter().find_map(|key| object.get(*key))
2611}
2612
2613fn is_comment_only_object(object: &serde_json::Map<String, serde_json::Value>) -> bool {
2614    !object.is_empty()
2615        && object
2616            .keys()
2617            .all(|key| matches!(key.as_str(), "_comment" | "comment" | "description"))
2618}
2619
2620fn value_can_receive_generation_attachment(value: &serde_json::Value) -> bool {
2621    let Some(object) = value.as_object() else {
2622        return false;
2623    };
2624    object.contains_key("model")
2625        || object.contains_key("tuner")
2626        || object.contains_key("finetune")
2627        || first_object_value(object, &["preprocessing", "processing", "transform"]).is_some()
2628        || compat_plain_operator_ref(value).is_some()
2629}
2630
2631fn object_value_as_map(
2632    value: Option<&serde_json::Value>,
2633) -> Option<BTreeMap<String, serde_json::Value>> {
2634    value.and_then(|value| {
2635        value.as_object().map(|object| {
2636            object
2637                .iter()
2638                .map(|(key, value)| (key.clone(), value.clone()))
2639                .collect()
2640        })
2641    })
2642}
2643
2644fn is_minimal_compat_operator_alias(
2645    object: Option<&serde_json::Map<String, serde_json::Value>>,
2646    operator: &serde_json::Value,
2647) -> bool {
2648    match object {
2649        None => compat_plain_operator_ref(operator).is_some(),
2650        Some(object) => {
2651            ["class", "function", "ref", "type"]
2652                .iter()
2653                .any(|key| object.contains_key(*key))
2654                && compat_plain_operator_ref(operator).is_some()
2655        }
2656    }
2657}
2658
2659fn annotate_named_steps(steps: &mut [PipelineDslStep], name: &str) {
2660    for step in steps {
2661        annotate_named_step(step, name);
2662    }
2663}
2664
2665fn annotate_named_step(step: &mut PipelineDslStep, name: &str) {
2666    let value = serde_json::Value::String(name.to_string());
2667    match step {
2668        PipelineDslStep::Transform(step)
2669        | PipelineDslStep::YTransform(step)
2670        | PipelineDslStep::Tag(step)
2671        | PipelineDslStep::Exclude(step)
2672        | PipelineDslStep::Filter(step)
2673        | PipelineDslStep::SampleFilter(step)
2674        | PipelineDslStep::Augmentation(step)
2675        | PipelineDslStep::FeatureAugmentation(step)
2676        | PipelineDslStep::SampleAugmentation(step)
2677        | PipelineDslStep::DataGeneration(step)
2678        | PipelineDslStep::Model(step)
2679        | PipelineDslStep::Tuner(step)
2680        | PipelineDslStep::Chart(step) => {
2681            step.metadata.insert("dsl_name".to_string(), value);
2682        }
2683        PipelineDslStep::ConcatTransform(step) => {
2684            step.metadata.insert("dsl_name".to_string(), value);
2685        }
2686        PipelineDslStep::Branch(step) => {
2687            step.metadata.insert("dsl_name".to_string(), value);
2688        }
2689        PipelineDslStep::Generator(step) => {
2690            step.metadata.insert("dsl_name".to_string(), value);
2691        }
2692        PipelineDslStep::Sequential(step) => {
2693            step.metadata.insert("dsl_name".to_string(), value);
2694        }
2695        PipelineDslStep::Merge(step) => {
2696            step.metadata.insert("dsl_name".to_string(), value);
2697        }
2698        PipelineDslStep::MergeModel(step) => {
2699            step.metadata.insert("dsl_name".to_string(), value);
2700        }
2701    }
2702}
2703
2704fn compat_plain_operator_ref(value: &serde_json::Value) -> Option<&str> {
2705    match value {
2706        serde_json::Value::String(reference) => Some(reference),
2707        serde_json::Value::Object(object) => ["class", "function", "ref", "type"]
2708            .into_iter()
2709            .find_map(|key| object.get(key).and_then(serde_json::Value::as_str)),
2710        _ => None,
2711    }
2712}
2713
2714fn compat_plain_operator_value(value: &serde_json::Value) -> Result<serde_json::Value> {
2715    match value {
2716        serde_json::Value::String(_) => Ok(value.clone()),
2717        serde_json::Value::Object(object) => {
2718            let mut operator = serde_json::Map::new();
2719            for key in ["class", "function", "ref", "type"] {
2720                if let Some(value) = object.get(key) {
2721                    operator.insert(key.to_string(), value.clone());
2722                }
2723            }
2724            if operator.is_empty() {
2725                return Err(DagMlError::GraphValidation(
2726                    "nirs4all-compatible plain operator object must contain class, function, ref or type"
2727                        .to_string(),
2728                ));
2729            }
2730            Ok(serde_json::Value::Object(operator))
2731        }
2732        _ => Err(DagMlError::GraphValidation(
2733            "nirs4all-compatible plain operator must be a string or object".to_string(),
2734        )),
2735    }
2736}
2737
2738fn compat_plain_operator_kind(value: &serde_json::Value) -> CompatPlainOperatorKind {
2739    let Some(reference) = compat_plain_operator_ref(value) else {
2740        return CompatPlainOperatorKind::Transform;
2741    };
2742    let lower = reference.to_ascii_lowercase();
2743    if compat_is_chart_alias(&lower) {
2744        CompatPlainOperatorKind::Chart
2745    } else if compat_is_tuner_alias(&lower) {
2746        CompatPlainOperatorKind::Tuner
2747    } else if compat_is_splitter_alias(&lower) {
2748        CompatPlainOperatorKind::Split
2749    } else if compat_is_model_alias(&lower) {
2750        CompatPlainOperatorKind::Model
2751    } else {
2752        CompatPlainOperatorKind::Transform
2753    }
2754}
2755
2756fn compat_is_chart_alias(lower: &str) -> bool {
2757    lower.starts_with("chart_")
2758        || lower == "chart"
2759        || lower.contains(".charts.")
2760        || lower.contains(".visualization.")
2761}
2762
2763fn compat_is_tuner_alias(lower: &str) -> bool {
2764    let short = lower.rsplit(['.', ':']).next().unwrap_or(lower);
2765    lower.contains(".tuners.")
2766        || lower.contains(".tuning.")
2767        || lower.contains("operators.tuners")
2768        || lower.contains("optuna")
2769        || lower.contains("ray.tune")
2770        || lower.contains("hyperopt")
2771        || short.ends_with("tuner")
2772        || short.ends_with("searchcv")
2773        || matches!(
2774            short,
2775            "gridsearchcv"
2776                | "randomizedsearchcv"
2777                | "halvinggridsearchcv"
2778                | "halvingrandomsearchcv"
2779                | "bayesiantuner"
2780                | "optunatuner"
2781        )
2782}
2783
2784fn compat_is_splitter_alias(lower: &str) -> bool {
2785    let short = lower.rsplit(['.', ':']).next().unwrap_or(lower);
2786    lower.contains("model_selection")
2787        || lower.contains(".splitters.")
2788        || lower.contains("operators.splitters")
2789        || short.contains("splitter")
2790        || short.ends_with("kfold")
2791        || short.ends_with("gfold")
2792        || short.ends_with("fold")
2793        || short.ends_with("split")
2794        || matches!(
2795            short,
2796            "leaveoneout" | "leavepout" | "predefinedsplit" | "timeseriessplit"
2797        )
2798}
2799
2800fn compat_is_model_alias(lower: &str) -> bool {
2801    let short = lower.rsplit(['.', ':']).next().unwrap_or(lower);
2802    lower.contains(".models.")
2803        || lower.contains("operators.models")
2804        || lower.contains("linear_model")
2805        || lower.contains("cross_decomposition")
2806        || lower.contains(".ensemble.")
2807        || lower.contains(".svm.")
2808        || lower.contains(".tree.")
2809        || lower.contains(".neighbors.")
2810        || lower.contains(".neural_network.")
2811        || lower.contains("xgboost")
2812        || lower.contains("lightgbm")
2813        || lower.contains("catboost")
2814        || short.ends_with("regressor")
2815        || short.ends_with("classifier")
2816        || short.ends_with("regression")
2817        || matches!(
2818            short,
2819            "ridge"
2820                | "lasso"
2821                | "elasticnet"
2822                | "svr"
2823                | "svc"
2824                | "linearsvr"
2825                | "linearsvc"
2826                | "pls"
2827                | "plsr"
2828                | "plsregression"
2829                | "metamodel"
2830        )
2831}
2832
2833fn compat_node_prefix(keyword: &str) -> &'static str {
2834    match keyword {
2835        "model" => "model",
2836        "tuner" | "finetune" => "tuner",
2837        "y_processing" | "y_transform" => "target",
2838        "tag" => "tag",
2839        "exclude" | "filter" | "sample_filter" => "filter",
2840        "sample_augmentation" | "feature_augmentation" | "augmentation" => "augment",
2841        "data_generation" | "generation" => "generator",
2842        "chart" => "chart",
2843        _ => "transform",
2844    }
2845}
2846
2847fn compat_param_aliases(keyword: &str) -> &'static [&'static str] {
2848    match keyword {
2849        "model" => &["model_params"],
2850        "tuner" | "finetune" => &["tuner_params", "finetune_params"],
2851        "preprocessing" | "processing" | "transform" => &[
2852            "preprocessing_params",
2853            "processing_params",
2854            "transform_params",
2855        ],
2856        "sample_augmentation" | "feature_augmentation" | "augmentation" => &["augmentation_params"],
2857        "data_generation" | "generation" => &["generation_params"],
2858        _ => &[],
2859    }
2860}
2861
2862fn compat_wrapper_param_keys(keyword: &str) -> &'static [&'static str] {
2863    match keyword {
2864        "tag" | "exclude" | "filter" | "sample_filter" => &["mode", "report", "tag_name"],
2865        "sample_augmentation" => &[
2866            "count",
2867            "selection",
2868            "random_state",
2869            "mode",
2870            "action",
2871            "report",
2872        ],
2873        "feature_augmentation" | "augmentation" => &[
2874            "size",
2875            "count",
2876            "selection",
2877            "random_state",
2878            "mode",
2879            "action",
2880            "report",
2881        ],
2882        "data_generation" | "generation" => &["size", "count", "random_state", "mode", "report"],
2883        "tuner" | "finetune" => &["n_trials", "metric", "direction", "timeout", "random_state"],
2884        _ => &[],
2885    }
2886}
2887
2888fn split_invocation_chain_entry(split: &SplitInvocation) -> Result<serde_json::Value> {
2889    let mut object = serde_json::Map::new();
2890    object.insert(
2891        "id".to_string(),
2892        serde_json::Value::String(split.id.clone()),
2893    );
2894    if let Some(controller_id) = &split.controller_id {
2895        object.insert(
2896            "controller_id".to_string(),
2897            serde_json::to_value(controller_id).map_err(|error| {
2898                DagMlError::GraphValidation(format!(
2899                    "failed to serialize split controller_id for compat split chain: {error}"
2900                ))
2901            })?,
2902        );
2903    }
2904    if split.leakage_policy != LeakageUnitPolicy::default() {
2905        object.insert(
2906            "leakage_policy".to_string(),
2907            serde_json::to_value(&split.leakage_policy).map_err(|error| {
2908                DagMlError::GraphValidation(format!(
2909                    "failed to serialize split leakage_policy for compat split chain: {error}"
2910                ))
2911            })?,
2912        );
2913    }
2914    if !split.params.is_empty() {
2915        object.insert(
2916            "params".to_string(),
2917            serde_json::to_value(&split.params).map_err(|error| {
2918                DagMlError::GraphValidation(format!(
2919                    "failed to serialize split params for compat split chain: {error}"
2920                ))
2921            })?,
2922        );
2923    }
2924    if let Some(fold_set) = &split.fold_set {
2925        object.insert(
2926            "fold_set".to_string(),
2927            serde_json::to_value(fold_set).map_err(|error| {
2928                DagMlError::GraphValidation(format!(
2929                    "failed to serialize split fold_set for compat split chain: {error}"
2930                ))
2931            })?,
2932        );
2933    }
2934    Ok(serde_json::Value::Object(object))
2935}
2936
2937fn compat_augmentation_shape(
2938    kind: &str,
2939    object: &serde_json::Map<String, serde_json::Value>,
2940) -> Result<PipelineDslShapePlan> {
2941    if let Some(shape) = object.get("shape") {
2942        return deserialize_value(shape.clone(), "augmentation shape");
2943    }
2944    let mut sample_scope = crate::policy::AugmentationScope::None;
2945    let mut feature_scope = crate::policy::AugmentationScope::None;
2946    match kind {
2947        "sample" => sample_scope = crate::policy::AugmentationScope::TrainOnly,
2948        "feature" => feature_scope = crate::policy::AugmentationScope::TrainOnly,
2949        _ => {
2950            sample_scope = crate::policy::AugmentationScope::TrainOnly;
2951            feature_scope = crate::policy::AugmentationScope::TrainOnly;
2952        }
2953    }
2954    if let Some(apply_to) = object
2955        .get("policy")
2956        .and_then(serde_json::Value::as_object)
2957        .and_then(|policy| policy.get("apply_to"))
2958        .and_then(serde_json::Value::as_str)
2959    {
2960        match apply_to {
2961            "train_only" => {}
2962            "all" | "all_partitions" => {
2963                if sample_scope != crate::policy::AugmentationScope::None {
2964                    sample_scope = crate::policy::AugmentationScope::AllPartitions;
2965                }
2966                if feature_scope != crate::policy::AugmentationScope::None {
2967                    feature_scope = crate::policy::AugmentationScope::AllPartitions;
2968                }
2969            }
2970            "none" => {
2971                sample_scope = crate::policy::AugmentationScope::None;
2972                feature_scope = crate::policy::AugmentationScope::None;
2973            }
2974            other => {
2975                return Err(DagMlError::GraphValidation(format!(
2976                    "unsupported nirs4all augmentation policy apply_to `{other}`"
2977                )));
2978            }
2979        }
2980    }
2981    Ok(PipelineDslShapePlan {
2982        input_granularity: None,
2983        target_granularity: None,
2984        fit_rows: Some(FitBoundary::FoldTrain),
2985        predict_rows: Some(FitBoundary::FoldValidation),
2986        feature_namespace: None,
2987        feature_schema_fingerprint: None,
2988        target_space: None,
2989        aggregation_policy: None,
2990        augmentation_policy: Some(AugmentationPolicy {
2991            sample_scope,
2992            feature_scope,
2993            require_origin_id: true,
2994            inherit_group: true,
2995            inherit_target: true,
2996            unsafe_flags: BTreeSet::new(),
2997        }),
2998        selection_policy: None,
2999    })
3000}
3001
3002fn compat_merge_modes(
3003    object: &serde_json::Map<String, serde_json::Value>,
3004) -> Result<(String, bool, PipelineDslMergeOutput)> {
3005    let merge = object
3006        .get("merge")
3007        .ok_or_else(|| DagMlError::GraphValidation("merge step lacks `merge`".to_string()))?;
3008    let merge_object = merge.as_object();
3009    let mode = merge
3010        .as_str()
3011        .or_else(|| {
3012            merge_object
3013                .and_then(|object| object.get("mode").or_else(|| object.get("strategy")))
3014                .and_then(serde_json::Value::as_str)
3015        })
3016        .map(str::to_string)
3017        .unwrap_or_else(|| infer_compat_merge_mode(merge_object));
3018    validate_compat_merge_mode(&mode)?;
3019    let include_original_data = object
3020        .get("include_original_data")
3021        .or_else(|| object.get("include_original"))
3022        .or_else(|| {
3023            merge_object.and_then(|object| {
3024                object
3025                    .get("include_original_data")
3026                    .or_else(|| object.get("include_original"))
3027            })
3028        })
3029        .and_then(serde_json::Value::as_bool)
3030        .unwrap_or(matches!(
3031            mode.as_str(),
3032            "all" | "mixed" | "predictions_plus_original"
3033        ));
3034    let output_as = object
3035        .get("output_as")
3036        .or_else(|| merge_object.and_then(|object| object.get("output_as")))
3037        .and_then(serde_json::Value::as_str)
3038        .map(compat_merge_output_as)
3039        .transpose()?
3040        .unwrap_or_else(|| compat_merge_output_for_mode(&mode));
3041    Ok((mode, include_original_data, output_as))
3042}
3043
3044fn infer_compat_merge_mode(
3045    merge_object: Option<&serde_json::Map<String, serde_json::Value>>,
3046) -> String {
3047    let Some(object) = merge_object else {
3048        return "predictions".to_string();
3049    };
3050    let has_predictions = object.contains_key("predictions") || object.contains_key("prediction");
3051    let has_features = object.contains_key("features") || object.contains_key("feature");
3052    let has_sources = object.contains_key("sources") || object.contains_key("source");
3053    match (has_predictions, has_features, has_sources) {
3054        (true, true, _) => "all",
3055        (true, false, _) => "predictions",
3056        (false, true, _) => "features",
3057        (false, false, true) => "sources",
3058        _ => "predictions",
3059    }
3060    .to_string()
3061}
3062
3063fn compat_merge_output_for_mode(mode: &str) -> PipelineDslMergeOutput {
3064    match mode {
3065        "predictions" | "prediction" => PipelineDslMergeOutput::Predictions,
3066        "sources" | "source" => PipelineDslMergeOutput::Sources,
3067        _ => PipelineDslMergeOutput::Features,
3068    }
3069}
3070
3071fn compat_merge_output_as(value: &str) -> Result<PipelineDslMergeOutput> {
3072    match value {
3073        "features" | "feature" => Ok(PipelineDslMergeOutput::Features),
3074        "predictions" | "prediction" => Ok(PipelineDslMergeOutput::Predictions),
3075        "sources" | "source" => Ok(PipelineDslMergeOutput::Sources),
3076        other => Err(DagMlError::GraphValidation(format!(
3077            "unsupported nirs4all merge output_as `{other}`"
3078        ))),
3079    }
3080}
3081
3082fn validate_compat_merge_mode(mode: &str) -> Result<()> {
3083    match mode {
3084        "predictions"
3085        | "prediction"
3086        | "sources"
3087        | "source"
3088        | "features"
3089        | "feature"
3090        | "concat"
3091        | "all"
3092        | "mixed"
3093        | "predictions_plus_original" => {}
3094        other => {
3095            return Err(DagMlError::GraphValidation(format!(
3096                "unsupported nirs4all merge mode `{other}`"
3097            )));
3098        }
3099    }
3100    Ok(())
3101}
3102
3103fn compat_generator_metadata(
3104    object: &serde_json::Map<String, serde_json::Value>,
3105    key: &str,
3106) -> Result<BTreeMap<String, serde_json::Value>> {
3107    let mut metadata: BTreeMap<String, serde_json::Value> =
3108        optional_object_field(object, "metadata")?.unwrap_or_default();
3109    metadata.insert(
3110        "dsl_compat_generator".to_string(),
3111        serde_json::Value::String(key.to_string()),
3112    );
3113    Ok(metadata)
3114}
3115
3116fn compat_branch_id(value: &serde_json::Value, index: usize) -> String {
3117    value
3118        .as_object()
3119        .and_then(|object| object.get("id"))
3120        .and_then(serde_json::Value::as_str)
3121        .map(|id| sanitize_branch_id(id, index))
3122        .unwrap_or_else(|| format!("choice{index}"))
3123}
3124
3125fn sanitize_branch_id(input: &str, index: usize) -> String {
3126    let sanitized = sanitize_generation_label(input);
3127    if sanitized == "value" {
3128        format!("branch{index}")
3129    } else {
3130        sanitized
3131    }
3132}
3133
3134fn step_has_prediction(step: &PipelineDslStep) -> bool {
3135    match step {
3136        PipelineDslStep::Model(_) | PipelineDslStep::Tuner(_) | PipelineDslStep::MergeModel(_) => {
3137            true
3138        }
3139        PipelineDslStep::Merge(step) => step.output_as == PipelineDslMergeOutput::Predictions,
3140        PipelineDslStep::Branch(step) => step
3141            .branches
3142            .iter()
3143            .any(|branch| branch.steps.iter().any(step_has_prediction)),
3144        PipelineDslStep::Generator(step) => generator_step_has_prediction(step),
3145        PipelineDslStep::Sequential(step) => step.steps.iter().any(step_has_prediction),
3146        _ => false,
3147    }
3148}
3149
3150fn generator_step_has_prediction(generator: &PipelineDslGeneratorStep) -> bool {
3151    generator
3152        .branches
3153        .iter()
3154        .any(|branch| branch.steps.iter().any(step_has_prediction))
3155        || generator.stages.iter().any(|stage| {
3156            stage
3157                .branches
3158                .iter()
3159                .any(|branch| branch.steps.iter().any(step_has_prediction))
3160        })
3161}
3162
3163fn generator_to_cartesian_stages(
3164    generator: PipelineDslGeneratorStep,
3165) -> Result<Vec<PipelineDslGeneratorStage>> {
3166    match generator.mode {
3167        PipelineDslGeneratorMode::Cartesian => Ok(generator.stages),
3168        PipelineDslGeneratorMode::Or => {
3169            if generator.pick.is_some()
3170                || generator.arrange.is_some()
3171                || generator.then_pick.is_some()
3172                || generator.then_arrange.is_some()
3173            {
3174                return Err(DagMlError::GraphValidation(format!(
3175                    "nirs4all-compatible data-only generator `{}` cannot be fused across downstream models when pick/arrange selectors are present",
3176                    generator.id
3177                )));
3178            }
3179            Ok(vec![PipelineDslGeneratorStage {
3180                id: sanitize_generation_label(generator.id.as_str()),
3181                selector: None,
3182                metadata: generator.metadata,
3183                branches: generator.branches,
3184            }])
3185        }
3186    }
3187}
3188
3189fn single_stage(
3190    id: String,
3191    branch_id: &str,
3192    steps: Vec<PipelineDslStep>,
3193) -> PipelineDslGeneratorStage {
3194    PipelineDslGeneratorStage {
3195        id,
3196        selector: None,
3197        metadata: BTreeMap::new(),
3198        branches: vec![PipelineDslBranch {
3199            id: branch_id.to_string(),
3200            selector: None,
3201            metadata: BTreeMap::new(),
3202            steps,
3203        }],
3204    }
3205}
3206
3207fn combined_cartesian_generator(
3208    id: NodeId,
3209    stages: Vec<PipelineDslGeneratorStage>,
3210) -> PipelineDslGeneratorStep {
3211    PipelineDslGeneratorStep {
3212        id,
3213        mode: PipelineDslGeneratorMode::Cartesian,
3214        branches: Vec::new(),
3215        stages,
3216        pick: None,
3217        arrange: None,
3218        then_pick: None,
3219        then_arrange: None,
3220        count: None,
3221        metadata: BTreeMap::from([(
3222            "dsl_compat_generator".to_string(),
3223            serde_json::Value::String("fused_data_to_prediction".to_string()),
3224        )]),
3225    }
3226}
3227
3228fn compat_grid_rows(
3229    value: &serde_json::Value,
3230    path: &str,
3231) -> Result<Vec<BTreeMap<String, serde_json::Value>>> {
3232    let object = value
3233        .as_object()
3234        .ok_or_else(|| DagMlError::GraphValidation(format!("{path}._grid_ must be an object")))?;
3235    if object.is_empty() {
3236        return Err(DagMlError::GraphValidation(format!(
3237            "{path}._grid_ must contain at least one parameter"
3238        )));
3239    }
3240    let entries = object
3241        .iter()
3242        .map(|(key, value)| {
3243            let values = match value {
3244                serde_json::Value::Array(values) => values.clone(),
3245                _ => vec![value.clone()],
3246            };
3247            if values.is_empty() {
3248                return Err(DagMlError::GraphValidation(format!(
3249                    "{path}._grid_.{key} has no values"
3250                )));
3251            }
3252            Ok((key.clone(), values))
3253        })
3254        .collect::<Result<Vec<_>>>()?;
3255    let mut rows = Vec::new();
3256    build_compat_grid_rows(&entries, 0, &mut BTreeMap::new(), &mut rows);
3257    Ok(rows)
3258}
3259
3260fn build_compat_grid_rows(
3261    entries: &[(String, Vec<serde_json::Value>)],
3262    index: usize,
3263    current: &mut BTreeMap<String, serde_json::Value>,
3264    rows: &mut Vec<BTreeMap<String, serde_json::Value>>,
3265) {
3266    if index == entries.len() {
3267        rows.push(current.clone());
3268        return;
3269    }
3270    let (key, values) = &entries[index];
3271    for value in values {
3272        current.insert(key.clone(), value.clone());
3273        build_compat_grid_rows(entries, index + 1, current, rows);
3274        current.remove(key);
3275    }
3276}
3277
3278fn compat_range_generator(
3279    value: &serde_json::Value,
3280    object: &serde_json::Map<String, serde_json::Value>,
3281    path: &str,
3282) -> Result<PipelineDslParamGenerator> {
3283    let param = object
3284        .get("param")
3285        .and_then(serde_json::Value::as_str)
3286        .unwrap_or("n_components")
3287        .to_string();
3288    let (start, stop, step) = if let Some(values) = value.as_array() {
3289        if values.len() != 3 {
3290            return Err(DagMlError::GraphValidation(format!(
3291                "{path}._range_ array must be [start, stop, step]"
3292            )));
3293        }
3294        (
3295            json_f64(&values[0], path, "_range_[0]")?,
3296            json_f64(&values[1], path, "_range_[1]")?,
3297            json_f64(&values[2], path, "_range_[2]")?,
3298        )
3299    } else if let Some(spec) = value.as_object() {
3300        (
3301            json_f64(
3302                spec.get("start").ok_or_else(|| {
3303                    DagMlError::GraphValidation(format!("{path}._range_ lacks start"))
3304                })?,
3305                path,
3306                "start",
3307            )?,
3308            json_f64(
3309                spec.get("stop").ok_or_else(|| {
3310                    DagMlError::GraphValidation(format!("{path}._range_ lacks stop"))
3311                })?,
3312                path,
3313                "stop",
3314            )?,
3315            json_f64(
3316                spec.get("step").ok_or_else(|| {
3317                    DagMlError::GraphValidation(format!("{path}._range_ lacks step"))
3318                })?,
3319                path,
3320                "step",
3321            )?,
3322        )
3323    } else {
3324        return Err(DagMlError::GraphValidation(format!(
3325            "{path}._range_ must be an array or object"
3326        )));
3327    };
3328    Ok(PipelineDslParamGenerator::Range {
3329        name: optional_object_field(object, "name")?,
3330        param,
3331        start,
3332        stop,
3333        step,
3334        inclusive: object
3335            .get("inclusive")
3336            .and_then(serde_json::Value::as_bool)
3337            .unwrap_or(true),
3338        count: optional_object_field(object, "count")?,
3339    })
3340}
3341
3342fn compat_log_range_generator(
3343    value: &serde_json::Value,
3344    object: &serde_json::Map<String, serde_json::Value>,
3345    path: &str,
3346) -> Result<PipelineDslParamGenerator> {
3347    let param = object
3348        .get("param")
3349        .and_then(serde_json::Value::as_str)
3350        .unwrap_or("alpha")
3351        .to_string();
3352    let spec = value.as_object().ok_or_else(|| {
3353        DagMlError::GraphValidation(format!("{path}._log_range_ must be an object"))
3354    })?;
3355    let start = json_f64(
3356        spec.get("start")
3357            .or_else(|| spec.get("from"))
3358            .ok_or_else(|| {
3359                DagMlError::GraphValidation(format!("{path}._log_range_ lacks start/from"))
3360            })?,
3361        path,
3362        "start",
3363    )?;
3364    let stop = json_f64(
3365        spec.get("stop").or_else(|| spec.get("to")).ok_or_else(|| {
3366            DagMlError::GraphValidation(format!("{path}._log_range_ lacks stop/to"))
3367        })?,
3368        path,
3369        "stop",
3370    )?;
3371    let count = spec
3372        .get("count")
3373        .or_else(|| spec.get("num"))
3374        .and_then(serde_json::Value::as_u64)
3375        .ok_or_else(|| DagMlError::GraphValidation(format!("{path}._log_range_ lacks count/num")))?
3376        as usize;
3377    Ok(PipelineDslParamGenerator::LogRange {
3378        name: optional_object_field(object, "name")?,
3379        param,
3380        start,
3381        stop,
3382        count,
3383        base: spec
3384            .get("base")
3385            .map(|value| json_f64(value, path, "base"))
3386            .transpose()?
3387            .unwrap_or(10.0),
3388    })
3389}
3390
3391fn compat_grid_param_generator(
3392    value: &serde_json::Value,
3393    object: &serde_json::Map<String, serde_json::Value>,
3394    path: &str,
3395) -> Result<PipelineDslParamGenerator> {
3396    let grid = value
3397        .as_object()
3398        .ok_or_else(|| DagMlError::GraphValidation(format!("{path}._grid_ must be an object")))?;
3399    let params = grid
3400        .iter()
3401        .map(|(key, value)| {
3402            let values = match value {
3403                serde_json::Value::Array(values) => values.clone(),
3404                _ => vec![value.clone()],
3405            };
3406            Ok((
3407                key.clone(),
3408                values
3409                    .into_iter()
3410                    .map(PipelineDslGeneratorValue::Value)
3411                    .collect::<Vec<_>>(),
3412            ))
3413        })
3414        .collect::<Result<BTreeMap<_, _>>>()?;
3415    Ok(PipelineDslParamGenerator::Grid {
3416        name: optional_object_field(object, "name")?,
3417        params,
3418        count: optional_object_field(object, "count")?,
3419    })
3420}
3421
3422fn compat_zip_variants(
3423    value: &serde_json::Value,
3424    path: &str,
3425) -> Result<Vec<PipelineDslVariantChoice>> {
3426    let object = value
3427        .as_object()
3428        .ok_or_else(|| DagMlError::GraphValidation(format!("{path}._zip_ must be an object")))?;
3429    let mut length = None;
3430    let mut columns = Vec::new();
3431    for (key, value) in object {
3432        let values = value.as_array().ok_or_else(|| {
3433            DagMlError::GraphValidation(format!("{path}._zip_.{key} must be an array"))
3434        })?;
3435        if let Some(expected) = length {
3436            if values.len() != expected {
3437                return Err(DagMlError::GraphValidation(format!(
3438                    "{path}._zip_ arrays must have equal lengths"
3439                )));
3440            }
3441        } else {
3442            length = Some(values.len());
3443        }
3444        columns.push((key.clone(), values.clone()));
3445    }
3446    let length = length.unwrap_or(0);
3447    if length == 0 {
3448        return Err(DagMlError::GraphValidation(format!(
3449            "{path}._zip_ must contain non-empty arrays"
3450        )));
3451    }
3452    Ok((0..length)
3453        .map(|index| {
3454            let params = columns
3455                .iter()
3456                .map(|(key, values)| (key.clone(), values[index].clone()))
3457                .collect::<BTreeMap<_, _>>();
3458            PipelineDslVariantChoice {
3459                label: format!("zip{index}"),
3460                params,
3461                value: None,
3462            }
3463        })
3464        .collect())
3465}
3466
3467fn compat_sample_rows(
3468    object: &serde_json::Map<String, serde_json::Value>,
3469    path: &str,
3470) -> Result<Vec<BTreeMap<String, serde_json::Value>>> {
3471    let param_names = if let Some(param) = object.get("param").and_then(serde_json::Value::as_str) {
3472        vec![param.to_string()]
3473    } else if let Some(tune) = object.get("tune").and_then(serde_json::Value::as_array) {
3474        let params = tune
3475            .iter()
3476            .map(|value| {
3477                value.as_str().map(str::to_string).ok_or_else(|| {
3478                    DagMlError::GraphValidation(format!(
3479                        "{path}._sample_.tune entries must be strings"
3480                    ))
3481                })
3482            })
3483            .collect::<Result<Vec<_>>>()?;
3484        if params.is_empty() {
3485            return Err(DagMlError::GraphValidation(format!(
3486                "{path}._sample_.tune cannot be empty"
3487            )));
3488        }
3489        params
3490    } else {
3491        return Err(DagMlError::GraphValidation(format!(
3492            "{path}._sample_ requires `param` or `tune` for deterministic JSON lowering"
3493        )));
3494    };
3495    let from = json_f64(
3496        object
3497            .get("from")
3498            .ok_or_else(|| DagMlError::GraphValidation(format!("{path}._sample_ lacks from")))?,
3499        path,
3500        "from",
3501    )?;
3502    let to = json_f64(
3503        object
3504            .get("to")
3505            .ok_or_else(|| DagMlError::GraphValidation(format!("{path}._sample_ lacks to")))?,
3506        path,
3507        "to",
3508    )?;
3509    let count = object
3510        .get("num")
3511        .or_else(|| object.get("count"))
3512        .and_then(serde_json::Value::as_u64)
3513        .ok_or_else(|| DagMlError::GraphValidation(format!("{path}._sample_ lacks num/count")))?
3514        as usize;
3515    if count == 0 {
3516        return Err(DagMlError::GraphValidation(format!(
3517            "{path}._sample_ count cannot be zero"
3518        )));
3519    }
3520    let distribution = object
3521        .get("distribution")
3522        .and_then(serde_json::Value::as_str)
3523        .unwrap_or("uniform");
3524    if distribution == "log_uniform" && (from <= 0.0 || to <= 0.0) {
3525        return Err(DagMlError::GraphValidation(format!(
3526            "{path}._sample_ log_uniform requires positive from/to"
3527        )));
3528    }
3529    (0..count)
3530        .map(|index| {
3531            let ratio = if count == 1 {
3532                0.0
3533            } else {
3534                index as f64 / (count - 1) as f64
3535            };
3536            let sampled = match distribution {
3537                "uniform" => from + (to - from) * ratio,
3538                "log_uniform" => {
3539                    let start = from.log10();
3540                    let stop = to.log10();
3541                    10f64.powf(start + (stop - start) * ratio)
3542                }
3543                other => {
3544                    return Err(DagMlError::GraphValidation(format!(
3545                        "{path}._sample_ unsupported deterministic distribution `{other}`"
3546                    )));
3547                }
3548            };
3549            let mut row = BTreeMap::new();
3550            let value = serde_json::Value::Number(
3551                serde_json::Number::from_f64(sampled).ok_or_else(|| {
3552                    DagMlError::GraphValidation(format!(
3553                        "{path}._sample_ produced non-finite value"
3554                    ))
3555                })?,
3556            );
3557            for param in &param_names {
3558                row.insert(param.clone(), value.clone());
3559            }
3560            Ok(row)
3561        })
3562        .collect()
3563}
3564
3565fn compat_sample_variants(
3566    value: &serde_json::Value,
3567    path: &str,
3568) -> Result<Vec<PipelineDslVariantChoice>> {
3569    let object = value
3570        .as_object()
3571        .ok_or_else(|| DagMlError::GraphValidation(format!("{path}._sample_ must be an object")))?;
3572    compat_sample_rows(object, path)?
3573        .into_iter()
3574        .enumerate()
3575        .map(|(index, params)| {
3576            Ok(PipelineDslVariantChoice {
3577                label: format!("sample{index}"),
3578                params,
3579                value: None,
3580            })
3581        })
3582        .collect()
3583}
3584
3585fn json_f64(value: &serde_json::Value, path: &str, field: &str) -> Result<f64> {
3586    value
3587        .as_f64()
3588        .ok_or_else(|| DagMlError::GraphValidation(format!("{path}.{field} must be numeric")))
3589}
3590
3591impl PipelineCompiler {
3592    fn compile_top_level_step(
3593        &mut self,
3594        step: &PipelineDslStep,
3595        external_data: &DataSource,
3596        state: &mut SequenceCompileState,
3597    ) -> Result<()> {
3598        self.compile_sequence_step(step, external_data, state, None, BTreeMap::new())
3599    }
3600
3601    fn compile_sequence_step(
3602        &mut self,
3603        step: &PipelineDslStep,
3604        original_data: &DataSource,
3605        state: &mut SequenceCompileState,
3606        branch_id: Option<&str>,
3607        extra_metadata: BTreeMap<String, serde_json::Value>,
3608    ) -> Result<()> {
3609        match step {
3610            PipelineDslStep::Transform(step) => {
3611                state.current_data = self.compile_data_operator_with_extra(
3612                    NodeKind::Transform,
3613                    step,
3614                    &state.current_data,
3615                    extra_metadata,
3616                )?;
3617                state.clear_pending();
3618                Ok(())
3619            }
3620            PipelineDslStep::YTransform(step) => {
3621                self.compile_y_transform_with_extra(step, extra_metadata)?;
3622                state.clear_pending();
3623                Ok(())
3624            }
3625            PipelineDslStep::Tag(step) => {
3626                state.current_data = self.compile_data_operator_with_extra(
3627                    NodeKind::Tag,
3628                    step,
3629                    &state.current_data,
3630                    extra_metadata,
3631                )?;
3632                state.clear_pending();
3633                Ok(())
3634            }
3635            PipelineDslStep::Exclude(step) => {
3636                state.current_data = self.compile_data_operator_with_extra(
3637                    NodeKind::Exclude,
3638                    step,
3639                    &state.current_data,
3640                    extra_metadata,
3641                )?;
3642                state.clear_pending();
3643                Ok(())
3644            }
3645            PipelineDslStep::Filter(step) => {
3646                state.current_data = self.compile_filter_operator(
3647                    "filter",
3648                    step,
3649                    &state.current_data,
3650                    extra_metadata,
3651                )?;
3652                state.clear_pending();
3653                Ok(())
3654            }
3655            PipelineDslStep::SampleFilter(step) => {
3656                state.current_data = self.compile_filter_operator(
3657                    "sample",
3658                    step,
3659                    &state.current_data,
3660                    extra_metadata,
3661                )?;
3662                state.clear_pending();
3663                Ok(())
3664            }
3665            PipelineDslStep::Augmentation(step) => {
3666                state.current_data = self.compile_data_operator_with_extra(
3667                    NodeKind::Augmentation,
3668                    step,
3669                    &state.current_data,
3670                    extra_metadata,
3671                )?;
3672                state.clear_pending();
3673                Ok(())
3674            }
3675            PipelineDslStep::FeatureAugmentation(step) => {
3676                state.current_data = self.compile_augmentation_operator_with_extra(
3677                    "feature",
3678                    step,
3679                    &state.current_data,
3680                    extra_metadata,
3681                )?;
3682                state.clear_pending();
3683                Ok(())
3684            }
3685            PipelineDslStep::SampleAugmentation(step) => {
3686                state.current_data = self.compile_augmentation_operator_with_extra(
3687                    "sample",
3688                    step,
3689                    &state.current_data,
3690                    extra_metadata,
3691                )?;
3692                state.clear_pending();
3693                Ok(())
3694            }
3695            PipelineDslStep::DataGeneration(step) => {
3696                state.current_data = self.compile_data_generation_with_extra(
3697                    step,
3698                    &state.current_data,
3699                    extra_metadata,
3700                )?;
3701                state.clear_pending();
3702                Ok(())
3703            }
3704            PipelineDslStep::ConcatTransform(step) => {
3705                state.current_data = self.compile_concat_transform_with_extra(
3706                    step,
3707                    &state.current_data,
3708                    extra_metadata,
3709                )?;
3710                state.clear_pending();
3711                Ok(())
3712            }
3713            PipelineDslStep::Model(step) => {
3714                state
3715                    .pending_predictions
3716                    .push(self.compile_model_with_extra(
3717                        step,
3718                        &state.current_data,
3719                        branch_id,
3720                        extra_metadata,
3721                    )?);
3722                Ok(())
3723            }
3724            PipelineDslStep::Tuner(step) => {
3725                state
3726                    .pending_predictions
3727                    .push(self.compile_tuner_with_extra(
3728                        step,
3729                        &state.current_data,
3730                        branch_id,
3731                        extra_metadata,
3732                    )?);
3733                Ok(())
3734            }
3735            PipelineDslStep::Branch(step) => {
3736                let output =
3737                    self.compile_branch_with_extra(step, &state.current_data, extra_metadata)?;
3738                state.pending_predictions = output.predictions;
3739                state.pending_branch_data = output.data_sources;
3740                Ok(())
3741            }
3742            PipelineDslStep::Generator(step) => {
3743                state.pending_predictions =
3744                    self.compile_generator_with_extra(step, &state.current_data, extra_metadata)?;
3745                state.pending_branch_data.clear();
3746                Ok(())
3747            }
3748            PipelineDslStep::Sequential(step) => {
3749                self.compile_sequence_container(
3750                    step,
3751                    original_data,
3752                    state,
3753                    branch_id,
3754                    extra_metadata,
3755                )?;
3756                Ok(())
3757            }
3758            PipelineDslStep::Merge(step) => {
3759                match self.compile_merge_with_extra(
3760                    step,
3761                    &state.pending_predictions,
3762                    &state.pending_branch_data,
3763                    original_data,
3764                    extra_metadata,
3765                )? {
3766                    MergeOutputSource::Data(data) => {
3767                        state.current_data = data;
3768                        state.clear_pending();
3769                    }
3770                    MergeOutputSource::Prediction(prediction) => {
3771                        state.clear_pending();
3772                        state.pending_predictions.push(prediction);
3773                    }
3774                }
3775                Ok(())
3776            }
3777            PipelineDslStep::MergeModel(step) => {
3778                let prediction = self.compile_merge_model_with_extra(
3779                    step,
3780                    &state.pending_predictions,
3781                    original_data,
3782                    extra_metadata,
3783                )?;
3784                state.clear_pending();
3785                state.pending_predictions.push(prediction);
3786                Ok(())
3787            }
3788            PipelineDslStep::Chart(step) => {
3789                state.current_data = self.compile_data_operator_with_extra(
3790                    NodeKind::Chart,
3791                    step,
3792                    &state.current_data,
3793                    extra_metadata,
3794                )?;
3795                state.clear_pending();
3796                Ok(())
3797            }
3798        }
3799    }
3800
3801    fn compile_sequence_container(
3802        &mut self,
3803        step: &PipelineDslSequenceStep,
3804        original_data: &DataSource,
3805        state: &mut SequenceCompileState,
3806        branch_id: Option<&str>,
3807        mut extra_metadata: BTreeMap<String, serde_json::Value>,
3808    ) -> Result<()> {
3809        if step.steps.is_empty() {
3810            return Err(DagMlError::GraphValidation(
3811                "pipeline DSL sequential step has no child steps".to_string(),
3812            ));
3813        }
3814        if let Some(sequence_id) = &step.id {
3815            extra_metadata.insert(
3816                "dsl_sequence".to_string(),
3817                serde_json::Value::String(sequence_id.to_string()),
3818            );
3819        }
3820        if !step.metadata.is_empty() {
3821            extra_metadata.insert(
3822                "dsl_sequence_metadata".to_string(),
3823                serde_json::to_value(&step.metadata).map_err(|error| {
3824                    DagMlError::GraphValidation(format!(
3825                        "failed to serialize pipeline DSL sequential metadata: {error}"
3826                    ))
3827                })?,
3828            );
3829        }
3830        for child in &step.steps {
3831            self.compile_sequence_step(
3832                child,
3833                original_data,
3834                state,
3835                branch_id,
3836                extra_metadata.clone(),
3837            )?;
3838        }
3839        Ok(())
3840    }
3841
3842    fn compile_branch_with_extra(
3843        &mut self,
3844        step: &PipelineDslBranchStep,
3845        current_data: &DataSource,
3846        extra_metadata: BTreeMap<String, serde_json::Value>,
3847    ) -> Result<BranchCompileOutput> {
3848        if step.branches.is_empty() {
3849            return Err(DagMlError::GraphValidation(format!(
3850                "pipeline DSL graph `{}` has a branch step without branches",
3851                self.graph_id
3852            )));
3853        }
3854        let mut predictions = Vec::new();
3855        let mut data_sources = Vec::new();
3856        for (index, branch) in step.branches.iter().enumerate() {
3857            validate_branch_id(&branch.id)?;
3858            if branch.steps.is_empty() {
3859                return Err(DagMlError::GraphValidation(format!(
3860                    "pipeline DSL branch `{}` has no steps",
3861                    branch.id
3862                )));
3863            }
3864            let branch_view_plan = compile_branch_view_plan(step, branch)?;
3865            let mut branch_state = SequenceCompileState::new(current_data.clone());
3866            let mut branch_metadata = branch_context_metadata(step, branch)?;
3867            if let Some(plan) = &branch_view_plan {
3868                branch_metadata.insert(
3869                    "dsl_branch_view_plan".to_string(),
3870                    serde_json::to_value(plan).map_err(|error| {
3871                        DagMlError::GraphValidation(format!(
3872                            "failed to serialize branch view plan for `{}`: {error}",
3873                            branch.id
3874                        ))
3875                    })?,
3876                );
3877            }
3878            branch_metadata.extend(extra_metadata.clone());
3879            for branch_step in &branch.steps {
3880                self.compile_sequence_step(
3881                    branch_step,
3882                    current_data,
3883                    &mut branch_state,
3884                    Some(&branch.id),
3885                    branch_metadata.clone(),
3886                )?;
3887            }
3888            if branch_state.pending_predictions.is_empty()
3889                && branch_state.pending_branch_data.is_empty()
3890                && same_data_source(&branch_state.current_data, current_data)
3891            {
3892                return Err(DagMlError::GraphValidation(format!(
3893                    "pipeline DSL branch `{}` must produce at least one model, merge prediction or transformed data output",
3894                    branch.id
3895                )));
3896            }
3897            if let Some(plan) = branch_view_plan {
3898                self.collect_branch_view_plan(plan)?;
3899            }
3900            let data_input_name = format!("{}_x", branch_input_prefix(&branch.id, index));
3901            data_sources.push(BranchDataSource {
3902                source: branch_state.current_data,
3903                input_name: data_input_name,
3904                branch_id: Some(branch.id.clone()),
3905            });
3906            data_sources.extend(branch_state.pending_branch_data);
3907            let prediction_count = branch_state.pending_predictions.len();
3908            for (prediction_index, prediction) in
3909                branch_state.pending_predictions.into_iter().enumerate()
3910            {
3911                let input_name = if prediction_count == 1 {
3912                    format!("{}_oof", branch_input_prefix(&branch.id, index))
3913                } else {
3914                    branch_prediction_input_name(
3915                        &branch.id,
3916                        index,
3917                        prediction_index,
3918                        &prediction.node_id,
3919                    )
3920                };
3921                predictions.push(PredictionSource {
3922                    input_name,
3923                    ..prediction
3924                });
3925            }
3926        }
3927        Ok(BranchCompileOutput {
3928            predictions,
3929            data_sources,
3930        })
3931    }
3932
3933    fn compile_generator_with_extra(
3934        &mut self,
3935        step: &PipelineDslGeneratorStep,
3936        current_data: &DataSource,
3937        extra_metadata: BTreeMap<String, serde_json::Value>,
3938    ) -> Result<Vec<PredictionSource>> {
3939        let choices = expand_generator_sequences(step)?;
3940        if choices.is_empty() {
3941            return Err(DagMlError::GraphValidation(format!(
3942                "pipeline DSL generator `{}` produced no choices",
3943                step.id
3944            )));
3945        }
3946        let mut predictions = Vec::new();
3947        for (choice_index, choice) in choices.into_iter().enumerate() {
3948            let choice = namespace_generated_sequence(step, choice, choice_index)?;
3949            validate_branch_id(&choice.id)?;
3950            if choice.steps.is_empty() {
3951                return Err(DagMlError::GraphValidation(format!(
3952                    "pipeline DSL generator `{}` choice `{}` has no steps",
3953                    step.id, choice.id
3954                )));
3955            }
3956            let mut choice_state = SequenceCompileState::new(current_data.clone());
3957            let mut choice_metadata = generator_choice_metadata(step, &choice)?;
3958            choice_metadata.extend(extra_metadata.clone());
3959            for choice_step in &choice.steps {
3960                self.compile_sequence_step(
3961                    choice_step,
3962                    current_data,
3963                    &mut choice_state,
3964                    Some(&choice.id),
3965                    choice_metadata.clone(),
3966                )?;
3967            }
3968            if choice_state.pending_predictions.is_empty() {
3969                return Err(DagMlError::GraphValidation(format!(
3970                    "pipeline DSL generator `{}` choice `{}` must produce at least one model or merge prediction",
3971                    step.id, choice.id
3972                )));
3973            }
3974            let prediction_count = choice_state.pending_predictions.len();
3975            for (prediction_index, prediction) in
3976                choice_state.pending_predictions.into_iter().enumerate()
3977            {
3978                let input_name = if prediction_count == 1 {
3979                    format!("{}_oof", branch_input_prefix(&choice.id, choice_index))
3980                } else {
3981                    branch_prediction_input_name(
3982                        &choice.id,
3983                        choice_index,
3984                        prediction_index,
3985                        &prediction.node_id,
3986                    )
3987                };
3988                predictions.push(PredictionSource {
3989                    input_name,
3990                    ..prediction
3991                });
3992            }
3993        }
3994        Ok(predictions)
3995    }
3996
3997    fn compile_data_operator(
3998        &mut self,
3999        kind: NodeKind,
4000        step: &PipelineDslOperatorStep,
4001        input: &DataSource,
4002    ) -> Result<DataSource> {
4003        self.compile_data_operator_with_extra(kind, step, input, BTreeMap::new())
4004    }
4005
4006    fn compile_filter_operator(
4007        &mut self,
4008        filter_kind: &str,
4009        step: &PipelineDslOperatorStep,
4010        input: &DataSource,
4011        mut extra: BTreeMap<String, serde_json::Value>,
4012    ) -> Result<DataSource> {
4013        extra.insert(
4014            "dsl_filter_kind".to_string(),
4015            serde_json::Value::String(filter_kind.to_string()),
4016        );
4017        self.compile_data_operator_with_extra(NodeKind::Exclude, step, input, extra)
4018    }
4019
4020    fn compile_augmentation_operator_with_extra(
4021        &mut self,
4022        augmentation_kind: &str,
4023        step: &PipelineDslOperatorStep,
4024        input: &DataSource,
4025        mut extra: BTreeMap<String, serde_json::Value>,
4026    ) -> Result<DataSource> {
4027        extra.insert(
4028            "dsl_augmentation_kind".to_string(),
4029            serde_json::Value::String(augmentation_kind.to_string()),
4030        );
4031        self.compile_data_operator_with_extra(NodeKind::Augmentation, step, input, extra)
4032    }
4033
4034    fn compile_data_generation_with_extra(
4035        &mut self,
4036        step: &PipelineDslOperatorStep,
4037        input: &DataSource,
4038        mut extra: BTreeMap<String, serde_json::Value>,
4039    ) -> Result<DataSource> {
4040        if step.shape.is_none() {
4041            return Err(DagMlError::GraphValidation(format!(
4042                "pipeline DSL data_generation `{}` requires a shape plan for leakage-safe runtime generation",
4043                step.id
4044            )));
4045        }
4046        extra.insert(
4047            "dsl_generation_kind".to_string(),
4048            serde_json::Value::String("data".to_string()),
4049        );
4050        self.compile_data_operator_with_extra(NodeKind::Generator, step, input, extra)
4051    }
4052
4053    fn compile_data_operator_with_extra(
4054        &mut self,
4055        kind: NodeKind,
4056        step: &PipelineDslOperatorStep,
4057        input: &DataSource,
4058        extra_metadata: BTreeMap<String, serde_json::Value>,
4059    ) -> Result<DataSource> {
4060        if kind == NodeKind::Augmentation && step.shape.is_none() {
4061            return Err(DagMlError::GraphValidation(format!(
4062                "pipeline DSL augmentation `{}` requires a shape plan for leakage-safe scope validation",
4063                step.id
4064            )));
4065        }
4066        let representation = step
4067            .representation
4068            .clone()
4069            .or_else(|| input.representation.clone())
4070            .or_else(|| self.input_representation.clone());
4071        let mut metadata = operator_runtime_metadata(step, None)?;
4072        metadata.extend(extra_metadata);
4073        let node = NodeSpec {
4074            id: step.id.clone(),
4075            kind,
4076            operator: Some(step.operator.clone()),
4077            params: step.params.clone(),
4078            ports: PortSchema {
4079                inputs: vec![data_port("x", input.representation.clone(), "")],
4080                outputs: vec![data_port("x_out", representation.clone(), "")],
4081            },
4082            metadata,
4083            seed_label: step.seed_label.clone(),
4084        };
4085        self.push_node(node)?;
4086        self.collect_operator_generation(&step.id, &step.variants, &step.param_generators)?;
4087        self.collect_shape_plan(&step.id, step.shape.as_ref())?;
4088        self.connect_data(input, &step.id, "x")?;
4089        Ok(DataSource {
4090            node_id: Some(step.id.clone()),
4091            port_name: "x_out".to_string(),
4092            representation,
4093        })
4094    }
4095
4096    fn compile_y_transform_with_extra(
4097        &mut self,
4098        step: &PipelineDslOperatorStep,
4099        extra_metadata: BTreeMap<String, serde_json::Value>,
4100    ) -> Result<()> {
4101        let mut metadata = operator_runtime_metadata(step, None)?;
4102        metadata.extend(extra_metadata);
4103        let node = NodeSpec {
4104            id: step.id.clone(),
4105            kind: NodeKind::YTransform,
4106            operator: Some(step.operator.clone()),
4107            params: step.params.clone(),
4108            ports: PortSchema {
4109                inputs: vec![target_port("y", "")],
4110                outputs: vec![target_port("y_out", "")],
4111            },
4112            metadata,
4113            seed_label: step.seed_label.clone(),
4114        };
4115        self.push_node(node)?;
4116        self.collect_operator_generation(&step.id, &step.variants, &step.param_generators)?;
4117        self.collect_shape_plan(&step.id, step.shape.as_ref())
4118    }
4119
4120    fn compile_concat_transform_with_extra(
4121        &mut self,
4122        step: &PipelineDslConcatTransformStep,
4123        input: &DataSource,
4124        extra_metadata: BTreeMap<String, serde_json::Value>,
4125    ) -> Result<DataSource> {
4126        if step.branches.is_empty() {
4127            return Err(DagMlError::GraphValidation(format!(
4128                "pipeline DSL concat_transform `{}` has no branches",
4129                step.id
4130            )));
4131        }
4132        let representation = step
4133            .representation
4134            .clone()
4135            .or_else(|| input.representation.clone())
4136            .or_else(|| self.input_representation.clone());
4137        let mut branch_outputs = Vec::with_capacity(step.branches.len());
4138        for (index, branch) in step.branches.iter().enumerate() {
4139            validate_branch_id(&branch.id)?;
4140            let mut branch_data = input.clone();
4141            for branch_step in &branch.steps {
4142                branch_data =
4143                    self.compile_data_operator(NodeKind::Transform, branch_step, &branch_data)?;
4144            }
4145            let input_name = format!("{}_x", branch_input_prefix(&branch.id, index));
4146            branch_outputs.push((input_name, branch_data));
4147        }
4148        let node = NodeSpec {
4149            id: step.id.clone(),
4150            kind: NodeKind::FeatureJoin,
4151            operator: None,
4152            params: BTreeMap::new(),
4153            ports: PortSchema {
4154                inputs: branch_outputs
4155                    .iter()
4156                    .map(|(name, source)| data_port(name, source.representation.clone(), ""))
4157                    .collect(),
4158                outputs: vec![data_port("x_out", representation.clone(), "")],
4159            },
4160            metadata: {
4161                let mut metadata = step.metadata.clone();
4162                metadata.extend(extra_metadata);
4163                metadata
4164            },
4165            seed_label: step.seed_label.clone(),
4166        };
4167        self.push_node(node)?;
4168        self.collect_operator_generation(&step.id, &step.variants, &step.param_generators)?;
4169        self.collect_shape_plan(&step.id, step.shape.as_ref())?;
4170        for (input_name, source) in &branch_outputs {
4171            self.connect_data_to_port(source, &step.id, input_name)?;
4172        }
4173        Ok(DataSource {
4174            node_id: Some(step.id.clone()),
4175            port_name: "x_out".to_string(),
4176            representation,
4177        })
4178    }
4179
4180    fn compile_model_with_extra(
4181        &mut self,
4182        step: &PipelineDslOperatorStep,
4183        input: &DataSource,
4184        branch_id: Option<&str>,
4185        extra_metadata: BTreeMap<String, serde_json::Value>,
4186    ) -> Result<PredictionSource> {
4187        self.compile_prediction_operator_with_extra(
4188            NodeKind::Model,
4189            step,
4190            input,
4191            branch_id,
4192            extra_metadata,
4193        )
4194    }
4195
4196    fn compile_tuner_with_extra(
4197        &mut self,
4198        step: &PipelineDslOperatorStep,
4199        input: &DataSource,
4200        branch_id: Option<&str>,
4201        extra_metadata: BTreeMap<String, serde_json::Value>,
4202    ) -> Result<PredictionSource> {
4203        self.compile_prediction_operator_with_extra(
4204            NodeKind::Tuner,
4205            step,
4206            input,
4207            branch_id,
4208            extra_metadata,
4209        )
4210    }
4211
4212    fn compile_prediction_operator_with_extra(
4213        &mut self,
4214        kind: NodeKind,
4215        step: &PipelineDslOperatorStep,
4216        input: &DataSource,
4217        branch_id: Option<&str>,
4218        extra_metadata: BTreeMap<String, serde_json::Value>,
4219    ) -> Result<PredictionSource> {
4220        let mut metadata = operator_runtime_metadata(step, branch_id)?;
4221        metadata.extend(extra_metadata);
4222        let node = NodeSpec {
4223            id: step.id.clone(),
4224            kind,
4225            operator: Some(step.operator.clone()),
4226            params: step.params.clone(),
4227            ports: PortSchema {
4228                inputs: vec![data_port("x", input.representation.clone(), "")],
4229                outputs: vec![prediction_port("oof", "")],
4230            },
4231            metadata,
4232            seed_label: step.seed_label.clone(),
4233        };
4234        self.push_node(node)?;
4235        self.collect_operator_generation(&step.id, &step.variants, &step.param_generators)?;
4236        self.collect_shape_plan(&step.id, step.shape.as_ref())?;
4237        self.connect_data(input, &step.id, "x")?;
4238        Ok(PredictionSource {
4239            node_id: step.id.clone(),
4240            port_name: "oof".to_string(),
4241            input_name: "oof".to_string(),
4242            branch_id: branch_id.map(str::to_string),
4243        })
4244    }
4245
4246    fn compile_merge_with_extra(
4247        &mut self,
4248        step: &PipelineDslMergeStep,
4249        predictions: &[PredictionSource],
4250        branch_data: &[BranchDataSource],
4251        original_data: &DataSource,
4252        extra_metadata: BTreeMap<String, serde_json::Value>,
4253    ) -> Result<MergeOutputSource> {
4254        let consumes_predictions = merge_consumes_predictions(step);
4255        let consumes_branch_data = merge_consumes_branch_data(step);
4256        let prediction_inputs = if consumes_predictions {
4257            predictions
4258        } else {
4259            &[]
4260        };
4261        let branch_data_inputs = if consumes_branch_data {
4262            branch_data
4263        } else {
4264            &[]
4265        };
4266        if prediction_inputs.is_empty()
4267            && branch_data_inputs.is_empty()
4268            && !step.include_original_data
4269        {
4270            return Err(DagMlError::GraphValidation(format!(
4271                "pipeline DSL merge `{}` has no pending predictions, branch data or original data input",
4272                step.id
4273            )));
4274        }
4275        validate_merge_selectors(&step.id, &step.selectors, prediction_inputs)?;
4276        let outputs_prediction = step.output_as == PipelineDslMergeOutput::Predictions;
4277        let representation = step
4278            .representation
4279            .clone()
4280            .or_else(|| original_data.representation.clone())
4281            .or_else(|| self.input_representation.clone());
4282        let mut input_ports =
4283            Vec::with_capacity(prediction_inputs.len() + branch_data_inputs.len() + 1);
4284        for prediction in prediction_inputs {
4285            input_ports.push(prediction_port(&prediction.input_name, ""));
4286        }
4287        for branch_source in branch_data_inputs {
4288            input_ports.push(data_port(
4289                &branch_source.input_name,
4290                branch_source.source.representation.clone(),
4291                "",
4292            ));
4293        }
4294        if step.include_original_data {
4295            input_ports.push(data_port(
4296                "x_original",
4297                original_data.representation.clone(),
4298                "",
4299            ));
4300        }
4301        let mut metadata = step.metadata.clone();
4302        metadata.insert(
4303            "merge_mode".to_string(),
4304            serde_json::Value::String(step.merge_mode.clone()),
4305        );
4306        metadata.insert(
4307            "output_as".to_string(),
4308            serde_json::to_value(step.output_as).map_err(|error| {
4309                DagMlError::GraphValidation(format!(
4310                    "failed to serialize pipeline DSL merge `{}` output mode: {error}",
4311                    step.id
4312                ))
4313            })?,
4314        );
4315        metadata.insert(
4316            "include_original_data".to_string(),
4317            serde_json::Value::Bool(step.include_original_data),
4318        );
4319        if let Some(on_missing) = &step.on_missing {
4320            metadata.insert(
4321                "on_missing".to_string(),
4322                serde_json::Value::String(on_missing.clone()),
4323            );
4324        }
4325        if !step.selectors.is_empty() {
4326            metadata.insert(
4327                "selectors".to_string(),
4328                serde_json::to_value(&step.selectors).map_err(|error| {
4329                    DagMlError::GraphValidation(format!(
4330                        "failed to serialize pipeline DSL merge `{}` selectors: {error}",
4331                        step.id
4332                    ))
4333                })?,
4334            );
4335        }
4336        if !branch_data_inputs.is_empty() {
4337            metadata.insert(
4338                "branch_data_inputs".to_string(),
4339                serde_json::to_value(
4340                    branch_data_inputs
4341                        .iter()
4342                        .map(|source| {
4343                            BTreeMap::from([
4344                                (
4345                                    "input_name".to_string(),
4346                                    serde_json::Value::String(source.input_name.clone()),
4347                                ),
4348                                (
4349                                    "branch".to_string(),
4350                                    source
4351                                        .branch_id
4352                                        .as_ref()
4353                                        .map(|branch| serde_json::Value::String(branch.clone()))
4354                                        .unwrap_or(serde_json::Value::Null),
4355                                ),
4356                            ])
4357                        })
4358                        .collect::<Vec<_>>(),
4359                )
4360                .map_err(|error| {
4361                    DagMlError::GraphValidation(format!(
4362                        "failed to serialize pipeline DSL merge `{}` branch data inputs: {error}",
4363                        step.id
4364                    ))
4365                })?,
4366            );
4367        }
4368        let branch_id = branch_id_from_metadata(&extra_metadata);
4369        metadata.extend(extra_metadata);
4370        let node = NodeSpec {
4371            id: step.id.clone(),
4372            kind: merge_node_kind(
4373                step,
4374                !prediction_inputs.is_empty(),
4375                !branch_data_inputs.is_empty(),
4376            ),
4377            operator: None,
4378            params: BTreeMap::new(),
4379            ports: PortSchema {
4380                inputs: input_ports,
4381                outputs: if outputs_prediction {
4382                    vec![prediction_port("prediction", "")]
4383                } else {
4384                    vec![data_port("x_out", representation.clone(), "")]
4385                },
4386            },
4387            metadata,
4388            seed_label: step.seed_label.clone(),
4389        };
4390        self.push_node(node)?;
4391        self.collect_operator_generation(&step.id, &step.variants, &step.param_generators)?;
4392        self.collect_shape_plan(&step.id, step.shape.as_ref())?;
4393        for prediction in prediction_inputs {
4394            self.edges.push(EdgeSpec {
4395                source: PortRef {
4396                    node_id: prediction.node_id.clone(),
4397                    port_name: prediction.port_name.clone(),
4398                },
4399                target: PortRef {
4400                    node_id: step.id.clone(),
4401                    port_name: prediction.input_name.clone(),
4402                },
4403                contract: EdgeContract {
4404                    requires_oof: true,
4405                    requires_fold_alignment: true,
4406                    ..EdgeContract::new(PortKind::Prediction, None)
4407                },
4408            });
4409        }
4410        for branch_source in branch_data_inputs {
4411            self.connect_data_to_port(&branch_source.source, &step.id, &branch_source.input_name)?;
4412        }
4413        if step.include_original_data {
4414            self.connect_data_to_port(original_data, &step.id, "x_original")?;
4415        }
4416        if outputs_prediction {
4417            Ok(MergeOutputSource::Prediction(PredictionSource {
4418                node_id: step.id.clone(),
4419                port_name: "prediction".to_string(),
4420                input_name: "oof".to_string(),
4421                branch_id,
4422            }))
4423        } else {
4424            Ok(MergeOutputSource::Data(DataSource {
4425                node_id: Some(step.id.clone()),
4426                port_name: "x_out".to_string(),
4427                representation,
4428            }))
4429        }
4430    }
4431
4432    fn compile_merge_model_with_extra(
4433        &mut self,
4434        step: &PipelineDslMergeModelStep,
4435        predictions: &[PredictionSource],
4436        external_data: &DataSource,
4437        extra_metadata: BTreeMap<String, serde_json::Value>,
4438    ) -> Result<PredictionSource> {
4439        if predictions.is_empty() {
4440            return Err(DagMlError::GraphValidation(format!(
4441                "pipeline DSL merge_model `{}` has no pending branch predictions",
4442                step.id
4443            )));
4444        }
4445        let mut input_ports = Vec::with_capacity(predictions.len() + 1);
4446        for prediction in predictions {
4447            input_ports.push(prediction_port(&prediction.input_name, ""));
4448        }
4449        if step.include_original_data {
4450            input_ports.push(data_port(
4451                "x_original",
4452                external_data.representation.clone(),
4453                "",
4454            ));
4455        }
4456        let mut metadata = step.metadata.clone();
4457        insert_training_metadata(
4458            &mut metadata,
4459            &step.train_params,
4460            step.tuning.as_ref(),
4461            step.inner_cv.as_ref(),
4462            &step.id,
4463        )?;
4464        metadata.insert(
4465            "merge_mode".to_string(),
4466            serde_json::Value::String(step.merge_mode.clone()),
4467        );
4468        let branch_id = branch_id_from_metadata(&extra_metadata);
4469        metadata.extend(extra_metadata);
4470        let node = NodeSpec {
4471            id: step.id.clone(),
4472            kind: NodeKind::Model,
4473            operator: Some(step.operator.clone()),
4474            params: step.params.clone(),
4475            ports: PortSchema {
4476                inputs: input_ports,
4477                outputs: vec![prediction_port("oof", "")],
4478            },
4479            metadata,
4480            seed_label: step.seed_label.clone(),
4481        };
4482        self.push_node(node)?;
4483        self.collect_operator_generation(&step.id, &step.variants, &step.param_generators)?;
4484        self.collect_shape_plan(&step.id, step.shape.as_ref())?;
4485        for prediction in predictions {
4486            self.edges.push(EdgeSpec {
4487                source: PortRef {
4488                    node_id: prediction.node_id.clone(),
4489                    port_name: prediction.port_name.clone(),
4490                },
4491                target: PortRef {
4492                    node_id: step.id.clone(),
4493                    port_name: prediction.input_name.clone(),
4494                },
4495                contract: EdgeContract {
4496                    requires_oof: true,
4497                    requires_fold_alignment: true,
4498                    ..EdgeContract::new(PortKind::Prediction, None)
4499                },
4500            });
4501        }
4502        if step.include_original_data {
4503            self.connect_data_to_port(external_data, &step.id, "x_original")?;
4504        }
4505        Ok(PredictionSource {
4506            node_id: step.id.clone(),
4507            port_name: "oof".to_string(),
4508            input_name: "oof".to_string(),
4509            branch_id,
4510        })
4511    }
4512
4513    fn push_node(&mut self, node: NodeSpec) -> Result<()> {
4514        if self.nodes.iter().any(|existing| existing.id == node.id) {
4515            return Err(DagMlError::GraphValidation(format!(
4516                "pipeline DSL graph `{}` produced duplicate node `{}`",
4517                self.graph_id, node.id
4518            )));
4519        }
4520        self.nodes.push(node);
4521        Ok(())
4522    }
4523
4524    fn collect_operator_generation(
4525        &mut self,
4526        node_id: &NodeId,
4527        choices: &[PipelineDslVariantChoice],
4528        generators: &[PipelineDslParamGenerator],
4529    ) -> Result<()> {
4530        if !choices.is_empty() {
4531            self.generation_dimensions
4532                .push(compile_variant_choice_dimension(node_id, choices)?);
4533        }
4534        for generator in generators {
4535            self.generation_dimensions
4536                .push(compile_param_generator_dimension(node_id, generator)?);
4537        }
4538        Ok(())
4539    }
4540
4541    fn collect_shape_plan(
4542        &mut self,
4543        node_id: &NodeId,
4544        shape: Option<&PipelineDslShapePlan>,
4545    ) -> Result<()> {
4546        let Some(shape) = shape else {
4547            return Ok(());
4548        };
4549        let plan = shape.to_data_model_shape_plan(node_id)?;
4550        if self.shape_plans.insert(node_id.clone(), plan).is_some() {
4551            return Err(DagMlError::GraphValidation(format!(
4552                "pipeline DSL graph `{}` produced duplicate shape plan for `{node_id}`",
4553                self.graph_id
4554            )));
4555        }
4556        Ok(())
4557    }
4558
4559    fn collect_branch_view_plan(&mut self, plan: BranchViewPlan) -> Result<()> {
4560        plan.validate()
4561            .map_err(|error| DagMlError::GraphValidation(error.to_string()))?;
4562        if self
4563            .branch_view_plans
4564            .iter()
4565            .any(|existing| existing.view_id == plan.view_id)
4566        {
4567            return Err(DagMlError::GraphValidation(format!(
4568                "pipeline DSL graph `{}` produced duplicate branch view `{}`",
4569                self.graph_id, plan.view_id
4570            )));
4571        }
4572        self.branch_view_plans.push(plan);
4573        Ok(())
4574    }
4575
4576    fn connect_data(
4577        &mut self,
4578        input: &DataSource,
4579        target_id: &NodeId,
4580        target_port: &str,
4581    ) -> Result<()> {
4582        self.connect_data_to_port(input, target_id, target_port)
4583    }
4584
4585    fn connect_data_to_port(
4586        &mut self,
4587        input: &DataSource,
4588        target_id: &NodeId,
4589        target_port: &str,
4590    ) -> Result<()> {
4591        if let Some(source_id) = &input.node_id {
4592            self.edges.push(EdgeSpec {
4593                source: PortRef {
4594                    node_id: source_id.clone(),
4595                    port_name: input.port_name.clone(),
4596                },
4597                target: PortRef {
4598                    node_id: target_id.clone(),
4599                    port_name: target_port.to_string(),
4600                },
4601                contract: EdgeContract {
4602                    requires_oof: false,
4603                    requires_fold_alignment: true,
4604                    ..EdgeContract::new(PortKind::Data, input.representation.clone())
4605                },
4606            });
4607        }
4608        Ok(())
4609    }
4610}
4611
4612impl PipelineDslShapePlan {
4613    fn to_data_model_shape_plan(&self, node_id: &NodeId) -> Result<DataModelShapePlan> {
4614        let plan = DataModelShapePlan {
4615            node_id: node_id.clone(),
4616            input_granularity: self.input_granularity.unwrap_or(Granularity::Sample),
4617            target_granularity: self.target_granularity.unwrap_or(Granularity::Sample),
4618            fit_rows: self.fit_rows.unwrap_or(FitBoundary::FoldTrain),
4619            predict_rows: self.predict_rows.unwrap_or(FitBoundary::FoldValidation),
4620            feature_namespace: self.feature_namespace.clone(),
4621            feature_schema_fingerprint: self.feature_schema_fingerprint.clone(),
4622            target_space: self
4623                .target_space
4624                .clone()
4625                .unwrap_or_else(|| "raw".to_string()),
4626            aggregation_policy: self.aggregation_policy.clone().unwrap_or_default(),
4627            augmentation_policy: self.augmentation_policy.clone().unwrap_or_default(),
4628            selection_policy: self.selection_policy.clone().unwrap_or_default(),
4629        };
4630        plan.validate()?;
4631        Ok(plan)
4632    }
4633}
4634
4635fn validate_shape_plan_targets(
4636    shape_plans: &BTreeMap<NodeId, DataModelShapePlan>,
4637    graph: &GraphSpec,
4638) -> Result<()> {
4639    for (node_id, plan) in shape_plans {
4640        if node_id != &plan.node_id {
4641            return Err(DagMlError::GraphValidation(format!(
4642                "pipeline DSL shape plan key `{node_id}` does not match node_id `{}`",
4643                plan.node_id
4644            )));
4645        }
4646        if !graph.nodes.iter().any(|node| &node.id == node_id) {
4647            return Err(DagMlError::GraphValidation(format!(
4648                "pipeline DSL shape plan references unknown node `{node_id}`"
4649            )));
4650        }
4651    }
4652    Ok(())
4653}
4654
4655fn compile_explicit_generation_dimensions(
4656    dimensions: &[PipelineDslGenerationDimension],
4657    nodes: &[NodeSpec],
4658) -> Result<Vec<GenerationDimension>> {
4659    if dimensions.is_empty() {
4660        return Ok(Vec::new());
4661    }
4662    let node_ids = nodes
4663        .iter()
4664        .map(|node| node.id.clone())
4665        .collect::<BTreeSet<_>>();
4666    dimensions
4667        .iter()
4668        .map(|dimension| compile_explicit_generation_dimension(dimension, &node_ids))
4669        .collect()
4670}
4671
4672fn compile_explicit_generation_dimension(
4673    dimension: &PipelineDslGenerationDimension,
4674    node_ids: &BTreeSet<NodeId>,
4675) -> Result<GenerationDimension> {
4676    let choices = dimension
4677        .choices
4678        .iter()
4679        .map(|choice| compile_explicit_generation_choice(&dimension.name, choice, node_ids))
4680        .collect::<Result<Vec<_>>>()?;
4681    Ok(GenerationDimension {
4682        name: dimension.name.clone(),
4683        choices,
4684    })
4685}
4686
4687fn compile_explicit_generation_choice(
4688    dimension_name: &str,
4689    choice: &PipelineDslGenerationChoice,
4690    node_ids: &BTreeSet<NodeId>,
4691) -> Result<GenerationChoice> {
4692    if choice.param_overrides.is_empty() {
4693        return Err(DagMlError::GraphValidation(format!(
4694            "pipeline DSL generation choice `{}` in dimension `{dimension_name}` has no param_overrides",
4695            choice.label
4696        )));
4697    }
4698    let param_overrides = choice
4699        .param_overrides
4700        .iter()
4701        .map(|override_spec| {
4702            if !node_ids.contains(&override_spec.node_id) {
4703                return Err(DagMlError::GraphValidation(format!(
4704                    "pipeline DSL generation choice `{}` in dimension `{dimension_name}` references unknown node `{}`",
4705                    choice.label, override_spec.node_id
4706                )));
4707            }
4708            Ok(GenerationParamOverride {
4709                node_id: override_spec.node_id.clone(),
4710                params: override_spec.params.clone(),
4711            })
4712        })
4713        .collect::<Result<Vec<_>>>()?;
4714    let value = match &choice.value {
4715        Some(value) => value.clone(),
4716        None => explicit_generation_choice_value(&param_overrides)?,
4717    };
4718    Ok(GenerationChoice {
4719        label: choice.label.clone(),
4720        value,
4721        param_overrides,
4722    })
4723}
4724
4725fn explicit_generation_choice_value(
4726    param_overrides: &[GenerationParamOverride],
4727) -> Result<serde_json::Value> {
4728    let mut by_node = serde_json::Map::new();
4729    for override_spec in param_overrides {
4730        let value = serde_json::to_value(&override_spec.params).map_err(|error| {
4731            DagMlError::GraphValidation(format!(
4732                "failed to serialize DSL generation override for node `{}`: {error}",
4733                override_spec.node_id
4734            ))
4735        })?;
4736        by_node.insert(override_spec.node_id.to_string(), value);
4737    }
4738    Ok(serde_json::Value::Object(by_node))
4739}
4740
4741fn build_campaign_template(
4742    spec: &PipelineDslSpec,
4743    generation: &GenerationSpec,
4744    shape_plans: &BTreeMap<NodeId, DataModelShapePlan>,
4745    data_bindings: &BTreeMap<NodeId, Vec<DataBinding>>,
4746    branch_view_plans: &[BranchViewPlan],
4747) -> Result<CampaignSpec> {
4748    let campaign = CampaignSpec {
4749        inner_cv: spec.inner_cv.clone(),
4750        id: spec
4751            .campaign_id
4752            .clone()
4753            .unwrap_or_else(|| format!("campaign:{}", spec.id)),
4754        root_seed: spec.root_seed,
4755        leakage_policy: spec.leakage_policy.clone().unwrap_or_default(),
4756        aggregation_policy: spec.aggregation_policy.clone().unwrap_or_default(),
4757        split_invocation: spec.split_invocation.clone(),
4758        generation: generation.clone(),
4759        shape_plans: shape_plans.clone(),
4760        data_bindings: data_bindings.clone(),
4761        branch_view_plans: branch_view_plans.to_vec(),
4762        metadata: spec.campaign_metadata.clone(),
4763    };
4764    campaign.validate()?;
4765    Ok(campaign)
4766}
4767
4768fn compile_data_bindings(
4769    bindings: &[DataBinding],
4770    graph: &GraphSpec,
4771) -> Result<BTreeMap<NodeId, Vec<DataBinding>>> {
4772    let mut by_node = BTreeMap::<NodeId, Vec<DataBinding>>::new();
4773    for binding in bindings {
4774        validate_dsl_data_binding(binding, graph)?;
4775        by_node
4776            .entry(binding.node_id.clone())
4777            .or_default()
4778            .push(binding.clone());
4779    }
4780    Ok(by_node)
4781}
4782
4783fn validate_dsl_data_binding(binding: &DataBinding, graph: &GraphSpec) -> Result<()> {
4784    binding.validate()?;
4785    let node = graph
4786        .nodes
4787        .iter()
4788        .find(|node| node.id == binding.node_id)
4789        .ok_or_else(|| {
4790            DagMlError::GraphValidation(format!(
4791                "pipeline DSL data binding references unknown node `{}`",
4792                binding.node_id
4793            ))
4794        })?;
4795    let Some(input_port) = node
4796        .ports
4797        .inputs
4798        .iter()
4799        .find(|port| port.name == binding.input_name)
4800    else {
4801        return Err(DagMlError::GraphValidation(format!(
4802            "pipeline DSL data binding `{}` references unknown input port `{}` on node `{}`",
4803            binding.request_id, binding.input_name, binding.node_id
4804        )));
4805    };
4806    if input_port.kind != PortKind::Data {
4807        return Err(DagMlError::GraphValidation(format!(
4808            "pipeline DSL data binding `{}` targets non-data input `{}.{}`",
4809            binding.request_id, binding.node_id, binding.input_name
4810        )));
4811    }
4812    Ok(())
4813}
4814
4815fn compile_variant_choice_dimension(
4816    node_id: &NodeId,
4817    choices: &[PipelineDslVariantChoice],
4818) -> Result<GenerationDimension> {
4819    Ok(GenerationDimension {
4820        name: format!("{node_id}.params"),
4821        choices: choices
4822            .iter()
4823            .map(|choice| {
4824                if choice.params.is_empty() {
4825                    return Err(DagMlError::GraphValidation(format!(
4826                        "pipeline DSL variant `{}` for node `{node_id}` has no params",
4827                        choice.label
4828                    )));
4829                }
4830                let value = match &choice.value {
4831                    Some(value) => value.clone(),
4832                    None => serde_json::to_value(&choice.params).map_err(|error| {
4833                        DagMlError::GraphValidation(format!(
4834                            "failed to serialize pipeline DSL variant `{}` for node `{node_id}`: {error}",
4835                            choice.label
4836                        ))
4837                    })?,
4838                };
4839                Ok(GenerationChoice {
4840                    label: choice.label.clone(),
4841                    value,
4842                    param_overrides: vec![GenerationParamOverride {
4843                        node_id: node_id.clone(),
4844                        params: choice.params.clone(),
4845                    }],
4846                })
4847            })
4848            .collect::<Result<Vec<_>>>()?,
4849    })
4850}
4851
4852fn compile_param_generator_dimension(
4853    node_id: &NodeId,
4854    generator: &PipelineDslParamGenerator,
4855) -> Result<GenerationDimension> {
4856    match generator {
4857        PipelineDslParamGenerator::Or {
4858            name,
4859            param,
4860            values,
4861            count,
4862        } => compile_or_generator(node_id, name.as_deref(), param, values, *count),
4863        PipelineDslParamGenerator::Range {
4864            name,
4865            param,
4866            start,
4867            stop,
4868            step,
4869            inclusive,
4870            count,
4871        } => compile_range_generator(RangeGeneratorSpec {
4872            node_id,
4873            name: name.as_deref(),
4874            param,
4875            start: *start,
4876            stop: *stop,
4877            step: *step,
4878            inclusive: *inclusive,
4879            count: *count,
4880        }),
4881        PipelineDslParamGenerator::LogRange {
4882            name,
4883            param,
4884            start,
4885            stop,
4886            count,
4887            base,
4888        } => compile_log_range_generator(
4889            node_id,
4890            name.as_deref(),
4891            param,
4892            *start,
4893            *stop,
4894            *count,
4895            *base,
4896        ),
4897        PipelineDslParamGenerator::Grid {
4898            name,
4899            params,
4900            count,
4901        } => compile_grid_generator(node_id, name.as_deref(), params, *count),
4902        PipelineDslParamGenerator::Pick {
4903            name,
4904            param,
4905            values,
4906            sizes,
4907            count,
4908        } => compile_pick_arrange_generator(
4909            node_id,
4910            name.as_deref(),
4911            param,
4912            values,
4913            sizes,
4914            *count,
4915            PickArrangeMode::Pick,
4916        ),
4917        PipelineDslParamGenerator::Arrange {
4918            name,
4919            param,
4920            values,
4921            sizes,
4922            count,
4923        } => compile_pick_arrange_generator(
4924            node_id,
4925            name.as_deref(),
4926            param,
4927            values,
4928            sizes,
4929            *count,
4930            PickArrangeMode::Arrange,
4931        ),
4932    }
4933}
4934
4935fn compile_or_generator(
4936    node_id: &NodeId,
4937    name: Option<&str>,
4938    param: &str,
4939    values: &[PipelineDslGeneratorValue],
4940    count: Option<usize>,
4941) -> Result<GenerationDimension> {
4942    validate_param_name(node_id, param)?;
4943    validate_count(node_id, name, count)?;
4944    if values.is_empty() {
4945        return Err(DagMlError::GraphValidation(format!(
4946            "pipeline DSL generator `{}` for node `{node_id}` has no values",
4947            generator_dimension_name(node_id, name, Some(param), "or")
4948        )));
4949    }
4950    let mut choices = values
4951        .iter()
4952        .enumerate()
4953        .map(|(index, value)| single_param_generation_choice(node_id, param, index, value))
4954        .collect::<Result<Vec<_>>>()?;
4955    apply_choice_count(&mut choices, count);
4956    Ok(GenerationDimension {
4957        name: generator_dimension_name(node_id, name, Some(param), "or"),
4958        choices,
4959    })
4960}
4961
4962struct RangeGeneratorSpec<'a> {
4963    node_id: &'a NodeId,
4964    name: Option<&'a str>,
4965    param: &'a str,
4966    start: f64,
4967    stop: f64,
4968    step: f64,
4969    inclusive: bool,
4970    count: Option<usize>,
4971}
4972
4973fn compile_range_generator(spec: RangeGeneratorSpec<'_>) -> Result<GenerationDimension> {
4974    validate_param_name(spec.node_id, spec.param)?;
4975    validate_count(spec.node_id, spec.name, spec.count)?;
4976    validate_finite(spec.node_id, spec.param, "range start", spec.start)?;
4977    validate_finite(spec.node_id, spec.param, "range stop", spec.stop)?;
4978    validate_finite(spec.node_id, spec.param, "range step", spec.step)?;
4979    if spec.step == 0.0 {
4980        return Err(DagMlError::GraphValidation(format!(
4981            "pipeline DSL range generator for `{}.{}` has zero step",
4982            spec.node_id, spec.param
4983        )));
4984    }
4985    if spec.start < spec.stop && spec.step < 0.0 {
4986        return Err(DagMlError::GraphValidation(format!(
4987            "pipeline DSL range generator for `{}.{}` steps away from stop",
4988            spec.node_id, spec.param
4989        )));
4990    }
4991    if spec.start > spec.stop && spec.step > 0.0 {
4992        return Err(DagMlError::GraphValidation(format!(
4993            "pipeline DSL range generator for `{}.{}` steps away from stop",
4994            spec.node_id, spec.param
4995        )));
4996    }
4997    let mut values = Vec::new();
4998    let mut current = spec.start;
4999    let mut guard = 0usize;
5000    while range_contains(current, spec.stop, spec.step, spec.inclusive) {
5001        values.push(json_number(current, spec.node_id, spec.param)?);
5002        current += spec.step;
5003        guard += 1;
5004        if guard > 10_000 {
5005            return Err(DagMlError::GraphValidation(format!(
5006                "pipeline DSL range generator for `{}.{}` produced more than 10000 values",
5007                spec.node_id, spec.param
5008            )));
5009        }
5010    }
5011    if values.is_empty() {
5012        return Err(DagMlError::GraphValidation(format!(
5013            "pipeline DSL range generator for `{}.{}` produced no values",
5014            spec.node_id, spec.param
5015        )));
5016    }
5017    let wrapped = values
5018        .into_iter()
5019        .map(PipelineDslGeneratorValue::Value)
5020        .collect::<Vec<_>>();
5021    compile_or_generator(spec.node_id, spec.name, spec.param, &wrapped, spec.count).map(
5022        |mut dimension| {
5023            dimension.name =
5024                generator_dimension_name(spec.node_id, spec.name, Some(spec.param), "range");
5025            dimension
5026        },
5027    )
5028}
5029
5030fn compile_log_range_generator(
5031    node_id: &NodeId,
5032    name: Option<&str>,
5033    param: &str,
5034    start: f64,
5035    stop: f64,
5036    count: usize,
5037    base: f64,
5038) -> Result<GenerationDimension> {
5039    validate_param_name(node_id, param)?;
5040    validate_finite(node_id, param, "log_range start", start)?;
5041    validate_finite(node_id, param, "log_range stop", stop)?;
5042    validate_finite(node_id, param, "log_range base", base)?;
5043    if start <= 0.0 || stop <= 0.0 {
5044        return Err(DagMlError::GraphValidation(format!(
5045            "pipeline DSL log_range generator for `{node_id}.{param}` requires positive start and stop"
5046        )));
5047    }
5048    if count == 0 {
5049        return Err(DagMlError::GraphValidation(format!(
5050            "pipeline DSL log_range generator for `{node_id}.{param}` has count=0"
5051        )));
5052    }
5053    if base <= 0.0 || (base - 1.0).abs() < f64::EPSILON {
5054        return Err(DagMlError::GraphValidation(format!(
5055            "pipeline DSL log_range generator for `{node_id}.{param}` requires base > 0 and != 1"
5056        )));
5057    }
5058    let start_log = start.log(base);
5059    let stop_log = stop.log(base);
5060    let values = if count == 1 {
5061        vec![json_number(start, node_id, param)?]
5062    } else {
5063        (0..count)
5064            .map(|index| {
5065                let ratio = index as f64 / (count - 1) as f64;
5066                json_number(
5067                    base.powf(start_log + (stop_log - start_log) * ratio),
5068                    node_id,
5069                    param,
5070                )
5071            })
5072            .collect::<Result<Vec<_>>>()?
5073    };
5074    let wrapped = values
5075        .into_iter()
5076        .map(PipelineDslGeneratorValue::Value)
5077        .collect::<Vec<_>>();
5078    compile_or_generator(node_id, name, param, &wrapped, None).map(|mut dimension| {
5079        dimension.name = generator_dimension_name(node_id, name, Some(param), "log_range");
5080        dimension
5081    })
5082}
5083
5084fn compile_grid_generator(
5085    node_id: &NodeId,
5086    name: Option<&str>,
5087    params: &BTreeMap<String, Vec<PipelineDslGeneratorValue>>,
5088    count: Option<usize>,
5089) -> Result<GenerationDimension> {
5090    validate_count(node_id, name, count)?;
5091    if params.is_empty() {
5092        return Err(DagMlError::GraphValidation(format!(
5093            "pipeline DSL grid generator for node `{node_id}` has no params"
5094        )));
5095    }
5096    for (param, values) in params {
5097        validate_param_name(node_id, param)?;
5098        if values.is_empty() {
5099            return Err(DagMlError::GraphValidation(format!(
5100                "pipeline DSL grid generator for `{node_id}.{param}` has no values"
5101            )));
5102        }
5103    }
5104    let entries = params
5105        .iter()
5106        .map(|(param, values)| (param.as_str(), values.as_slice()))
5107        .collect::<Vec<_>>();
5108    let mut rows = Vec::<BTreeMap<String, PipelineDslGeneratorValue>>::new();
5109    build_grid_rows(&entries, 0, &mut BTreeMap::new(), &mut rows, count);
5110    let choices = rows
5111        .into_iter()
5112        .enumerate()
5113        .map(|(index, row)| multi_param_generation_choice(node_id, index, row))
5114        .collect::<Result<Vec<_>>>()?;
5115    Ok(GenerationDimension {
5116        name: generator_dimension_name(node_id, name, None, "grid"),
5117        choices,
5118    })
5119}
5120
5121#[derive(Clone, Copy, Debug, Eq, PartialEq)]
5122enum PickArrangeMode {
5123    Pick,
5124    Arrange,
5125}
5126
5127fn compile_pick_arrange_generator(
5128    node_id: &NodeId,
5129    name: Option<&str>,
5130    param: &str,
5131    values: &[PipelineDslGeneratorValue],
5132    sizes: &[usize],
5133    count: Option<usize>,
5134    mode: PickArrangeMode,
5135) -> Result<GenerationDimension> {
5136    validate_param_name(node_id, param)?;
5137    validate_count(node_id, name, count)?;
5138    if values.is_empty() {
5139        return Err(DagMlError::GraphValidation(format!(
5140            "pipeline DSL {:?} generator for `{node_id}.{param}` has no values",
5141            mode
5142        )));
5143    }
5144    if sizes.is_empty() {
5145        return Err(DagMlError::GraphValidation(format!(
5146            "pipeline DSL {:?} generator for `{node_id}.{param}` has no sizes",
5147            mode
5148        )));
5149    }
5150    let mut selections = Vec::<Vec<usize>>::new();
5151    for size in sizes {
5152        if *size == 0 || *size > values.len() {
5153            return Err(DagMlError::GraphValidation(format!(
5154                "pipeline DSL {:?} generator for `{node_id}.{param}` has invalid size `{size}`",
5155                mode
5156            )));
5157        }
5158        match mode {
5159            PickArrangeMode::Pick => build_combinations(
5160                values.len(),
5161                *size,
5162                0,
5163                &mut Vec::new(),
5164                &mut selections,
5165                count,
5166            ),
5167            PickArrangeMode::Arrange => build_permutations(
5168                values.len(),
5169                *size,
5170                &mut BTreeSet::new(),
5171                &mut Vec::new(),
5172                &mut selections,
5173                count,
5174            ),
5175        }
5176        if count.is_some_and(|limit| selections.len() >= limit) {
5177            break;
5178        }
5179    }
5180    let mut choices = selections
5181        .into_iter()
5182        .enumerate()
5183        .map(|(index, selection)| {
5184            let selected_values = selection
5185                .iter()
5186                .map(|selected| values[*selected].value().clone())
5187                .collect::<Vec<_>>();
5188            let selected_labels = selection
5189                .iter()
5190                .map(|selected| values[*selected].label_fragment())
5191                .collect::<Vec<_>>();
5192            let mut params = BTreeMap::new();
5193            params.insert(param.to_string(), serde_json::Value::Array(selected_values));
5194            Ok(GenerationChoice {
5195                label: format!(
5196                    "{index:04}_{}_{}",
5197                    match mode {
5198                        PickArrangeMode::Pick => "pick",
5199                        PickArrangeMode::Arrange => "arrange",
5200                    },
5201                    sanitize_generation_label(&selected_labels.join("_"))
5202                ),
5203                value: serde_json::to_value(&params).map_err(|error| {
5204                    DagMlError::GraphValidation(format!(
5205                        "failed to serialize pipeline DSL {:?} generator choice for `{node_id}.{param}`: {error}",
5206                        mode
5207                    ))
5208                })?,
5209                param_overrides: vec![GenerationParamOverride {
5210                    node_id: node_id.clone(),
5211                    params,
5212                }],
5213            })
5214        })
5215        .collect::<Result<Vec<_>>>()?;
5216    apply_choice_count(&mut choices, count);
5217    Ok(GenerationDimension {
5218        name: generator_dimension_name(
5219            node_id,
5220            name,
5221            Some(param),
5222            match mode {
5223                PickArrangeMode::Pick => "pick",
5224                PickArrangeMode::Arrange => "arrange",
5225            },
5226        ),
5227        choices,
5228    })
5229}
5230
5231fn single_param_generation_choice(
5232    node_id: &NodeId,
5233    param: &str,
5234    index: usize,
5235    value: &PipelineDslGeneratorValue,
5236) -> Result<GenerationChoice> {
5237    let mut params = BTreeMap::new();
5238    params.insert(param.to_string(), value.value().clone());
5239    Ok(GenerationChoice {
5240        label: format!(
5241            "{index:04}_{}_{}",
5242            sanitize_generation_label(param),
5243            value.label_fragment()
5244        ),
5245        value: serde_json::to_value(&params).map_err(|error| {
5246            DagMlError::GraphValidation(format!(
5247                "failed to serialize pipeline DSL generator choice for `{node_id}.{param}`: {error}"
5248            ))
5249        })?,
5250        param_overrides: vec![GenerationParamOverride {
5251            node_id: node_id.clone(),
5252            params,
5253        }],
5254    })
5255}
5256
5257fn multi_param_generation_choice(
5258    node_id: &NodeId,
5259    index: usize,
5260    row: BTreeMap<String, PipelineDslGeneratorValue>,
5261) -> Result<GenerationChoice> {
5262    let mut params = BTreeMap::new();
5263    let mut label_parts = Vec::new();
5264    for (param, value) in row {
5265        label_parts.push(format!(
5266            "{}_{}",
5267            sanitize_generation_label(&param),
5268            value.label_fragment()
5269        ));
5270        params.insert(param, value.value().clone());
5271    }
5272    Ok(GenerationChoice {
5273        label: format!("{index:04}_{}", label_parts.join("__")),
5274        value: serde_json::to_value(&params).map_err(|error| {
5275            DagMlError::GraphValidation(format!(
5276                "failed to serialize pipeline DSL grid generator choice for node `{node_id}`: {error}"
5277            ))
5278        })?,
5279        param_overrides: vec![GenerationParamOverride {
5280            node_id: node_id.clone(),
5281            params,
5282        }],
5283    })
5284}
5285
5286fn build_grid_rows(
5287    entries: &[(&str, &[PipelineDslGeneratorValue])],
5288    entry_index: usize,
5289    current: &mut BTreeMap<String, PipelineDslGeneratorValue>,
5290    rows: &mut Vec<BTreeMap<String, PipelineDslGeneratorValue>>,
5291    count: Option<usize>,
5292) {
5293    if count.is_some_and(|limit| rows.len() >= limit) {
5294        return;
5295    }
5296    if entry_index == entries.len() {
5297        rows.push(current.clone());
5298        return;
5299    }
5300    let (param, values) = entries[entry_index];
5301    for value in values {
5302        current.insert(param.to_string(), value.clone());
5303        build_grid_rows(entries, entry_index + 1, current, rows, count);
5304        current.remove(param);
5305        if count.is_some_and(|limit| rows.len() >= limit) {
5306            break;
5307        }
5308    }
5309}
5310
5311fn build_combinations(
5312    value_count: usize,
5313    size: usize,
5314    start: usize,
5315    current: &mut Vec<usize>,
5316    selections: &mut Vec<Vec<usize>>,
5317    count: Option<usize>,
5318) {
5319    if count.is_some_and(|limit| selections.len() >= limit) {
5320        return;
5321    }
5322    if current.len() == size {
5323        selections.push(current.clone());
5324        return;
5325    }
5326    let remaining = size - current.len();
5327    if value_count < remaining {
5328        return;
5329    }
5330    for index in start..=value_count - remaining {
5331        current.push(index);
5332        build_combinations(value_count, size, index + 1, current, selections, count);
5333        current.pop();
5334        if count.is_some_and(|limit| selections.len() >= limit) {
5335            break;
5336        }
5337    }
5338}
5339
5340fn build_permutations(
5341    value_count: usize,
5342    size: usize,
5343    used: &mut BTreeSet<usize>,
5344    current: &mut Vec<usize>,
5345    selections: &mut Vec<Vec<usize>>,
5346    count: Option<usize>,
5347) {
5348    if count.is_some_and(|limit| selections.len() >= limit) {
5349        return;
5350    }
5351    if current.len() == size {
5352        selections.push(current.clone());
5353        return;
5354    }
5355    for index in 0..value_count {
5356        if used.contains(&index) {
5357            continue;
5358        }
5359        used.insert(index);
5360        current.push(index);
5361        build_permutations(value_count, size, used, current, selections, count);
5362        current.pop();
5363        used.remove(&index);
5364        if count.is_some_and(|limit| selections.len() >= limit) {
5365            break;
5366        }
5367    }
5368}
5369
5370fn apply_choice_count(choices: &mut Vec<GenerationChoice>, count: Option<usize>) {
5371    if let Some(limit) = count {
5372        choices.truncate(limit);
5373    }
5374}
5375
5376fn validate_count(node_id: &NodeId, name: Option<&str>, count: Option<usize>) -> Result<()> {
5377    if count == Some(0) {
5378        return Err(DagMlError::GraphValidation(format!(
5379            "pipeline DSL generator `{}` for node `{node_id}` has count=0",
5380            generator_dimension_name(node_id, name, None, "params")
5381        )));
5382    }
5383    Ok(())
5384}
5385
5386fn validate_param_name(node_id: &NodeId, param: &str) -> Result<()> {
5387    if param.trim().is_empty() {
5388        return Err(DagMlError::GraphValidation(format!(
5389            "pipeline DSL param generator for node `{node_id}` has an empty param name"
5390        )));
5391    }
5392    Ok(())
5393}
5394
5395fn validate_finite(node_id: &NodeId, param: &str, field: &str, value: f64) -> Result<()> {
5396    if !value.is_finite() {
5397        return Err(DagMlError::GraphValidation(format!(
5398            "pipeline DSL {field} for `{node_id}.{param}` must be finite"
5399        )));
5400    }
5401    Ok(())
5402}
5403
5404fn range_contains(current: f64, stop: f64, step: f64, inclusive: bool) -> bool {
5405    let epsilon = step.abs() * 1e-12 + f64::EPSILON;
5406    if step > 0.0 {
5407        if inclusive {
5408            current <= stop + epsilon
5409        } else {
5410            current < stop - epsilon
5411        }
5412    } else if inclusive {
5413        current >= stop - epsilon
5414    } else {
5415        current > stop + epsilon
5416    }
5417}
5418
5419fn json_number(value: f64, node_id: &NodeId, param: &str) -> Result<serde_json::Value> {
5420    let number = serde_json::Number::from_f64(value).ok_or_else(|| {
5421        DagMlError::GraphValidation(format!(
5422            "pipeline DSL numeric generator for `{node_id}.{param}` produced a non-finite value"
5423        ))
5424    })?;
5425    Ok(serde_json::Value::Number(number))
5426}
5427
5428fn generator_dimension_name(
5429    node_id: &NodeId,
5430    name: Option<&str>,
5431    param: Option<&str>,
5432    suffix: &str,
5433) -> String {
5434    if let Some(name) = name {
5435        return name.to_string();
5436    }
5437    match param {
5438        Some(param) => format!("{node_id}.{param}.{suffix}"),
5439        None => format!("{node_id}.{suffix}"),
5440    }
5441}
5442
5443impl PipelineDslGeneratorValue {
5444    fn value(&self) -> &serde_json::Value {
5445        match self {
5446            Self::Labeled { value, .. } | Self::Value(value) => value,
5447        }
5448    }
5449
5450    fn label_fragment(&self) -> String {
5451        match self {
5452            Self::Labeled { label, .. } => sanitize_generation_label(label),
5453            Self::Value(value) => {
5454                let rendered = match value {
5455                    serde_json::Value::String(value) => value.clone(),
5456                    _ => serde_json::to_string(value).unwrap_or_else(|_| "value".to_string()),
5457                };
5458                sanitize_generation_label(&rendered)
5459            }
5460        }
5461    }
5462}
5463
5464fn sanitize_generation_label(input: &str) -> String {
5465    let sanitized = input
5466        .chars()
5467        .map(|character| {
5468            if character.is_ascii_alphanumeric() || matches!(character, '_' | '-' | '.') {
5469                character
5470            } else {
5471                '_'
5472            }
5473        })
5474        .collect::<String>()
5475        .trim_matches('_')
5476        .to_string();
5477    if sanitized.is_empty() {
5478        "value".to_string()
5479    } else {
5480        sanitized
5481    }
5482}
5483
5484fn build_generation_spec(
5485    requested_strategy: Option<GenerationStrategy>,
5486    max_variants: Option<usize>,
5487    dimensions: Vec<GenerationDimension>,
5488) -> Result<GenerationSpec> {
5489    let strategy = requested_strategy.unwrap_or(if dimensions.is_empty() {
5490        GenerationStrategy::None
5491    } else {
5492        GenerationStrategy::Cartesian
5493    });
5494    let generation = GenerationSpec {
5495        strategy,
5496        dimensions,
5497        max_variants: if strategy == GenerationStrategy::None {
5498            Some(1)
5499        } else {
5500            max_variants
5501        },
5502    };
5503    generation.validate()?;
5504    Ok(generation)
5505}
5506
5507fn operator_runtime_metadata(
5508    step: &PipelineDslOperatorStep,
5509    branch_id: Option<&str>,
5510) -> Result<BTreeMap<String, serde_json::Value>> {
5511    let mut metadata = step.metadata.clone();
5512    if let Some(branch_id) = branch_id {
5513        metadata.insert(
5514            "dsl_branch".to_string(),
5515            serde_json::Value::String(branch_id.to_string()),
5516        );
5517    }
5518    insert_training_metadata(
5519        &mut metadata,
5520        &step.train_params,
5521        step.tuning.as_ref(),
5522        step.inner_cv.as_ref(),
5523        &step.id,
5524    )?;
5525    Ok(metadata)
5526}
5527
5528fn branch_context_metadata(
5529    branch_step: &PipelineDslBranchStep,
5530    branch: &PipelineDslBranch,
5531) -> Result<BTreeMap<String, serde_json::Value>> {
5532    let mut metadata = BTreeMap::new();
5533    metadata.insert(
5534        "dsl_branch".to_string(),
5535        serde_json::Value::String(branch.id.clone()),
5536    );
5537    metadata.insert(
5538        "dsl_branch_mode".to_string(),
5539        serde_json::to_value(branch_step.mode).map_err(|error| {
5540            DagMlError::GraphValidation(format!(
5541                "failed to serialize pipeline DSL branch mode for `{}`: {error}",
5542                branch.id
5543            ))
5544        })?,
5545    );
5546    if let Some(selector) = &branch_step.selector {
5547        metadata.insert("dsl_branch_step_selector".to_string(), selector.clone());
5548    }
5549    if !branch_step.metadata.is_empty() {
5550        metadata.insert(
5551            "dsl_branch_step_metadata".to_string(),
5552            serde_json::to_value(&branch_step.metadata).map_err(|error| {
5553                DagMlError::GraphValidation(format!(
5554                    "failed to serialize pipeline DSL branch step metadata for `{}`: {error}",
5555                    branch.id
5556                ))
5557            })?,
5558        );
5559    }
5560    if let Some(selector) = &branch.selector {
5561        metadata.insert("dsl_branch_selector".to_string(), selector.clone());
5562    }
5563    if !branch.metadata.is_empty() {
5564        metadata.insert(
5565            "dsl_branch_metadata".to_string(),
5566            serde_json::to_value(&branch.metadata).map_err(|error| {
5567                DagMlError::GraphValidation(format!(
5568                    "failed to serialize pipeline DSL branch metadata for `{}`: {error}",
5569                    branch.id
5570                ))
5571            })?,
5572        );
5573    }
5574    Ok(metadata)
5575}
5576
5577fn compile_branch_view_plan(
5578    branch_step: &PipelineDslBranchStep,
5579    branch: &PipelineDslBranch,
5580) -> Result<Option<BranchViewPlan>> {
5581    let Some(mode) = branch_view_mode(branch_step.mode) else {
5582        return Ok(None);
5583    };
5584    let selector = branch_view_selector(mode, branch_step.selector.as_ref(), branch)?;
5585    let mut metadata = branch.metadata.clone();
5586    if let Some(step_selector) = &branch_step.selector {
5587        metadata.insert(
5588            "dsl_branch_step_selector".to_string(),
5589            step_selector.clone(),
5590        );
5591    }
5592    if let Some(branch_selector) = &branch.selector {
5593        metadata.insert("dsl_branch_selector".to_string(), branch_selector.clone());
5594    }
5595    if !branch_step.metadata.is_empty() {
5596        metadata.insert(
5597            "dsl_branch_step_metadata".to_string(),
5598            serde_json::to_value(&branch_step.metadata).map_err(|error| {
5599                DagMlError::GraphValidation(format!(
5600                    "failed to serialize pipeline DSL branch step metadata for `{}`: {error}",
5601                    branch.id
5602                ))
5603            })?,
5604        );
5605    }
5606    let plan = BranchViewPlan {
5607        view_id: format!("branch_view:{}", branch.id),
5608        branch_id: branch.id.clone(),
5609        mode,
5610        selector,
5611        allow_overlap: branch_overlap_allowed(branch_step, branch),
5612        metadata,
5613    };
5614    plan.validate()
5615        .map_err(|error| DagMlError::GraphValidation(error.to_string()))?;
5616    Ok(Some(plan))
5617}
5618
5619fn branch_view_mode(mode: PipelineDslBranchMode) -> Option<BranchViewMode> {
5620    match mode {
5621        PipelineDslBranchMode::Duplication => None,
5622        PipelineDslBranchMode::Separation => Some(BranchViewMode::Separation),
5623        PipelineDslBranchMode::BySource => Some(BranchViewMode::BySource),
5624        PipelineDslBranchMode::ByMetadata => Some(BranchViewMode::ByMetadata),
5625        PipelineDslBranchMode::ByTag => Some(BranchViewMode::ByTag),
5626        PipelineDslBranchMode::ByFilter => Some(BranchViewMode::ByFilter),
5627    }
5628}
5629
5630fn branch_view_selector(
5631    mode: BranchViewMode,
5632    step_selector: Option<&serde_json::Value>,
5633    branch: &PipelineDslBranch,
5634) -> Result<DataViewSelector> {
5635    match mode {
5636        BranchViewMode::BySource => branch_view_selector_by_source(branch),
5637        BranchViewMode::ByMetadata => branch_view_selector_by_metadata(step_selector, branch),
5638        BranchViewMode::ByTag => branch_view_selector_by_tag(branch),
5639        BranchViewMode::ByFilter => branch_view_selector_by_filter(branch),
5640        BranchViewMode::Separation => branch_view_selector_generic(step_selector, branch),
5641    }
5642}
5643
5644fn branch_view_selector_by_source(branch: &PipelineDslBranch) -> Result<DataViewSelector> {
5645    let Some(selector) = &branch.selector else {
5646        return Err(DagMlError::GraphValidation(format!(
5647            "pipeline DSL by_source branch `{}` requires a selector",
5648            branch.id
5649        )));
5650    };
5651    let source_ids = selector_strings(selector, &["source", "source_id"], &["sources", "source_ids"])
5652        .or_else(|| selector.as_str().map(|value| vec![value.to_string()]))
5653        .ok_or_else(|| {
5654            DagMlError::GraphValidation(format!(
5655                "pipeline DSL by_source branch `{}` selector must be a source string or object with source/source_ids",
5656                branch.id
5657            ))
5658        })?;
5659    Ok(DataViewSelector {
5660        source_ids,
5661        ..DataViewSelector::default()
5662    })
5663}
5664
5665fn branch_view_selector_by_metadata(
5666    step_selector: Option<&serde_json::Value>,
5667    branch: &PipelineDslBranch,
5668) -> Result<DataViewSelector> {
5669    let Some(selector) = &branch.selector else {
5670        return Err(DagMlError::GraphValidation(format!(
5671            "pipeline DSL by_metadata branch `{}` requires a selector",
5672            branch.id
5673        )));
5674    };
5675    if let Some(metadata) = selector_metadata_map(selector)? {
5676        return Ok(DataViewSelector {
5677            metadata,
5678            ..DataViewSelector::default()
5679        });
5680    }
5681    let branch_key = selector
5682        .as_object()
5683        .and_then(|_| selector_metadata_key(selector));
5684    let key = branch_key
5685        .or_else(|| step_selector.and_then(selector_metadata_key))
5686        .ok_or_else(|| {
5687            DagMlError::GraphValidation(format!(
5688                "pipeline DSL by_metadata branch `{}` requires a metadata key on the branch or branch step selector",
5689                branch.id
5690            ))
5691        })?;
5692    let value = selector_value(selector).ok_or_else(|| {
5693        DagMlError::GraphValidation(format!(
5694            "pipeline DSL by_metadata branch `{}` requires a metadata value",
5695            branch.id
5696        ))
5697    })?;
5698    Ok(DataViewSelector {
5699        metadata: BTreeMap::from([(key, value)]),
5700        ..DataViewSelector::default()
5701    })
5702}
5703
5704fn branch_view_selector_by_tag(branch: &PipelineDslBranch) -> Result<DataViewSelector> {
5705    let Some(selector) = &branch.selector else {
5706        return Err(DagMlError::GraphValidation(format!(
5707            "pipeline DSL by_tag branch `{}` requires a selector",
5708            branch.id
5709        )));
5710    };
5711    let tags = selector_strings(selector, &["tag"], &["tags"])
5712        .or_else(|| selector.as_str().map(|value| vec![value.to_string()]))
5713        .ok_or_else(|| {
5714            DagMlError::GraphValidation(format!(
5715                "pipeline DSL by_tag branch `{}` selector must be a tag string or object with tag/tags",
5716                branch.id
5717            ))
5718        })?;
5719    Ok(DataViewSelector {
5720        tags,
5721        ..DataViewSelector::default()
5722    })
5723}
5724
5725fn branch_view_selector_by_filter(branch: &PipelineDslBranch) -> Result<DataViewSelector> {
5726    let Some(selector) = &branch.selector else {
5727        return Err(DagMlError::GraphValidation(format!(
5728            "pipeline DSL by_filter branch `{}` requires a selector",
5729            branch.id
5730        )));
5731    };
5732    let filter = selector
5733        .as_object()
5734        .and_then(|object| object.get("filter").cloned())
5735        .unwrap_or_else(|| selector.clone());
5736    Ok(DataViewSelector {
5737        filter: Some(filter),
5738        ..DataViewSelector::default()
5739    })
5740}
5741
5742fn branch_view_selector_generic(
5743    step_selector: Option<&serde_json::Value>,
5744    branch: &PipelineDslBranch,
5745) -> Result<DataViewSelector> {
5746    let Some(selector) = &branch.selector else {
5747        return Err(DagMlError::GraphValidation(format!(
5748            "pipeline DSL separation branch `{}` requires a selector",
5749            branch.id
5750        )));
5751    };
5752    if selector_strings(
5753        selector,
5754        &["source", "source_id"],
5755        &["sources", "source_ids"],
5756    )
5757    .is_some()
5758        || selector
5759            .as_object()
5760            .is_some_and(|object| object.contains_key("source") || object.contains_key("sources"))
5761    {
5762        return branch_view_selector_by_source(branch);
5763    }
5764    if selector_metadata_map(selector)?.is_some()
5765        || selector
5766            .as_object()
5767            .and_then(|_| selector_metadata_key(selector))
5768            .is_some()
5769        || step_selector.and_then(selector_metadata_key).is_some()
5770    {
5771        return branch_view_selector_by_metadata(step_selector, branch);
5772    }
5773    if selector_strings(selector, &["tag"], &["tags"]).is_some() {
5774        return branch_view_selector_by_tag(branch);
5775    }
5776    if selector
5777        .as_object()
5778        .is_some_and(|object| object.contains_key("filter"))
5779    {
5780        return branch_view_selector_by_filter(branch);
5781    }
5782    Err(DagMlError::GraphValidation(format!(
5783        "pipeline DSL separation branch `{}` selector must declare source_ids, metadata, tags or filter",
5784        branch.id
5785    )))
5786}
5787
5788fn selector_strings(
5789    value: &serde_json::Value,
5790    singular_keys: &[&str],
5791    plural_keys: &[&str],
5792) -> Option<Vec<String>> {
5793    let object = value.as_object()?;
5794    for key in singular_keys {
5795        if let Some(text) = object.get(*key).and_then(serde_json::Value::as_str) {
5796            return Some(vec![text.to_string()]);
5797        }
5798    }
5799    for key in plural_keys {
5800        if let Some(values) = object.get(*key).and_then(serde_json::Value::as_array) {
5801            let parsed = values
5802                .iter()
5803                .filter_map(serde_json::Value::as_str)
5804                .map(str::to_string)
5805                .collect::<Vec<_>>();
5806            if parsed.len() == values.len() {
5807                return Some(parsed);
5808            }
5809        }
5810    }
5811    None
5812}
5813
5814fn selector_metadata_map(
5815    value: &serde_json::Value,
5816) -> Result<Option<BTreeMap<String, serde_json::Value>>> {
5817    let Some(object) = value.as_object() else {
5818        return Ok(None);
5819    };
5820    let Some(metadata) = object.get("metadata") else {
5821        return Ok(None);
5822    };
5823    let Some(metadata) = metadata.as_object() else {
5824        return Err(DagMlError::GraphValidation(
5825            "pipeline DSL branch metadata selector must be an object".to_string(),
5826        ));
5827    };
5828    Ok(Some(
5829        metadata
5830            .iter()
5831            .map(|(key, value)| (key.clone(), value.clone()))
5832            .collect(),
5833    ))
5834}
5835
5836fn selector_metadata_key(value: &serde_json::Value) -> Option<String> {
5837    if let Some(text) = value.as_str() {
5838        return Some(text.to_string());
5839    }
5840    let object = value.as_object()?;
5841    ["metadata_key", "column", "key", "by_metadata"]
5842        .into_iter()
5843        .find_map(|key| object.get(key).and_then(serde_json::Value::as_str))
5844        .map(str::to_string)
5845}
5846
5847fn selector_value(value: &serde_json::Value) -> Option<serde_json::Value> {
5848    match value {
5849        serde_json::Value::String(_)
5850        | serde_json::Value::Bool(_)
5851        | serde_json::Value::Number(_) => Some(value.clone()),
5852        serde_json::Value::Object(object) => object
5853            .get("value")
5854            .or_else(|| object.get("equals"))
5855            .cloned(),
5856        _ => None,
5857    }
5858}
5859
5860fn branch_overlap_allowed(branch_step: &PipelineDslBranchStep, branch: &PipelineDslBranch) -> bool {
5861    branch
5862        .metadata
5863        .get("allow_overlap")
5864        .or_else(|| branch_step.metadata.get("allow_overlap"))
5865        .and_then(serde_json::Value::as_bool)
5866        .unwrap_or(false)
5867}
5868
5869fn branch_id_from_metadata(metadata: &BTreeMap<String, serde_json::Value>) -> Option<String> {
5870    metadata
5871        .get("dsl_branch")
5872        .and_then(|value| value.as_str())
5873        .map(str::to_string)
5874}
5875
5876fn expand_generator_sequences(step: &PipelineDslGeneratorStep) -> Result<Vec<GeneratedSequence>> {
5877    if step.count == Some(0) {
5878        return Err(DagMlError::GraphValidation(format!(
5879            "pipeline DSL generator `{}` count cannot be zero",
5880            step.id
5881        )));
5882    }
5883    match step.mode {
5884        PipelineDslGeneratorMode::Or => expand_or_generator_sequences(step),
5885        PipelineDslGeneratorMode::Cartesian => expand_cartesian_generator_sequences(step),
5886    }
5887}
5888
5889fn expand_or_generator_sequences(
5890    step: &PipelineDslGeneratorStep,
5891) -> Result<Vec<GeneratedSequence>> {
5892    if !step.stages.is_empty() {
5893        return Err(DagMlError::GraphValidation(format!(
5894            "pipeline DSL generator `{}` uses mode `or` but declares cartesian stages",
5895            step.id
5896        )));
5897    }
5898    if step.branches.is_empty() {
5899        return Err(DagMlError::GraphValidation(format!(
5900            "pipeline DSL generator `{}` has no branches",
5901            step.id
5902        )));
5903    }
5904    let options = step
5905        .branches
5906        .iter()
5907        .enumerate()
5908        .map(|(index, branch)| {
5909            validate_branch_id(&branch.id)?;
5910            Ok(GeneratedSequence {
5911                id: generator_choice_id(&step.id, index),
5912                labels: vec![branch.id.clone()],
5913                steps: branch.steps.clone(),
5914                metadata: branch.metadata.clone(),
5915            })
5916        })
5917        .collect::<Result<Vec<_>>>()?;
5918
5919    let choices = if let Some(sizes) = selection_sizes(step.pick)? {
5920        generated_pick_sequences(&options, &step.id, "pick", &sizes, step.count)?
5921    } else if let Some(sizes) = selection_sizes(step.arrange)? {
5922        generated_arrange_sequences(&options, &step.id, "arrange", &sizes, step.count)?
5923    } else {
5924        truncate_generated_sequences(options, step.count)
5925    };
5926
5927    let choices = if let Some(sizes) = selection_sizes(step.then_pick)? {
5928        generated_pick_sequences(&choices, &step.id, "then_pick", &sizes, step.count)?
5929    } else if let Some(sizes) = selection_sizes(step.then_arrange)? {
5930        generated_arrange_sequences(&choices, &step.id, "then_arrange", &sizes, step.count)?
5931    } else {
5932        choices
5933    };
5934    Ok(truncate_generated_sequences(choices, step.count))
5935}
5936
5937fn expand_cartesian_generator_sequences(
5938    step: &PipelineDslGeneratorStep,
5939) -> Result<Vec<GeneratedSequence>> {
5940    if !step.branches.is_empty() {
5941        return Err(DagMlError::GraphValidation(format!(
5942            "pipeline DSL generator `{}` uses mode `cartesian` but declares direct branches",
5943            step.id
5944        )));
5945    }
5946    if step.stages.is_empty() {
5947        return Err(DagMlError::GraphValidation(format!(
5948            "pipeline DSL generator `{}` has no cartesian stages",
5949            step.id
5950        )));
5951    }
5952    if step.pick.is_some()
5953        || step.arrange.is_some()
5954        || step.then_pick.is_some()
5955        || step.then_arrange.is_some()
5956    {
5957        return Err(DagMlError::GraphValidation(format!(
5958            "pipeline DSL generator `{}` cannot combine cartesian mode with pick/arrange selectors",
5959            step.id
5960        )));
5961    }
5962
5963    let mut stage_options = Vec::<Vec<GeneratedSequence>>::new();
5964    for (stage_index, stage) in step.stages.iter().enumerate() {
5965        validate_branch_id(&stage.id)?;
5966        if stage.branches.is_empty() {
5967            return Err(DagMlError::GraphValidation(format!(
5968                "pipeline DSL generator `{}` stage `{}` has no branches",
5969                step.id, stage.id
5970            )));
5971        }
5972        let mut options = Vec::new();
5973        for branch in &stage.branches {
5974            validate_branch_id(&branch.id)?;
5975            let mut metadata = branch.metadata.clone();
5976            if let Some(selector) = &stage.selector {
5977                metadata.insert("dsl_generator_stage_selector".to_string(), selector.clone());
5978            }
5979            if !stage.metadata.is_empty() {
5980                metadata.insert(
5981                    "dsl_generator_stage_metadata".to_string(),
5982                    serde_json::to_value(&stage.metadata).map_err(|error| {
5983                        DagMlError::GraphValidation(format!(
5984                            "failed to serialize pipeline DSL generator `{}` stage `{}` metadata: {error}",
5985                            step.id, stage.id
5986                        ))
5987                    })?,
5988                );
5989            }
5990            options.push(GeneratedSequence {
5991                id: format!("{stage_index}:{}", branch.id),
5992                labels: vec![format!("{}:{}", stage.id, branch.id)],
5993                steps: branch.steps.clone(),
5994                metadata,
5995            });
5996        }
5997        stage_options.push(options);
5998    }
5999
6000    let mut rows = Vec::<Vec<usize>>::new();
6001    build_cartesian_indices(&stage_options, 0, &mut Vec::new(), &mut rows, step.count);
6002    let mut choices = Vec::with_capacity(rows.len());
6003    for (choice_index, row) in rows.into_iter().enumerate() {
6004        let selected = row
6005            .into_iter()
6006            .enumerate()
6007            .map(|(stage_index, option_index)| stage_options[stage_index][option_index].clone())
6008            .collect::<Vec<_>>();
6009        choices.push(merge_generated_sequence(
6010            generator_choice_id(&step.id, choice_index),
6011            selected,
6012        )?);
6013    }
6014    Ok(choices)
6015}
6016
6017fn generated_pick_sequences(
6018    options: &[GeneratedSequence],
6019    generator_id: &NodeId,
6020    mode: &str,
6021    sizes: &[usize],
6022    count: Option<usize>,
6023) -> Result<Vec<GeneratedSequence>> {
6024    let mut selections = Vec::<Vec<usize>>::new();
6025    for size in sizes {
6026        if *size == 0 || *size > options.len() {
6027            return Err(DagMlError::GraphValidation(format!(
6028                "pipeline DSL generator `{generator_id}` {mode} size {size} is outside 1..={}",
6029                options.len()
6030            )));
6031        }
6032        build_combinations(
6033            options.len(),
6034            *size,
6035            0,
6036            &mut Vec::new(),
6037            &mut selections,
6038            count,
6039        );
6040    }
6041    selections
6042        .into_iter()
6043        .enumerate()
6044        .map(|(index, selection)| {
6045            let selected = selection
6046                .into_iter()
6047                .map(|option_index| options[option_index].clone())
6048                .collect::<Vec<_>>();
6049            merge_generated_sequence(generator_choice_id(generator_id, index), selected)
6050        })
6051        .collect()
6052}
6053
6054fn generated_arrange_sequences(
6055    options: &[GeneratedSequence],
6056    generator_id: &NodeId,
6057    mode: &str,
6058    sizes: &[usize],
6059    count: Option<usize>,
6060) -> Result<Vec<GeneratedSequence>> {
6061    let mut selections = Vec::<Vec<usize>>::new();
6062    for size in sizes {
6063        if *size == 0 || *size > options.len() {
6064            return Err(DagMlError::GraphValidation(format!(
6065                "pipeline DSL generator `{generator_id}` {mode} size {size} is outside 1..={}",
6066                options.len()
6067            )));
6068        }
6069        build_permutations(
6070            options.len(),
6071            *size,
6072            &mut BTreeSet::new(),
6073            &mut Vec::new(),
6074            &mut selections,
6075            count,
6076        );
6077    }
6078    selections
6079        .into_iter()
6080        .enumerate()
6081        .map(|(index, selection)| {
6082            let selected = selection
6083                .into_iter()
6084                .map(|option_index| options[option_index].clone())
6085                .collect::<Vec<_>>();
6086            merge_generated_sequence(generator_choice_id(generator_id, index), selected)
6087        })
6088        .collect()
6089}
6090
6091fn merge_generated_sequence(
6092    id: String,
6093    sequences: Vec<GeneratedSequence>,
6094) -> Result<GeneratedSequence> {
6095    if sequences.is_empty() {
6096        return Err(DagMlError::GraphValidation(format!(
6097            "pipeline DSL generated sequence `{id}` has no selected options"
6098        )));
6099    }
6100    let mut labels = Vec::new();
6101    let mut steps = Vec::new();
6102    let mut metadata = BTreeMap::new();
6103    for sequence in sequences {
6104        labels.extend(sequence.labels);
6105        steps.extend(sequence.steps);
6106        if !sequence.metadata.is_empty() {
6107            metadata.insert(
6108                format!("option:{}", metadata.len()),
6109                serde_json::to_value(sequence.metadata).map_err(|error| {
6110                    DagMlError::GraphValidation(format!(
6111                        "failed to serialize generated sequence `{id}` metadata: {error}"
6112                    ))
6113                })?,
6114            );
6115        }
6116    }
6117    Ok(GeneratedSequence {
6118        id,
6119        labels,
6120        steps,
6121        metadata,
6122    })
6123}
6124
6125fn truncate_generated_sequences(
6126    mut sequences: Vec<GeneratedSequence>,
6127    count: Option<usize>,
6128) -> Vec<GeneratedSequence> {
6129    if let Some(limit) = count {
6130        sequences.truncate(limit);
6131    }
6132    sequences
6133}
6134
6135fn build_cartesian_indices<T>(
6136    stages: &[Vec<T>],
6137    stage_index: usize,
6138    current: &mut Vec<usize>,
6139    rows: &mut Vec<Vec<usize>>,
6140    count: Option<usize>,
6141) {
6142    if count.is_some_and(|limit| rows.len() >= limit) {
6143        return;
6144    }
6145    if stage_index == stages.len() {
6146        rows.push(current.clone());
6147        return;
6148    }
6149    for option_index in 0..stages[stage_index].len() {
6150        current.push(option_index);
6151        build_cartesian_indices(stages, stage_index + 1, current, rows, count);
6152        current.pop();
6153        if count.is_some_and(|limit| rows.len() >= limit) {
6154            break;
6155        }
6156    }
6157}
6158
6159fn selection_sizes(selection: Option<PipelineDslSelectionSpec>) -> Result<Option<Vec<usize>>> {
6160    selection
6161        .map(|selection| match selection {
6162            PipelineDslSelectionSpec::Single(size) => {
6163                if size == 0 {
6164                    return Err(DagMlError::GraphValidation(
6165                        "pipeline DSL generator selection size cannot be zero".to_string(),
6166                    ));
6167                }
6168                Ok(vec![size])
6169            }
6170            PipelineDslSelectionSpec::Range([start, stop]) => {
6171                if start == 0 || stop == 0 || start > stop {
6172                    return Err(DagMlError::GraphValidation(format!(
6173                        "pipeline DSL generator selection range [{start}, {stop}] is invalid"
6174                    )));
6175                }
6176                Ok((start..=stop).collect())
6177            }
6178        })
6179        .transpose()
6180}
6181
6182fn generator_choice_id(generator_id: &NodeId, choice_index: usize) -> String {
6183    format!("{generator_id}:choice{choice_index}")
6184}
6185
6186fn generator_choice_metadata(
6187    step: &PipelineDslGeneratorStep,
6188    choice: &GeneratedSequence,
6189) -> Result<BTreeMap<String, serde_json::Value>> {
6190    let mut metadata = step.metadata.clone();
6191    metadata.insert(
6192        "dsl_generator".to_string(),
6193        serde_json::Value::String(step.id.to_string()),
6194    );
6195    metadata.insert(
6196        "dsl_generator_mode".to_string(),
6197        serde_json::to_value(step.mode).map_err(|error| {
6198            DagMlError::GraphValidation(format!(
6199                "failed to serialize pipeline DSL generator `{}` mode: {error}",
6200                step.id
6201            ))
6202        })?,
6203    );
6204    metadata.insert(
6205        "dsl_generator_choice".to_string(),
6206        serde_json::Value::String(choice.id.clone()),
6207    );
6208    metadata.insert(
6209        "dsl_generator_labels".to_string(),
6210        serde_json::to_value(&choice.labels).map_err(|error| {
6211            DagMlError::GraphValidation(format!(
6212                "failed to serialize pipeline DSL generator `{}` choice labels: {error}",
6213                step.id
6214            ))
6215        })?,
6216    );
6217    if !choice.metadata.is_empty() {
6218        metadata.insert(
6219            "dsl_generator_choice_metadata".to_string(),
6220            serde_json::to_value(&choice.metadata).map_err(|error| {
6221                DagMlError::GraphValidation(format!(
6222                    "failed to serialize pipeline DSL generator `{}` choice metadata: {error}",
6223                    step.id
6224                ))
6225            })?,
6226        );
6227    }
6228    Ok(metadata)
6229}
6230
6231fn namespace_generated_sequence(
6232    generator: &PipelineDslGeneratorStep,
6233    mut choice: GeneratedSequence,
6234    choice_index: usize,
6235) -> Result<GeneratedSequence> {
6236    let mut node_map = BTreeMap::<NodeId, NodeId>::new();
6237    let mut counter = 0usize;
6238    for step in &mut choice.steps {
6239        namespace_step_ids(generator, choice_index, step, &mut counter, &mut node_map)?;
6240    }
6241    for step in &mut choice.steps {
6242        rewrite_step_node_refs(step, &node_map);
6243    }
6244    Ok(choice)
6245}
6246
6247fn namespace_step_ids(
6248    generator: &PipelineDslGeneratorStep,
6249    choice_index: usize,
6250    step: &mut PipelineDslStep,
6251    counter: &mut usize,
6252    node_map: &mut BTreeMap<NodeId, NodeId>,
6253) -> Result<()> {
6254    match step {
6255        PipelineDslStep::Transform(step)
6256        | PipelineDslStep::YTransform(step)
6257        | PipelineDslStep::Tag(step)
6258        | PipelineDslStep::Exclude(step)
6259        | PipelineDslStep::Filter(step)
6260        | PipelineDslStep::SampleFilter(step)
6261        | PipelineDslStep::Augmentation(step)
6262        | PipelineDslStep::FeatureAugmentation(step)
6263        | PipelineDslStep::SampleAugmentation(step)
6264        | PipelineDslStep::DataGeneration(step)
6265        | PipelineDslStep::Model(step)
6266        | PipelineDslStep::Tuner(step)
6267        | PipelineDslStep::Chart(step) => {
6268            namespace_operator_step_id(generator, choice_index, step, counter, node_map)?;
6269        }
6270        PipelineDslStep::ConcatTransform(step) => {
6271            namespace_node_id_field(generator, choice_index, &mut step.id, counter, node_map)?;
6272            for branch in &mut step.branches {
6273                for branch_step in &mut branch.steps {
6274                    namespace_operator_step_id(
6275                        generator,
6276                        choice_index,
6277                        branch_step,
6278                        counter,
6279                        node_map,
6280                    )?;
6281                }
6282            }
6283        }
6284        PipelineDslStep::Branch(step) => {
6285            for branch in &mut step.branches {
6286                for branch_step in &mut branch.steps {
6287                    namespace_step_ids(generator, choice_index, branch_step, counter, node_map)?;
6288                }
6289            }
6290        }
6291        PipelineDslStep::Generator(step) => {
6292            namespace_node_id_field(generator, choice_index, &mut step.id, counter, node_map)?;
6293            for branch in &mut step.branches {
6294                for branch_step in &mut branch.steps {
6295                    namespace_step_ids(generator, choice_index, branch_step, counter, node_map)?;
6296                }
6297            }
6298            for stage in &mut step.stages {
6299                for branch in &mut stage.branches {
6300                    for branch_step in &mut branch.steps {
6301                        namespace_step_ids(
6302                            generator,
6303                            choice_index,
6304                            branch_step,
6305                            counter,
6306                            node_map,
6307                        )?;
6308                    }
6309                }
6310            }
6311        }
6312        PipelineDslStep::Sequential(step) => {
6313            if let Some(id) = &mut step.id {
6314                namespace_node_id_field(generator, choice_index, id, counter, node_map)?;
6315            }
6316            for child in &mut step.steps {
6317                namespace_step_ids(generator, choice_index, child, counter, node_map)?;
6318            }
6319        }
6320        PipelineDslStep::Merge(step) => {
6321            namespace_node_id_field(generator, choice_index, &mut step.id, counter, node_map)?;
6322        }
6323        PipelineDslStep::MergeModel(step) => {
6324            namespace_node_id_field(generator, choice_index, &mut step.id, counter, node_map)?;
6325        }
6326    }
6327    Ok(())
6328}
6329
6330fn namespace_operator_step_id(
6331    generator: &PipelineDslGeneratorStep,
6332    choice_index: usize,
6333    step: &mut PipelineDslOperatorStep,
6334    counter: &mut usize,
6335    node_map: &mut BTreeMap<NodeId, NodeId>,
6336) -> Result<()> {
6337    namespace_node_id_field(generator, choice_index, &mut step.id, counter, node_map)
6338}
6339
6340fn namespace_node_id_field(
6341    generator: &PipelineDslGeneratorStep,
6342    choice_index: usize,
6343    node_id: &mut NodeId,
6344    counter: &mut usize,
6345    node_map: &mut BTreeMap<NodeId, NodeId>,
6346) -> Result<()> {
6347    let original = node_id.clone();
6348    if node_map.contains_key(&original) {
6349        return Err(DagMlError::GraphValidation(format!(
6350            "pipeline DSL generator `{}` choice `{}` reuses node id `{original}`; generated choices require unique node ids inside each expanded sequence",
6351            generator.id, choice_index
6352        )));
6353    }
6354    let next = namespaced_generated_node_id(&generator.id, choice_index, *counter, &original)?;
6355    *counter += 1;
6356    *node_id = next.clone();
6357    node_map.insert(original, next);
6358    Ok(())
6359}
6360
6361fn namespaced_generated_node_id(
6362    generator_id: &NodeId,
6363    choice_index: usize,
6364    node_index: usize,
6365    original: &NodeId,
6366) -> Result<NodeId> {
6367    let generator = sanitized_id_fragment(generator_id.as_str(), 32);
6368    let suffix = sanitized_id_fragment(original.as_str(), 28);
6369    NodeId::new(format!(
6370        "gen:{generator}:c{choice_index}:n{node_index}.{suffix}"
6371    ))
6372}
6373
6374fn sanitized_id_fragment(input: &str, max_len: usize) -> String {
6375    let sanitized = sanitize_generation_label(input);
6376    let mut fragment = sanitized.chars().take(max_len).collect::<String>();
6377    if fragment.is_empty() {
6378        fragment = "x".to_string();
6379    }
6380    fragment
6381}
6382
6383fn rewrite_step_node_refs(step: &mut PipelineDslStep, node_map: &BTreeMap<NodeId, NodeId>) {
6384    match step {
6385        PipelineDslStep::Transform(_)
6386        | PipelineDslStep::YTransform(_)
6387        | PipelineDslStep::Tag(_)
6388        | PipelineDslStep::Exclude(_)
6389        | PipelineDslStep::Filter(_)
6390        | PipelineDslStep::SampleFilter(_)
6391        | PipelineDslStep::Augmentation(_)
6392        | PipelineDslStep::FeatureAugmentation(_)
6393        | PipelineDslStep::SampleAugmentation(_)
6394        | PipelineDslStep::DataGeneration(_)
6395        | PipelineDslStep::Model(_)
6396        | PipelineDslStep::Tuner(_)
6397        | PipelineDslStep::Chart(_) => {}
6398        PipelineDslStep::ConcatTransform(step) => {
6399            for branch in &mut step.branches {
6400                for branch_step in &mut branch.steps {
6401                    rewrite_operator_step_refs(branch_step, node_map);
6402                }
6403            }
6404        }
6405        PipelineDslStep::Branch(step) => {
6406            for branch in &mut step.branches {
6407                for branch_step in &mut branch.steps {
6408                    rewrite_step_node_refs(branch_step, node_map);
6409                }
6410            }
6411        }
6412        PipelineDslStep::Generator(step) => {
6413            for branch in &mut step.branches {
6414                for branch_step in &mut branch.steps {
6415                    rewrite_step_node_refs(branch_step, node_map);
6416                }
6417            }
6418            for stage in &mut step.stages {
6419                for branch in &mut stage.branches {
6420                    for branch_step in &mut branch.steps {
6421                        rewrite_step_node_refs(branch_step, node_map);
6422                    }
6423                }
6424            }
6425        }
6426        PipelineDslStep::Sequential(step) => {
6427            for child in &mut step.steps {
6428                rewrite_step_node_refs(child, node_map);
6429            }
6430        }
6431        PipelineDslStep::Merge(step) => {
6432            rewrite_merge_selectors(&mut step.selectors, node_map);
6433        }
6434        PipelineDslStep::MergeModel(_) => {}
6435    }
6436}
6437
6438fn rewrite_operator_step_refs(
6439    _step: &mut PipelineDslOperatorStep,
6440    _node_map: &BTreeMap<NodeId, NodeId>,
6441) {
6442}
6443
6444fn rewrite_merge_selectors(
6445    selectors: &mut [PipelineDslMergeSelector],
6446    node_map: &BTreeMap<NodeId, NodeId>,
6447) {
6448    for selector in selectors {
6449        if let Some(model) = &selector.model {
6450            if let Some(rewritten) = node_map.get(model) {
6451                selector.model = Some(rewritten.clone());
6452            }
6453        }
6454    }
6455}
6456
6457fn validate_merge_selectors(
6458    merge_id: &NodeId,
6459    selectors: &[PipelineDslMergeSelector],
6460    predictions: &[PredictionSource],
6461) -> Result<()> {
6462    if selectors.is_empty() {
6463        return Ok(());
6464    }
6465    if predictions.is_empty() {
6466        return Err(DagMlError::GraphValidation(format!(
6467            "pipeline DSL merge `{merge_id}` declares selectors but has no prediction inputs"
6468        )));
6469    }
6470    for (selector_index, selector) in selectors.iter().enumerate() {
6471        let mut matched = predictions.iter().collect::<Vec<_>>();
6472        if let Some(input_name) = &selector.input_name {
6473            if input_name.trim().is_empty() {
6474                return Err(DagMlError::GraphValidation(format!(
6475                    "pipeline DSL merge `{merge_id}` selector {selector_index} has an empty input_name"
6476                )));
6477            }
6478            matched.retain(|prediction| prediction.input_name == *input_name);
6479        }
6480        if let Some(branch) = &selector.branch {
6481            if branch.trim().is_empty() {
6482                return Err(DagMlError::GraphValidation(format!(
6483                    "pipeline DSL merge `{merge_id}` selector {selector_index} has an empty branch"
6484                )));
6485            }
6486            matched.retain(|prediction| prediction.branch_id.as_deref() == Some(branch.as_str()));
6487        }
6488        if let Some(model) = &selector.model {
6489            matched.retain(|prediction| prediction.node_id == *model);
6490        }
6491        if matched.is_empty() {
6492            return Err(DagMlError::GraphValidation(format!(
6493                "pipeline DSL merge `{merge_id}` selector {selector_index} does not match any pending prediction input"
6494            )));
6495        }
6496        validate_merge_selector_select(merge_id, selector_index, selector, matched.len())?;
6497    }
6498    Ok(())
6499}
6500
6501fn validate_merge_selector_select(
6502    merge_id: &NodeId,
6503    selector_index: usize,
6504    selector: &PipelineDslMergeSelector,
6505    matched_count: usize,
6506) -> Result<()> {
6507    let Some(select) = &selector.select else {
6508        return Ok(());
6509    };
6510    if let Some(mode) = select.as_str() {
6511        match mode {
6512            "all" => return Ok(()),
6513            "best" => {
6514                require_selector_metric(merge_id, selector_index, selector, mode)?;
6515                return Ok(());
6516            }
6517            _ => {
6518                return Err(DagMlError::GraphValidation(format!(
6519                    "pipeline DSL merge `{merge_id}` selector {selector_index} has unsupported select mode `{mode}`"
6520                )));
6521            }
6522        }
6523    }
6524    let Some(object) = select.as_object() else {
6525        return Err(DagMlError::GraphValidation(format!(
6526            "pipeline DSL merge `{merge_id}` selector {selector_index} select must be `all`, `best` or an object with `top_k`"
6527        )));
6528    };
6529    if object.len() != 1 || !object.contains_key("top_k") {
6530        return Err(DagMlError::GraphValidation(format!(
6531            "pipeline DSL merge `{merge_id}` selector {selector_index} object select currently supports only `top_k`"
6532        )));
6533    }
6534    let Some(top_k) = object.get("top_k").and_then(|value| value.as_u64()) else {
6535        return Err(DagMlError::GraphValidation(format!(
6536            "pipeline DSL merge `{merge_id}` selector {selector_index} top_k must be a positive integer"
6537        )));
6538    };
6539    if top_k == 0 {
6540        return Err(DagMlError::GraphValidation(format!(
6541            "pipeline DSL merge `{merge_id}` selector {selector_index} top_k must be positive"
6542        )));
6543    }
6544    if top_k as usize > matched_count {
6545        return Err(DagMlError::GraphValidation(format!(
6546            "pipeline DSL merge `{merge_id}` selector {selector_index} top_k={top_k} exceeds {matched_count} matched prediction inputs"
6547        )));
6548    }
6549    require_selector_metric(merge_id, selector_index, selector, "top_k")
6550}
6551
6552fn require_selector_metric(
6553    merge_id: &NodeId,
6554    selector_index: usize,
6555    selector: &PipelineDslMergeSelector,
6556    select_mode: &str,
6557) -> Result<()> {
6558    if selector
6559        .metric
6560        .as_ref()
6561        .is_some_and(|metric| !metric.trim().is_empty())
6562    {
6563        return Ok(());
6564    }
6565    Err(DagMlError::GraphValidation(format!(
6566        "pipeline DSL merge `{merge_id}` selector {selector_index} select `{select_mode}` requires a non-empty metric"
6567    )))
6568}
6569
6570fn insert_training_metadata(
6571    metadata: &mut BTreeMap<String, serde_json::Value>,
6572    train_params: &BTreeMap<String, serde_json::Value>,
6573    tuning: Option<&PipelineDslTuningSpec>,
6574    inner_cv: Option<&NestedCvSpec>,
6575    node_id: &NodeId,
6576) -> Result<()> {
6577    if let Some(inner_cv) = inner_cv {
6578        // Carry the node-local nested-CV policy on the graph node so
6579        // build_execution_plan can lower it into NodePlan.inner_cv.
6580        metadata.insert(
6581            "dsl_inner_cv".to_string(),
6582            serde_json::to_value(inner_cv).map_err(|error| {
6583                DagMlError::GraphValidation(format!(
6584                    "failed to serialize pipeline DSL inner_cv for node `{node_id}`: {error}"
6585                ))
6586            })?,
6587        );
6588    }
6589    if !train_params.is_empty() {
6590        metadata.insert(
6591            "dsl_train_params".to_string(),
6592            serde_json::to_value(train_params).map_err(|error| {
6593                DagMlError::GraphValidation(format!(
6594                    "failed to serialize pipeline DSL train params for node `{node_id}`: {error}"
6595                ))
6596            })?,
6597        );
6598    }
6599    if let Some(tuning) = tuning {
6600        metadata.insert(
6601            "dsl_tuning".to_string(),
6602            serde_json::to_value(tuning).map_err(|error| {
6603                DagMlError::GraphValidation(format!(
6604                    "failed to serialize pipeline DSL tuning for node `{node_id}`: {error}"
6605                ))
6606            })?,
6607        );
6608    }
6609    Ok(())
6610}
6611
6612fn same_data_source(left: &DataSource, right: &DataSource) -> bool {
6613    left.node_id == right.node_id
6614        && left.port_name == right.port_name
6615        && left.representation == right.representation
6616}
6617
6618fn merge_consumes_predictions(step: &PipelineDslMergeStep) -> bool {
6619    match step.output_as {
6620        PipelineDslMergeOutput::Predictions => true,
6621        PipelineDslMergeOutput::Sources => false,
6622        PipelineDslMergeOutput::Features => {
6623            matches!(
6624                step.merge_mode.as_str(),
6625                "predictions" | "prediction" | "all" | "mixed" | "predictions_plus_original"
6626            ) || !step.selectors.is_empty()
6627        }
6628    }
6629}
6630
6631fn merge_consumes_branch_data(step: &PipelineDslMergeStep) -> bool {
6632    match step.output_as {
6633        PipelineDslMergeOutput::Predictions => false,
6634        PipelineDslMergeOutput::Sources => true,
6635        PipelineDslMergeOutput::Features => matches!(
6636            step.merge_mode.as_str(),
6637            "features" | "feature" | "concat" | "all" | "mixed" | "sources" | "source"
6638        ),
6639    }
6640}
6641
6642fn merge_node_kind(
6643    step: &PipelineDslMergeStep,
6644    has_predictions: bool,
6645    has_branch_data: bool,
6646) -> NodeKind {
6647    match step.output_as {
6648        PipelineDslMergeOutput::Predictions => NodeKind::PredictionJoin,
6649        PipelineDslMergeOutput::Sources => NodeKind::SourceJoin,
6650        PipelineDslMergeOutput::Features => {
6651            if has_predictions && (step.include_original_data || has_branch_data) {
6652                NodeKind::MixedJoin
6653            } else if has_predictions {
6654                NodeKind::PredictionJoin
6655            } else {
6656                NodeKind::FeatureJoin
6657            }
6658        }
6659    }
6660}
6661
6662fn data_port(name: &str, representation: Option<String>, description: &str) -> PortSpec {
6663    PortSpec {
6664        name: name.to_string(),
6665        kind: PortKind::Data,
6666        representation,
6667        cardinality: PortCardinality::One,
6668        unit_level: None,
6669        alignment_key: None,
6670        target_level: None,
6671        description: description.to_string(),
6672    }
6673}
6674
6675fn apply_data_unit_contract(port: &mut PortSpec, contract: &PipelineDslDataPort) {
6676    port.unit_level = contract.unit_level;
6677    port.alignment_key = contract.alignment_key.clone();
6678    port.target_level = contract.target_level;
6679}
6680
6681fn target_port(name: &str, description: &str) -> PortSpec {
6682    PortSpec {
6683        name: name.to_string(),
6684        kind: PortKind::Target,
6685        representation: None,
6686        cardinality: PortCardinality::One,
6687        unit_level: None,
6688        alignment_key: None,
6689        target_level: None,
6690        description: description.to_string(),
6691    }
6692}
6693
6694fn prediction_port(name: &str, description: &str) -> PortSpec {
6695    PortSpec {
6696        name: name.to_string(),
6697        kind: PortKind::Prediction,
6698        representation: None,
6699        cardinality: PortCardinality::One,
6700        unit_level: None,
6701        alignment_key: None,
6702        target_level: None,
6703        description: description.to_string(),
6704    }
6705}
6706
6707fn apply_prediction_unit_contract(port: &mut PortSpec, contract: &PipelineDslPredictionPort) {
6708    port.representation = contract.representation.clone();
6709    port.unit_level = contract.unit_level;
6710    port.alignment_key = contract.alignment_key.clone();
6711    port.target_level = contract.target_level;
6712}
6713
6714fn validate_branch_id(branch_id: &str) -> Result<()> {
6715    if branch_id.trim().is_empty() {
6716        return Err(DagMlError::GraphValidation(
6717            "pipeline DSL branch id must not be empty".to_string(),
6718        ));
6719    }
6720    if !branch_id
6721        .bytes()
6722        .all(|byte| byte.is_ascii_alphanumeric() || matches!(byte, b'_' | b'-' | b'.' | b':'))
6723    {
6724        return Err(DagMlError::GraphValidation(format!(
6725            "pipeline DSL branch id `{branch_id}` contains unsupported characters"
6726        )));
6727    }
6728    Ok(())
6729}
6730
6731fn branch_input_prefix(branch_id: &str, index: usize) -> String {
6732    let sanitized = branch_id
6733        .chars()
6734        .map(|character| {
6735            if character.is_ascii_alphanumeric() || character == '_' {
6736                character
6737            } else {
6738                '_'
6739            }
6740        })
6741        .collect::<String>()
6742        .trim_matches('_')
6743        .to_string();
6744    if sanitized.is_empty() {
6745        format!("branch{index}")
6746    } else {
6747        sanitized
6748    }
6749}
6750
6751fn branch_prediction_input_name(
6752    branch_id: &str,
6753    branch_index: usize,
6754    prediction_index: usize,
6755    node_id: &NodeId,
6756) -> String {
6757    let branch = branch_input_prefix(branch_id, branch_index);
6758    let model = node_id
6759        .as_str()
6760        .chars()
6761        .map(|character| {
6762            if character.is_ascii_alphanumeric() || character == '_' {
6763                character
6764            } else {
6765                '_'
6766            }
6767        })
6768        .collect::<String>()
6769        .trim_matches('_')
6770        .to_string();
6771    if model.is_empty() {
6772        format!("{branch}_model{prediction_index}_oof")
6773    } else {
6774        format!("{branch}_{model}_oof")
6775    }
6776}
6777
6778fn default_input_name() -> String {
6779    "x".to_string()
6780}
6781
6782fn default_output_name() -> String {
6783    "prediction".to_string()
6784}
6785
6786fn default_data_representation() -> String {
6787    "tabular_numeric".to_string()
6788}
6789
6790fn default_true() -> bool {
6791    true
6792}
6793
6794fn default_log_base() -> f64 {
6795    10.0
6796}
6797
6798fn default_merge_mode() -> String {
6799    "predictions_plus_original".to_string()
6800}
6801
6802#[cfg(test)]
6803mod tests {
6804    use super::*;
6805    use crate::controller::{
6806        ArtifactPolicy, ControllerCapability, ControllerFitScope, ControllerManifest,
6807        OperatorSelector, RngPolicy,
6808    };
6809    use crate::phase::Phase;
6810
6811    fn registry_manifest(id: &str, kind: NodeKind, aliases: &[&str]) -> ControllerManifest {
6812        ControllerManifest {
6813            controller_id: crate::ids::ControllerId::new(id).unwrap(),
6814            controller_version: "0.1.0".to_string(),
6815            operator_kind: kind,
6816            priority: 0,
6817            supported_phases: BTreeSet::from([Phase::FitCv]),
6818            input_ports: Vec::new(),
6819            output_ports: Vec::new(),
6820            data_requirements: None,
6821            capabilities: BTreeSet::from([ControllerCapability::Deterministic]),
6822            operator_selectors: vec![OperatorSelector {
6823                aliases: aliases.iter().map(|alias| (*alias).to_string()).collect(),
6824                ..OperatorSelector::default()
6825            }],
6826            fit_scope: ControllerFitScope::FoldTrain,
6827            rng_policy: RngPolicy::UsesCoreSeed,
6828            artifact_policy: ArtifactPolicy::Serializable,
6829        }
6830    }
6831
6832    #[test]
6833    fn compiles_linear_pipeline_dsl_to_valid_graph() {
6834        let spec: PipelineDslSpec = serde_json::from_str(
6835            r#"{
6836  "id": "dsl-linear-smoke",
6837  "steps": [
6838    {
6839      "kind": "transform",
6840      "id": "transform:snv",
6841      "operator": {"type": "StandardNormalVariate"},
6842      "seed_label": "snv"
6843    },
6844    {
6845      "kind": "model",
6846      "id": "model:base",
6847      "operator": {"type": "RandomForestRegressor"},
6848      "params": {"n_estimators": 100},
6849      "seed_label": "base"
6850    }
6851  ]
6852}"#,
6853        )
6854        .unwrap();
6855
6856        let graph = compile_pipeline_dsl(&spec).unwrap();
6857
6858        assert_eq!(graph.id, "dsl-linear-smoke");
6859        assert_eq!(graph.nodes.len(), 2);
6860        assert_eq!(graph.edges.len(), 1);
6861        assert_eq!(graph.nodes[0].kind, NodeKind::Transform);
6862        assert_eq!(graph.nodes[1].kind, NodeKind::Model);
6863        assert_eq!(graph.edges[0].source.node_id.as_str(), "transform:snv");
6864        assert_eq!(graph.edges[0].target.node_id.as_str(), "model:base");
6865        assert_eq!(graph.edges[0].contract.kind, PortKind::Data);
6866        graph.validate().unwrap();
6867    }
6868
6869    #[test]
6870    fn compiles_pipeline_dsl_unit_contracts_to_graph_interface() {
6871        let spec: PipelineDslSpec = serde_json::from_str(
6872            r#"{
6873  "id": "dsl-unit-contract-smoke",
6874  "input": {
6875    "name": "spectra",
6876    "representation": "tabular",
6877    "unit_level": "observation",
6878    "alignment_key": "sample_id",
6879    "target_level": "physical_sample"
6880  },
6881  "output": {
6882    "name": "prediction",
6883    "representation": "regression",
6884    "unit_level": "physical_sample",
6885    "alignment_key": "sample_id",
6886    "target_level": "physical_sample"
6887  },
6888  "steps": [
6889    {
6890      "kind": "model",
6891      "id": "model:base",
6892      "operator": {"type": "RandomForestRegressor"}
6893    }
6894  ]
6895}"#,
6896        )
6897        .unwrap();
6898
6899        let graph = compile_pipeline_dsl(&spec).unwrap();
6900
6901        assert_eq!(
6902            graph.interface.inputs[0].unit_level,
6903            Some(EntityUnitLevel::Observation)
6904        );
6905        assert_eq!(
6906            graph.interface.inputs[0].alignment_key.as_deref(),
6907            Some("sample_id")
6908        );
6909        assert_eq!(
6910            graph.interface.outputs[0].unit_level,
6911            Some(EntityUnitLevel::PhysicalSample)
6912        );
6913        assert_eq!(
6914            graph.interface.outputs[0].representation.as_deref(),
6915            Some("regression")
6916        );
6917    }
6918
6919    #[test]
6920    fn compiles_branch_merge_predictions_plus_original_dsl() {
6921        let spec: PipelineDslSpec = serde_json::from_str(
6922            r#"{
6923  "id": "dsl-branch-merge-smoke",
6924  "steps": [
6925    {
6926      "kind": "branch",
6927      "branches": [
6928        {
6929          "id": "b0",
6930          "steps": [
6931            {
6932              "kind": "model",
6933              "id": "branch:b0.model:ridge",
6934              "operator": {"type": "Ridge"},
6935              "params": {"alpha": 0.3},
6936              "seed_label": "branch:b0"
6937            }
6938          ]
6939        },
6940        {
6941          "id": "b1",
6942          "steps": [
6943            {
6944              "kind": "augmentation",
6945              "id": "branch:b1.augment:noise",
6946              "operator": {"type": "GaussianNoise"},
6947              "params": {"scope": "train_only"},
6948              "seed_label": "branch:b1.augment",
6949              "shape": {
6950                "fit_rows": "fold_train",
6951                "predict_rows": "fold_validation",
6952                "augmentation_policy": {
6953                  "sample_scope": "train_only",
6954                  "feature_scope": "none",
6955                  "require_origin_id": true,
6956                  "inherit_group": true,
6957                  "inherit_target": true
6958                }
6959              }
6960            },
6961            {
6962              "kind": "model",
6963              "id": "branch:b1.model:rf",
6964              "operator": {"type": "RandomForestRegressor"},
6965              "params": {"n_estimators": 64},
6966              "seed_label": "branch:b1"
6967            }
6968          ]
6969        }
6970      ]
6971    },
6972    {
6973      "kind": "merge_model",
6974      "id": "merge:stack.pred_plus_original.meta:ridge",
6975      "operator": {"type": "RidgeMetaStacker"},
6976      "params": {"alpha": 0.2},
6977      "seed_label": "merge:stack"
6978    }
6979  ]
6980}"#,
6981        )
6982        .unwrap();
6983
6984        let graph = compile_pipeline_dsl(&spec).unwrap();
6985
6986        assert_eq!(graph.nodes.len(), 4);
6987        assert_eq!(graph.edges.len(), 3);
6988        let merge = graph
6989            .nodes
6990            .iter()
6991            .find(|node| node.id.as_str() == "merge:stack.pred_plus_original.meta:ridge")
6992            .unwrap();
6993        assert_eq!(merge.ports.inputs.len(), 3);
6994        assert_eq!(merge.ports.inputs[0].name, "b0_oof");
6995        assert_eq!(merge.ports.inputs[1].name, "b1_oof");
6996        assert_eq!(merge.ports.inputs[2].name, "x_original");
6997        let prediction_edges = graph
6998            .edges
6999            .iter()
7000            .filter(|edge| edge.contract.kind == PortKind::Prediction)
7001            .collect::<Vec<_>>();
7002        assert_eq!(prediction_edges.len(), 2);
7003        assert!(prediction_edges
7004            .iter()
7005            .all(|edge| edge.contract.requires_oof));
7006        assert!(prediction_edges
7007            .iter()
7008            .all(|edge| edge.contract.requires_fold_alignment));
7009        assert!(graph.edges.iter().any(|edge| edge.source.node_id.as_str()
7010            == "branch:b1.augment:noise"
7011            && edge.target.node_id.as_str() == "branch:b1.model:rf"));
7012        graph.validate().unwrap();
7013    }
7014
7015    #[test]
7016    fn compiles_separation_branch_view_plans() {
7017        let spec: PipelineDslSpec = serde_json::from_str(
7018            r#"{
7019  "id": "dsl-separation-branch-views",
7020  "steps": [
7021    {
7022      "kind": "branch",
7023      "mode": "by_metadata",
7024      "selector": {"metadata_key": "site"},
7025      "branches": [
7026        {
7027          "id": "site_a",
7028          "selector": "A",
7029          "steps": [
7030            {"kind": "model", "id": "model:site.a", "operator": {"type": "PLSRegression"}}
7031          ]
7032        },
7033        {
7034          "id": "site_b",
7035          "selector": {"value": "B"},
7036          "steps": [
7037            {"kind": "model", "id": "model:site.b", "operator": {"type": "Ridge"}}
7038          ]
7039        }
7040      ]
7041    },
7042    {
7043      "kind": "merge_model",
7044      "id": "model:site.meta",
7045      "operator": {"type": "Ridge"},
7046      "include_original_data": false
7047    }
7048  ]
7049}"#,
7050        )
7051        .unwrap();
7052
7053        let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
7054
7055        assert_eq!(compiled.branch_view_plans.len(), 2);
7056        assert_eq!(
7057            compiled.campaign_template.branch_view_plans,
7058            compiled.branch_view_plans
7059        );
7060        assert_eq!(
7061            compiled.branch_view_plans[0].mode,
7062            BranchViewMode::ByMetadata
7063        );
7064        assert_eq!(compiled.branch_view_plans[0].selector.metadata["site"], "A");
7065        assert_eq!(compiled.branch_view_plans[1].selector.metadata["site"], "B");
7066        let site_model = compiled
7067            .graph
7068            .nodes
7069            .iter()
7070            .find(|node| node.id.as_str() == "model:site.a")
7071            .unwrap();
7072        assert_eq!(
7073            site_model.metadata["dsl_branch_view_plan"]["selector"]["metadata"]["site"],
7074            "A"
7075        );
7076    }
7077
7078    #[test]
7079    fn refuses_separation_branch_without_selector() {
7080        let spec: PipelineDslSpec = serde_json::from_str(
7081            r#"{
7082  "id": "dsl-bad-separation-branch",
7083  "steps": [
7084    {
7085      "kind": "branch",
7086      "mode": "by_source",
7087      "branches": [
7088        {
7089          "id": "nir",
7090          "steps": [
7091            {"kind": "model", "id": "model:nir", "operator": {"type": "Ridge"}}
7092          ]
7093        }
7094      ]
7095    }
7096  ]
7097}"#,
7098        )
7099        .unwrap();
7100
7101        let error = compile_pipeline_dsl_with_generation(&spec)
7102            .unwrap_err()
7103            .to_string();
7104
7105        assert!(error.contains("by_source branch `nir` requires a selector"));
7106    }
7107
7108    #[test]
7109    fn compiles_branch_feature_merge_into_downstream_model() {
7110        let spec: PipelineDslSpec = serde_json::from_str(
7111            r#"{
7112  "id": "dsl-branch-feature-merge",
7113  "steps": [
7114    {
7115      "kind": "branch",
7116      "branches": [
7117        {
7118          "id": "snv",
7119          "steps": [
7120            {
7121              "kind": "transform",
7122              "id": "branch:snv.transform",
7123              "operator": {"type": "SNV"}
7124            }
7125          ]
7126        },
7127        {
7128          "id": "msc",
7129          "steps": [
7130            {
7131              "kind": "transform",
7132              "id": "branch:msc.transform",
7133              "operator": {"type": "MSC"}
7134            }
7135          ]
7136        }
7137      ]
7138    },
7139    {
7140      "kind": "merge",
7141      "id": "merge:features",
7142      "merge_mode": "features",
7143      "output_as": "features",
7144      "include_original_data": false
7145    },
7146    {
7147      "kind": "model",
7148      "id": "model:pls",
7149      "operator": {"type": "PLSRegression"}
7150    }
7151  ]
7152}"#,
7153        )
7154        .unwrap();
7155
7156        let graph = compile_pipeline_dsl(&spec).unwrap();
7157        graph.validate().unwrap();
7158        let merge = graph
7159            .nodes
7160            .iter()
7161            .find(|node| node.id.as_str() == "merge:features")
7162            .unwrap();
7163        assert_eq!(merge.kind, NodeKind::FeatureJoin);
7164        assert_eq!(merge.ports.inputs.len(), 2);
7165        assert!(merge.ports.inputs.iter().any(|port| port.name == "snv_x"));
7166        assert!(merge.ports.inputs.iter().any(|port| port.name == "msc_x"));
7167        assert!(graph.edges.iter().any(|edge| {
7168            edge.source.node_id.as_str() == "branch:snv.transform"
7169                && edge.target.node_id.as_str() == "merge:features"
7170                && edge.target.port_name == "snv_x"
7171                && edge.contract.kind == PortKind::Data
7172        }));
7173        assert!(graph.edges.iter().any(|edge| {
7174            edge.source.node_id.as_str() == "merge:features"
7175                && edge.target.node_id.as_str() == "model:pls"
7176                && edge.contract.kind == PortKind::Data
7177        }));
7178        assert!(!graph
7179            .edges
7180            .iter()
7181            .any(|edge| edge.contract.kind == PortKind::Prediction));
7182    }
7183
7184    #[test]
7185    fn compiles_nirs4all_style_multi_model_branch_and_separate_merge() {
7186        let spec: PipelineDslSpec = serde_json::from_str(
7187            r#"{
7188  "id": "dsl-nirs4all-branch-parity",
7189  "steps": [
7190    {
7191      "kind": "branch",
7192      "mode": "duplication",
7193      "selector": {"scope": "all_samples"},
7194      "branches": [
7195        {
7196          "id": "pls_path",
7197          "steps": [
7198            {
7199              "kind": "model",
7200              "id": "branch:pls.model:pls5",
7201              "operator": {"class": "sklearn.cross_decomposition.PLSRegression"},
7202              "params": {"n_components": 5}
7203            },
7204            {
7205              "kind": "model",
7206              "id": "branch:pls.model:pls10",
7207              "operator": {"class": "sklearn.cross_decomposition.PLSRegression"},
7208              "params": {"n_components": 10}
7209            }
7210          ]
7211        },
7212        {
7213          "id": "rf_path",
7214          "selector": {"source": "nir"},
7215          "steps": [
7216            {
7217              "kind": "transform",
7218              "id": "branch:rf.transform:snv",
7219              "operator": {"class": "nirs4all.operators.transforms.StandardNormalVariate"}
7220            },
7221            {
7222              "kind": "model",
7223              "id": "branch:rf.model:rf",
7224              "operator": {"class": "sklearn.ensemble.RandomForestRegressor"},
7225              "params": {"n_estimators": 64}
7226            },
7227            {
7228              "kind": "model",
7229              "id": "branch:rf.model:gbr",
7230              "operator": {"class": "sklearn.ensemble.GradientBoostingRegressor"},
7231              "params": {"n_estimators": 32}
7232            }
7233          ]
7234        }
7235      ]
7236    },
7237    {
7238      "kind": "merge",
7239      "id": "merge:stack.predictions_plus_original",
7240      "merge_mode": "predictions_plus_original",
7241      "output_as": "features",
7242      "include_original_data": true,
7243      "selectors": [
7244        {"branch": "pls_path", "select": "best", "metric": "rmse"},
7245        {"branch": "rf_path", "select": {"top_k": 2}, "metric": "r2"}
7246      ],
7247      "metadata": {"on_missing": "warn"}
7248    },
7249    {
7250      "kind": "model",
7251      "id": "model:meta.ridge",
7252      "operator": {"class": "sklearn.linear_model.Ridge"},
7253      "variants": [
7254        {"label": "low", "params": {"alpha": 0.1}},
7255        {"label": "mid", "params": {"alpha": 0.5}}
7256      ]
7257    },
7258    {
7259      "kind": "model",
7260      "id": "model:meta.rf",
7261      "operator": {"class": "sklearn.ensemble.RandomForestRegressor"},
7262      "params": {"n_estimators": 30}
7263    }
7264  ]
7265}"#,
7266        )
7267        .unwrap();
7268
7269        let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
7270        let graph = compiled.graph;
7271        let merge = graph
7272            .nodes
7273            .iter()
7274            .find(|node| node.id.as_str() == "merge:stack.predictions_plus_original")
7275            .unwrap();
7276
7277        assert_eq!(merge.kind, NodeKind::MixedJoin);
7278        assert_eq!(merge.ports.inputs.len(), 5);
7279        assert_eq!(merge.ports.outputs[0].kind, PortKind::Data);
7280        assert_eq!(merge.metadata["merge_mode"], "predictions_plus_original");
7281        assert_eq!(merge.metadata["selectors"][0]["branch"], "pls_path");
7282        let rf_model = graph
7283            .nodes
7284            .iter()
7285            .find(|node| node.id.as_str() == "branch:rf.model:rf")
7286            .unwrap();
7287        assert_eq!(rf_model.metadata["dsl_branch"], "rf_path");
7288        assert_eq!(rf_model.metadata["dsl_branch_mode"], "duplication");
7289        assert_eq!(
7290            rf_model.metadata["dsl_branch_step_selector"]["scope"],
7291            "all_samples"
7292        );
7293        assert_eq!(rf_model.metadata["dsl_branch_selector"]["source"], "nir");
7294        assert_eq!(
7295            graph
7296                .edges
7297                .iter()
7298                .filter(|edge| edge.target.node_id == merge.id
7299                    && edge.contract.kind == PortKind::Prediction
7300                    && edge.contract.requires_oof)
7301                .count(),
7302            4
7303        );
7304        assert!(graph
7305            .edges
7306            .iter()
7307            .any(|edge| edge.source.node_id == merge.id
7308                && edge.target.node_id.as_str() == "model:meta.ridge"
7309                && edge.contract.kind == PortKind::Data));
7310        assert!(graph
7311            .edges
7312            .iter()
7313            .any(|edge| edge.source.node_id == merge.id
7314                && edge.target.node_id.as_str() == "model:meta.rf"
7315                && edge.contract.kind == PortKind::Data));
7316        assert_eq!(compiled.generation.dimensions.len(), 1);
7317        assert_eq!(
7318            compiled.generation.dimensions[0].name,
7319            "model:meta.ridge.params"
7320        );
7321        graph.validate().unwrap();
7322    }
7323
7324    #[test]
7325    fn merge_selectors_reject_unknown_branch_and_missing_metric() {
7326        let unknown_branch: PipelineDslSpec = serde_json::from_str(
7327            r#"{
7328  "id": "dsl-bad-merge-selector-branch",
7329  "steps": [
7330    {
7331      "kind": "branch",
7332      "branches": [
7333        {
7334          "id": "known",
7335          "steps": [
7336            {
7337              "kind": "model",
7338              "id": "branch:known.model:ridge",
7339              "operator": {"type": "Ridge"}
7340            }
7341          ]
7342        }
7343      ]
7344    },
7345    {
7346      "kind": "merge",
7347      "id": "merge:bad.selector",
7348      "selectors": [
7349        {"branch": "missing", "select": "all"}
7350      ]
7351    }
7352  ]
7353}"#,
7354        )
7355        .unwrap();
7356        let error = compile_pipeline_dsl_with_generation(&unknown_branch).unwrap_err();
7357        assert!(format!("{error}").contains("does not match any pending prediction input"));
7358
7359        let missing_metric: PipelineDslSpec = serde_json::from_str(
7360            r#"{
7361  "id": "dsl-bad-merge-selector-metric",
7362  "steps": [
7363    {
7364      "kind": "branch",
7365      "branches": [
7366        {
7367          "id": "known",
7368          "steps": [
7369            {
7370              "kind": "model",
7371              "id": "branch:known.model:ridge",
7372              "operator": {"type": "Ridge"}
7373            }
7374          ]
7375        }
7376      ]
7377    },
7378    {
7379      "kind": "merge",
7380      "id": "merge:bad.metric",
7381      "selectors": [
7382        {"branch": "known", "select": "best"}
7383      ]
7384    }
7385  ]
7386}"#,
7387        )
7388        .unwrap();
7389        let error = compile_pipeline_dsl_with_generation(&missing_metric).unwrap_err();
7390        assert!(format!("{error}").contains("requires a non-empty metric"));
7391    }
7392
7393    #[test]
7394    fn merge_selectors_reject_top_k_above_scope() {
7395        let spec: PipelineDslSpec = serde_json::from_str(
7396            r#"{
7397  "id": "dsl-bad-merge-selector-top-k",
7398  "steps": [
7399    {
7400      "kind": "branch",
7401      "branches": [
7402        {
7403          "id": "known",
7404          "steps": [
7405            {
7406              "kind": "model",
7407              "id": "branch:known.model:ridge",
7408              "operator": {"type": "Ridge"}
7409            }
7410          ]
7411        }
7412      ]
7413    },
7414    {
7415      "kind": "merge",
7416      "id": "merge:bad.topk",
7417      "selectors": [
7418        {"branch": "known", "select": {"top_k": 2}, "metric": "rmse"}
7419      ]
7420    }
7421  ]
7422}"#,
7423        )
7424        .unwrap();
7425
7426        let error = compile_pipeline_dsl_with_generation(&spec).unwrap_err();
7427        assert!(format!("{error}").contains("top_k=2 exceeds 1 matched prediction inputs"));
7428    }
7429
7430    #[test]
7431    fn compiles_nirs4all_shape_changing_and_tuning_surface() {
7432        let spec: PipelineDslSpec = serde_json::from_str(
7433            r#"{
7434  "id": "dsl-nirs4all-shape-parity",
7435  "steps": [
7436    {
7437      "kind": "y_transform",
7438      "id": "target:scale",
7439      "operator": {"class": "sklearn.preprocessing.StandardScaler"}
7440    },
7441    {
7442      "kind": "tag",
7443      "id": "tag:y_outliers",
7444      "operator": {"class": "nirs4all.filters.YOutlierFilter"},
7445      "params": {"method": "iqr"}
7446    },
7447    {
7448      "kind": "exclude",
7449      "id": "exclude:train_outliers",
7450      "operator": {"class": "nirs4all.filters.YOutlierFilter"},
7451      "params": {"mode": "any"}
7452    },
7453    {
7454      "kind": "sample_augmentation",
7455      "id": "augment:sample.noise",
7456      "operator": {"class": "nirs4all.operators.transforms.GaussianAdditiveNoise"},
7457      "params": {"count": 3, "selection": "random"},
7458      "shape": {
7459        "fit_rows": "fold_train",
7460        "predict_rows": "fold_validation",
7461        "augmentation_policy": {
7462          "sample_scope": "train_only",
7463          "feature_scope": "none",
7464          "require_origin_id": true,
7465          "inherit_group": true,
7466          "inherit_target": true
7467        }
7468      }
7469    },
7470    {
7471      "kind": "feature_augmentation",
7472      "id": "augment:feature.views",
7473      "operator": {"class": "nirs4all.operators.transforms.FeatureAugmentation"},
7474      "params": {"action": "extend"},
7475      "shape": {
7476        "fit_rows": "fold_train",
7477        "predict_rows": "fold_validation",
7478        "feature_namespace": "augmented_views",
7479        "augmentation_policy": {
7480          "sample_scope": "none",
7481          "feature_scope": "train_only",
7482          "require_origin_id": false
7483        }
7484      }
7485    },
7486    {
7487      "kind": "concat_transform",
7488      "id": "join:concat.multi_view",
7489      "branches": [
7490        {
7491          "id": "pca",
7492          "steps": [
7493            {
7494              "id": "concat:pca",
7495              "operator": {"class": "sklearn.decomposition.PCA"},
7496              "params": {"n_components": 20}
7497            }
7498          ]
7499        },
7500        {
7501          "id": "derivative_pca",
7502          "steps": [
7503            {
7504              "id": "concat:derivative",
7505              "operator": {"class": "nirs4all.operators.transforms.FirstDerivative"}
7506            },
7507            {
7508              "id": "concat:derivative.pca",
7509              "operator": {"class": "sklearn.decomposition.PCA"},
7510              "params": {"n_components": 10}
7511            }
7512          ]
7513        }
7514      ],
7515      "shape": {
7516        "fit_rows": "fold_train",
7517        "feature_namespace": "concat.multi_view",
7518        "selection_policy": {
7519          "scope": "unsupervised"
7520        }
7521      }
7522    },
7523    {
7524      "kind": "model",
7525      "id": "model:tuned",
7526      "operator": {"class": "sklearn.ensemble.RandomForestRegressor"},
7527      "finetune_params": {
7528        "n_trials": 10,
7529        "approach": "single",
7530        "eval_mode": "mean",
7531        "sampler": "random",
7532        "metric": "rmse",
7533        "model_params": {
7534          "max_depth": [3, 5, 7]
7535        }
7536      },
7537      "train_params": {
7538        "sample_weight": "balanced"
7539      }
7540    }
7541  ]
7542}"#,
7543        )
7544        .unwrap();
7545
7546        let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
7547        let graph = compiled.graph;
7548        let kinds = graph
7549            .nodes
7550            .iter()
7551            .map(|node| node.kind.clone())
7552            .collect::<Vec<_>>();
7553        assert!(kinds.contains(&NodeKind::YTransform));
7554        assert!(kinds.contains(&NodeKind::Tag));
7555        assert!(kinds.contains(&NodeKind::Exclude));
7556        assert!(kinds.contains(&NodeKind::Augmentation));
7557        assert!(kinds.contains(&NodeKind::FeatureJoin));
7558        assert_eq!(compiled.shape_plans.len(), 3);
7559
7560        let sample_aug = graph
7561            .nodes
7562            .iter()
7563            .find(|node| node.id.as_str() == "augment:sample.noise")
7564            .unwrap();
7565        assert_eq!(sample_aug.metadata["dsl_augmentation_kind"], "sample");
7566        let feature_aug = graph
7567            .nodes
7568            .iter()
7569            .find(|node| node.id.as_str() == "augment:feature.views")
7570            .unwrap();
7571        assert_eq!(feature_aug.metadata["dsl_augmentation_kind"], "feature");
7572        let model = graph
7573            .nodes
7574            .iter()
7575            .find(|node| node.id.as_str() == "model:tuned")
7576            .unwrap();
7577        assert_eq!(model.metadata["dsl_tuning"]["n_trials"], 10);
7578        assert_eq!(
7579            model.metadata["dsl_train_params"]["sample_weight"],
7580            "balanced"
7581        );
7582        graph.validate().unwrap();
7583    }
7584
7585    #[test]
7586    fn extracts_node_param_variants_into_generation_spec() {
7587        let spec: PipelineDslSpec = serde_json::from_str(
7588            r#"{
7589  "id": "dsl-generation-smoke",
7590  "max_variants": 4,
7591  "steps": [
7592    {
7593      "kind": "transform",
7594      "id": "transform:preprocess",
7595      "operator": {"type": "Preprocess"},
7596      "variants": [
7597        {
7598          "label": "snv",
7599          "params": {"method": "snv"}
7600        },
7601        {
7602          "label": "msc",
7603          "params": {"method": "msc"}
7604        }
7605      ]
7606    },
7607    {
7608      "kind": "model",
7609      "id": "model:base",
7610      "operator": {"type": "Ridge"},
7611      "variants": [
7612        {
7613          "label": "low",
7614          "params": {"alpha": 0.1}
7615        },
7616        {
7617          "label": "high",
7618          "params": {"alpha": 1.0}
7619        }
7620      ]
7621    }
7622  ]
7623}"#,
7624        )
7625        .unwrap();
7626
7627        let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
7628
7629        assert_eq!(compiled.generation.strategy, GenerationStrategy::Cartesian);
7630        assert_eq!(compiled.generation.max_variants, Some(4));
7631        assert_eq!(compiled.generation.dimensions.len(), 2);
7632        assert_eq!(
7633            compiled.generation.dimensions[0].name,
7634            "transform:preprocess.params"
7635        );
7636        assert_eq!(compiled.generation.dimensions[0].choices[0].label, "snv");
7637        assert_eq!(
7638            compiled.generation.dimensions[0].choices[0].param_overrides[0].node_id,
7639            NodeId::new("transform:preprocess").unwrap()
7640        );
7641        assert_eq!(
7642            compiled.generation.dimensions[1].choices[1].param_overrides[0].params["alpha"],
7643            1.0
7644        );
7645        assert!(compiled.generation_fingerprint.is_some());
7646        assert_eq!(
7647            compiled.graph.search_space_fingerprint,
7648            compiled.generation_fingerprint
7649        );
7650        compiled.graph.validate().unwrap();
7651    }
7652
7653    #[test]
7654    fn expands_compact_param_generators_into_generation_dimensions() {
7655        let spec: PipelineDslSpec = serde_json::from_str(
7656            r#"{
7657  "id": "dsl-compact-generation",
7658  "steps": [
7659    {
7660      "kind": "model",
7661      "id": "model:tuned",
7662      "operator": {"type": "TunedModel"},
7663      "generators": [
7664        {
7665          "kind": "or",
7666          "name": "model_family",
7667          "param": "family",
7668          "values": [
7669            {"label": "ridge", "value": "ridge"},
7670            {"label": "rf", "value": "random_forest"}
7671          ]
7672        },
7673        {
7674          "kind": "range",
7675          "param": "alpha",
7676          "start": 0.1,
7677          "stop": 0.9,
7678          "step": 0.4
7679        },
7680        {
7681          "kind": "log_range",
7682          "param": "lambda",
7683          "start": 0.01,
7684          "stop": 1.0,
7685          "count": 3
7686        },
7687        {
7688          "kind": "grid",
7689          "name": "tree_grid",
7690          "params": {
7691            "max_depth": [3, 5],
7692            "n_estimators": [50, 100]
7693          },
7694          "count": 3
7695        },
7696        {
7697          "kind": "pick",
7698          "param": "views",
7699          "values": ["snv", "msc", "derivative"],
7700          "sizes": [1, 2],
7701          "count": 4
7702        },
7703        {
7704          "kind": "arrange",
7705          "param": "chain",
7706          "values": ["snv", "pca", "pls"],
7707          "sizes": [2],
7708          "count": 3
7709        }
7710      ]
7711    }
7712  ]
7713}"#,
7714        )
7715        .unwrap();
7716
7717        let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
7718
7719        assert_eq!(compiled.generation.strategy, GenerationStrategy::Cartesian);
7720        assert_eq!(compiled.generation.dimensions.len(), 6);
7721        assert_eq!(compiled.generation.dimensions[0].name, "model_family");
7722        assert_eq!(compiled.generation.dimensions[0].choices.len(), 2);
7723        assert_eq!(
7724            compiled.generation.dimensions[1].name,
7725            "model:tuned.alpha.range"
7726        );
7727        assert_eq!(compiled.generation.dimensions[1].choices.len(), 3);
7728        assert_eq!(
7729            compiled.generation.dimensions[1].choices[1].param_overrides[0].params["alpha"],
7730            0.5
7731        );
7732        assert_eq!(
7733            compiled.generation.dimensions[2].name,
7734            "model:tuned.lambda.log_range"
7735        );
7736        assert_eq!(compiled.generation.dimensions[2].choices.len(), 3);
7737        assert_eq!(compiled.generation.dimensions[3].name, "tree_grid");
7738        assert_eq!(compiled.generation.dimensions[3].choices.len(), 3);
7739        assert_eq!(
7740            compiled.generation.dimensions[3].choices[2].param_overrides[0].params["n_estimators"],
7741            50
7742        );
7743        assert_eq!(
7744            compiled.generation.dimensions[4].choices[3].param_overrides[0].params["views"],
7745            serde_json::json!(["snv", "msc"])
7746        );
7747        assert_eq!(
7748            compiled.generation.dimensions[5].choices[2].param_overrides[0].params["chain"],
7749            serde_json::json!(["pca", "snv"])
7750        );
7751        assert!(compiled.generation_fingerprint.is_some());
7752    }
7753
7754    #[test]
7755    fn compact_param_generators_reject_invalid_counts() {
7756        let spec: PipelineDslSpec = serde_json::from_str(
7757            r#"{
7758  "id": "dsl-bad-compact-generation",
7759  "steps": [
7760    {
7761      "kind": "model",
7762      "id": "model:bad",
7763      "operator": {"type": "Ridge"},
7764      "generators": [
7765        {
7766          "kind": "or",
7767          "param": "alpha",
7768          "values": [0.1, 1.0],
7769          "count": 0
7770        }
7771      ]
7772    }
7773  ]
7774}"#,
7775        )
7776        .unwrap();
7777
7778        let error = compile_pipeline_dsl_with_generation(&spec).unwrap_err();
7779        assert!(format!("{error}").contains("count=0"));
7780    }
7781
7782    #[test]
7783    fn compiles_coordinated_generation_dimensions() {
7784        let spec: PipelineDslSpec = serde_json::from_str(
7785            r#"{
7786  "id": "dsl-coordinated-generation",
7787  "max_variants": 2,
7788  "generation_dimensions": [
7789    {
7790      "name": "stack_profile",
7791      "choices": [
7792        {
7793          "label": "linear_stack",
7794          "param_overrides": [
7795            {"node_id": "branch:b0.model:ridge", "params": {"alpha": 0.1}},
7796            {"node_id": "branch:b1.model:rf", "params": {"max_depth": 4}},
7797            {"node_id": "merge:stack.pred_plus_original.meta:ridge", "params": {"alpha": 0.05}}
7798          ]
7799        },
7800        {
7801          "label": "robust_stack",
7802          "param_overrides": [
7803            {"node_id": "branch:b0.model:ridge", "params": {"alpha": 1.0}},
7804            {"node_id": "branch:b1.model:rf", "params": {"max_depth": 8}},
7805            {"node_id": "merge:stack.pred_plus_original.meta:ridge", "params": {"alpha": 0.5}}
7806          ]
7807        }
7808      ]
7809    }
7810  ],
7811  "steps": [
7812    {
7813      "kind": "branch",
7814      "branches": [
7815        {
7816          "id": "b0",
7817          "steps": [
7818            {
7819              "kind": "model",
7820              "id": "branch:b0.model:ridge",
7821              "operator": {"type": "Ridge"}
7822            }
7823          ]
7824        },
7825        {
7826          "id": "b1",
7827          "steps": [
7828            {
7829              "kind": "model",
7830              "id": "branch:b1.model:rf",
7831              "operator": {"type": "RandomForestRegressor"}
7832            }
7833          ]
7834        }
7835      ]
7836    },
7837    {
7838      "kind": "merge_model",
7839      "id": "merge:stack.pred_plus_original.meta:ridge",
7840      "operator": {"type": "RidgeMetaStacker"}
7841    }
7842  ]
7843}"#,
7844        )
7845        .unwrap();
7846
7847        let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
7848
7849        assert_eq!(compiled.generation.strategy, GenerationStrategy::Cartesian);
7850        assert_eq!(compiled.generation.max_variants, Some(2));
7851        assert_eq!(compiled.generation.dimensions.len(), 1);
7852        assert_eq!(compiled.generation.dimensions[0].name, "stack_profile");
7853        assert_eq!(
7854            compiled.generation.dimensions[0].choices[0]
7855                .param_overrides
7856                .len(),
7857            3
7858        );
7859        assert_eq!(
7860            compiled.generation.dimensions[0].choices[1].param_overrides[2].node_id,
7861            NodeId::new("merge:stack.pred_plus_original.meta:ridge").unwrap()
7862        );
7863        assert_eq!(
7864            compiled.generation.dimensions[0].choices[1].value
7865                ["merge:stack.pred_plus_original.meta:ridge"]["alpha"],
7866            0.5
7867        );
7868        assert_eq!(
7869            compiled.graph.search_space_fingerprint,
7870            compiled.generation_fingerprint
7871        );
7872        compiled.graph.validate().unwrap();
7873    }
7874
7875    #[test]
7876    fn refuses_coordinated_generation_for_unknown_node() {
7877        let spec: PipelineDslSpec = serde_json::from_str(
7878            r#"{
7879  "id": "dsl-bad-generation-target",
7880  "generation_dimensions": [
7881    {
7882      "name": "bad_target",
7883      "choices": [
7884        {
7885          "label": "bad",
7886          "param_overrides": [
7887            {"node_id": "model:missing", "params": {"alpha": 0.1}}
7888          ]
7889        }
7890      ]
7891    }
7892  ],
7893  "steps": [
7894    {
7895      "kind": "model",
7896      "id": "model:base",
7897      "operator": {"type": "Ridge"}
7898    }
7899  ]
7900}"#,
7901        )
7902        .unwrap();
7903
7904        let error = compile_pipeline_dsl_with_generation(&spec).unwrap_err();
7905        assert!(format!("{error}").contains("references unknown node `model:missing`"));
7906    }
7907
7908    #[test]
7909    fn artifact_contains_campaign_template_without_split_graph_nodes() {
7910        let spec: PipelineDslSpec = serde_json::from_str(
7911            r#"{
7912  "id": "dsl-campaign-template",
7913  "campaign_id": "campaign:dsl.template",
7914  "root_seed": 123,
7915  "leakage_policy": {
7916    "split_unit": "group",
7917    "require_group_ids": true
7918  },
7919  "split_invocation": {
7920    "id": "split:group-kfold",
7921    "leakage_policy": {
7922      "split_unit": "group",
7923      "require_group_ids": true
7924    },
7925    "params": {
7926      "n_splits": 3
7927    }
7928  },
7929  "generation_dimensions": [
7930    {
7931      "name": "model_family",
7932      "choices": [
7933        {
7934          "label": "ridge_low",
7935          "param_overrides": [
7936            {"node_id": "model:base", "params": {"alpha": 0.1}}
7937          ]
7938        },
7939        {
7940          "label": "ridge_high",
7941          "param_overrides": [
7942            {"node_id": "model:base", "params": {"alpha": 1.0}}
7943          ]
7944        }
7945      ]
7946    }
7947  ],
7948  "data_bindings": [
7949    {
7950      "node_id": "model:base",
7951      "input_name": "x",
7952      "request_id": "data:model.base.x",
7953      "schema_fingerprint": "f97b37872fa22134b508f98fd8e207e5b776b52594fb8f6f5c3e15bee212246b",
7954      "plan_fingerprint": "7c5431d85574b3f337022fa5d25971d5b5cf445b90331b49938f573ff6901e4d",
7955      "relation_fingerprint": "a3a7e329df35db9f2883a17b8611b7fae6dcaa031875e3ec2c9be1b9e29cbe10",
7956      "output_representation": "tabular_numeric",
7957      "feature_set_id": "x",
7958      "source_ids": ["nir"],
7959      "require_relations": true
7960    }
7961  ],
7962  "steps": [
7963    {
7964      "kind": "model",
7965      "id": "model:base",
7966      "operator": {"type": "Ridge"}
7967    }
7968  ],
7969  "campaign_metadata": {
7970    "owner": "dsl-test"
7971  }
7972}"#,
7973        )
7974        .unwrap();
7975
7976        let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
7977
7978        assert_eq!(compiled.campaign_template.id, "campaign:dsl.template");
7979        assert_eq!(compiled.campaign_template.root_seed, Some(123));
7980        assert_eq!(
7981            compiled
7982                .campaign_template
7983                .split_invocation
7984                .as_ref()
7985                .unwrap()
7986                .id,
7987            "split:group-kfold"
7988        );
7989        assert_eq!(compiled.campaign_template.generation, compiled.generation);
7990        assert_eq!(
7991            compiled.data_bindings[&NodeId::new("model:base").unwrap()][0].request_id,
7992            "data:model.base.x"
7993        );
7994        assert_eq!(
7995            compiled.campaign_template.data_bindings,
7996            compiled.data_bindings
7997        );
7998        assert_eq!(compiled.graph.nodes.len(), 1);
7999        assert!(compiled
8000            .graph
8001            .nodes
8002            .iter()
8003            .all(|node| !node.id.as_str().starts_with("split:")));
8004    }
8005
8006    #[test]
8007    fn refuses_data_binding_for_unknown_or_non_data_port() {
8008        let unknown_input_spec: PipelineDslSpec = serde_json::from_str(
8009            r#"{
8010  "id": "dsl-bad-data-binding",
8011  "data_bindings": [
8012    {
8013      "node_id": "model:base",
8014      "input_name": "missing",
8015      "request_id": "data:bad",
8016      "schema_fingerprint": "f97b37872fa22134b508f98fd8e207e5b776b52594fb8f6f5c3e15bee212246b",
8017      "plan_fingerprint": "7c5431d85574b3f337022fa5d25971d5b5cf445b90331b49938f573ff6901e4d",
8018      "output_representation": "tabular_numeric"
8019    }
8020  ],
8021  "steps": [
8022    {
8023      "kind": "model",
8024      "id": "model:base",
8025      "operator": {"type": "Ridge"}
8026    }
8027  ]
8028}"#,
8029        )
8030        .unwrap();
8031        let error = compile_pipeline_dsl_with_generation(&unknown_input_spec).unwrap_err();
8032        assert!(format!("{error}").contains("unknown input port `missing`"));
8033
8034        let prediction_input_spec: PipelineDslSpec = serde_json::from_str(
8035            r#"{
8036  "id": "dsl-prediction-port-data-binding",
8037  "data_bindings": [
8038    {
8039      "node_id": "merge:stack.pred_plus_original.meta:ridge",
8040      "input_name": "b0_oof",
8041      "request_id": "data:bad.prediction-port",
8042      "schema_fingerprint": "f97b37872fa22134b508f98fd8e207e5b776b52594fb8f6f5c3e15bee212246b",
8043      "plan_fingerprint": "7c5431d85574b3f337022fa5d25971d5b5cf445b90331b49938f573ff6901e4d",
8044      "output_representation": "tabular_numeric"
8045    }
8046  ],
8047  "steps": [
8048    {
8049      "kind": "branch",
8050      "branches": [
8051        {
8052          "id": "b0",
8053          "steps": [
8054            {
8055              "kind": "model",
8056              "id": "branch:b0.model:ridge",
8057              "operator": {"type": "Ridge"}
8058            }
8059          ]
8060        }
8061      ]
8062    },
8063    {
8064      "kind": "merge_model",
8065      "id": "merge:stack.pred_plus_original.meta:ridge",
8066      "operator": {"type": "RidgeMetaStacker"}
8067    }
8068  ]
8069}"#,
8070        )
8071        .unwrap();
8072        let error = compile_pipeline_dsl_with_generation(&prediction_input_spec).unwrap_err();
8073        assert!(format!("{error}").contains("targets non-data input"));
8074    }
8075
8076    #[test]
8077    fn extracts_shape_plans_into_compiled_artifact() {
8078        let spec: PipelineDslSpec = serde_json::from_str(
8079            r#"{
8080  "id": "dsl-shape-plan-smoke",
8081  "steps": [
8082    {
8083      "kind": "augmentation",
8084      "id": "augment:synthetic",
8085      "operator": {"type": "SampleAugmenter"},
8086      "shape": {
8087        "input_granularity": "sample",
8088        "target_granularity": "sample",
8089        "fit_rows": "fold_train",
8090        "predict_rows": "fold_validation",
8091        "feature_namespace": "aug.synthetic",
8092        "augmentation_policy": {
8093          "sample_scope": "train_only",
8094          "feature_scope": "none",
8095          "require_origin_id": true,
8096          "inherit_group": true,
8097          "inherit_target": true
8098        }
8099      }
8100    },
8101    {
8102      "kind": "transform",
8103      "id": "transform:select",
8104      "operator": {"type": "SupervisedFeatureSelector"},
8105      "shape": {
8106        "fit_rows": "fold_train",
8107        "feature_namespace": "selected",
8108        "selection_policy": {
8109          "scope": "supervised_fold_train",
8110          "store_masks": true
8111        }
8112      }
8113    },
8114    {
8115      "kind": "model",
8116      "id": "model:base",
8117      "operator": {"type": "Ridge"}
8118    }
8119  ]
8120}"#,
8121        )
8122        .unwrap();
8123
8124        let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
8125
8126        assert_eq!(compiled.shape_plans.len(), 2);
8127        let augment_plan = compiled
8128            .shape_plans
8129            .get(&NodeId::new("augment:synthetic").unwrap())
8130            .unwrap();
8131        assert_eq!(
8132            augment_plan.feature_namespace.as_deref(),
8133            Some("aug.synthetic")
8134        );
8135        assert_eq!(
8136            augment_plan.augmentation_policy.sample_scope,
8137            crate::policy::AugmentationScope::TrainOnly
8138        );
8139        let select_plan = compiled
8140            .shape_plans
8141            .get(&NodeId::new("transform:select").unwrap())
8142            .unwrap();
8143        assert_eq!(
8144            select_plan.selection_policy.scope,
8145            crate::policy::FeatureSelectionScope::SupervisedFoldTrain
8146        );
8147        assert_eq!(compiled.generation.strategy, GenerationStrategy::None);
8148        compiled.graph.validate().unwrap();
8149    }
8150
8151    #[test]
8152    fn compiles_sequential_filter_and_or_generator_surface() {
8153        let spec: PipelineDslSpec = serde_json::from_str(
8154            r#"{
8155  "id": "dsl-generator-or-parity",
8156  "steps": [
8157    {
8158      "kind": "sequential",
8159      "id": "seq:pre",
8160      "steps": [
8161        {
8162          "kind": "sample_filter",
8163          "id": "filter:y_outlier",
8164          "operator": {"class": "nirs4all.operators.filters.YOutlierFilter"},
8165          "params": {"mode": "any"}
8166        },
8167        {
8168          "kind": "transform",
8169          "id": "transform:scale",
8170          "operator": {"class": "sklearn.preprocessing.StandardScaler"}
8171        }
8172      ]
8173    },
8174    {
8175      "kind": "generator",
8176      "id": "generator:model_choices",
8177      "mode": "or",
8178      "pick": 1,
8179      "branches": [
8180        {
8181          "id": "pls",
8182          "steps": [
8183            {
8184              "kind": "model",
8185              "id": "model:pls",
8186              "operator": {"class": "sklearn.cross_decomposition.PLSRegression"},
8187              "params": {"n_components": 8}
8188            }
8189          ]
8190        },
8191        {
8192          "id": "rf",
8193          "steps": [
8194            {
8195              "kind": "model",
8196              "id": "model:rf",
8197              "operator": {"class": "sklearn.ensemble.RandomForestRegressor"},
8198              "params": {"n_estimators": 64}
8199            }
8200          ]
8201        }
8202      ]
8203    },
8204    {
8205      "kind": "merge",
8206      "id": "merge:generated",
8207      "output_as": "features",
8208      "include_original_data": false,
8209      "selectors": [
8210        {"branch": "generator:model_choices:choice0", "select": "all"}
8211      ]
8212    }
8213  ]
8214}"#,
8215        )
8216        .unwrap();
8217
8218        let graph = compile_pipeline_dsl(&spec).unwrap();
8219        graph.validate().unwrap();
8220        let filter = graph
8221            .nodes
8222            .iter()
8223            .find(|node| node.id.as_str() == "filter:y_outlier")
8224            .unwrap();
8225        assert_eq!(filter.kind, NodeKind::Exclude);
8226        assert_eq!(filter.metadata["dsl_filter_kind"], "sample");
8227
8228        let generated_models = graph
8229            .nodes
8230            .iter()
8231            .filter(|node| node.kind == NodeKind::Model)
8232            .collect::<Vec<_>>();
8233        assert_eq!(generated_models.len(), 2);
8234        assert!(generated_models
8235            .iter()
8236            .all(|node| node.id.as_str().starts_with("gen:generator_model_choices")));
8237        assert!(generated_models.iter().all(|node| {
8238            node.metadata
8239                .get("dsl_generator")
8240                .and_then(|value| value.as_str())
8241                == Some("generator:model_choices")
8242        }));
8243
8244        let merge_inputs = graph
8245            .nodes
8246            .iter()
8247            .find(|node| node.id.as_str() == "merge:generated")
8248            .unwrap()
8249            .ports
8250            .inputs
8251            .iter()
8252            .map(|port| port.name.as_str())
8253            .collect::<Vec<_>>();
8254        assert_eq!(
8255            merge_inputs,
8256            vec![
8257                "generator_model_choices_choice0_oof",
8258                "generator_model_choices_choice1_oof"
8259            ]
8260        );
8261    }
8262
8263    #[test]
8264    fn compiles_cartesian_generator_as_explicit_prediction_choices() {
8265        let spec: PipelineDslSpec = serde_json::from_str(
8266            r#"{
8267  "id": "dsl-generator-cartesian-parity",
8268  "steps": [
8269    {
8270      "kind": "generator",
8271      "id": "generator:cartesian",
8272      "mode": "cartesian",
8273      "stages": [
8274        {
8275          "id": "preproc",
8276          "branches": [
8277            {
8278              "id": "snv",
8279              "steps": [
8280                {
8281                  "kind": "transform",
8282                  "id": "transform:snv",
8283                  "operator": {"class": "nirs4all.operators.transforms.StandardNormalVariate"}
8284                }
8285              ]
8286            },
8287            {
8288              "id": "msc",
8289              "steps": [
8290                {
8291                  "kind": "transform",
8292                  "id": "transform:msc",
8293                  "operator": {"class": "nirs4all.operators.transforms.MultiplicativeScatterCorrection"}
8294                }
8295              ]
8296            }
8297          ]
8298        },
8299        {
8300          "id": "model",
8301          "branches": [
8302            {
8303              "id": "ridge",
8304              "steps": [
8305                {
8306                  "kind": "model",
8307                  "id": "model:ridge",
8308                  "operator": {"class": "sklearn.linear_model.Ridge"}
8309                }
8310              ]
8311            },
8312            {
8313              "id": "lasso",
8314              "steps": [
8315                {
8316                  "kind": "model",
8317                  "id": "model:lasso",
8318                  "operator": {"class": "sklearn.linear_model.Lasso"}
8319                }
8320              ]
8321            }
8322          ]
8323        }
8324      ]
8325    },
8326    {
8327      "kind": "merge",
8328      "id": "merge:cartesian",
8329      "output_as": "features",
8330      "include_original_data": false
8331    }
8332  ]
8333}"#,
8334        )
8335        .unwrap();
8336
8337        let graph = compile_pipeline_dsl(&spec).unwrap();
8338        graph.validate().unwrap();
8339        let models = graph
8340            .nodes
8341            .iter()
8342            .filter(|node| node.kind == NodeKind::Model)
8343            .collect::<Vec<_>>();
8344        assert_eq!(models.len(), 4);
8345        assert!(models.iter().all(|node| {
8346            node.metadata
8347                .get("dsl_generator_mode")
8348                .and_then(|value| value.as_str())
8349                == Some("cartesian")
8350        }));
8351        let merge = graph
8352            .nodes
8353            .iter()
8354            .find(|node| node.id.as_str() == "merge:cartesian")
8355            .unwrap();
8356        assert_eq!(merge.ports.inputs.len(), 4);
8357        assert_eq!(
8358            graph
8359                .edges
8360                .iter()
8361                .filter(|edge| edge.target.node_id.as_str() == "merge:cartesian")
8362                .count(),
8363            4
8364        );
8365    }
8366
8367    #[test]
8368    fn refuses_generator_choice_without_prediction_output() {
8369        let spec: PipelineDslSpec = serde_json::from_str(
8370            r#"{
8371  "id": "dsl-generator-bad-choice",
8372  "steps": [
8373    {
8374      "kind": "generator",
8375      "id": "generator:bad",
8376      "branches": [
8377        {
8378          "id": "transform_only",
8379          "steps": [
8380            {
8381              "kind": "transform",
8382              "id": "transform:only",
8383              "operator": {"class": "sklearn.preprocessing.StandardScaler"}
8384            }
8385          ]
8386        }
8387      ]
8388    }
8389  ]
8390}"#,
8391        )
8392        .unwrap();
8393
8394        let error = compile_pipeline_dsl(&spec).unwrap_err();
8395        assert!(format!("{error}").contains("must produce at least one model or merge prediction"));
8396    }
8397
8398    #[test]
8399    fn parses_nirs4all_compat_pipeline_and_fuses_data_generators() {
8400        let spec = parse_pipeline_dsl_json(
8401            br#"{
8402  "id": "dsl-nirs4all-compat-fused",
8403  "pipeline": [
8404    {"sources": ["nir"]},
8405    {"_cartesian_": [
8406      {"_or_": ["SNV", "MSC", null]},
8407      {"_or_": [null, {"preprocessing": "SavitzkyGolay", "params": {"window": 11, "deriv": 1}}]}
8408    ]},
8409    {"split": {"type": "GroupKFold", "n_splits": 3}},
8410    {"_chain_": [
8411      {"_grid_": {"model": ["PLSRegression"], "n_components": [5, 10]}},
8412      {"_grid_": {"model": ["Ridge"], "alpha": [0.1, 1.0]}},
8413      {"_sample_": {"model": "SVR", "distribution": "log_uniform", "from": 0.001, "to": 1.0, "num": 2, "tune": ["C", "gamma"], "kernel": "rbf"}}
8414    ]},
8415    {"merge": "all"},
8416    {"model": "Ridge", "id": "model:meta", "params": {"alpha": 0.5}}
8417  ]
8418}"#,
8419        )
8420        .unwrap();
8421
8422        assert_eq!(spec.steps.len(), 2);
8423        assert_eq!(
8424            spec.split_invocation
8425                .as_ref()
8426                .unwrap()
8427                .params
8428                .get("type")
8429                .unwrap(),
8430            "GroupKFold"
8431        );
8432
8433        let graph = compile_pipeline_dsl(&spec).unwrap();
8434        graph.validate().unwrap();
8435        let meta = graph
8436            .nodes
8437            .iter()
8438            .find(|node| node.id.as_str() == "model:meta")
8439            .unwrap();
8440        assert_eq!(meta.kind, NodeKind::Model);
8441        assert!(meta
8442            .ports
8443            .inputs
8444            .iter()
8445            .any(|port| port.name == "x_original"));
8446        assert!(graph.edges.iter().any(|edge| {
8447            edge.target.node_id.as_str() == "model:meta"
8448                && edge.contract.kind == PortKind::Prediction
8449                && edge.contract.requires_oof
8450        }));
8451        assert!(graph.nodes.iter().any(|node| {
8452            node.metadata
8453                .get("dsl_compat_keyword")
8454                .and_then(serde_json::Value::as_str)
8455                == Some("preprocessing")
8456        }));
8457        assert!(graph.nodes.iter().any(|node| {
8458            node.kind == NodeKind::Model
8459                && node.params.contains_key("C")
8460                && node.params.contains_key("gamma")
8461        }));
8462    }
8463
8464    #[test]
8465    fn parses_nirs4all_range_attached_to_following_model() {
8466        let spec = parse_pipeline_dsl_json(
8467            br#"{
8468  "id": "dsl-nirs4all-compat-range",
8469  "pipeline": [
8470    {"_range_": [5, 15, 5]},
8471    {"model": "PLSRegression", "id": "model:pls"}
8472  ]
8473}"#,
8474        )
8475        .unwrap();
8476
8477        let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
8478        assert_eq!(compiled.generation.dimensions.len(), 1);
8479        assert_eq!(compiled.generation.dimensions[0].choices.len(), 3);
8480        assert_eq!(
8481            compiled.generation.dimensions[0].choices[0].param_overrides[0].params["n_components"],
8482            5.0
8483        );
8484    }
8485
8486    #[test]
8487    fn parses_nirs4all_minimal_aliases_plain_classes_and_split_chain() {
8488        let spec = parse_pipeline_dsl_json(
8489            br#"{
8490  "id": "dsl-nirs4all-compat-minimal-aliases",
8491  "pipeline": [
8492    "chart_2d",
8493    {"class": "sklearn.preprocessing.MinMaxScaler", "params": {"feature_range": [0, 1]}},
8494    {"class": "nirs4all.operators.splitters.SPXYGFold", "params": {"n_splits": 1, "test_size": 0.2}, "group": "Sample_ID"},
8495    {"class": "sklearn.model_selection.KFold", "params": {"n_splits": 3, "shuffle": true, "random_state": 42}},
8496    "SNV",
8497    "PLSRegression"
8498  ]
8499}"#,
8500        )
8501        .unwrap();
8502
8503        let split = spec.split_invocation.as_ref().unwrap();
8504        assert_eq!(split.id, "split:compat.chain");
8505        let chain = split.params["compat_split_chain"].as_array().unwrap();
8506        assert_eq!(chain.len(), 2);
8507        assert_eq!(
8508            chain[0]["params"]["class"],
8509            "nirs4all.operators.splitters.SPXYGFold"
8510        );
8511        assert_eq!(chain[0]["params"]["group"], "Sample_ID");
8512        assert_eq!(chain[1]["params"]["class"], "sklearn.model_selection.KFold");
8513
8514        let graph = compile_pipeline_dsl(&spec).unwrap();
8515        graph.validate().unwrap();
8516        assert!(graph.nodes.iter().any(|node| node.kind == NodeKind::Chart));
8517        assert!(graph.nodes.iter().any(|node| {
8518            node.kind == NodeKind::Transform
8519                && node.operator.as_ref().unwrap()["class"] == "sklearn.preprocessing.MinMaxScaler"
8520        }));
8521        assert!(graph.nodes.iter().any(|node| {
8522            node.kind == NodeKind::Transform
8523                && node.operator.as_ref().unwrap().as_str() == Some("SNV")
8524        }));
8525        assert!(graph.nodes.iter().any(|node| {
8526            node.kind == NodeKind::Model
8527                && node.operator.as_ref().unwrap().as_str() == Some("PLSRegression")
8528        }));
8529    }
8530
8531    #[test]
8532    fn registry_reclassifies_non_heuristic_minimal_aliases_before_compile() {
8533        let spec = parse_pipeline_dsl_json(
8534            br#"{
8535  "id": "dsl-registry-minimal-aliases",
8536  "pipeline": [
8537    "SNV",
8538    "ElasticSpectra"
8539  ]
8540}"#,
8541        )
8542        .unwrap();
8543        let mut registry = ControllerRegistry::new();
8544        registry
8545            .register(registry_manifest(
8546                "controller:transformer.mixin",
8547                NodeKind::Transform,
8548                &["SNV"],
8549            ))
8550            .unwrap();
8551        registry
8552            .register(registry_manifest(
8553                "controller:elastic.spectra",
8554                NodeKind::Model,
8555                &["ElasticSpectra"],
8556            ))
8557            .unwrap();
8558
8559        let compiled =
8560            compile_pipeline_dsl_with_generation_and_controller_registry(&spec, &registry).unwrap();
8561        let model = compiled
8562            .graph
8563            .nodes
8564            .iter()
8565            .find(|node| {
8566                node.operator.as_ref().and_then(serde_json::Value::as_str) == Some("ElasticSpectra")
8567            })
8568            .unwrap();
8569
8570        assert_eq!(model.kind, NodeKind::Model);
8571        assert_eq!(model.metadata[DSL_REGISTRY_INFERRED_KIND], "model");
8572        assert_eq!(model.metadata[DSL_COMPAT_ORIGINAL_KEYWORD], "preprocessing");
8573        assert!(compiled.graph.nodes.iter().any(|node| {
8574            node.kind == NodeKind::Transform
8575                && node.operator.as_ref().and_then(serde_json::Value::as_str) == Some("SNV")
8576        }));
8577    }
8578
8579    #[test]
8580    fn parses_nirs4all_named_step_wrapper_and_plain_class_model() {
8581        let spec = parse_pipeline_dsl_json(
8582            br#"{
8583  "id": "dsl-nirs4all-compat-named-step",
8584  "pipeline": [
8585    {"name": "scaled", "step": {"class": "sklearn.preprocessing.StandardScaler"}},
8586    {"class": "sklearn.ensemble.RandomForestRegressor", "params": {"n_estimators": 10, "random_state": 42}}
8587  ]
8588}"#,
8589        )
8590        .unwrap();
8591
8592        let graph = compile_pipeline_dsl(&spec).unwrap();
8593        graph.validate().unwrap();
8594        let scaled = graph
8595            .nodes
8596            .iter()
8597            .find(|node| node.kind == NodeKind::Transform)
8598            .unwrap();
8599        assert_eq!(scaled.metadata["dsl_name"], "scaled");
8600        let model = graph
8601            .nodes
8602            .iter()
8603            .find(|node| node.kind == NodeKind::Model)
8604            .unwrap();
8605        assert_eq!(
8606            model.operator.as_ref().unwrap()["class"],
8607            "sklearn.ensemble.RandomForestRegressor"
8608        );
8609        assert_eq!(model.params["n_estimators"], 10);
8610    }
8611
8612    #[test]
8613    fn compiles_tuner_as_external_prediction_node() {
8614        let spec: PipelineDslSpec = serde_json::from_str(
8615            r#"{
8616  "id": "dsl-tuner",
8617  "steps": [
8618    {
8619      "kind": "tuner",
8620      "id": "tuner:optuna",
8621      "operator": "OptunaTuner",
8622      "params": {"sampler": "tpe"},
8623      "tuning": {"n_trials": 4, "metric": "rmse"}
8624    },
8625    {
8626      "kind": "merge_model",
8627      "id": "model:meta",
8628      "operator": "Ridge"
8629    }
8630  ]
8631}"#,
8632        )
8633        .unwrap();
8634
8635        let graph = compile_pipeline_dsl(&spec).unwrap();
8636        graph.validate().unwrap();
8637        let tuner = graph
8638            .nodes
8639            .iter()
8640            .find(|node| node.id.as_str() == "tuner:optuna")
8641            .unwrap();
8642        assert_eq!(tuner.kind, NodeKind::Tuner);
8643        assert_eq!(
8644            tuner.operator.as_ref().unwrap().as_str(),
8645            Some("OptunaTuner")
8646        );
8647        assert_eq!(tuner.metadata["dsl_tuning"]["n_trials"], 4);
8648        assert!(graph.edges.iter().any(|edge| {
8649            edge.source.node_id.as_str() == "tuner:optuna"
8650                && edge.source.port_name == "oof"
8651                && edge.target.node_id.as_str() == "model:meta"
8652                && edge.contract.kind == PortKind::Prediction
8653                && edge.contract.requires_oof
8654                && edge.contract.requires_fold_alignment
8655        }));
8656    }
8657
8658    #[test]
8659    fn parses_compat_tuner_minimal_alias_and_wrappers() {
8660        let spec = parse_pipeline_dsl_json(
8661            br#"{
8662  "id": "dsl-compat-tuner",
8663  "pipeline": [
8664    "SNV",
8665    {"tuner": "OptunaTuner", "id": "tuner:compat", "n_trials": 3, "metric": "rmse"},
8666    {"merge": "all"},
8667    {"model": "Ridge"}
8668  ]
8669}"#,
8670        )
8671        .unwrap();
8672
8673        let graph = compile_pipeline_dsl(&spec).unwrap();
8674        graph.validate().unwrap();
8675        let transform = graph
8676            .nodes
8677            .iter()
8678            .find(|node| node.kind == NodeKind::Transform)
8679            .unwrap();
8680        assert_eq!(transform.operator.as_ref().unwrap().as_str(), Some("SNV"));
8681        let tuner = graph
8682            .nodes
8683            .iter()
8684            .find(|node| node.id.as_str() == "tuner:compat")
8685            .unwrap();
8686        assert_eq!(tuner.kind, NodeKind::Tuner);
8687        assert_eq!(tuner.params["n_trials"], 3);
8688        assert_eq!(tuner.metadata["dsl_compat_keyword"], "tuner");
8689    }
8690
8691    #[test]
8692    fn parses_bare_tuner_alias_as_tuner_node() {
8693        let spec = parse_pipeline_dsl_json(
8694            br#"{
8695  "id": "dsl-bare-tuner-alias",
8696  "pipeline": ["SNV", "OptunaTuner"]
8697}"#,
8698        )
8699        .unwrap();
8700
8701        let graph = compile_pipeline_dsl(&spec).unwrap();
8702        graph.validate().unwrap();
8703        assert!(graph.nodes.iter().any(|node| {
8704            node.kind == NodeKind::Transform
8705                && node.operator.as_ref().unwrap().as_str() == Some("SNV")
8706        }));
8707        assert!(graph.nodes.iter().any(|node| {
8708            node.kind == NodeKind::Tuner
8709                && node.operator.as_ref().unwrap().as_str() == Some("OptunaTuner")
8710        }));
8711    }
8712
8713    #[test]
8714    fn compiles_runtime_data_generation_as_external_generator_node() {
8715        let spec: PipelineDslSpec = serde_json::from_str(
8716            r#"{
8717  "id": "dsl-runtime-data-generation",
8718  "steps": [
8719    {
8720      "kind": "generation",
8721      "id": "generator:synthetic.train",
8722      "operator": "SMOTE",
8723      "params": {"ratio": 0.5},
8724      "shape": {
8725        "fit_rows": "fold_train",
8726        "predict_rows": "fold_validation",
8727        "augmentation_policy": {
8728          "sample_scope": "train_only",
8729          "feature_scope": "none",
8730          "require_origin_id": true,
8731          "inherit_group": true,
8732          "inherit_target": true
8733        }
8734      }
8735    },
8736    {
8737      "kind": "model",
8738      "id": "model:ridge",
8739      "operator": "Ridge"
8740    }
8741  ]
8742}"#,
8743        )
8744        .unwrap();
8745
8746        let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
8747        compiled.graph.validate().unwrap();
8748        let generator = compiled
8749            .graph
8750            .nodes
8751            .iter()
8752            .find(|node| node.id.as_str() == "generator:synthetic.train")
8753            .unwrap();
8754        assert_eq!(generator.kind, NodeKind::Generator);
8755        assert_eq!(generator.operator.as_ref().unwrap().as_str(), Some("SMOTE"));
8756        assert_eq!(generator.metadata["dsl_generation_kind"], "data");
8757        assert!(compiled
8758            .shape_plans
8759            .contains_key(&NodeId::new("generator:synthetic.train").unwrap()));
8760        assert!(compiled.graph.edges.iter().any(|edge| {
8761            edge.source.node_id.as_str() == "generator:synthetic.train"
8762                && edge.source.port_name == "x_out"
8763                && edge.target.node_id.as_str() == "model:ridge"
8764                && edge.target.port_name == "x"
8765                && edge.contract.kind == PortKind::Data
8766        }));
8767    }
8768
8769    #[test]
8770    fn parses_compat_runtime_generation_step() {
8771        let spec = parse_pipeline_dsl_json(
8772            br#"{
8773  "id": "dsl-compat-runtime-generation",
8774  "pipeline": [
8775    {
8776      "generation": "SMOTE",
8777      "id": "generator:compat.synthetic",
8778      "generation_params": {"ratio": 0.25},
8779      "shape": {
8780        "fit_rows": "fold_train",
8781        "predict_rows": "fold_validation",
8782        "augmentation_policy": {
8783          "sample_scope": "train_only",
8784          "feature_scope": "none",
8785          "require_origin_id": true,
8786          "inherit_group": true,
8787          "inherit_target": true
8788        }
8789      }
8790    },
8791    "Ridge"
8792  ]
8793}"#,
8794        )
8795        .unwrap();
8796
8797        let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
8798        let generator = compiled
8799            .graph
8800            .nodes
8801            .iter()
8802            .find(|node| node.id.as_str() == "generator:compat.synthetic")
8803            .unwrap();
8804        assert_eq!(generator.kind, NodeKind::Generator);
8805        assert_eq!(generator.params["ratio"], 0.25);
8806        assert_eq!(generator.metadata["dsl_compat_keyword"], "data_generation");
8807    }
8808
8809    #[test]
8810    fn parses_nirs4all_compat_feature_branch_merge_dict() {
8811        let spec = parse_pipeline_dsl_json(
8812            br#"{
8813  "id": "dsl-nirs4all-compat-feature-merge",
8814  "pipeline": [
8815    {
8816      "branch": {
8817        "snv": ["SNV"],
8818        "msc": ["MSC"]
8819      }
8820    },
8821    {
8822      "merge": {
8823        "features": "all",
8824        "output_as": "features",
8825        "on_missing": "error"
8826      }
8827    },
8828    "PLSRegression"
8829  ]
8830}"#,
8831        )
8832        .unwrap();
8833
8834        let graph = compile_pipeline_dsl(&spec).unwrap();
8835        graph.validate().unwrap();
8836        let merge = graph
8837            .nodes
8838            .iter()
8839            .find(|node| node.kind == NodeKind::FeatureJoin)
8840            .unwrap();
8841        assert_eq!(merge.metadata["merge_mode"], "features");
8842        assert_eq!(merge.metadata["on_missing"], "error");
8843        assert!(merge.metadata.contains_key("dsl_compat_merge"));
8844        assert!(merge.ports.inputs.iter().any(|port| port.name == "snv_x"));
8845        assert!(merge.ports.inputs.iter().any(|port| port.name == "msc_x"));
8846        assert!(graph.nodes.iter().any(|node| node.kind == NodeKind::Model
8847            && node.operator.as_ref().unwrap().as_str() == Some("PLSRegression")));
8848    }
8849
8850    #[test]
8851    fn published_pipeline_dsl_schema_declares_current_contract() {
8852        let schema: serde_json::Value = serde_json::from_str(include_str!(
8853            "../../../docs/contracts/pipeline_dsl.schema.json"
8854        ))
8855        .unwrap();
8856
8857        assert_eq!(schema["$id"], PIPELINE_DSL_SCHEMA_ID);
8858        assert!(schema["oneOf"].is_array());
8859        assert!(schema["$defs"]["canonical_step_kind"]["enum"]
8860            .as_array()
8861            .unwrap()
8862            .iter()
8863            .any(|value| value.as_str() == Some("generator")));
8864        assert!(schema["$defs"]["canonical_step_kind"]["enum"]
8865            .as_array()
8866            .unwrap()
8867            .iter()
8868            .any(|value| value.as_str() == Some("data_generation")));
8869        assert!(schema["$defs"]["canonical_step_kind"]["enum"]
8870            .as_array()
8871            .unwrap()
8872            .iter()
8873            .any(|value| value.as_str() == Some("tuner")));
8874        assert!(schema["$defs"]["compat_generator_key"]["enum"]
8875            .as_array()
8876            .unwrap()
8877            .iter()
8878            .any(|value| value.as_str() == Some("_cartesian_")));
8879        assert!(schema["$defs"]["compat_step_object"]["properties"]
8880            .as_object()
8881            .unwrap()
8882            .contains_key("class"));
8883        assert!(schema["$defs"]["compat_step_object"]["properties"]
8884            .as_object()
8885            .unwrap()
8886            .contains_key("step"));
8887        assert!(schema["$defs"]["pipeline_unit_contract"]["properties"]
8888            .as_object()
8889            .unwrap()
8890            .contains_key("unit_level"));
8891        assert!(schema["$defs"]["entity_unit_level"]["enum"]
8892            .as_array()
8893            .unwrap()
8894            .iter()
8895            .any(|value| value.as_str() == Some("observation")));
8896    }
8897
8898    #[test]
8899    fn refuses_unsafe_shape_plan_from_dsl() {
8900        let spec: PipelineDslSpec = serde_json::from_str(
8901            r#"{
8902  "id": "dsl-unsafe-shape-plan",
8903  "steps": [
8904    {
8905      "kind": "augmentation",
8906      "id": "augment:bad",
8907      "operator": {"type": "LeakyAugmenter"},
8908      "shape": {
8909        "augmentation_policy": {
8910          "sample_scope": "all_partitions"
8911        }
8912      }
8913    }
8914  ]
8915}"#,
8916        )
8917        .unwrap();
8918
8919        let error = compile_pipeline_dsl_with_generation(&spec).unwrap_err();
8920        assert!(format!("{error}").contains("sample augmentation over all partitions"));
8921    }
8922
8923    #[test]
8924    fn refuses_augmentation_without_shape_plan() {
8925        let spec: PipelineDslSpec = serde_json::from_str(
8926            r#"{
8927  "id": "dsl-augmentation-without-shape",
8928  "steps": [
8929    {
8930      "kind": "augmentation",
8931      "id": "augment:missing-shape",
8932      "operator": {"type": "GaussianNoise"}
8933    }
8934  ]
8935}"#,
8936        )
8937        .unwrap();
8938
8939        let error = compile_pipeline_dsl_with_generation(&spec).unwrap_err();
8940        assert!(format!("{error}").contains("requires a shape plan"));
8941    }
8942
8943    #[test]
8944    fn refuses_data_generation_without_shape_plan() {
8945        let spec: PipelineDslSpec = serde_json::from_str(
8946            r#"{
8947  "id": "dsl-generation-without-shape",
8948  "steps": [
8949    {
8950      "kind": "data_generation",
8951      "id": "generator:missing-shape",
8952      "operator": {"type": "SMOTE"}
8953    }
8954  ]
8955}"#,
8956        )
8957        .unwrap();
8958
8959        let error = compile_pipeline_dsl_with_generation(&spec).unwrap_err();
8960        assert!(format!("{error}").contains("requires a shape plan"));
8961    }
8962
8963    #[test]
8964    fn refuses_branch_without_prediction_or_data_output() {
8965        let spec: PipelineDslSpec = serde_json::from_str(
8966            r#"{
8967  "id": "dsl-bad-branch",
8968  "steps": [
8969    {
8970      "kind": "branch",
8971      "branches": [
8972        {
8973          "id": "b0",
8974          "steps": [
8975            {
8976              "kind": "y_transform",
8977              "id": "target:only",
8978              "operator": {"type": "StandardScaler"}
8979            }
8980          ]
8981        }
8982      ]
8983    }
8984  ]
8985}"#,
8986        )
8987        .unwrap();
8988
8989        let error = compile_pipeline_dsl(&spec).unwrap_err();
8990        assert!(format!("{error}")
8991            .contains("must produce at least one model, merge prediction or transformed data"));
8992    }
8993
8994    #[test]
8995    fn dsl_top_level_inner_cv_maps_to_campaign_template() {
8996        let spec: PipelineDslSpec = serde_json::from_str(
8997            r#"{
8998  "id": "dsl-inner-cv-campaign",
8999  "inner_cv": {"kind": "kfold", "n_splits": 4, "shuffle": true, "seed": 7},
9000  "steps": [
9001    {"kind": "model", "id": "model:base", "operator": {"type": "Ridge"}, "params": {"alpha": 0.5}}
9002  ]
9003}"#,
9004        )
9005        .unwrap();
9006
9007        let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
9008        match compiled.campaign_template.inner_cv {
9009            Some(crate::fold::NestedCvSpec::KFold(ref k)) => {
9010                assert_eq!(k.n_splits, 4);
9011                assert!(k.shuffle);
9012                assert_eq!(k.seed, Some(7));
9013            }
9014            ref other => panic!("expected campaign-level KFold inner_cv, got {other:?}"),
9015        }
9016    }
9017
9018    #[test]
9019    fn dsl_model_step_inner_cv_maps_to_node_metadata() {
9020        let spec: PipelineDslSpec = serde_json::from_str(
9021            r#"{
9022  "id": "dsl-inner-cv-node",
9023  "steps": [
9024    {
9025      "kind": "model",
9026      "id": "model:meta",
9027      "operator": {"type": "Ridge"},
9028      "inner_cv": {"kind": "group_kfold", "n_splits": 3}
9029    }
9030  ]
9031}"#,
9032        )
9033        .unwrap();
9034
9035        let graph = compile_pipeline_dsl(&spec).unwrap();
9036        let node = graph
9037            .nodes
9038            .iter()
9039            .find(|node| node.id.as_str() == "model:meta")
9040            .expect("compiled model node exists");
9041        let value = node
9042            .metadata
9043            .get("dsl_inner_cv")
9044            .expect("node carries dsl_inner_cv metadata");
9045        let inner: crate::fold::NestedCvSpec = serde_json::from_value(value.clone()).unwrap();
9046        match inner {
9047            crate::fold::NestedCvSpec::GroupKFold(ref g) => assert_eq!(g.n_splits, 3),
9048            other => panic!("expected node-local GroupKFold inner_cv, got {other:?}"),
9049        }
9050    }
9051
9052    #[test]
9053    fn dsl_absent_inner_cv_leaves_campaign_and_nodes_unset() {
9054        let spec: PipelineDslSpec = serde_json::from_str(
9055            r#"{
9056  "id": "dsl-no-inner-cv",
9057  "steps": [
9058    {"kind": "model", "id": "model:base", "operator": {"type": "Ridge"}}
9059  ]
9060}"#,
9061        )
9062        .unwrap();
9063
9064        let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
9065        assert!(compiled.campaign_template.inner_cv.is_none());
9066        for node in &compiled.graph.nodes {
9067            assert!(!node.metadata.contains_key("dsl_inner_cv"));
9068        }
9069    }
9070
9071    #[test]
9072    fn compat_pipeline_preserves_campaign_and_model_inner_cv() {
9073        // nirs4all-compatible dict form ("pipeline" key) routes through the compat
9074        // lowerer; campaign-global and node-local inner_cv must survive lowering.
9075        let spec = parse_pipeline_dsl_json(
9076            br#"{
9077  "id": "dsl-compat-inner-cv",
9078  "inner_cv": {"kind": "kfold", "n_splits": 5, "shuffle": false, "seed": 3},
9079  "pipeline": [
9080    {"split": {"type": "KFold", "n_splits": 4}},
9081    {"model": "Ridge", "id": "model:base", "inner_cv": {"kind": "group_kfold", "n_splits": 3}}
9082  ]
9083}"#,
9084        )
9085        .unwrap();
9086
9087        match spec.inner_cv {
9088            Some(crate::fold::NestedCvSpec::KFold(ref k)) => assert_eq!(k.n_splits, 5),
9089            ref other => panic!("expected compat campaign-global KFold inner_cv, got {other:?}"),
9090        }
9091
9092        let graph = compile_pipeline_dsl(&spec).unwrap();
9093        let node = graph
9094            .nodes
9095            .iter()
9096            .find(|node| node.id.as_str() == "model:base")
9097            .expect("compat model node exists");
9098        let inner: crate::fold::NestedCvSpec =
9099            serde_json::from_value(node.metadata.get("dsl_inner_cv").cloned().unwrap()).unwrap();
9100        match inner {
9101            crate::fold::NestedCvSpec::GroupKFold(ref g) => assert_eq!(g.n_splits, 3),
9102            other => panic!("expected compat node-local GroupKFold inner_cv, got {other:?}"),
9103        }
9104    }
9105
9106    #[test]
9107    fn compat_merge_model_collapse_preserves_inner_cv() {
9108        // The compat `merge` + `model` stacker shorthand collapses into a
9109        // merge-model step; its node-local inner_cv must reach the graph node.
9110        let spec = parse_pipeline_dsl_json(
9111            br#"{
9112  "id": "dsl-compat-merge-inner-cv",
9113  "pipeline": [
9114    {"_chain_": [
9115      {"_grid_": {"model": ["PLSRegression"], "n_components": [5, 10]}},
9116      {"_grid_": {"model": ["Ridge"], "alpha": [0.1, 1.0]}}
9117    ]},
9118    {"merge": "predictions"},
9119    {"model": "Ridge", "id": "model:meta", "params": {"alpha": 0.5}, "inner_cv": {"kind": "kfold", "n_splits": 4, "shuffle": false, "seed": null}}
9120  ]
9121}"#,
9122        )
9123        .unwrap();
9124
9125        let graph = compile_pipeline_dsl(&spec).unwrap();
9126        let node = graph
9127            .nodes
9128            .iter()
9129            .find(|node| node.id.as_str() == "model:meta")
9130            .expect("compat merge-model node exists");
9131        let inner: crate::fold::NestedCvSpec =
9132            serde_json::from_value(node.metadata.get("dsl_inner_cv").cloned().unwrap()).unwrap();
9133        match inner {
9134            crate::fold::NestedCvSpec::KFold(ref k) => assert_eq!(k.n_splits, 4),
9135            other => panic!("expected merge-model KFold inner_cv, got {other:?}"),
9136        }
9137    }
9138}