1use std::collections::{BTreeMap, BTreeSet};
2
3use serde::{de::DeserializeOwned, Deserialize, Serialize};
4
5use crate::controller::ControllerRegistry;
6use crate::data::{BranchViewMode, BranchViewPlan, DataBinding, DataViewSelector};
7use crate::error::{DagMlError, Result};
8use crate::fold::NestedCvSpec;
9use crate::generation::{
10 generation_spec_fingerprint, GenerationChoice, GenerationDimension, GenerationParamOverride,
11 GenerationSpec, GenerationStrategy,
12};
13use crate::graph::{
14 EdgeContract, EdgeSpec, GraphInterface, GraphSpec, NodeKind, NodeSpec, PortCardinality,
15 PortKind, PortRef, PortSchema, PortSpec,
16};
17use crate::ids::NodeId;
18use crate::plan::{CampaignSpec, SplitInvocation};
19use crate::policy::{
20 AggregationPolicy, AugmentationPolicy, DataModelShapePlan, FeatureSelectionPolicy, FitBoundary,
21 Granularity, LeakageUnitPolicy,
22};
23use crate::relation::EntityUnitLevel;
24
25pub const PIPELINE_DSL_SCHEMA_VERSION: u32 = 1;
26pub const PIPELINE_DSL_SCHEMA_ID: &str =
27 "https://github.com/GBeurier/dag-ml/schemas/pipeline_dsl.v1.schema.json";
28const DSL_MINIMAL_OPERATOR_ALIAS: &str = "dsl_minimal_operator_alias";
29const DSL_REGISTRY_INFERRED_KIND: &str = "dsl_registry_inferred_kind";
30const DSL_COMPAT_ORIGINAL_KEYWORD: &str = "dsl_compat_original_keyword";
31
32#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
33pub struct PipelineDslSpec {
34 pub id: String,
35 #[serde(default)]
36 pub input: PipelineDslDataPort,
37 #[serde(default)]
38 pub output: PipelineDslPredictionPort,
39 #[serde(default)]
40 pub generation_strategy: Option<GenerationStrategy>,
41 #[serde(default)]
42 pub max_variants: Option<usize>,
43 #[serde(default, skip_serializing_if = "Vec::is_empty")]
44 pub generation_dimensions: Vec<PipelineDslGenerationDimension>,
45 #[serde(default, skip_serializing_if = "Option::is_none")]
46 pub campaign_id: Option<String>,
47 #[serde(default)]
48 pub root_seed: Option<u64>,
49 #[serde(default, skip_serializing_if = "Option::is_none")]
50 pub leakage_policy: Option<LeakageUnitPolicy>,
51 #[serde(default, skip_serializing_if = "Option::is_none")]
52 pub aggregation_policy: Option<AggregationPolicy>,
53 #[serde(default, skip_serializing_if = "Option::is_none")]
54 pub split_invocation: Option<SplitInvocation>,
55 #[serde(default, skip_serializing_if = "Option::is_none")]
58 pub inner_cv: Option<NestedCvSpec>,
59 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
60 pub campaign_metadata: BTreeMap<String, serde_json::Value>,
61 #[serde(default, skip_serializing_if = "Vec::is_empty")]
62 pub data_bindings: Vec<DataBinding>,
63 #[serde(default)]
64 pub steps: Vec<PipelineDslStep>,
65 #[serde(default)]
66 pub metadata: BTreeMap<String, serde_json::Value>,
67}
68
69#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
70pub struct PipelineDslDataPort {
71 #[serde(default = "default_input_name")]
72 pub name: String,
73 #[serde(default = "default_data_representation")]
74 pub representation: String,
75 #[serde(default, skip_serializing_if = "Option::is_none")]
76 pub unit_level: Option<EntityUnitLevel>,
77 #[serde(default, skip_serializing_if = "Option::is_none")]
78 pub alignment_key: Option<String>,
79 #[serde(default, skip_serializing_if = "Option::is_none")]
80 pub target_level: Option<EntityUnitLevel>,
81 #[serde(default)]
82 pub description: String,
83}
84
85impl Default for PipelineDslDataPort {
86 fn default() -> Self {
87 Self {
88 name: default_input_name(),
89 representation: default_data_representation(),
90 unit_level: None,
91 alignment_key: None,
92 target_level: None,
93 description: String::new(),
94 }
95 }
96}
97
98#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
99pub struct PipelineDslPredictionPort {
100 #[serde(default = "default_output_name")]
101 pub name: String,
102 #[serde(default, skip_serializing_if = "Option::is_none")]
103 pub representation: Option<String>,
104 #[serde(default, skip_serializing_if = "Option::is_none")]
105 pub unit_level: Option<EntityUnitLevel>,
106 #[serde(default, skip_serializing_if = "Option::is_none")]
107 pub alignment_key: Option<String>,
108 #[serde(default, skip_serializing_if = "Option::is_none")]
109 pub target_level: Option<EntityUnitLevel>,
110 #[serde(default)]
111 pub description: String,
112}
113
114impl Default for PipelineDslPredictionPort {
115 fn default() -> Self {
116 Self {
117 name: default_output_name(),
118 representation: None,
119 unit_level: None,
120 alignment_key: None,
121 target_level: None,
122 description: String::new(),
123 }
124 }
125}
126
127#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
128#[serde(tag = "kind", rename_all = "snake_case")]
129pub enum PipelineDslStep {
130 Transform(PipelineDslOperatorStep),
131 YTransform(PipelineDslOperatorStep),
132 Tag(PipelineDslOperatorStep),
133 Exclude(PipelineDslOperatorStep),
134 Filter(PipelineDslOperatorStep),
135 SampleFilter(PipelineDslOperatorStep),
136 Augmentation(PipelineDslOperatorStep),
137 FeatureAugmentation(PipelineDslOperatorStep),
138 SampleAugmentation(PipelineDslOperatorStep),
139 #[serde(alias = "generation")]
140 DataGeneration(PipelineDslOperatorStep),
141 ConcatTransform(PipelineDslConcatTransformStep),
142 Model(PipelineDslOperatorStep),
143 #[serde(alias = "finetune")]
144 Tuner(PipelineDslOperatorStep),
145 Branch(PipelineDslBranchStep),
146 Generator(PipelineDslGeneratorStep),
147 Sequential(PipelineDslSequenceStep),
148 Merge(PipelineDslMergeStep),
149 MergeModel(PipelineDslMergeModelStep),
150 Chart(PipelineDslOperatorStep),
151}
152
153#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
154pub struct PipelineDslOperatorStep {
155 pub id: NodeId,
156 pub operator: serde_json::Value,
157 #[serde(default)]
158 pub params: BTreeMap<String, serde_json::Value>,
159 #[serde(default)]
160 pub metadata: BTreeMap<String, serde_json::Value>,
161 #[serde(default)]
162 pub seed_label: Option<String>,
163 #[serde(default)]
164 pub representation: Option<String>,
165 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
166 pub train_params: BTreeMap<String, serde_json::Value>,
167 #[serde(
168 default,
169 alias = "finetune_params",
170 skip_serializing_if = "Option::is_none"
171 )]
172 pub tuning: Option<PipelineDslTuningSpec>,
173 #[serde(default, skip_serializing_if = "Vec::is_empty")]
174 pub variants: Vec<PipelineDslVariantChoice>,
175 #[serde(default, alias = "generators", skip_serializing_if = "Vec::is_empty")]
176 pub param_generators: Vec<PipelineDslParamGenerator>,
177 #[serde(default, skip_serializing_if = "Option::is_none")]
178 pub shape: Option<PipelineDslShapePlan>,
179 #[serde(default, skip_serializing_if = "Option::is_none")]
183 pub inner_cv: Option<NestedCvSpec>,
184}
185
186#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
187pub struct PipelineDslTuningSpec {
188 #[serde(default, skip_serializing_if = "Option::is_none")]
189 pub n_trials: Option<usize>,
190 #[serde(default, skip_serializing_if = "Option::is_none")]
191 pub approach: Option<String>,
192 #[serde(default, skip_serializing_if = "Option::is_none")]
193 pub eval_mode: Option<String>,
194 #[serde(default, skip_serializing_if = "Option::is_none")]
195 pub sampler: Option<String>,
196 #[serde(default, skip_serializing_if = "Option::is_none")]
197 pub metric: Option<String>,
198 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
199 pub model_params: BTreeMap<String, serde_json::Value>,
200 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
201 pub train_params: BTreeMap<String, serde_json::Value>,
202 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
203 pub metadata: BTreeMap<String, serde_json::Value>,
204}
205
206#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
207pub struct PipelineDslVariantChoice {
208 pub label: String,
209 #[serde(default)]
210 pub params: BTreeMap<String, serde_json::Value>,
211 #[serde(default)]
212 pub value: Option<serde_json::Value>,
213}
214
215#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
216#[serde(tag = "kind", rename_all = "snake_case")]
217pub enum PipelineDslParamGenerator {
218 Or {
219 #[serde(default, skip_serializing_if = "Option::is_none")]
220 name: Option<String>,
221 param: String,
222 values: Vec<PipelineDslGeneratorValue>,
223 #[serde(default, skip_serializing_if = "Option::is_none")]
224 count: Option<usize>,
225 },
226 Range {
227 #[serde(default, skip_serializing_if = "Option::is_none")]
228 name: Option<String>,
229 param: String,
230 start: f64,
231 stop: f64,
232 step: f64,
233 #[serde(default = "default_true")]
234 inclusive: bool,
235 #[serde(default, skip_serializing_if = "Option::is_none")]
236 count: Option<usize>,
237 },
238 LogRange {
239 #[serde(default, skip_serializing_if = "Option::is_none")]
240 name: Option<String>,
241 param: String,
242 start: f64,
243 stop: f64,
244 count: usize,
245 #[serde(default = "default_log_base")]
246 base: f64,
247 },
248 Grid {
249 #[serde(default, skip_serializing_if = "Option::is_none")]
250 name: Option<String>,
251 params: BTreeMap<String, Vec<PipelineDslGeneratorValue>>,
252 #[serde(default, skip_serializing_if = "Option::is_none")]
253 count: Option<usize>,
254 },
255 Pick {
256 #[serde(default, skip_serializing_if = "Option::is_none")]
257 name: Option<String>,
258 param: String,
259 values: Vec<PipelineDslGeneratorValue>,
260 sizes: Vec<usize>,
261 #[serde(default, skip_serializing_if = "Option::is_none")]
262 count: Option<usize>,
263 },
264 Arrange {
265 #[serde(default, skip_serializing_if = "Option::is_none")]
266 name: Option<String>,
267 param: String,
268 values: Vec<PipelineDslGeneratorValue>,
269 sizes: Vec<usize>,
270 #[serde(default, skip_serializing_if = "Option::is_none")]
271 count: Option<usize>,
272 },
273}
274
275#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
276#[serde(untagged)]
277pub enum PipelineDslGeneratorValue {
278 Labeled {
279 label: String,
280 value: serde_json::Value,
281 },
282 Value(serde_json::Value),
283}
284
285#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
286pub struct PipelineDslGenerationDimension {
287 pub name: String,
288 #[serde(default)]
289 pub choices: Vec<PipelineDslGenerationChoice>,
290}
291
292#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
293pub struct PipelineDslGenerationChoice {
294 pub label: String,
295 #[serde(default)]
296 pub value: Option<serde_json::Value>,
297 #[serde(default)]
298 pub param_overrides: Vec<PipelineDslGenerationParamOverride>,
299}
300
301#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
302pub struct PipelineDslGenerationParamOverride {
303 pub node_id: NodeId,
304 #[serde(default)]
305 pub params: BTreeMap<String, serde_json::Value>,
306}
307
308#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
309pub struct PipelineDslBranchStep {
310 #[serde(default)]
311 pub mode: PipelineDslBranchMode,
312 #[serde(default, skip_serializing_if = "Option::is_none")]
313 pub selector: Option<serde_json::Value>,
314 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
315 pub metadata: BTreeMap<String, serde_json::Value>,
316 pub branches: Vec<PipelineDslBranch>,
317}
318
319#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
320#[serde(rename_all = "snake_case")]
321pub enum PipelineDslBranchMode {
322 #[default]
323 Duplication,
324 Separation,
325 BySource,
326 ByMetadata,
327 ByTag,
328 ByFilter,
329}
330
331#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
332pub struct PipelineDslBranch {
333 pub id: String,
334 #[serde(default, skip_serializing_if = "Option::is_none")]
335 pub selector: Option<serde_json::Value>,
336 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
337 pub metadata: BTreeMap<String, serde_json::Value>,
338 #[serde(default)]
339 pub steps: Vec<PipelineDslStep>,
340}
341
342#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
343pub struct PipelineDslSequenceStep {
344 #[serde(default, skip_serializing_if = "Option::is_none")]
345 pub id: Option<NodeId>,
346 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
347 pub metadata: BTreeMap<String, serde_json::Value>,
348 #[serde(default)]
349 pub steps: Vec<PipelineDslStep>,
350}
351
352#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
353pub struct PipelineDslGeneratorStep {
354 pub id: NodeId,
355 #[serde(default)]
356 pub mode: PipelineDslGeneratorMode,
357 #[serde(default, skip_serializing_if = "Vec::is_empty")]
358 pub branches: Vec<PipelineDslBranch>,
359 #[serde(default, skip_serializing_if = "Vec::is_empty")]
360 pub stages: Vec<PipelineDslGeneratorStage>,
361 #[serde(default, skip_serializing_if = "Option::is_none")]
362 pub pick: Option<PipelineDslSelectionSpec>,
363 #[serde(default, skip_serializing_if = "Option::is_none")]
364 pub arrange: Option<PipelineDslSelectionSpec>,
365 #[serde(default, skip_serializing_if = "Option::is_none")]
366 pub then_pick: Option<PipelineDslSelectionSpec>,
367 #[serde(default, skip_serializing_if = "Option::is_none")]
368 pub then_arrange: Option<PipelineDslSelectionSpec>,
369 #[serde(default, skip_serializing_if = "Option::is_none")]
370 pub count: Option<usize>,
371 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
372 pub metadata: BTreeMap<String, serde_json::Value>,
373}
374
375#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
376#[serde(rename_all = "snake_case")]
377pub enum PipelineDslGeneratorMode {
378 #[default]
379 Or,
380 Cartesian,
381}
382
383#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
384pub struct PipelineDslGeneratorStage {
385 pub id: String,
386 #[serde(default, skip_serializing_if = "Option::is_none")]
387 pub selector: Option<serde_json::Value>,
388 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
389 pub metadata: BTreeMap<String, serde_json::Value>,
390 #[serde(default)]
391 pub branches: Vec<PipelineDslBranch>,
392}
393
394#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
395#[serde(untagged)]
396pub enum PipelineDslSelectionSpec {
397 Single(usize),
398 Range([usize; 2]),
399}
400
401#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
402pub struct PipelineDslConcatTransformStep {
403 pub id: NodeId,
404 #[serde(default)]
405 pub branches: Vec<PipelineDslConcatBranch>,
406 #[serde(default)]
407 pub metadata: BTreeMap<String, serde_json::Value>,
408 #[serde(default)]
409 pub seed_label: Option<String>,
410 #[serde(default)]
411 pub representation: Option<String>,
412 #[serde(default, skip_serializing_if = "Vec::is_empty")]
413 pub variants: Vec<PipelineDslVariantChoice>,
414 #[serde(default, alias = "generators", skip_serializing_if = "Vec::is_empty")]
415 pub param_generators: Vec<PipelineDslParamGenerator>,
416 #[serde(default, skip_serializing_if = "Option::is_none")]
417 pub shape: Option<PipelineDslShapePlan>,
418}
419
420#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
421pub struct PipelineDslConcatBranch {
422 pub id: String,
423 #[serde(default)]
424 pub steps: Vec<PipelineDslOperatorStep>,
425}
426
427#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
428pub struct PipelineDslMergeStep {
429 pub id: NodeId,
430 #[serde(default = "default_merge_mode")]
431 pub merge_mode: String,
432 #[serde(default)]
433 pub output_as: PipelineDslMergeOutput,
434 #[serde(default = "default_true")]
435 pub include_original_data: bool,
436 #[serde(default, skip_serializing_if = "Option::is_none")]
437 pub on_missing: Option<String>,
438 #[serde(default, skip_serializing_if = "Vec::is_empty")]
439 pub selectors: Vec<PipelineDslMergeSelector>,
440 #[serde(default)]
441 pub metadata: BTreeMap<String, serde_json::Value>,
442 #[serde(default)]
443 pub seed_label: Option<String>,
444 #[serde(default)]
445 pub representation: Option<String>,
446 #[serde(default, skip_serializing_if = "Vec::is_empty")]
447 pub variants: Vec<PipelineDslVariantChoice>,
448 #[serde(default, alias = "generators", skip_serializing_if = "Vec::is_empty")]
449 pub param_generators: Vec<PipelineDslParamGenerator>,
450 #[serde(default, skip_serializing_if = "Option::is_none")]
451 pub shape: Option<PipelineDslShapePlan>,
452}
453
454#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
455#[serde(rename_all = "snake_case")]
456pub enum PipelineDslMergeOutput {
457 #[default]
458 Features,
459 Predictions,
460 Sources,
461}
462
463#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
464pub struct PipelineDslMergeSelector {
465 #[serde(default, skip_serializing_if = "Option::is_none")]
466 pub input_name: Option<String>,
467 #[serde(default, skip_serializing_if = "Option::is_none")]
468 pub branch: Option<String>,
469 #[serde(default, skip_serializing_if = "Option::is_none")]
470 pub model: Option<NodeId>,
471 #[serde(default, skip_serializing_if = "Option::is_none")]
472 pub select: Option<serde_json::Value>,
473 #[serde(default, skip_serializing_if = "Option::is_none")]
474 pub metric: Option<String>,
475 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
476 pub metadata: BTreeMap<String, serde_json::Value>,
477}
478
479#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
480pub struct PipelineDslMergeModelStep {
481 pub id: NodeId,
482 pub operator: serde_json::Value,
483 #[serde(default)]
484 pub params: BTreeMap<String, serde_json::Value>,
485 #[serde(default)]
486 pub metadata: BTreeMap<String, serde_json::Value>,
487 #[serde(default)]
488 pub seed_label: Option<String>,
489 #[serde(default = "default_true")]
490 pub include_original_data: bool,
491 #[serde(default = "default_merge_mode")]
492 pub merge_mode: String,
493 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
494 pub train_params: BTreeMap<String, serde_json::Value>,
495 #[serde(
496 default,
497 alias = "finetune_params",
498 skip_serializing_if = "Option::is_none"
499 )]
500 pub tuning: Option<PipelineDslTuningSpec>,
501 #[serde(default, skip_serializing_if = "Vec::is_empty")]
502 pub variants: Vec<PipelineDslVariantChoice>,
503 #[serde(default, alias = "generators", skip_serializing_if = "Vec::is_empty")]
504 pub param_generators: Vec<PipelineDslParamGenerator>,
505 #[serde(default, skip_serializing_if = "Option::is_none")]
506 pub shape: Option<PipelineDslShapePlan>,
507 #[serde(default, skip_serializing_if = "Option::is_none")]
510 pub inner_cv: Option<NestedCvSpec>,
511}
512
513#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
514pub struct PipelineDslShapePlan {
515 #[serde(default)]
516 pub input_granularity: Option<Granularity>,
517 #[serde(default)]
518 pub target_granularity: Option<Granularity>,
519 #[serde(default)]
520 pub fit_rows: Option<FitBoundary>,
521 #[serde(default)]
522 pub predict_rows: Option<FitBoundary>,
523 #[serde(default)]
524 pub feature_namespace: Option<String>,
525 #[serde(default)]
526 pub feature_schema_fingerprint: Option<String>,
527 #[serde(default)]
528 pub target_space: Option<String>,
529 #[serde(default)]
530 pub aggregation_policy: Option<AggregationPolicy>,
531 #[serde(default)]
532 pub augmentation_policy: Option<AugmentationPolicy>,
533 #[serde(default)]
534 pub selection_policy: Option<FeatureSelectionPolicy>,
535}
536
537#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
538pub struct CompiledPipelineDsl {
539 pub graph: GraphSpec,
540 pub generation: GenerationSpec,
541 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
542 pub shape_plans: BTreeMap<NodeId, DataModelShapePlan>,
543 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
544 pub data_bindings: BTreeMap<NodeId, Vec<DataBinding>>,
545 #[serde(default, skip_serializing_if = "Vec::is_empty")]
546 pub branch_view_plans: Vec<BranchViewPlan>,
547 pub campaign_template: CampaignSpec,
548 #[serde(default, skip_serializing_if = "Option::is_none")]
549 pub generation_fingerprint: Option<String>,
550}
551
552pub fn compile_pipeline_dsl(spec: &PipelineDslSpec) -> Result<GraphSpec> {
553 Ok(compile_pipeline_dsl_with_generation(spec)?.graph)
554}
555
556pub fn compile_pipeline_dsl_with_controller_registry(
557 spec: &PipelineDslSpec,
558 registry: &ControllerRegistry,
559) -> Result<GraphSpec> {
560 Ok(compile_pipeline_dsl_with_generation_and_controller_registry(spec, registry)?.graph)
561}
562
563pub fn parse_pipeline_dsl_json(data: &[u8]) -> Result<PipelineDslSpec> {
564 match serde_json::from_slice::<PipelineDslSpec>(data) {
565 Ok(spec) if validate_pipeline_dsl(&spec).is_ok() => Ok(spec),
566 Ok(spec) => {
567 let strict_error = validate_pipeline_dsl(&spec)
568 .err()
569 .map(|error| error.to_string())
570 .unwrap_or_else(|| "unknown validation error".to_string());
571 let value = serde_json::from_slice::<serde_json::Value>(data).map_err(|error| {
572 DagMlError::GraphValidation(format!("failed to parse pipeline DSL JSON: {error}"))
573 })?;
574 lower_nirs4all_compat_pipeline_dsl(&value).map_err(|compat_error| {
575 DagMlError::GraphValidation(format!(
576 "failed to parse pipeline DSL as valid canonical PipelineDslSpec ({strict_error}) or nirs4all-compatible JSON ({compat_error})"
577 ))
578 })
579 }
580 Err(strict_error) => {
581 let value = serde_json::from_slice::<serde_json::Value>(data).map_err(|error| {
582 DagMlError::GraphValidation(format!("failed to parse pipeline DSL JSON: {error}"))
583 })?;
584 lower_nirs4all_compat_pipeline_dsl(&value).map_err(|compat_error| {
585 DagMlError::GraphValidation(format!(
586 "failed to parse pipeline DSL as canonical PipelineDslSpec ({strict_error}) or nirs4all-compatible JSON ({compat_error})"
587 ))
588 })
589 }
590 }
591}
592
593pub fn lower_nirs4all_compat_pipeline_dsl(value: &serde_json::Value) -> Result<PipelineDslSpec> {
594 CompatDslLowerer::default().lower_root(value)
595}
596
597pub fn resolve_pipeline_dsl_minimal_aliases(
598 spec: &PipelineDslSpec,
599 registry: &ControllerRegistry,
600) -> Result<PipelineDslSpec> {
601 let mut resolved = spec.clone();
602 for step in &mut resolved.steps {
603 resolve_step_minimal_aliases(step, registry)?;
604 }
605 validate_pipeline_dsl(&resolved)?;
606 Ok(resolved)
607}
608
609pub fn compile_pipeline_dsl_with_generation_and_controller_registry(
610 spec: &PipelineDslSpec,
611 registry: &ControllerRegistry,
612) -> Result<CompiledPipelineDsl> {
613 let resolved = resolve_pipeline_dsl_minimal_aliases(spec, registry)?;
614 compile_pipeline_dsl_with_generation(&resolved)
615}
616
617pub fn compile_pipeline_dsl_with_generation(spec: &PipelineDslSpec) -> Result<CompiledPipelineDsl> {
618 validate_pipeline_dsl(spec)?;
619 let input_representation = Some(spec.input.representation.clone());
620 let external_data = DataSource {
621 node_id: None,
622 port_name: spec.input.name.clone(),
623 representation: input_representation.clone(),
624 };
625 let mut compiler = PipelineCompiler {
626 graph_id: spec.id.clone(),
627 input_representation: input_representation.clone(),
628 nodes: Vec::new(),
629 edges: Vec::new(),
630 generation_dimensions: Vec::new(),
631 shape_plans: BTreeMap::new(),
632 branch_view_plans: Vec::new(),
633 };
634 let mut sequence_state = SequenceCompileState::new(external_data.clone());
635
636 for step in &spec.steps {
637 compiler.compile_top_level_step(step, &external_data, &mut sequence_state)?;
638 }
639
640 let mut generation_dimensions =
641 compile_explicit_generation_dimensions(&spec.generation_dimensions, &compiler.nodes)?;
642 generation_dimensions.extend(compiler.generation_dimensions);
643 let generation = build_generation_spec(
644 spec.generation_strategy,
645 spec.max_variants,
646 generation_dimensions,
647 )?;
648 let generation_fingerprint = if generation.strategy == GenerationStrategy::None {
649 None
650 } else {
651 Some(generation_spec_fingerprint(&generation)?)
652 };
653 let mut interface_input = data_port(
654 &spec.input.name,
655 input_representation.clone(),
656 &spec.input.description,
657 );
658 apply_data_unit_contract(&mut interface_input, &spec.input);
659 let mut interface_output = prediction_port(&spec.output.name, &spec.output.description);
660 apply_prediction_unit_contract(&mut interface_output, &spec.output);
661
662 let graph = GraphSpec {
663 id: spec.id.clone(),
664 interface: GraphInterface {
665 inputs: vec![interface_input],
666 outputs: vec![interface_output],
667 },
668 nodes: compiler.nodes,
669 edges: compiler.edges,
670 search_space_fingerprint: generation_fingerprint.clone(),
671 metadata: spec.metadata.clone(),
672 };
673 graph.validate()?;
674 validate_shape_plan_targets(&compiler.shape_plans, &graph)?;
675 let data_bindings = compile_data_bindings(&spec.data_bindings, &graph)?;
676 let campaign_template = build_campaign_template(
677 spec,
678 &generation,
679 &compiler.shape_plans,
680 &data_bindings,
681 &compiler.branch_view_plans,
682 )?;
683 Ok(CompiledPipelineDsl {
684 graph,
685 generation,
686 shape_plans: compiler.shape_plans,
687 data_bindings,
688 branch_view_plans: compiler.branch_view_plans,
689 campaign_template,
690 generation_fingerprint,
691 })
692}
693
694fn resolve_step_minimal_aliases(
695 step: &mut PipelineDslStep,
696 registry: &ControllerRegistry,
697) -> Result<()> {
698 if let Some(resolved) = resolve_operator_step_minimal_alias(step, registry)? {
699 *step = resolved;
700 }
701 match step {
702 PipelineDslStep::Branch(branch) => {
703 for branch in &mut branch.branches {
704 for child in &mut branch.steps {
705 resolve_step_minimal_aliases(child, registry)?;
706 }
707 }
708 }
709 PipelineDslStep::Generator(generator) => {
710 for branch in &mut generator.branches {
711 for child in &mut branch.steps {
712 resolve_step_minimal_aliases(child, registry)?;
713 }
714 }
715 for stage in &mut generator.stages {
716 for branch in &mut stage.branches {
717 for child in &mut branch.steps {
718 resolve_step_minimal_aliases(child, registry)?;
719 }
720 }
721 }
722 }
723 PipelineDslStep::Sequential(sequence) => {
724 for child in &mut sequence.steps {
725 resolve_step_minimal_aliases(child, registry)?;
726 }
727 }
728 _ => {}
729 }
730 Ok(())
731}
732
733fn resolve_operator_step_minimal_alias(
734 step: &PipelineDslStep,
735 registry: &ControllerRegistry,
736) -> Result<Option<PipelineDslStep>> {
737 let Some((current_kind, operator_step)) = operator_step_node_kind(step) else {
738 return Ok(None);
739 };
740 if !is_minimal_operator_alias(operator_step) {
741 return Ok(None);
742 }
743 let Some(inferred_kind) = registry.infer_operator_kind(&operator_step.operator)? else {
744 return Ok(None);
745 };
746 if inferred_kind == current_kind {
747 return Ok(None);
748 }
749 let mut resolved = operator_step.clone();
750 annotate_registry_inferred_operator_step(&mut resolved, &inferred_kind)?;
751 Ok(Some(operator_pipeline_step_for_node_kind(
752 inferred_kind,
753 resolved,
754 )?))
755}
756
757fn operator_step_node_kind(step: &PipelineDslStep) -> Option<(NodeKind, &PipelineDslOperatorStep)> {
758 match step {
759 PipelineDslStep::Transform(step) => Some((NodeKind::Transform, step)),
760 PipelineDslStep::YTransform(step) => Some((NodeKind::YTransform, step)),
761 PipelineDslStep::Tag(step) => Some((NodeKind::Tag, step)),
762 PipelineDslStep::Exclude(step) => Some((NodeKind::Exclude, step)),
763 PipelineDslStep::Filter(step) | PipelineDslStep::SampleFilter(step) => {
764 Some((NodeKind::Exclude, step))
765 }
766 PipelineDslStep::Augmentation(step)
767 | PipelineDslStep::FeatureAugmentation(step)
768 | PipelineDslStep::SampleAugmentation(step) => Some((NodeKind::Augmentation, step)),
769 PipelineDslStep::DataGeneration(step) => Some((NodeKind::Generator, step)),
770 PipelineDslStep::Model(step) => Some((NodeKind::Model, step)),
771 PipelineDslStep::Tuner(step) => Some((NodeKind::Tuner, step)),
772 PipelineDslStep::Chart(step) => Some((NodeKind::Chart, step)),
773 _ => None,
774 }
775}
776
777fn is_minimal_operator_alias(step: &PipelineDslOperatorStep) -> bool {
778 step.metadata
779 .get(DSL_MINIMAL_OPERATOR_ALIAS)
780 .and_then(serde_json::Value::as_bool)
781 .unwrap_or(false)
782}
783
784fn annotate_registry_inferred_operator_step(
785 step: &mut PipelineDslOperatorStep,
786 inferred_kind: &NodeKind,
787) -> Result<()> {
788 if let Some(keyword) = step.metadata.get("dsl_compat_keyword").cloned() {
789 step.metadata
790 .entry(DSL_COMPAT_ORIGINAL_KEYWORD.to_string())
791 .or_insert(keyword);
792 }
793 step.metadata.insert(
794 "dsl_compat_keyword".to_string(),
795 serde_json::Value::String(compat_keyword_for_node_kind(inferred_kind)?.to_string()),
796 );
797 step.metadata.insert(
798 DSL_REGISTRY_INFERRED_KIND.to_string(),
799 serde_json::to_value(inferred_kind).map_err(|error| {
800 DagMlError::GraphValidation(format!(
801 "failed to serialize registry-inferred operator kind: {error}"
802 ))
803 })?,
804 );
805 Ok(())
806}
807
808fn operator_pipeline_step_for_node_kind(
809 kind: NodeKind,
810 step: PipelineDslOperatorStep,
811) -> Result<PipelineDslStep> {
812 match kind {
813 NodeKind::Transform => Ok(PipelineDslStep::Transform(step)),
814 NodeKind::YTransform => Ok(PipelineDslStep::YTransform(step)),
815 NodeKind::Tag => Ok(PipelineDslStep::Tag(step)),
816 NodeKind::Exclude => Ok(PipelineDslStep::Exclude(step)),
817 NodeKind::Augmentation => Ok(PipelineDslStep::Augmentation(step)),
818 NodeKind::Generator => Ok(PipelineDslStep::DataGeneration(step)),
819 NodeKind::Model => Ok(PipelineDslStep::Model(step)),
820 NodeKind::Tuner => Ok(PipelineDslStep::Tuner(step)),
821 NodeKind::Chart => Ok(PipelineDslStep::Chart(step)),
822 unsupported => Err(DagMlError::GraphValidation(format!(
823 "minimal operator alias matched unsupported node kind {:?}; use explicit DSL syntax",
824 unsupported
825 ))),
826 }
827}
828
829fn compat_keyword_for_node_kind(kind: &NodeKind) -> Result<&'static str> {
830 match kind {
831 NodeKind::Transform => Ok("preprocessing"),
832 NodeKind::YTransform => Ok("y_processing"),
833 NodeKind::Tag => Ok("tag"),
834 NodeKind::Exclude => Ok("exclude"),
835 NodeKind::Augmentation => Ok("augmentation"),
836 NodeKind::Generator => Ok("data_generation"),
837 NodeKind::Model => Ok("model"),
838 NodeKind::Tuner => Ok("tuner"),
839 NodeKind::Chart => Ok("chart"),
840 unsupported => Err(DagMlError::GraphValidation(format!(
841 "minimal operator alias matched unsupported node kind {:?}; use explicit DSL syntax",
842 unsupported
843 ))),
844 }
845}
846
847fn validate_pipeline_dsl(spec: &PipelineDslSpec) -> Result<()> {
848 if spec.id.trim().is_empty() {
849 return Err(DagMlError::GraphValidation(
850 "pipeline DSL graph id must not be empty".to_string(),
851 ));
852 }
853 if spec.input.name.trim().is_empty() {
854 return Err(DagMlError::GraphValidation(
855 "pipeline DSL input name must not be empty".to_string(),
856 ));
857 }
858 if spec.input.representation.trim().is_empty() {
859 return Err(DagMlError::GraphValidation(
860 "pipeline DSL input representation must not be empty".to_string(),
861 ));
862 }
863 if spec.output.name.trim().is_empty() {
864 return Err(DagMlError::GraphValidation(
865 "pipeline DSL output name must not be empty".to_string(),
866 ));
867 }
868 if spec.steps.is_empty() {
869 return Err(DagMlError::GraphValidation(
870 "pipeline DSL must contain at least one step".to_string(),
871 ));
872 }
873 Ok(())
874}
875
876struct PipelineCompiler {
877 graph_id: String,
878 input_representation: Option<String>,
879 nodes: Vec<NodeSpec>,
880 edges: Vec<EdgeSpec>,
881 generation_dimensions: Vec<GenerationDimension>,
882 shape_plans: BTreeMap<NodeId, DataModelShapePlan>,
883 branch_view_plans: Vec<BranchViewPlan>,
884}
885
886#[derive(Clone, Debug)]
887struct DataSource {
888 node_id: Option<NodeId>,
889 port_name: String,
890 representation: Option<String>,
891}
892
893#[derive(Clone, Debug)]
894struct PredictionSource {
895 node_id: NodeId,
896 port_name: String,
897 input_name: String,
898 branch_id: Option<String>,
899}
900
901#[derive(Clone, Debug)]
902struct BranchDataSource {
903 source: DataSource,
904 input_name: String,
905 branch_id: Option<String>,
906}
907
908#[derive(Clone, Debug, Default)]
909struct BranchCompileOutput {
910 predictions: Vec<PredictionSource>,
911 data_sources: Vec<BranchDataSource>,
912}
913
914#[derive(Clone, Debug)]
915struct SequenceCompileState {
916 current_data: DataSource,
917 pending_predictions: Vec<PredictionSource>,
918 pending_branch_data: Vec<BranchDataSource>,
919}
920
921impl SequenceCompileState {
922 fn new(current_data: DataSource) -> Self {
923 Self {
924 current_data,
925 pending_predictions: Vec::new(),
926 pending_branch_data: Vec::new(),
927 }
928 }
929
930 fn clear_pending(&mut self) {
931 self.pending_predictions.clear();
932 self.pending_branch_data.clear();
933 }
934}
935
936#[derive(Clone, Debug)]
937enum MergeOutputSource {
938 Data(DataSource),
939 Prediction(PredictionSource),
940}
941
942#[derive(Clone, Debug)]
943struct GeneratedSequence {
944 id: String,
945 labels: Vec<String>,
946 steps: Vec<PipelineDslStep>,
947 metadata: BTreeMap<String, serde_json::Value>,
948}
949
950#[derive(Clone, Debug, Default)]
951struct CompatGenerationAttachment {
952 variants: Vec<PipelineDslVariantChoice>,
953 param_generators: Vec<PipelineDslParamGenerator>,
954}
955
956#[derive(Default)]
957struct CompatDslLowerer {
958 node_counter: usize,
959 generator_counter: usize,
960 split_invocation: Option<SplitInvocation>,
961 metadata: BTreeMap<String, serde_json::Value>,
962}
963
964#[derive(Clone, Copy, Debug, Eq, PartialEq)]
965enum CompatPlainOperatorKind {
966 Transform,
967 Model,
968 Tuner,
969 Split,
970 Chart,
971}
972
973impl CompatDslLowerer {
974 fn lower_root(mut self, value: &serde_json::Value) -> Result<PipelineDslSpec> {
975 let root = value.as_object();
976 let pipeline = match value {
977 serde_json::Value::Array(_) => value,
978 serde_json::Value::Object(object) => object
979 .get("pipeline")
980 .or_else(|| object.get("steps"))
981 .ok_or_else(|| {
982 DagMlError::GraphValidation(
983 "nirs4all-compatible pipeline DSL must be a JSON array or an object with `pipeline`/`steps`".to_string(),
984 )
985 })?,
986 _ => {
987 return Err(DagMlError::GraphValidation(
988 "nirs4all-compatible pipeline DSL must be a JSON array or object".to_string(),
989 ));
990 }
991 };
992 let pipeline = pipeline.as_array().ok_or_else(|| {
993 DagMlError::GraphValidation(
994 "nirs4all-compatible pipeline field must be an array".to_string(),
995 )
996 })?;
997 let steps = self.lower_steps(pipeline, "pipeline")?;
998 let id = root
999 .and_then(|object| object.get("id"))
1000 .and_then(serde_json::Value::as_str)
1001 .unwrap_or("dsl-nirs4all-compat")
1002 .to_string();
1003 let mut metadata: BTreeMap<String, serde_json::Value> =
1004 optional_root_field(root, "metadata")?.unwrap_or_default();
1005 metadata.extend(std::mem::take(&mut self.metadata));
1006 metadata.insert(
1007 "dsl_compat_profile".to_string(),
1008 serde_json::Value::String("nirs4all_json_v1".to_string()),
1009 );
1010 let root_split = optional_root_field(root, "split_invocation")?;
1011 let split_invocation = match (root_split, self.split_invocation) {
1012 (Some(_), Some(_)) => {
1013 return Err(DagMlError::GraphValidation(
1014 "nirs4all-compatible pipeline declares split_invocation and a pipeline split step".to_string(),
1015 ));
1016 }
1017 (Some(split), None) | (None, Some(split)) => Some(split),
1018 (None, None) => None,
1019 };
1020 Ok(PipelineDslSpec {
1021 inner_cv: optional_root_field(root, "inner_cv")?,
1022 id,
1023 input: optional_root_field(root, "input")?.unwrap_or_default(),
1024 output: optional_root_field(root, "output")?.unwrap_or_default(),
1025 generation_strategy: optional_root_field(root, "generation_strategy")?,
1026 max_variants: optional_root_field(root, "max_variants")?,
1027 generation_dimensions: optional_root_field(root, "generation_dimensions")?
1028 .unwrap_or_default(),
1029 campaign_id: optional_root_field(root, "campaign_id")?,
1030 root_seed: optional_root_field(root, "root_seed")?,
1031 leakage_policy: optional_root_field(root, "leakage_policy")?,
1032 aggregation_policy: optional_root_field(root, "aggregation_policy")?,
1033 split_invocation,
1034 campaign_metadata: optional_root_field(root, "campaign_metadata")?.unwrap_or_default(),
1035 data_bindings: optional_root_field(root, "data_bindings")?.unwrap_or_default(),
1036 steps,
1037 metadata,
1038 })
1039 }
1040
1041 fn lower_steps(
1042 &mut self,
1043 values: &[serde_json::Value],
1044 path: &str,
1045 ) -> Result<Vec<PipelineDslStep>> {
1046 let mut lowered = Vec::new();
1047 let mut index = 0usize;
1048 while index < values.len() {
1049 let current_path = format!("{path}[{index}]");
1050 if self.consume_side_effect_step(&values[index], ¤t_path)? {
1051 index += 1;
1052 continue;
1053 }
1054 if let Some(attachment) =
1055 self.parse_attached_generation(&values[index], ¤t_path)?
1056 {
1057 if value_can_receive_generation_attachment(&values[index]) {
1058 let mut attached = self.lower_value_with_attachment(
1059 &values[index],
1060 ¤t_path,
1061 attachment,
1062 )?;
1063 lowered.append(&mut attached);
1064 index += 1;
1065 continue;
1066 }
1067 let next = values.get(index + 1).ok_or_else(|| {
1068 DagMlError::GraphValidation(format!(
1069 "{current_path} declares a parameter generator but has no following operator/model step"
1070 ))
1071 })?;
1072 let mut attached = self.lower_value_with_attachment(
1073 next,
1074 &format!("{path}[{}]", index + 1),
1075 attachment,
1076 )?;
1077 lowered.append(&mut attached);
1078 index += 2;
1079 continue;
1080 }
1081 if let Some(merge_model) =
1082 self.lower_merge_followed_by_model(values, index, ¤t_path)?
1083 {
1084 lowered.push(PipelineDslStep::MergeModel(merge_model));
1085 index += 2;
1086 continue;
1087 }
1088
1089 let steps = self.lower_value_as_steps(&values[index], ¤t_path)?;
1090 if let [PipelineDslStep::Generator(generator)] = steps.as_slice() {
1091 if !generator_step_has_prediction(generator) {
1092 if let Some((combined, consumed)) = self.combine_data_generator_with_following(
1093 generator.clone(),
1094 &values[index + 1..],
1095 path,
1096 index + 1,
1097 )? {
1098 lowered.push(PipelineDslStep::Generator(combined));
1099 index += consumed + 1;
1100 continue;
1101 }
1102 }
1103 }
1104 lowered.extend(steps);
1105 index += 1;
1106 }
1107 Ok(lowered)
1108 }
1109
1110 fn consume_side_effect_step(&mut self, value: &serde_json::Value, path: &str) -> Result<bool> {
1111 if compat_plain_operator_kind(value) == CompatPlainOperatorKind::Split {
1112 self.set_split_invocation(self.lower_plain_split_invocation(value, path)?, path)?;
1113 return Ok(true);
1114 }
1115 let Some(object) = value.as_object() else {
1116 return Ok(false);
1117 };
1118 if is_comment_only_object(object) {
1119 return Ok(true);
1120 }
1121 if let Some(split) = object.get("split") {
1122 self.set_split_invocation(self.lower_split_invocation(split, object, path)?, path)?;
1123 return Ok(true);
1124 }
1125 if let Some(sources) = object.get("sources") {
1126 self.metadata
1127 .insert("compat_sources".to_string(), sources.clone());
1128 return Ok(true);
1129 }
1130 Ok(false)
1131 }
1132
1133 fn lower_value_as_steps(
1134 &mut self,
1135 value: &serde_json::Value,
1136 path: &str,
1137 ) -> Result<Vec<PipelineDslStep>> {
1138 match value {
1139 serde_json::Value::Null => Ok(Vec::new()),
1140 serde_json::Value::Array(children) => {
1141 Ok(vec![PipelineDslStep::Sequential(PipelineDslSequenceStep {
1142 id: None,
1143 metadata: BTreeMap::new(),
1144 steps: self.lower_steps(children, path)?,
1145 })])
1146 }
1147 serde_json::Value::String(_) => {
1148 let step = match compat_plain_operator_kind(value) {
1149 CompatPlainOperatorKind::Transform => PipelineDslStep::Transform(
1150 self.compat_operator_step(None, "preprocessing", value, None, None)?,
1151 ),
1152 CompatPlainOperatorKind::Model => PipelineDslStep::Model(
1153 self.compat_operator_step(None, "model", value, None, None)?,
1154 ),
1155 CompatPlainOperatorKind::Tuner => PipelineDslStep::Tuner(
1156 self.compat_operator_step(None, "tuner", value, None, None)?,
1157 ),
1158 CompatPlainOperatorKind::Chart => PipelineDslStep::Chart(
1159 self.compat_operator_step(None, "chart", value, None, None)?,
1160 ),
1161 CompatPlainOperatorKind::Split => {
1162 return Err(DagMlError::GraphValidation(format!(
1163 "{path} splitter alias was not consumed as a campaign split"
1164 )));
1165 }
1166 };
1167 Ok(vec![step])
1168 }
1169 serde_json::Value::Object(object) => {
1170 if object.contains_key("kind") {
1171 let step = serde_json::from_value::<PipelineDslStep>(value.clone()).map_err(
1172 |error| {
1173 DagMlError::GraphValidation(format!(
1174 "failed to parse canonical DSL step at {path}: {error}"
1175 ))
1176 },
1177 )?;
1178 return Ok(vec![step]);
1179 }
1180 if self.consume_side_effect_step(value, path)? {
1181 return Ok(Vec::new());
1182 }
1183 if let Some(operator) =
1184 first_object_value(object, &["preprocessing", "processing", "transform"])
1185 {
1186 return Ok(vec![PipelineDslStep::Transform(
1187 self.compat_operator_step(
1188 Some(object),
1189 "preprocessing",
1190 operator,
1191 None,
1192 None,
1193 )?,
1194 )]);
1195 }
1196 if let Some(operator) = first_object_value(object, &["y_processing", "y_transform"])
1197 {
1198 return Ok(vec![PipelineDslStep::YTransform(
1199 self.compat_operator_step(
1200 Some(object),
1201 "y_processing",
1202 operator,
1203 None,
1204 None,
1205 )?,
1206 )]);
1207 }
1208 if let Some(operator) = object.get("tag") {
1209 return Ok(vec![PipelineDslStep::Tag(self.compat_operator_step(
1210 Some(object),
1211 "tag",
1212 operator,
1213 None,
1214 None,
1215 )?)]);
1216 }
1217 if let Some(operator) = object.get("exclude") {
1218 return Ok(vec![PipelineDslStep::Exclude(self.compat_operator_step(
1219 Some(object),
1220 "exclude",
1221 operator,
1222 None,
1223 None,
1224 )?)]);
1225 }
1226 if let Some(operator) = object.get("filter") {
1227 return Ok(vec![PipelineDslStep::Filter(self.compat_operator_step(
1228 Some(object),
1229 "filter",
1230 operator,
1231 None,
1232 None,
1233 )?)]);
1234 }
1235 if let Some(operator) = object.get("sample_filter") {
1236 return Ok(vec![PipelineDslStep::SampleFilter(
1237 self.compat_operator_step(
1238 Some(object),
1239 "sample_filter",
1240 operator,
1241 None,
1242 None,
1243 )?,
1244 )]);
1245 }
1246 if let Some(operator) = object.get("sample_augmentation") {
1247 return Ok(vec![PipelineDslStep::SampleAugmentation(
1248 self.compat_operator_step(
1249 Some(object),
1250 "sample_augmentation",
1251 operator,
1252 None,
1253 Some(compat_augmentation_shape("sample", object)?),
1254 )?,
1255 )]);
1256 }
1257 if let Some(operator) = object.get("feature_augmentation") {
1258 return Ok(vec![PipelineDslStep::FeatureAugmentation(
1259 self.compat_operator_step(
1260 Some(object),
1261 "feature_augmentation",
1262 operator,
1263 None,
1264 Some(compat_augmentation_shape("feature", object)?),
1265 )?,
1266 )]);
1267 }
1268 if let Some(operator) = object.get("augmentation") {
1269 return Ok(vec![PipelineDslStep::Augmentation(
1270 self.compat_operator_step(
1271 Some(object),
1272 "augmentation",
1273 operator,
1274 None,
1275 Some(compat_augmentation_shape("both", object)?),
1276 )?,
1277 )]);
1278 }
1279 if let Some(operator) =
1280 first_object_value(object, &["data_generation", "generation"])
1281 {
1282 return Ok(vec![PipelineDslStep::DataGeneration(
1283 self.compat_operator_step(
1284 Some(object),
1285 "data_generation",
1286 operator,
1287 None,
1288 None,
1289 )?,
1290 )]);
1291 }
1292 if let Some(operator) = object.get("model") {
1293 return Ok(vec![PipelineDslStep::Model(self.compat_operator_step(
1294 Some(object),
1295 "model",
1296 operator,
1297 None,
1298 None,
1299 )?)]);
1300 }
1301 if let Some(operator) = first_object_value(object, &["tuner", "finetune"]) {
1302 return Ok(vec![PipelineDslStep::Tuner(self.compat_operator_step(
1303 Some(object),
1304 "tuner",
1305 operator,
1306 None,
1307 None,
1308 )?)]);
1309 }
1310 if let Some(operator) = object.get("chart") {
1311 return Ok(vec![PipelineDslStep::Chart(self.compat_operator_step(
1312 Some(object),
1313 "chart",
1314 operator,
1315 None,
1316 None,
1317 )?)]);
1318 }
1319 if object.contains_key("branch") {
1320 return Ok(vec![PipelineDslStep::Branch(
1321 self.lower_branch_step(object, path)?,
1322 )]);
1323 }
1324 if object.contains_key("concat_transform") {
1325 return Ok(vec![PipelineDslStep::ConcatTransform(
1326 self.lower_concat_transform_step(object, path)?,
1327 )]);
1328 }
1329 if object.contains_key("merge") {
1330 return Ok(vec![PipelineDslStep::Merge(
1331 self.lower_merge_step(object, path)?,
1332 )]);
1333 }
1334 if let Some(step_value) = object.get("step") {
1335 let mut steps =
1336 self.lower_pipeline_fragment(step_value, &format!("{path}.step"))?;
1337 if let Some(name) = object.get("name").and_then(serde_json::Value::as_str) {
1338 annotate_named_steps(&mut steps, name);
1339 }
1340 return Ok(steps);
1341 }
1342 if object.contains_key("_or_") {
1343 return Ok(vec![PipelineDslStep::Generator(
1344 self.lower_or_generator(object, "_or_", path)?,
1345 )]);
1346 }
1347 if object.contains_key("_chain_") {
1348 return Ok(vec![PipelineDslStep::Generator(
1349 self.lower_or_generator(object, "_chain_", path)?,
1350 )]);
1351 }
1352 if object.contains_key("_cartesian_") {
1353 return Ok(vec![PipelineDslStep::Generator(
1354 self.lower_cartesian_generator(object, path)?,
1355 )]);
1356 }
1357 if object.contains_key("_grid_") {
1358 return Ok(vec![PipelineDslStep::Generator(
1359 self.lower_grid_generator(object, path)?,
1360 )]);
1361 }
1362 if object.contains_key("_sample_") {
1363 return Ok(vec![PipelineDslStep::Generator(
1364 self.lower_sample_generator(object, path)?,
1365 )]);
1366 }
1367 if compat_plain_operator_ref(value).is_some() {
1368 let operator = compat_plain_operator_value(value)?;
1369 return match compat_plain_operator_kind(value) {
1370 CompatPlainOperatorKind::Transform => Ok(vec![PipelineDslStep::Transform(
1371 self.compat_operator_step(
1372 Some(object),
1373 "preprocessing",
1374 &operator,
1375 None,
1376 None,
1377 )?,
1378 )]),
1379 CompatPlainOperatorKind::Model => {
1380 Ok(vec![PipelineDslStep::Model(self.compat_operator_step(
1381 Some(object),
1382 "model",
1383 &operator,
1384 None,
1385 None,
1386 )?)])
1387 }
1388 CompatPlainOperatorKind::Tuner => {
1389 Ok(vec![PipelineDslStep::Tuner(self.compat_operator_step(
1390 Some(object),
1391 "tuner",
1392 &operator,
1393 None,
1394 None,
1395 )?)])
1396 }
1397 CompatPlainOperatorKind::Chart => {
1398 Ok(vec![PipelineDslStep::Chart(self.compat_operator_step(
1399 Some(object),
1400 "chart",
1401 &operator,
1402 None,
1403 None,
1404 )?)])
1405 }
1406 CompatPlainOperatorKind::Split => Err(DagMlError::GraphValidation(
1407 format!("{path} splitter object was not consumed as a campaign split"),
1408 )),
1409 };
1410 }
1411 if object.contains_key("type") || object.contains_key("ref") {
1412 return Ok(vec![PipelineDslStep::Transform(
1413 self.compat_operator_step(None, "preprocessing", value, None, None)?,
1414 )]);
1415 }
1416 Err(DagMlError::GraphValidation(format!(
1417 "unsupported nirs4all-compatible DSL object at {path}"
1418 )))
1419 }
1420 _ => Err(DagMlError::GraphValidation(format!(
1421 "unsupported nirs4all-compatible DSL value at {path}"
1422 ))),
1423 }
1424 }
1425
1426 fn lower_value_with_attachment(
1427 &mut self,
1428 value: &serde_json::Value,
1429 path: &str,
1430 attachment: CompatGenerationAttachment,
1431 ) -> Result<Vec<PipelineDslStep>> {
1432 match value {
1433 serde_json::Value::String(_) => match compat_plain_operator_kind(value) {
1434 CompatPlainOperatorKind::Transform => Ok(vec![PipelineDslStep::Transform(
1435 self.compat_operator_step(
1436 None,
1437 "preprocessing",
1438 value,
1439 Some(attachment),
1440 None,
1441 )?,
1442 )]),
1443 CompatPlainOperatorKind::Model => Ok(vec![PipelineDslStep::Model(
1444 self.compat_operator_step(None, "model", value, Some(attachment), None)?,
1445 )]),
1446 CompatPlainOperatorKind::Tuner => Ok(vec![PipelineDslStep::Tuner(
1447 self.compat_operator_step(None, "tuner", value, Some(attachment), None)?,
1448 )]),
1449 CompatPlainOperatorKind::Chart => Ok(vec![PipelineDslStep::Chart(
1450 self.compat_operator_step(None, "chart", value, Some(attachment), None)?,
1451 )]),
1452 CompatPlainOperatorKind::Split => Err(DagMlError::GraphValidation(format!(
1453 "{path} splitter alias cannot receive a parameter generator"
1454 ))),
1455 },
1456 serde_json::Value::Object(object) => {
1457 if let Some(operator) = object.get("model") {
1458 return Ok(vec![PipelineDslStep::Model(self.compat_operator_step(
1459 Some(object),
1460 "model",
1461 operator,
1462 Some(attachment),
1463 None,
1464 )?)]);
1465 }
1466 if let Some(operator) = first_object_value(object, &["tuner", "finetune"]) {
1467 return Ok(vec![PipelineDslStep::Tuner(self.compat_operator_step(
1468 Some(object),
1469 "tuner",
1470 operator,
1471 Some(attachment),
1472 None,
1473 )?)]);
1474 }
1475 if let Some(operator) =
1476 first_object_value(object, &["preprocessing", "processing", "transform"])
1477 {
1478 return Ok(vec![PipelineDslStep::Transform(self.compat_operator_step(
1479 Some(object),
1480 "preprocessing",
1481 operator,
1482 Some(attachment),
1483 None,
1484 )?)]);
1485 }
1486 if compat_plain_operator_ref(value).is_some() {
1487 let operator = compat_plain_operator_value(value)?;
1488 return match compat_plain_operator_kind(value) {
1489 CompatPlainOperatorKind::Transform => Ok(vec![PipelineDslStep::Transform(
1490 self.compat_operator_step(
1491 Some(object),
1492 "preprocessing",
1493 &operator,
1494 Some(attachment),
1495 None,
1496 )?,
1497 )]),
1498 CompatPlainOperatorKind::Model => Ok(vec![PipelineDslStep::Model(
1499 self.compat_operator_step(
1500 Some(object),
1501 "model",
1502 &operator,
1503 Some(attachment),
1504 None,
1505 )?,
1506 )]),
1507 CompatPlainOperatorKind::Tuner => Ok(vec![PipelineDslStep::Tuner(
1508 self.compat_operator_step(
1509 Some(object),
1510 "tuner",
1511 &operator,
1512 Some(attachment),
1513 None,
1514 )?,
1515 )]),
1516 CompatPlainOperatorKind::Chart => Ok(vec![PipelineDslStep::Chart(
1517 self.compat_operator_step(
1518 Some(object),
1519 "chart",
1520 &operator,
1521 Some(attachment),
1522 None,
1523 )?,
1524 )]),
1525 CompatPlainOperatorKind::Split => Err(DagMlError::GraphValidation(
1526 format!("{path} splitter object cannot receive a parameter generator"),
1527 )),
1528 };
1529 }
1530 Err(DagMlError::GraphValidation(format!(
1531 "{path} cannot receive a preceding nirs4all parameter generator; expected model, tuner or preprocessing"
1532 )))
1533 }
1534 _ => Err(DagMlError::GraphValidation(format!(
1535 "{path} cannot receive a preceding nirs4all parameter generator; expected model, tuner or preprocessing"
1536 ))),
1537 }
1538 }
1539
1540 fn lower_merge_followed_by_model(
1541 &mut self,
1542 values: &[serde_json::Value],
1543 index: usize,
1544 _path: &str,
1545 ) -> Result<Option<PipelineDslMergeModelStep>> {
1546 let Some(merge_object) = values[index].as_object() else {
1547 return Ok(None);
1548 };
1549 if !merge_object.contains_key("merge") {
1550 return Ok(None);
1551 }
1552 let Some(next) = values.get(index + 1).and_then(serde_json::Value::as_object) else {
1553 return Ok(None);
1554 };
1555 let Some(operator) = next.get("model") else {
1556 return Ok(None);
1557 };
1558 let (merge_mode, include_original_data, _) = compat_merge_modes(merge_object)?;
1559 let operator_step = self.compat_operator_step(Some(next), "model", operator, None, None)?;
1560 Ok(Some(PipelineDslMergeModelStep {
1561 inner_cv: operator_step.inner_cv,
1562 id: operator_step.id,
1563 operator: operator_step.operator,
1564 params: operator_step.params,
1565 metadata: operator_step.metadata,
1566 seed_label: operator_step.seed_label,
1567 include_original_data,
1568 merge_mode,
1569 train_params: operator_step.train_params,
1570 tuning: operator_step.tuning,
1571 variants: operator_step.variants,
1572 param_generators: operator_step.param_generators,
1573 shape: operator_step.shape,
1574 }))
1575 }
1576
1577 fn combine_data_generator_with_following(
1578 &mut self,
1579 generator: PipelineDslGeneratorStep,
1580 remaining: &[serde_json::Value],
1581 path: &str,
1582 absolute_start: usize,
1583 ) -> Result<Option<(PipelineDslGeneratorStep, usize)>> {
1584 let fused_id = generator.id.clone();
1585 let mut stages = generator_to_cartesian_stages(generator)?;
1586 let mut prefix_steps = Vec::new();
1587 let mut consumed = 0usize;
1588 while consumed < remaining.len() {
1589 let current_path = format!("{path}[{}]", absolute_start + consumed);
1590 if self.consume_side_effect_step(&remaining[consumed], ¤t_path)? {
1591 consumed += 1;
1592 continue;
1593 }
1594 let steps = if let Some(attachment) =
1595 self.parse_attached_generation(&remaining[consumed], ¤t_path)?
1596 {
1597 let next = remaining.get(consumed + 1).ok_or_else(|| {
1598 DagMlError::GraphValidation(format!(
1599 "{current_path} declares a parameter generator but has no following operator/model step"
1600 ))
1601 })?;
1602 consumed += 1;
1603 self.lower_value_with_attachment(
1604 next,
1605 &format!("{path}[{}]", absolute_start + consumed),
1606 attachment,
1607 )?
1608 } else if let Some(merge_model) =
1609 self.lower_merge_followed_by_model(remaining, consumed, ¤t_path)?
1610 {
1611 consumed += 1;
1612 vec![PipelineDslStep::MergeModel(merge_model)]
1613 } else {
1614 self.lower_value_as_steps(&remaining[consumed], ¤t_path)?
1615 };
1616 consumed += 1;
1617 if steps.is_empty() {
1618 continue;
1619 }
1620 if let [PipelineDslStep::Generator(next_generator)] = steps.as_slice() {
1621 if !prefix_steps.is_empty() {
1622 stages.push(single_stage(
1623 format!("stage{}", stages.len()),
1624 "prefix",
1625 std::mem::take(&mut prefix_steps),
1626 ));
1627 }
1628 let next_has_prediction = generator_step_has_prediction(next_generator);
1629 stages.extend(generator_to_cartesian_stages(next_generator.clone())?);
1630 if next_has_prediction {
1631 return Ok(Some((
1632 combined_cartesian_generator(fused_id.clone(), stages),
1633 consumed,
1634 )));
1635 }
1636 continue;
1637 }
1638 let has_prediction = steps.iter().any(step_has_prediction);
1639 prefix_steps.extend(steps);
1640 if has_prediction {
1641 stages.push(single_stage(
1642 format!("stage{}", stages.len()),
1643 "then",
1644 std::mem::take(&mut prefix_steps),
1645 ));
1646 return Ok(Some((
1647 combined_cartesian_generator(fused_id.clone(), stages),
1648 consumed,
1649 )));
1650 }
1651 }
1652 Ok(None)
1653 }
1654
1655 fn lower_branch_step(
1656 &mut self,
1657 object: &serde_json::Map<String, serde_json::Value>,
1658 path: &str,
1659 ) -> Result<PipelineDslBranchStep> {
1660 let branch_value = object.get("branch").expect("checked by caller");
1661 let mode = optional_object_field(object, "mode")?.unwrap_or_default();
1662 let selector = object.get("selector").cloned();
1663 let metadata = optional_object_field(object, "metadata")?.unwrap_or_default();
1664 let branches = match branch_value {
1665 serde_json::Value::Array(values) => values
1666 .iter()
1667 .enumerate()
1668 .map(|(index, value)| {
1669 let id = compat_branch_id(value, index);
1670 Ok(PipelineDslBranch {
1671 id,
1672 selector: None,
1673 metadata: BTreeMap::new(),
1674 steps: self
1675 .lower_pipeline_fragment(value, &format!("{path}.branch[{index}]"))?,
1676 })
1677 })
1678 .collect::<Result<Vec<_>>>()?,
1679 serde_json::Value::Object(branch_object) => {
1680 if let Some(values) = branch_object
1681 .get("branches")
1682 .and_then(serde_json::Value::as_array)
1683 {
1684 values
1685 .iter()
1686 .enumerate()
1687 .map(|(index, value)| {
1688 self.lower_named_branch(
1689 value,
1690 index,
1691 &format!("{path}.branch.branches[{index}]"),
1692 )
1693 })
1694 .collect::<Result<Vec<_>>>()?
1695 } else {
1696 branch_object
1697 .iter()
1698 .filter(|(key, _)| {
1699 !matches!(key.as_str(), "mode" | "selector" | "metadata")
1700 })
1701 .enumerate()
1702 .map(|(index, (key, value))| {
1703 Ok(PipelineDslBranch {
1704 id: sanitize_branch_id(key, index),
1705 selector: None,
1706 metadata: BTreeMap::new(),
1707 steps: self.lower_pipeline_fragment(
1708 value,
1709 &format!("{path}.branch.{key}"),
1710 )?,
1711 })
1712 })
1713 .collect::<Result<Vec<_>>>()?
1714 }
1715 }
1716 _ => {
1717 return Err(DagMlError::GraphValidation(format!(
1718 "{path}.branch must be an array or object"
1719 )));
1720 }
1721 };
1722 Ok(PipelineDslBranchStep {
1723 mode,
1724 selector,
1725 metadata,
1726 branches,
1727 })
1728 }
1729
1730 fn lower_named_branch(
1731 &mut self,
1732 value: &serde_json::Value,
1733 index: usize,
1734 path: &str,
1735 ) -> Result<PipelineDslBranch> {
1736 if let Some(object) = value.as_object() {
1737 if object.contains_key("steps") || object.contains_key("pipeline") {
1738 let id = object
1739 .get("id")
1740 .and_then(serde_json::Value::as_str)
1741 .map(|id| sanitize_branch_id(id, index))
1742 .unwrap_or_else(|| format!("branch{index}"));
1743 let selector = object.get("selector").cloned();
1744 let metadata = optional_object_field(object, "metadata")?.unwrap_or_default();
1745 let steps_value = object
1746 .get("steps")
1747 .or_else(|| object.get("pipeline"))
1748 .ok_or_else(|| {
1749 DagMlError::GraphValidation(format!(
1750 "{path} branch object must contain steps or pipeline"
1751 ))
1752 })?;
1753 return Ok(PipelineDslBranch {
1754 id,
1755 selector,
1756 metadata,
1757 steps: self.lower_pipeline_fragment(steps_value, path)?,
1758 });
1759 }
1760 }
1761 Ok(PipelineDslBranch {
1762 id: compat_branch_id(value, index),
1763 selector: None,
1764 metadata: BTreeMap::new(),
1765 steps: self.lower_pipeline_fragment(value, path)?,
1766 })
1767 }
1768
1769 fn lower_concat_transform_step(
1770 &mut self,
1771 object: &serde_json::Map<String, serde_json::Value>,
1772 path: &str,
1773 ) -> Result<PipelineDslConcatTransformStep> {
1774 let value = object.get("concat_transform").expect("checked by caller");
1775 let branches = match value {
1776 serde_json::Value::Array(values) => values
1777 .iter()
1778 .enumerate()
1779 .map(|(index, value)| {
1780 Ok(PipelineDslConcatBranch {
1781 id: compat_branch_id(value, index),
1782 steps: self.lower_concat_operator_steps(
1783 value,
1784 &format!("{path}.concat_transform[{index}]"),
1785 )?,
1786 })
1787 })
1788 .collect::<Result<Vec<_>>>()?,
1789 serde_json::Value::Object(map) => map
1790 .iter()
1791 .enumerate()
1792 .map(|(index, (key, value))| {
1793 Ok(PipelineDslConcatBranch {
1794 id: sanitize_branch_id(key, index),
1795 steps: self.lower_concat_operator_steps(
1796 value,
1797 &format!("{path}.concat_transform.{key}"),
1798 )?,
1799 })
1800 })
1801 .collect::<Result<Vec<_>>>()?,
1802 _ => {
1803 return Err(DagMlError::GraphValidation(format!(
1804 "{path}.concat_transform must be an array or object"
1805 )));
1806 }
1807 };
1808 Ok(PipelineDslConcatTransformStep {
1809 id: explicit_or_generated_node_id(object, "id", || self.next_node_id("join"))?,
1810 branches,
1811 metadata: optional_object_field(object, "metadata")?.unwrap_or_default(),
1812 seed_label: optional_object_field(object, "seed_label")?,
1813 representation: optional_object_field(object, "representation")?,
1814 variants: Vec::new(),
1815 param_generators: Vec::new(),
1816 shape: optional_object_field(object, "shape")?,
1817 })
1818 }
1819
1820 fn lower_concat_operator_steps(
1821 &mut self,
1822 value: &serde_json::Value,
1823 path: &str,
1824 ) -> Result<Vec<PipelineDslOperatorStep>> {
1825 let steps = self.lower_pipeline_fragment(value, path)?;
1826 steps
1827 .into_iter()
1828 .map(|step| match step {
1829 PipelineDslStep::Transform(step) => Ok(step),
1830 _ => Err(DagMlError::GraphValidation(format!(
1831 "{path} concat_transform branches currently accept only preprocessing/transform steps"
1832 ))),
1833 })
1834 .collect()
1835 }
1836
1837 fn lower_merge_step(
1838 &mut self,
1839 object: &serde_json::Map<String, serde_json::Value>,
1840 _path: &str,
1841 ) -> Result<PipelineDslMergeStep> {
1842 let (merge_mode, include_original_data, output_as) = compat_merge_modes(object)?;
1843 let mut metadata: BTreeMap<String, serde_json::Value> =
1844 optional_object_field(object, "metadata")?.unwrap_or_default();
1845 if let Some(merge) = object.get("merge").filter(|merge| merge.is_object()) {
1846 metadata.insert("dsl_compat_merge".to_string(), merge.clone());
1847 }
1848 Ok(PipelineDslMergeStep {
1849 id: explicit_or_generated_node_id(object, "id", || self.next_node_id("merge"))?,
1850 merge_mode,
1851 output_as,
1852 include_original_data,
1853 on_missing: compat_merge_field(object, "on_missing")?,
1854 selectors: compat_merge_field(object, "selectors")?.unwrap_or_default(),
1855 metadata,
1856 seed_label: optional_object_field(object, "seed_label")?,
1857 representation: optional_object_field(object, "representation")?,
1858 variants: Vec::new(),
1859 param_generators: Vec::new(),
1860 shape: optional_object_field(object, "shape")?,
1861 })
1862 }
1863
1864 fn lower_or_generator(
1865 &mut self,
1866 object: &serde_json::Map<String, serde_json::Value>,
1867 key: &str,
1868 path: &str,
1869 ) -> Result<PipelineDslGeneratorStep> {
1870 let values = object
1871 .get(key)
1872 .and_then(serde_json::Value::as_array)
1873 .ok_or_else(|| DagMlError::GraphValidation(format!("{path}.{key} must be an array")))?;
1874 let branches = values
1875 .iter()
1876 .enumerate()
1877 .map(|(index, value)| {
1878 Ok(PipelineDslBranch {
1879 id: compat_branch_id(value, index),
1880 selector: None,
1881 metadata: BTreeMap::new(),
1882 steps: self
1883 .lower_pipeline_fragment(value, &format!("{path}.{key}[{index}]"))?,
1884 })
1885 })
1886 .collect::<Result<Vec<_>>>()?;
1887 Ok(PipelineDslGeneratorStep {
1888 id: explicit_or_generated_node_id(object, "id", || self.next_generator_id())?,
1889 mode: PipelineDslGeneratorMode::Or,
1890 branches,
1891 stages: Vec::new(),
1892 pick: optional_object_field(object, "pick")?,
1893 arrange: optional_object_field(object, "arrange")?,
1894 then_pick: optional_object_field(object, "then_pick")?,
1895 then_arrange: optional_object_field(object, "then_arrange")?,
1896 count: optional_object_field(object, "count")?,
1897 metadata: compat_generator_metadata(object, key)?,
1898 })
1899 }
1900
1901 fn lower_cartesian_generator(
1902 &mut self,
1903 object: &serde_json::Map<String, serde_json::Value>,
1904 path: &str,
1905 ) -> Result<PipelineDslGeneratorStep> {
1906 let values = object
1907 .get("_cartesian_")
1908 .and_then(serde_json::Value::as_array)
1909 .ok_or_else(|| {
1910 DagMlError::GraphValidation(format!("{path}._cartesian_ must be an array"))
1911 })?;
1912 let stages = values
1913 .iter()
1914 .enumerate()
1915 .map(|(index, value)| {
1916 self.lower_cartesian_stage(value, index, &format!("{path}._cartesian_[{index}]"))
1917 })
1918 .collect::<Result<Vec<_>>>()?;
1919 Ok(PipelineDslGeneratorStep {
1920 id: explicit_or_generated_node_id(object, "id", || self.next_generator_id())?,
1921 mode: PipelineDslGeneratorMode::Cartesian,
1922 branches: Vec::new(),
1923 stages,
1924 pick: None,
1925 arrange: None,
1926 then_pick: None,
1927 then_arrange: None,
1928 count: optional_object_field(object, "count")?,
1929 metadata: compat_generator_metadata(object, "_cartesian_")?,
1930 })
1931 }
1932
1933 fn lower_cartesian_stage(
1934 &mut self,
1935 value: &serde_json::Value,
1936 index: usize,
1937 path: &str,
1938 ) -> Result<PipelineDslGeneratorStage> {
1939 if let Some(object) = value.as_object() {
1940 if object.contains_key("_or_") {
1941 let generator = self.lower_or_generator(object, "_or_", path)?;
1942 return Ok(PipelineDslGeneratorStage {
1943 id: format!("stage{index}"),
1944 selector: None,
1945 metadata: BTreeMap::new(),
1946 branches: generator.branches,
1947 });
1948 }
1949 if object.contains_key("_chain_") {
1950 let generator = self.lower_or_generator(object, "_chain_", path)?;
1951 return Ok(PipelineDslGeneratorStage {
1952 id: format!("stage{index}"),
1953 selector: None,
1954 metadata: BTreeMap::new(),
1955 branches: generator.branches,
1956 });
1957 }
1958 if object.contains_key("_grid_") {
1959 return Ok(PipelineDslGeneratorStage {
1960 id: format!("stage{index}"),
1961 selector: None,
1962 metadata: BTreeMap::new(),
1963 branches: self.lower_grid_branches(object.get("_grid_").unwrap(), path)?,
1964 });
1965 }
1966 if object.contains_key("_sample_") {
1967 let generator = self.lower_sample_generator(object, path)?;
1968 return Ok(PipelineDslGeneratorStage {
1969 id: format!("stage{index}"),
1970 selector: None,
1971 metadata: BTreeMap::new(),
1972 branches: generator.branches,
1973 });
1974 }
1975 }
1976 Ok(PipelineDslGeneratorStage {
1977 id: format!("stage{index}"),
1978 selector: None,
1979 metadata: BTreeMap::new(),
1980 branches: vec![PipelineDslBranch {
1981 id: "option0".to_string(),
1982 selector: None,
1983 metadata: BTreeMap::new(),
1984 steps: self.lower_pipeline_fragment(value, path)?,
1985 }],
1986 })
1987 }
1988
1989 fn lower_grid_generator(
1990 &mut self,
1991 object: &serde_json::Map<String, serde_json::Value>,
1992 path: &str,
1993 ) -> Result<PipelineDslGeneratorStep> {
1994 Ok(PipelineDslGeneratorStep {
1995 id: explicit_or_generated_node_id(object, "id", || self.next_generator_id())?,
1996 mode: PipelineDslGeneratorMode::Or,
1997 branches: self.lower_grid_branches(object.get("_grid_").unwrap(), path)?,
1998 stages: Vec::new(),
1999 pick: None,
2000 arrange: None,
2001 then_pick: None,
2002 then_arrange: None,
2003 count: optional_object_field(object, "count")?,
2004 metadata: compat_generator_metadata(object, "_grid_")?,
2005 })
2006 }
2007
2008 fn lower_sample_generator(
2009 &mut self,
2010 object: &serde_json::Map<String, serde_json::Value>,
2011 path: &str,
2012 ) -> Result<PipelineDslGeneratorStep> {
2013 Ok(PipelineDslGeneratorStep {
2014 id: explicit_or_generated_node_id(object, "id", || self.next_generator_id())?,
2015 mode: PipelineDslGeneratorMode::Or,
2016 branches: self.lower_sample_branches(object.get("_sample_").unwrap(), path)?,
2017 stages: Vec::new(),
2018 pick: None,
2019 arrange: None,
2020 then_pick: None,
2021 then_arrange: None,
2022 count: optional_object_field(object, "count")?,
2023 metadata: compat_generator_metadata(object, "_sample_")?,
2024 })
2025 }
2026
2027 fn lower_sample_branches(
2028 &mut self,
2029 value: &serde_json::Value,
2030 path: &str,
2031 ) -> Result<Vec<PipelineDslBranch>> {
2032 let sample = value.as_object().ok_or_else(|| {
2033 DagMlError::GraphValidation(format!("{path}._sample_ must be an object"))
2034 })?;
2035 let rows = compat_sample_rows(sample, path)?;
2036 let operator = sample
2037 .get("model")
2038 .or_else(|| sample.get("tuner"))
2039 .or_else(|| sample.get("finetune"))
2040 .or_else(|| sample.get("preprocessing"))
2041 .or_else(|| sample.get("transform"))
2042 .ok_or_else(|| {
2043 DagMlError::GraphValidation(format!(
2044 "{path}._sample_ structural lowering requires `model`, `tuner`, `preprocessing` or `transform`"
2045 ))
2046 })?
2047 .clone();
2048 let keyword = if sample.contains_key("model") {
2049 "model"
2050 } else if sample.contains_key("tuner") || sample.contains_key("finetune") {
2051 "tuner"
2052 } else {
2053 "preprocessing"
2054 };
2055 let fixed_params = sample
2056 .iter()
2057 .filter(|(key, _)| {
2058 !matches!(
2059 key.as_str(),
2060 "model"
2061 | "tuner"
2062 | "finetune"
2063 | "preprocessing"
2064 | "transform"
2065 | "distribution"
2066 | "from"
2067 | "to"
2068 | "num"
2069 | "count"
2070 | "param"
2071 | "tune"
2072 )
2073 })
2074 .map(|(key, value)| (key.clone(), value.clone()))
2075 .collect::<BTreeMap<_, _>>();
2076 rows.into_iter()
2077 .enumerate()
2078 .map(|(index, mut row)| {
2079 row.extend(fixed_params.clone());
2080 let step = self.compat_operator_step_from_parts(
2081 keyword,
2082 operator.clone(),
2083 row,
2084 None,
2085 None,
2086 )?;
2087 Ok(PipelineDslBranch {
2088 id: format!("sample{index}"),
2089 selector: None,
2090 metadata: BTreeMap::new(),
2091 steps: vec![if keyword == "model" {
2092 PipelineDslStep::Model(step)
2093 } else if keyword == "tuner" {
2094 PipelineDslStep::Tuner(step)
2095 } else {
2096 PipelineDslStep::Transform(step)
2097 }],
2098 })
2099 })
2100 .collect()
2101 }
2102
2103 fn lower_grid_branches(
2104 &mut self,
2105 value: &serde_json::Value,
2106 path: &str,
2107 ) -> Result<Vec<PipelineDslBranch>> {
2108 let rows = compat_grid_rows(value, path)?;
2109 rows.into_iter()
2110 .enumerate()
2111 .map(|(index, row)| {
2112 let metadata = BTreeMap::from([(
2113 "compat_grid_row".to_string(),
2114 serde_json::to_value(&row).map_err(|error| {
2115 DagMlError::GraphValidation(format!(
2116 "failed to serialize grid row at {path}: {error}"
2117 ))
2118 })?,
2119 )]);
2120 Ok(PipelineDslBranch {
2121 id: format!("grid{index}"),
2122 selector: None,
2123 metadata,
2124 steps: self.lower_grid_row(row, path)?,
2125 })
2126 })
2127 .collect()
2128 }
2129
2130 fn lower_grid_row(
2131 &mut self,
2132 mut row: BTreeMap<String, serde_json::Value>,
2133 path: &str,
2134 ) -> Result<Vec<PipelineDslStep>> {
2135 if let Some(operator) = row.remove("model") {
2136 return Ok(vec![PipelineDslStep::Model(
2137 self.compat_operator_step_from_parts("model", operator, row, None, None)?,
2138 )]);
2139 }
2140 if let Some(operator) = row.remove("tuner").or_else(|| row.remove("finetune")) {
2141 return Ok(vec![PipelineDslStep::Tuner(
2142 self.compat_operator_step_from_parts("tuner", operator, row, None, None)?,
2143 )]);
2144 }
2145 if let Some(operator) = row
2146 .remove("preprocessing")
2147 .or_else(|| row.remove("processing"))
2148 .or_else(|| row.remove("transform"))
2149 {
2150 return Ok(vec![PipelineDslStep::Transform(
2151 self.compat_operator_step_from_parts("preprocessing", operator, row, None, None)?,
2152 )]);
2153 }
2154 Err(DagMlError::GraphValidation(format!(
2155 "{path}._grid_ rows must contain `model`, `tuner`, `preprocessing` or `transform` for structural lowering"
2156 )))
2157 }
2158
2159 fn lower_pipeline_fragment(
2160 &mut self,
2161 value: &serde_json::Value,
2162 path: &str,
2163 ) -> Result<Vec<PipelineDslStep>> {
2164 match value {
2165 serde_json::Value::Null => Ok(Vec::new()),
2166 serde_json::Value::Array(values) => self.lower_steps(values, path),
2167 _ => self.lower_value_as_steps(value, path),
2168 }
2169 }
2170
2171 fn parse_attached_generation(
2172 &mut self,
2173 value: &serde_json::Value,
2174 path: &str,
2175 ) -> Result<Option<CompatGenerationAttachment>> {
2176 let Some(object) = value.as_object() else {
2177 return Ok(None);
2178 };
2179 if let Some(range) = object.get("_range_") {
2180 return Ok(Some(CompatGenerationAttachment {
2181 variants: Vec::new(),
2182 param_generators: vec![compat_range_generator(range, object, path)?],
2183 }));
2184 }
2185 if let Some(range) = object.get("_log_range_") {
2186 return Ok(Some(CompatGenerationAttachment {
2187 variants: Vec::new(),
2188 param_generators: vec![compat_log_range_generator(range, object, path)?],
2189 }));
2190 }
2191 if let Some(grid) = object.get("_grid_") {
2192 if grid.as_object().is_some_and(|grid| {
2193 !grid.contains_key("model")
2194 && !grid.contains_key("preprocessing")
2195 && !grid.contains_key("transform")
2196 }) {
2197 return Ok(Some(CompatGenerationAttachment {
2198 variants: Vec::new(),
2199 param_generators: vec![compat_grid_param_generator(grid, object, path)?],
2200 }));
2201 }
2202 }
2203 if let Some(zip) = object.get("_zip_") {
2204 return Ok(Some(CompatGenerationAttachment {
2205 variants: compat_zip_variants(zip, path)?,
2206 param_generators: Vec::new(),
2207 }));
2208 }
2209 if let Some(sample) = object.get("_sample_") {
2210 if sample.as_object().is_some_and(|sample| {
2211 sample.contains_key("model")
2212 || sample.contains_key("tuner")
2213 || sample.contains_key("finetune")
2214 || sample.contains_key("preprocessing")
2215 || sample.contains_key("transform")
2216 }) {
2217 return Ok(None);
2218 }
2219 return Ok(Some(CompatGenerationAttachment {
2220 variants: compat_sample_variants(sample, path)?,
2221 param_generators: Vec::new(),
2222 }));
2223 }
2224 Ok(None)
2225 }
2226
2227 fn compat_operator_step(
2228 &mut self,
2229 object: Option<&serde_json::Map<String, serde_json::Value>>,
2230 keyword: &str,
2231 operator: &serde_json::Value,
2232 attachment: Option<CompatGenerationAttachment>,
2233 fallback_shape: Option<PipelineDslShapePlan>,
2234 ) -> Result<PipelineDslOperatorStep> {
2235 let id_prefix = compat_node_prefix(keyword);
2236 let mut params = object
2237 .and_then(|object| object_value_as_map(object.get("params")))
2238 .unwrap_or_default();
2239 if let Some(object) = object {
2240 for alias in compat_param_aliases(keyword) {
2241 if let Some(alias_params) = object_value_as_map(object.get(*alias)) {
2242 params.extend(alias_params);
2243 }
2244 }
2245 for wrapper_key in compat_wrapper_param_keys(keyword) {
2246 if let Some(value) = object.get(*wrapper_key) {
2247 params.insert((*wrapper_key).to_string(), value.clone());
2248 }
2249 }
2250 }
2251 let shape = match object.and_then(|object| object.get("shape")) {
2252 Some(shape) => Some(deserialize_value(
2253 shape.clone(),
2254 "pipeline DSL compat shape",
2255 )?),
2256 None => fallback_shape,
2257 };
2258 let mut step = PipelineDslOperatorStep {
2259 inner_cv: optional_object_field_from_option(object, "inner_cv")?,
2260 id: match object {
2261 Some(object) => {
2262 explicit_or_generated_node_id(object, "id", || self.next_node_id(id_prefix))?
2263 }
2264 None => self.next_node_id(id_prefix)?,
2265 },
2266 operator: operator.clone(),
2267 params,
2268 metadata: optional_object_field_from_option(object, "metadata")?.unwrap_or_default(),
2269 seed_label: optional_object_field_from_option(object, "seed_label")?,
2270 representation: optional_object_field_from_option(object, "representation")?,
2271 train_params: optional_object_field_from_option(object, "train_params")?
2272 .unwrap_or_default(),
2273 tuning: optional_object_field_from_option(object, "tuning")?.or(
2274 optional_object_field_from_option(object, "finetune_params")?,
2275 ),
2276 variants: optional_object_field_from_option(object, "variants")?.unwrap_or_default(),
2277 param_generators: optional_object_field_from_option(object, "generators")?
2278 .unwrap_or_default(),
2279 shape,
2280 };
2281 step.metadata.insert(
2282 "dsl_compat_keyword".to_string(),
2283 serde_json::Value::String(keyword.to_string()),
2284 );
2285 if is_minimal_compat_operator_alias(object, operator) {
2286 step.metadata.insert(
2287 DSL_MINIMAL_OPERATOR_ALIAS.to_string(),
2288 serde_json::Value::Bool(true),
2289 );
2290 }
2291 if let Some(policy) = object.and_then(|object| object.get("policy")) {
2292 step.metadata
2293 .insert("dsl_compat_policy".to_string(), policy.clone());
2294 }
2295 if let Some(name) = object
2296 .and_then(|object| object.get("name"))
2297 .and_then(serde_json::Value::as_str)
2298 {
2299 step.metadata.insert(
2300 "dsl_name".to_string(),
2301 serde_json::Value::String(name.to_string()),
2302 );
2303 }
2304 if let Some(attachment) = attachment {
2305 step.variants.extend(attachment.variants);
2306 step.param_generators.extend(attachment.param_generators);
2307 }
2308 Ok(step)
2309 }
2310
2311 fn compat_operator_step_from_parts(
2312 &mut self,
2313 keyword: &str,
2314 operator: serde_json::Value,
2315 params: BTreeMap<String, serde_json::Value>,
2316 attachment: Option<CompatGenerationAttachment>,
2317 shape: Option<PipelineDslShapePlan>,
2318 ) -> Result<PipelineDslOperatorStep> {
2319 let mut step = PipelineDslOperatorStep {
2320 inner_cv: None,
2321 id: self.next_node_id(compat_node_prefix(keyword))?,
2322 operator,
2323 params,
2324 metadata: BTreeMap::from([(
2325 "dsl_compat_keyword".to_string(),
2326 serde_json::Value::String(keyword.to_string()),
2327 )]),
2328 seed_label: None,
2329 representation: None,
2330 train_params: BTreeMap::new(),
2331 tuning: None,
2332 variants: Vec::new(),
2333 param_generators: Vec::new(),
2334 shape,
2335 };
2336 if let Some(attachment) = attachment {
2337 step.variants.extend(attachment.variants);
2338 step.param_generators.extend(attachment.param_generators);
2339 }
2340 Ok(step)
2341 }
2342
2343 fn lower_split_invocation(
2344 &self,
2345 split: &serde_json::Value,
2346 object: &serde_json::Map<String, serde_json::Value>,
2347 path: &str,
2348 ) -> Result<SplitInvocation> {
2349 let mut params = BTreeMap::new();
2350 let mut id = object
2351 .get("id")
2352 .and_then(serde_json::Value::as_str)
2353 .unwrap_or("split:compat")
2354 .to_string();
2355 let mut controller_id = optional_object_field(object, "controller_id")?;
2356 let mut leakage_policy =
2357 optional_object_field(object, "leakage_policy")?.unwrap_or_default();
2358 let fold_set = optional_object_field(object, "fold_set")?;
2359 match split {
2360 serde_json::Value::String(kind) => {
2361 params.insert("kind".to_string(), serde_json::Value::String(kind.clone()));
2362 id = format!("split:{}", sanitize_generation_label(kind));
2363 }
2364 serde_json::Value::Object(split_object) => {
2365 if let Some(split_id) = split_object.get("id").and_then(serde_json::Value::as_str) {
2366 id = split_id.to_string();
2367 }
2368 if controller_id.is_none() {
2369 controller_id = optional_object_field(split_object, "controller_id")?;
2370 }
2371 if let Some(policy) = optional_object_field(split_object, "leakage_policy")? {
2372 leakage_policy = policy;
2373 }
2374 if let Some(explicit_params) = object_value_as_map(split_object.get("params")) {
2375 params.extend(explicit_params);
2376 }
2377 for (key, value) in split_object {
2378 if !matches!(
2379 key.as_str(),
2380 "id" | "controller_id" | "leakage_policy" | "fold_set" | "params"
2381 ) {
2382 params.insert(key.clone(), value.clone());
2383 }
2384 }
2385 }
2386 _ => {
2387 return Err(DagMlError::GraphValidation(format!(
2388 "{path}.split must be a string or object"
2389 )));
2390 }
2391 }
2392 for (key, value) in object {
2393 if !matches!(
2394 key.as_str(),
2395 "split" | "id" | "controller_id" | "leakage_policy" | "fold_set" | "params"
2396 ) {
2397 params.entry(key.clone()).or_insert_with(|| value.clone());
2398 }
2399 }
2400 Ok(SplitInvocation {
2401 id,
2402 controller_id,
2403 leakage_policy,
2404 params,
2405 fold_set,
2406 })
2407 }
2408
2409 fn lower_plain_split_invocation(
2410 &self,
2411 value: &serde_json::Value,
2412 path: &str,
2413 ) -> Result<SplitInvocation> {
2414 let mut params = BTreeMap::new();
2415 let id;
2416 let mut controller_id = None;
2417 let mut leakage_policy = LeakageUnitPolicy::default();
2418 let mut fold_set = None;
2419 if let Some(object) = value.as_object() {
2420 id = object
2421 .get("id")
2422 .and_then(serde_json::Value::as_str)
2423 .map(str::to_string)
2424 .unwrap_or_else(|| {
2425 compat_plain_operator_ref(value)
2426 .map(|reference| format!("split:{}", sanitize_generation_label(reference)))
2427 .unwrap_or_else(|| "split:compat".to_string())
2428 });
2429 controller_id = optional_object_field(object, "controller_id")?;
2430 leakage_policy = optional_object_field(object, "leakage_policy")?.unwrap_or_default();
2431 fold_set = optional_object_field(object, "fold_set")?;
2432 if let Some(explicit_params) = object_value_as_map(object.get("params")) {
2433 params.extend(explicit_params);
2434 }
2435 for (key, item) in object {
2436 if !matches!(
2437 key.as_str(),
2438 "id" | "controller_id" | "leakage_policy" | "fold_set" | "params" | "name"
2439 ) {
2440 params.insert(key.clone(), item.clone());
2441 }
2442 }
2443 } else if let Some(reference) = compat_plain_operator_ref(value) {
2444 id = format!("split:{}", sanitize_generation_label(reference));
2445 params.insert(
2446 "class".to_string(),
2447 serde_json::Value::String(reference.to_string()),
2448 );
2449 } else {
2450 return Err(DagMlError::GraphValidation(format!(
2451 "{path} is not a nirs4all-compatible splitter alias"
2452 )));
2453 }
2454 if let Some(reference) = compat_plain_operator_ref(value) {
2455 params
2456 .entry("class".to_string())
2457 .or_insert_with(|| serde_json::Value::String(reference.to_string()));
2458 }
2459 Ok(SplitInvocation {
2460 id,
2461 controller_id,
2462 leakage_policy,
2463 params,
2464 fold_set,
2465 })
2466 }
2467
2468 fn set_split_invocation(&mut self, split: SplitInvocation, path: &str) -> Result<()> {
2469 let Some(existing) = self.split_invocation.as_mut() else {
2470 self.split_invocation = Some(split);
2471 return Ok(());
2472 };
2473 if existing.fold_set.is_some() && split.fold_set.is_some() {
2474 return Err(DagMlError::GraphValidation(format!(
2475 "{path} declares a second split with a fold_set; only one explicit fold_set can drive campaign OOF validation"
2476 )));
2477 }
2478 if existing.fold_set.is_none() {
2479 existing.fold_set = split.fold_set.clone();
2480 }
2481 let default_policy = LeakageUnitPolicy::default();
2482 if existing.leakage_policy == default_policy {
2483 existing.leakage_policy = split.leakage_policy.clone();
2484 } else if split.leakage_policy != default_policy
2485 && existing.leakage_policy != split.leakage_policy
2486 {
2487 return Err(DagMlError::GraphValidation(format!(
2488 "{path} declares split leakage_policy incompatible with the existing campaign split policy"
2489 )));
2490 }
2491 let first = split_invocation_chain_entry(existing)?;
2492 let second = split_invocation_chain_entry(&split)?;
2493 let mut chain = existing
2494 .params
2495 .remove("compat_split_chain")
2496 .and_then(|value| value.as_array().cloned())
2497 .unwrap_or_else(|| vec![first]);
2498 chain.push(second);
2499 existing.id = "split:compat.chain".to_string();
2500 existing.controller_id = None;
2501 existing.params.clear();
2502 existing.params.insert(
2503 "kind".to_string(),
2504 serde_json::Value::String("compat_split_chain".to_string()),
2505 );
2506 existing.params.insert(
2507 "compat_split_chain".to_string(),
2508 serde_json::Value::Array(chain),
2509 );
2510 Ok(())
2511 }
2512
2513 fn next_node_id(&mut self, prefix: &str) -> Result<NodeId> {
2514 let id = NodeId::new(format!("{prefix}:compat.{}", self.node_counter))?;
2515 self.node_counter += 1;
2516 Ok(id)
2517 }
2518
2519 fn next_generator_id(&mut self) -> Result<NodeId> {
2520 let id = NodeId::new(format!("generator:compat.{}", self.generator_counter))?;
2521 self.generator_counter += 1;
2522 Ok(id)
2523 }
2524}
2525
2526fn optional_root_field<T>(
2527 root: Option<&serde_json::Map<String, serde_json::Value>>,
2528 key: &str,
2529) -> Result<Option<T>>
2530where
2531 T: DeserializeOwned,
2532{
2533 match root.and_then(|object| object.get(key)) {
2534 Some(value) => Ok(Some(deserialize_value(value.clone(), key)?)),
2535 None => Ok(None),
2536 }
2537}
2538
2539fn optional_object_field<T>(
2540 object: &serde_json::Map<String, serde_json::Value>,
2541 key: &str,
2542) -> Result<Option<T>>
2543where
2544 T: DeserializeOwned,
2545{
2546 match object.get(key) {
2547 Some(value) => Ok(Some(deserialize_value(value.clone(), key)?)),
2548 None => Ok(None),
2549 }
2550}
2551
2552fn optional_object_field_from_option<T>(
2553 object: Option<&serde_json::Map<String, serde_json::Value>>,
2554 key: &str,
2555) -> Result<Option<T>>
2556where
2557 T: DeserializeOwned,
2558{
2559 match object.and_then(|object| object.get(key)) {
2560 Some(value) => Ok(Some(deserialize_value(value.clone(), key)?)),
2561 None => Ok(None),
2562 }
2563}
2564
2565fn compat_merge_field<T>(
2566 object: &serde_json::Map<String, serde_json::Value>,
2567 key: &str,
2568) -> Result<Option<T>>
2569where
2570 T: DeserializeOwned,
2571{
2572 let value = object.get(key).or_else(|| {
2573 object
2574 .get("merge")
2575 .and_then(serde_json::Value::as_object)
2576 .and_then(|merge| merge.get(key))
2577 });
2578 match value {
2579 Some(value) => Ok(Some(deserialize_value(value.clone(), key)?)),
2580 None => Ok(None),
2581 }
2582}
2583
2584fn deserialize_value<T>(value: serde_json::Value, label: &str) -> Result<T>
2585where
2586 T: DeserializeOwned,
2587{
2588 serde_json::from_value(value)
2589 .map_err(|error| DagMlError::GraphValidation(format!("failed to parse {label}: {error}")))
2590}
2591
2592fn explicit_or_generated_node_id<F>(
2593 object: &serde_json::Map<String, serde_json::Value>,
2594 key: &str,
2595 generated: F,
2596) -> Result<NodeId>
2597where
2598 F: FnOnce() -> Result<NodeId>,
2599{
2600 match object.get(key).and_then(serde_json::Value::as_str) {
2601 Some(id) => NodeId::new(id),
2602 None => generated(),
2603 }
2604}
2605
2606fn first_object_value<'a>(
2607 object: &'a serde_json::Map<String, serde_json::Value>,
2608 keys: &[&str],
2609) -> Option<&'a serde_json::Value> {
2610 keys.iter().find_map(|key| object.get(*key))
2611}
2612
2613fn is_comment_only_object(object: &serde_json::Map<String, serde_json::Value>) -> bool {
2614 !object.is_empty()
2615 && object
2616 .keys()
2617 .all(|key| matches!(key.as_str(), "_comment" | "comment" | "description"))
2618}
2619
2620fn value_can_receive_generation_attachment(value: &serde_json::Value) -> bool {
2621 let Some(object) = value.as_object() else {
2622 return false;
2623 };
2624 object.contains_key("model")
2625 || object.contains_key("tuner")
2626 || object.contains_key("finetune")
2627 || first_object_value(object, &["preprocessing", "processing", "transform"]).is_some()
2628 || compat_plain_operator_ref(value).is_some()
2629}
2630
2631fn object_value_as_map(
2632 value: Option<&serde_json::Value>,
2633) -> Option<BTreeMap<String, serde_json::Value>> {
2634 value.and_then(|value| {
2635 value.as_object().map(|object| {
2636 object
2637 .iter()
2638 .map(|(key, value)| (key.clone(), value.clone()))
2639 .collect()
2640 })
2641 })
2642}
2643
2644fn is_minimal_compat_operator_alias(
2645 object: Option<&serde_json::Map<String, serde_json::Value>>,
2646 operator: &serde_json::Value,
2647) -> bool {
2648 match object {
2649 None => compat_plain_operator_ref(operator).is_some(),
2650 Some(object) => {
2651 ["class", "function", "ref", "type"]
2652 .iter()
2653 .any(|key| object.contains_key(*key))
2654 && compat_plain_operator_ref(operator).is_some()
2655 }
2656 }
2657}
2658
2659fn annotate_named_steps(steps: &mut [PipelineDslStep], name: &str) {
2660 for step in steps {
2661 annotate_named_step(step, name);
2662 }
2663}
2664
2665fn annotate_named_step(step: &mut PipelineDslStep, name: &str) {
2666 let value = serde_json::Value::String(name.to_string());
2667 match step {
2668 PipelineDslStep::Transform(step)
2669 | PipelineDslStep::YTransform(step)
2670 | PipelineDslStep::Tag(step)
2671 | PipelineDslStep::Exclude(step)
2672 | PipelineDslStep::Filter(step)
2673 | PipelineDslStep::SampleFilter(step)
2674 | PipelineDslStep::Augmentation(step)
2675 | PipelineDslStep::FeatureAugmentation(step)
2676 | PipelineDslStep::SampleAugmentation(step)
2677 | PipelineDslStep::DataGeneration(step)
2678 | PipelineDslStep::Model(step)
2679 | PipelineDslStep::Tuner(step)
2680 | PipelineDslStep::Chart(step) => {
2681 step.metadata.insert("dsl_name".to_string(), value);
2682 }
2683 PipelineDslStep::ConcatTransform(step) => {
2684 step.metadata.insert("dsl_name".to_string(), value);
2685 }
2686 PipelineDslStep::Branch(step) => {
2687 step.metadata.insert("dsl_name".to_string(), value);
2688 }
2689 PipelineDslStep::Generator(step) => {
2690 step.metadata.insert("dsl_name".to_string(), value);
2691 }
2692 PipelineDslStep::Sequential(step) => {
2693 step.metadata.insert("dsl_name".to_string(), value);
2694 }
2695 PipelineDslStep::Merge(step) => {
2696 step.metadata.insert("dsl_name".to_string(), value);
2697 }
2698 PipelineDslStep::MergeModel(step) => {
2699 step.metadata.insert("dsl_name".to_string(), value);
2700 }
2701 }
2702}
2703
2704fn compat_plain_operator_ref(value: &serde_json::Value) -> Option<&str> {
2705 match value {
2706 serde_json::Value::String(reference) => Some(reference),
2707 serde_json::Value::Object(object) => ["class", "function", "ref", "type"]
2708 .into_iter()
2709 .find_map(|key| object.get(key).and_then(serde_json::Value::as_str)),
2710 _ => None,
2711 }
2712}
2713
2714fn compat_plain_operator_value(value: &serde_json::Value) -> Result<serde_json::Value> {
2715 match value {
2716 serde_json::Value::String(_) => Ok(value.clone()),
2717 serde_json::Value::Object(object) => {
2718 let mut operator = serde_json::Map::new();
2719 for key in ["class", "function", "ref", "type"] {
2720 if let Some(value) = object.get(key) {
2721 operator.insert(key.to_string(), value.clone());
2722 }
2723 }
2724 if operator.is_empty() {
2725 return Err(DagMlError::GraphValidation(
2726 "nirs4all-compatible plain operator object must contain class, function, ref or type"
2727 .to_string(),
2728 ));
2729 }
2730 Ok(serde_json::Value::Object(operator))
2731 }
2732 _ => Err(DagMlError::GraphValidation(
2733 "nirs4all-compatible plain operator must be a string or object".to_string(),
2734 )),
2735 }
2736}
2737
2738fn compat_plain_operator_kind(value: &serde_json::Value) -> CompatPlainOperatorKind {
2739 let Some(reference) = compat_plain_operator_ref(value) else {
2740 return CompatPlainOperatorKind::Transform;
2741 };
2742 let lower = reference.to_ascii_lowercase();
2743 if compat_is_chart_alias(&lower) {
2744 CompatPlainOperatorKind::Chart
2745 } else if compat_is_tuner_alias(&lower) {
2746 CompatPlainOperatorKind::Tuner
2747 } else if compat_is_splitter_alias(&lower) {
2748 CompatPlainOperatorKind::Split
2749 } else if compat_is_model_alias(&lower) {
2750 CompatPlainOperatorKind::Model
2751 } else {
2752 CompatPlainOperatorKind::Transform
2753 }
2754}
2755
2756fn compat_is_chart_alias(lower: &str) -> bool {
2757 lower.starts_with("chart_")
2758 || lower == "chart"
2759 || lower.contains(".charts.")
2760 || lower.contains(".visualization.")
2761}
2762
2763fn compat_is_tuner_alias(lower: &str) -> bool {
2764 let short = lower.rsplit(['.', ':']).next().unwrap_or(lower);
2765 lower.contains(".tuners.")
2766 || lower.contains(".tuning.")
2767 || lower.contains("operators.tuners")
2768 || lower.contains("optuna")
2769 || lower.contains("ray.tune")
2770 || lower.contains("hyperopt")
2771 || short.ends_with("tuner")
2772 || short.ends_with("searchcv")
2773 || matches!(
2774 short,
2775 "gridsearchcv"
2776 | "randomizedsearchcv"
2777 | "halvinggridsearchcv"
2778 | "halvingrandomsearchcv"
2779 | "bayesiantuner"
2780 | "optunatuner"
2781 )
2782}
2783
2784fn compat_is_splitter_alias(lower: &str) -> bool {
2785 let short = lower.rsplit(['.', ':']).next().unwrap_or(lower);
2786 lower.contains("model_selection")
2787 || lower.contains(".splitters.")
2788 || lower.contains("operators.splitters")
2789 || short.contains("splitter")
2790 || short.ends_with("kfold")
2791 || short.ends_with("gfold")
2792 || short.ends_with("fold")
2793 || short.ends_with("split")
2794 || matches!(
2795 short,
2796 "leaveoneout" | "leavepout" | "predefinedsplit" | "timeseriessplit"
2797 )
2798}
2799
2800fn compat_is_model_alias(lower: &str) -> bool {
2801 let short = lower.rsplit(['.', ':']).next().unwrap_or(lower);
2802 lower.contains(".models.")
2803 || lower.contains("operators.models")
2804 || lower.contains("linear_model")
2805 || lower.contains("cross_decomposition")
2806 || lower.contains(".ensemble.")
2807 || lower.contains(".svm.")
2808 || lower.contains(".tree.")
2809 || lower.contains(".neighbors.")
2810 || lower.contains(".neural_network.")
2811 || lower.contains("xgboost")
2812 || lower.contains("lightgbm")
2813 || lower.contains("catboost")
2814 || short.ends_with("regressor")
2815 || short.ends_with("classifier")
2816 || short.ends_with("regression")
2817 || matches!(
2818 short,
2819 "ridge"
2820 | "lasso"
2821 | "elasticnet"
2822 | "svr"
2823 | "svc"
2824 | "linearsvr"
2825 | "linearsvc"
2826 | "pls"
2827 | "plsr"
2828 | "plsregression"
2829 | "metamodel"
2830 )
2831}
2832
2833fn compat_node_prefix(keyword: &str) -> &'static str {
2834 match keyword {
2835 "model" => "model",
2836 "tuner" | "finetune" => "tuner",
2837 "y_processing" | "y_transform" => "target",
2838 "tag" => "tag",
2839 "exclude" | "filter" | "sample_filter" => "filter",
2840 "sample_augmentation" | "feature_augmentation" | "augmentation" => "augment",
2841 "data_generation" | "generation" => "generator",
2842 "chart" => "chart",
2843 _ => "transform",
2844 }
2845}
2846
2847fn compat_param_aliases(keyword: &str) -> &'static [&'static str] {
2848 match keyword {
2849 "model" => &["model_params"],
2850 "tuner" | "finetune" => &["tuner_params", "finetune_params"],
2851 "preprocessing" | "processing" | "transform" => &[
2852 "preprocessing_params",
2853 "processing_params",
2854 "transform_params",
2855 ],
2856 "sample_augmentation" | "feature_augmentation" | "augmentation" => &["augmentation_params"],
2857 "data_generation" | "generation" => &["generation_params"],
2858 _ => &[],
2859 }
2860}
2861
2862fn compat_wrapper_param_keys(keyword: &str) -> &'static [&'static str] {
2863 match keyword {
2864 "tag" | "exclude" | "filter" | "sample_filter" => &["mode", "report", "tag_name"],
2865 "sample_augmentation" => &[
2866 "count",
2867 "selection",
2868 "random_state",
2869 "mode",
2870 "action",
2871 "report",
2872 ],
2873 "feature_augmentation" | "augmentation" => &[
2874 "size",
2875 "count",
2876 "selection",
2877 "random_state",
2878 "mode",
2879 "action",
2880 "report",
2881 ],
2882 "data_generation" | "generation" => &["size", "count", "random_state", "mode", "report"],
2883 "tuner" | "finetune" => &["n_trials", "metric", "direction", "timeout", "random_state"],
2884 _ => &[],
2885 }
2886}
2887
2888fn split_invocation_chain_entry(split: &SplitInvocation) -> Result<serde_json::Value> {
2889 let mut object = serde_json::Map::new();
2890 object.insert(
2891 "id".to_string(),
2892 serde_json::Value::String(split.id.clone()),
2893 );
2894 if let Some(controller_id) = &split.controller_id {
2895 object.insert(
2896 "controller_id".to_string(),
2897 serde_json::to_value(controller_id).map_err(|error| {
2898 DagMlError::GraphValidation(format!(
2899 "failed to serialize split controller_id for compat split chain: {error}"
2900 ))
2901 })?,
2902 );
2903 }
2904 if split.leakage_policy != LeakageUnitPolicy::default() {
2905 object.insert(
2906 "leakage_policy".to_string(),
2907 serde_json::to_value(&split.leakage_policy).map_err(|error| {
2908 DagMlError::GraphValidation(format!(
2909 "failed to serialize split leakage_policy for compat split chain: {error}"
2910 ))
2911 })?,
2912 );
2913 }
2914 if !split.params.is_empty() {
2915 object.insert(
2916 "params".to_string(),
2917 serde_json::to_value(&split.params).map_err(|error| {
2918 DagMlError::GraphValidation(format!(
2919 "failed to serialize split params for compat split chain: {error}"
2920 ))
2921 })?,
2922 );
2923 }
2924 if let Some(fold_set) = &split.fold_set {
2925 object.insert(
2926 "fold_set".to_string(),
2927 serde_json::to_value(fold_set).map_err(|error| {
2928 DagMlError::GraphValidation(format!(
2929 "failed to serialize split fold_set for compat split chain: {error}"
2930 ))
2931 })?,
2932 );
2933 }
2934 Ok(serde_json::Value::Object(object))
2935}
2936
2937fn compat_augmentation_shape(
2938 kind: &str,
2939 object: &serde_json::Map<String, serde_json::Value>,
2940) -> Result<PipelineDslShapePlan> {
2941 if let Some(shape) = object.get("shape") {
2942 return deserialize_value(shape.clone(), "augmentation shape");
2943 }
2944 let mut sample_scope = crate::policy::AugmentationScope::None;
2945 let mut feature_scope = crate::policy::AugmentationScope::None;
2946 match kind {
2947 "sample" => sample_scope = crate::policy::AugmentationScope::TrainOnly,
2948 "feature" => feature_scope = crate::policy::AugmentationScope::TrainOnly,
2949 _ => {
2950 sample_scope = crate::policy::AugmentationScope::TrainOnly;
2951 feature_scope = crate::policy::AugmentationScope::TrainOnly;
2952 }
2953 }
2954 if let Some(apply_to) = object
2955 .get("policy")
2956 .and_then(serde_json::Value::as_object)
2957 .and_then(|policy| policy.get("apply_to"))
2958 .and_then(serde_json::Value::as_str)
2959 {
2960 match apply_to {
2961 "train_only" => {}
2962 "all" | "all_partitions" => {
2963 if sample_scope != crate::policy::AugmentationScope::None {
2964 sample_scope = crate::policy::AugmentationScope::AllPartitions;
2965 }
2966 if feature_scope != crate::policy::AugmentationScope::None {
2967 feature_scope = crate::policy::AugmentationScope::AllPartitions;
2968 }
2969 }
2970 "none" => {
2971 sample_scope = crate::policy::AugmentationScope::None;
2972 feature_scope = crate::policy::AugmentationScope::None;
2973 }
2974 other => {
2975 return Err(DagMlError::GraphValidation(format!(
2976 "unsupported nirs4all augmentation policy apply_to `{other}`"
2977 )));
2978 }
2979 }
2980 }
2981 Ok(PipelineDslShapePlan {
2982 input_granularity: None,
2983 target_granularity: None,
2984 fit_rows: Some(FitBoundary::FoldTrain),
2985 predict_rows: Some(FitBoundary::FoldValidation),
2986 feature_namespace: None,
2987 feature_schema_fingerprint: None,
2988 target_space: None,
2989 aggregation_policy: None,
2990 augmentation_policy: Some(AugmentationPolicy {
2991 sample_scope,
2992 feature_scope,
2993 require_origin_id: true,
2994 inherit_group: true,
2995 inherit_target: true,
2996 unsafe_flags: BTreeSet::new(),
2997 }),
2998 selection_policy: None,
2999 })
3000}
3001
3002fn compat_merge_modes(
3003 object: &serde_json::Map<String, serde_json::Value>,
3004) -> Result<(String, bool, PipelineDslMergeOutput)> {
3005 let merge = object
3006 .get("merge")
3007 .ok_or_else(|| DagMlError::GraphValidation("merge step lacks `merge`".to_string()))?;
3008 let merge_object = merge.as_object();
3009 let mode = merge
3010 .as_str()
3011 .or_else(|| {
3012 merge_object
3013 .and_then(|object| object.get("mode").or_else(|| object.get("strategy")))
3014 .and_then(serde_json::Value::as_str)
3015 })
3016 .map(str::to_string)
3017 .unwrap_or_else(|| infer_compat_merge_mode(merge_object));
3018 validate_compat_merge_mode(&mode)?;
3019 let include_original_data = object
3020 .get("include_original_data")
3021 .or_else(|| object.get("include_original"))
3022 .or_else(|| {
3023 merge_object.and_then(|object| {
3024 object
3025 .get("include_original_data")
3026 .or_else(|| object.get("include_original"))
3027 })
3028 })
3029 .and_then(serde_json::Value::as_bool)
3030 .unwrap_or(matches!(
3031 mode.as_str(),
3032 "all" | "mixed" | "predictions_plus_original"
3033 ));
3034 let output_as = object
3035 .get("output_as")
3036 .or_else(|| merge_object.and_then(|object| object.get("output_as")))
3037 .and_then(serde_json::Value::as_str)
3038 .map(compat_merge_output_as)
3039 .transpose()?
3040 .unwrap_or_else(|| compat_merge_output_for_mode(&mode));
3041 Ok((mode, include_original_data, output_as))
3042}
3043
3044fn infer_compat_merge_mode(
3045 merge_object: Option<&serde_json::Map<String, serde_json::Value>>,
3046) -> String {
3047 let Some(object) = merge_object else {
3048 return "predictions".to_string();
3049 };
3050 let has_predictions = object.contains_key("predictions") || object.contains_key("prediction");
3051 let has_features = object.contains_key("features") || object.contains_key("feature");
3052 let has_sources = object.contains_key("sources") || object.contains_key("source");
3053 match (has_predictions, has_features, has_sources) {
3054 (true, true, _) => "all",
3055 (true, false, _) => "predictions",
3056 (false, true, _) => "features",
3057 (false, false, true) => "sources",
3058 _ => "predictions",
3059 }
3060 .to_string()
3061}
3062
3063fn compat_merge_output_for_mode(mode: &str) -> PipelineDslMergeOutput {
3064 match mode {
3065 "predictions" | "prediction" => PipelineDslMergeOutput::Predictions,
3066 "sources" | "source" => PipelineDslMergeOutput::Sources,
3067 _ => PipelineDslMergeOutput::Features,
3068 }
3069}
3070
3071fn compat_merge_output_as(value: &str) -> Result<PipelineDslMergeOutput> {
3072 match value {
3073 "features" | "feature" => Ok(PipelineDslMergeOutput::Features),
3074 "predictions" | "prediction" => Ok(PipelineDslMergeOutput::Predictions),
3075 "sources" | "source" => Ok(PipelineDslMergeOutput::Sources),
3076 other => Err(DagMlError::GraphValidation(format!(
3077 "unsupported nirs4all merge output_as `{other}`"
3078 ))),
3079 }
3080}
3081
3082fn validate_compat_merge_mode(mode: &str) -> Result<()> {
3083 match mode {
3084 "predictions"
3085 | "prediction"
3086 | "sources"
3087 | "source"
3088 | "features"
3089 | "feature"
3090 | "concat"
3091 | "all"
3092 | "mixed"
3093 | "predictions_plus_original" => {}
3094 other => {
3095 return Err(DagMlError::GraphValidation(format!(
3096 "unsupported nirs4all merge mode `{other}`"
3097 )));
3098 }
3099 }
3100 Ok(())
3101}
3102
3103fn compat_generator_metadata(
3104 object: &serde_json::Map<String, serde_json::Value>,
3105 key: &str,
3106) -> Result<BTreeMap<String, serde_json::Value>> {
3107 let mut metadata: BTreeMap<String, serde_json::Value> =
3108 optional_object_field(object, "metadata")?.unwrap_or_default();
3109 metadata.insert(
3110 "dsl_compat_generator".to_string(),
3111 serde_json::Value::String(key.to_string()),
3112 );
3113 Ok(metadata)
3114}
3115
3116fn compat_branch_id(value: &serde_json::Value, index: usize) -> String {
3117 value
3118 .as_object()
3119 .and_then(|object| object.get("id"))
3120 .and_then(serde_json::Value::as_str)
3121 .map(|id| sanitize_branch_id(id, index))
3122 .unwrap_or_else(|| format!("choice{index}"))
3123}
3124
3125fn sanitize_branch_id(input: &str, index: usize) -> String {
3126 let sanitized = sanitize_generation_label(input);
3127 if sanitized == "value" {
3128 format!("branch{index}")
3129 } else {
3130 sanitized
3131 }
3132}
3133
3134fn step_has_prediction(step: &PipelineDslStep) -> bool {
3135 match step {
3136 PipelineDslStep::Model(_) | PipelineDslStep::Tuner(_) | PipelineDslStep::MergeModel(_) => {
3137 true
3138 }
3139 PipelineDslStep::Merge(step) => step.output_as == PipelineDslMergeOutput::Predictions,
3140 PipelineDslStep::Branch(step) => step
3141 .branches
3142 .iter()
3143 .any(|branch| branch.steps.iter().any(step_has_prediction)),
3144 PipelineDslStep::Generator(step) => generator_step_has_prediction(step),
3145 PipelineDslStep::Sequential(step) => step.steps.iter().any(step_has_prediction),
3146 _ => false,
3147 }
3148}
3149
3150fn generator_step_has_prediction(generator: &PipelineDslGeneratorStep) -> bool {
3151 generator
3152 .branches
3153 .iter()
3154 .any(|branch| branch.steps.iter().any(step_has_prediction))
3155 || generator.stages.iter().any(|stage| {
3156 stage
3157 .branches
3158 .iter()
3159 .any(|branch| branch.steps.iter().any(step_has_prediction))
3160 })
3161}
3162
3163fn generator_to_cartesian_stages(
3164 generator: PipelineDslGeneratorStep,
3165) -> Result<Vec<PipelineDslGeneratorStage>> {
3166 match generator.mode {
3167 PipelineDslGeneratorMode::Cartesian => Ok(generator.stages),
3168 PipelineDslGeneratorMode::Or => {
3169 if generator.pick.is_some()
3170 || generator.arrange.is_some()
3171 || generator.then_pick.is_some()
3172 || generator.then_arrange.is_some()
3173 {
3174 return Err(DagMlError::GraphValidation(format!(
3175 "nirs4all-compatible data-only generator `{}` cannot be fused across downstream models when pick/arrange selectors are present",
3176 generator.id
3177 )));
3178 }
3179 Ok(vec![PipelineDslGeneratorStage {
3180 id: sanitize_generation_label(generator.id.as_str()),
3181 selector: None,
3182 metadata: generator.metadata,
3183 branches: generator.branches,
3184 }])
3185 }
3186 }
3187}
3188
3189fn single_stage(
3190 id: String,
3191 branch_id: &str,
3192 steps: Vec<PipelineDslStep>,
3193) -> PipelineDslGeneratorStage {
3194 PipelineDslGeneratorStage {
3195 id,
3196 selector: None,
3197 metadata: BTreeMap::new(),
3198 branches: vec![PipelineDslBranch {
3199 id: branch_id.to_string(),
3200 selector: None,
3201 metadata: BTreeMap::new(),
3202 steps,
3203 }],
3204 }
3205}
3206
3207fn combined_cartesian_generator(
3208 id: NodeId,
3209 stages: Vec<PipelineDslGeneratorStage>,
3210) -> PipelineDslGeneratorStep {
3211 PipelineDslGeneratorStep {
3212 id,
3213 mode: PipelineDslGeneratorMode::Cartesian,
3214 branches: Vec::new(),
3215 stages,
3216 pick: None,
3217 arrange: None,
3218 then_pick: None,
3219 then_arrange: None,
3220 count: None,
3221 metadata: BTreeMap::from([(
3222 "dsl_compat_generator".to_string(),
3223 serde_json::Value::String("fused_data_to_prediction".to_string()),
3224 )]),
3225 }
3226}
3227
3228fn compat_grid_rows(
3229 value: &serde_json::Value,
3230 path: &str,
3231) -> Result<Vec<BTreeMap<String, serde_json::Value>>> {
3232 let object = value
3233 .as_object()
3234 .ok_or_else(|| DagMlError::GraphValidation(format!("{path}._grid_ must be an object")))?;
3235 if object.is_empty() {
3236 return Err(DagMlError::GraphValidation(format!(
3237 "{path}._grid_ must contain at least one parameter"
3238 )));
3239 }
3240 let entries = object
3241 .iter()
3242 .map(|(key, value)| {
3243 let values = match value {
3244 serde_json::Value::Array(values) => values.clone(),
3245 _ => vec![value.clone()],
3246 };
3247 if values.is_empty() {
3248 return Err(DagMlError::GraphValidation(format!(
3249 "{path}._grid_.{key} has no values"
3250 )));
3251 }
3252 Ok((key.clone(), values))
3253 })
3254 .collect::<Result<Vec<_>>>()?;
3255 let mut rows = Vec::new();
3256 build_compat_grid_rows(&entries, 0, &mut BTreeMap::new(), &mut rows);
3257 Ok(rows)
3258}
3259
3260fn build_compat_grid_rows(
3261 entries: &[(String, Vec<serde_json::Value>)],
3262 index: usize,
3263 current: &mut BTreeMap<String, serde_json::Value>,
3264 rows: &mut Vec<BTreeMap<String, serde_json::Value>>,
3265) {
3266 if index == entries.len() {
3267 rows.push(current.clone());
3268 return;
3269 }
3270 let (key, values) = &entries[index];
3271 for value in values {
3272 current.insert(key.clone(), value.clone());
3273 build_compat_grid_rows(entries, index + 1, current, rows);
3274 current.remove(key);
3275 }
3276}
3277
3278fn compat_range_generator(
3279 value: &serde_json::Value,
3280 object: &serde_json::Map<String, serde_json::Value>,
3281 path: &str,
3282) -> Result<PipelineDslParamGenerator> {
3283 let param = object
3284 .get("param")
3285 .and_then(serde_json::Value::as_str)
3286 .unwrap_or("n_components")
3287 .to_string();
3288 let (start, stop, step) = if let Some(values) = value.as_array() {
3289 if values.len() != 3 {
3290 return Err(DagMlError::GraphValidation(format!(
3291 "{path}._range_ array must be [start, stop, step]"
3292 )));
3293 }
3294 (
3295 json_f64(&values[0], path, "_range_[0]")?,
3296 json_f64(&values[1], path, "_range_[1]")?,
3297 json_f64(&values[2], path, "_range_[2]")?,
3298 )
3299 } else if let Some(spec) = value.as_object() {
3300 (
3301 json_f64(
3302 spec.get("start").ok_or_else(|| {
3303 DagMlError::GraphValidation(format!("{path}._range_ lacks start"))
3304 })?,
3305 path,
3306 "start",
3307 )?,
3308 json_f64(
3309 spec.get("stop").ok_or_else(|| {
3310 DagMlError::GraphValidation(format!("{path}._range_ lacks stop"))
3311 })?,
3312 path,
3313 "stop",
3314 )?,
3315 json_f64(
3316 spec.get("step").ok_or_else(|| {
3317 DagMlError::GraphValidation(format!("{path}._range_ lacks step"))
3318 })?,
3319 path,
3320 "step",
3321 )?,
3322 )
3323 } else {
3324 return Err(DagMlError::GraphValidation(format!(
3325 "{path}._range_ must be an array or object"
3326 )));
3327 };
3328 Ok(PipelineDslParamGenerator::Range {
3329 name: optional_object_field(object, "name")?,
3330 param,
3331 start,
3332 stop,
3333 step,
3334 inclusive: object
3335 .get("inclusive")
3336 .and_then(serde_json::Value::as_bool)
3337 .unwrap_or(true),
3338 count: optional_object_field(object, "count")?,
3339 })
3340}
3341
3342fn compat_log_range_generator(
3343 value: &serde_json::Value,
3344 object: &serde_json::Map<String, serde_json::Value>,
3345 path: &str,
3346) -> Result<PipelineDslParamGenerator> {
3347 let param = object
3348 .get("param")
3349 .and_then(serde_json::Value::as_str)
3350 .unwrap_or("alpha")
3351 .to_string();
3352 let spec = value.as_object().ok_or_else(|| {
3353 DagMlError::GraphValidation(format!("{path}._log_range_ must be an object"))
3354 })?;
3355 let start = json_f64(
3356 spec.get("start")
3357 .or_else(|| spec.get("from"))
3358 .ok_or_else(|| {
3359 DagMlError::GraphValidation(format!("{path}._log_range_ lacks start/from"))
3360 })?,
3361 path,
3362 "start",
3363 )?;
3364 let stop = json_f64(
3365 spec.get("stop").or_else(|| spec.get("to")).ok_or_else(|| {
3366 DagMlError::GraphValidation(format!("{path}._log_range_ lacks stop/to"))
3367 })?,
3368 path,
3369 "stop",
3370 )?;
3371 let count = spec
3372 .get("count")
3373 .or_else(|| spec.get("num"))
3374 .and_then(serde_json::Value::as_u64)
3375 .ok_or_else(|| DagMlError::GraphValidation(format!("{path}._log_range_ lacks count/num")))?
3376 as usize;
3377 Ok(PipelineDslParamGenerator::LogRange {
3378 name: optional_object_field(object, "name")?,
3379 param,
3380 start,
3381 stop,
3382 count,
3383 base: spec
3384 .get("base")
3385 .map(|value| json_f64(value, path, "base"))
3386 .transpose()?
3387 .unwrap_or(10.0),
3388 })
3389}
3390
3391fn compat_grid_param_generator(
3392 value: &serde_json::Value,
3393 object: &serde_json::Map<String, serde_json::Value>,
3394 path: &str,
3395) -> Result<PipelineDslParamGenerator> {
3396 let grid = value
3397 .as_object()
3398 .ok_or_else(|| DagMlError::GraphValidation(format!("{path}._grid_ must be an object")))?;
3399 let params = grid
3400 .iter()
3401 .map(|(key, value)| {
3402 let values = match value {
3403 serde_json::Value::Array(values) => values.clone(),
3404 _ => vec![value.clone()],
3405 };
3406 Ok((
3407 key.clone(),
3408 values
3409 .into_iter()
3410 .map(PipelineDslGeneratorValue::Value)
3411 .collect::<Vec<_>>(),
3412 ))
3413 })
3414 .collect::<Result<BTreeMap<_, _>>>()?;
3415 Ok(PipelineDslParamGenerator::Grid {
3416 name: optional_object_field(object, "name")?,
3417 params,
3418 count: optional_object_field(object, "count")?,
3419 })
3420}
3421
3422fn compat_zip_variants(
3423 value: &serde_json::Value,
3424 path: &str,
3425) -> Result<Vec<PipelineDslVariantChoice>> {
3426 let object = value
3427 .as_object()
3428 .ok_or_else(|| DagMlError::GraphValidation(format!("{path}._zip_ must be an object")))?;
3429 let mut length = None;
3430 let mut columns = Vec::new();
3431 for (key, value) in object {
3432 let values = value.as_array().ok_or_else(|| {
3433 DagMlError::GraphValidation(format!("{path}._zip_.{key} must be an array"))
3434 })?;
3435 if let Some(expected) = length {
3436 if values.len() != expected {
3437 return Err(DagMlError::GraphValidation(format!(
3438 "{path}._zip_ arrays must have equal lengths"
3439 )));
3440 }
3441 } else {
3442 length = Some(values.len());
3443 }
3444 columns.push((key.clone(), values.clone()));
3445 }
3446 let length = length.unwrap_or(0);
3447 if length == 0 {
3448 return Err(DagMlError::GraphValidation(format!(
3449 "{path}._zip_ must contain non-empty arrays"
3450 )));
3451 }
3452 Ok((0..length)
3453 .map(|index| {
3454 let params = columns
3455 .iter()
3456 .map(|(key, values)| (key.clone(), values[index].clone()))
3457 .collect::<BTreeMap<_, _>>();
3458 PipelineDslVariantChoice {
3459 label: format!("zip{index}"),
3460 params,
3461 value: None,
3462 }
3463 })
3464 .collect())
3465}
3466
3467fn compat_sample_rows(
3468 object: &serde_json::Map<String, serde_json::Value>,
3469 path: &str,
3470) -> Result<Vec<BTreeMap<String, serde_json::Value>>> {
3471 let param_names = if let Some(param) = object.get("param").and_then(serde_json::Value::as_str) {
3472 vec![param.to_string()]
3473 } else if let Some(tune) = object.get("tune").and_then(serde_json::Value::as_array) {
3474 let params = tune
3475 .iter()
3476 .map(|value| {
3477 value.as_str().map(str::to_string).ok_or_else(|| {
3478 DagMlError::GraphValidation(format!(
3479 "{path}._sample_.tune entries must be strings"
3480 ))
3481 })
3482 })
3483 .collect::<Result<Vec<_>>>()?;
3484 if params.is_empty() {
3485 return Err(DagMlError::GraphValidation(format!(
3486 "{path}._sample_.tune cannot be empty"
3487 )));
3488 }
3489 params
3490 } else {
3491 return Err(DagMlError::GraphValidation(format!(
3492 "{path}._sample_ requires `param` or `tune` for deterministic JSON lowering"
3493 )));
3494 };
3495 let from = json_f64(
3496 object
3497 .get("from")
3498 .ok_or_else(|| DagMlError::GraphValidation(format!("{path}._sample_ lacks from")))?,
3499 path,
3500 "from",
3501 )?;
3502 let to = json_f64(
3503 object
3504 .get("to")
3505 .ok_or_else(|| DagMlError::GraphValidation(format!("{path}._sample_ lacks to")))?,
3506 path,
3507 "to",
3508 )?;
3509 let count = object
3510 .get("num")
3511 .or_else(|| object.get("count"))
3512 .and_then(serde_json::Value::as_u64)
3513 .ok_or_else(|| DagMlError::GraphValidation(format!("{path}._sample_ lacks num/count")))?
3514 as usize;
3515 if count == 0 {
3516 return Err(DagMlError::GraphValidation(format!(
3517 "{path}._sample_ count cannot be zero"
3518 )));
3519 }
3520 let distribution = object
3521 .get("distribution")
3522 .and_then(serde_json::Value::as_str)
3523 .unwrap_or("uniform");
3524 if distribution == "log_uniform" && (from <= 0.0 || to <= 0.0) {
3525 return Err(DagMlError::GraphValidation(format!(
3526 "{path}._sample_ log_uniform requires positive from/to"
3527 )));
3528 }
3529 (0..count)
3530 .map(|index| {
3531 let ratio = if count == 1 {
3532 0.0
3533 } else {
3534 index as f64 / (count - 1) as f64
3535 };
3536 let sampled = match distribution {
3537 "uniform" => from + (to - from) * ratio,
3538 "log_uniform" => {
3539 let start = from.log10();
3540 let stop = to.log10();
3541 10f64.powf(start + (stop - start) * ratio)
3542 }
3543 other => {
3544 return Err(DagMlError::GraphValidation(format!(
3545 "{path}._sample_ unsupported deterministic distribution `{other}`"
3546 )));
3547 }
3548 };
3549 let mut row = BTreeMap::new();
3550 let value = serde_json::Value::Number(
3551 serde_json::Number::from_f64(sampled).ok_or_else(|| {
3552 DagMlError::GraphValidation(format!(
3553 "{path}._sample_ produced non-finite value"
3554 ))
3555 })?,
3556 );
3557 for param in ¶m_names {
3558 row.insert(param.clone(), value.clone());
3559 }
3560 Ok(row)
3561 })
3562 .collect()
3563}
3564
3565fn compat_sample_variants(
3566 value: &serde_json::Value,
3567 path: &str,
3568) -> Result<Vec<PipelineDslVariantChoice>> {
3569 let object = value
3570 .as_object()
3571 .ok_or_else(|| DagMlError::GraphValidation(format!("{path}._sample_ must be an object")))?;
3572 compat_sample_rows(object, path)?
3573 .into_iter()
3574 .enumerate()
3575 .map(|(index, params)| {
3576 Ok(PipelineDslVariantChoice {
3577 label: format!("sample{index}"),
3578 params,
3579 value: None,
3580 })
3581 })
3582 .collect()
3583}
3584
3585fn json_f64(value: &serde_json::Value, path: &str, field: &str) -> Result<f64> {
3586 value
3587 .as_f64()
3588 .ok_or_else(|| DagMlError::GraphValidation(format!("{path}.{field} must be numeric")))
3589}
3590
3591impl PipelineCompiler {
3592 fn compile_top_level_step(
3593 &mut self,
3594 step: &PipelineDslStep,
3595 external_data: &DataSource,
3596 state: &mut SequenceCompileState,
3597 ) -> Result<()> {
3598 self.compile_sequence_step(step, external_data, state, None, BTreeMap::new())
3599 }
3600
3601 fn compile_sequence_step(
3602 &mut self,
3603 step: &PipelineDslStep,
3604 original_data: &DataSource,
3605 state: &mut SequenceCompileState,
3606 branch_id: Option<&str>,
3607 extra_metadata: BTreeMap<String, serde_json::Value>,
3608 ) -> Result<()> {
3609 match step {
3610 PipelineDslStep::Transform(step) => {
3611 state.current_data = self.compile_data_operator_with_extra(
3612 NodeKind::Transform,
3613 step,
3614 &state.current_data,
3615 extra_metadata,
3616 )?;
3617 state.clear_pending();
3618 Ok(())
3619 }
3620 PipelineDslStep::YTransform(step) => {
3621 self.compile_y_transform_with_extra(step, extra_metadata)?;
3622 state.clear_pending();
3623 Ok(())
3624 }
3625 PipelineDslStep::Tag(step) => {
3626 state.current_data = self.compile_data_operator_with_extra(
3627 NodeKind::Tag,
3628 step,
3629 &state.current_data,
3630 extra_metadata,
3631 )?;
3632 state.clear_pending();
3633 Ok(())
3634 }
3635 PipelineDslStep::Exclude(step) => {
3636 state.current_data = self.compile_data_operator_with_extra(
3637 NodeKind::Exclude,
3638 step,
3639 &state.current_data,
3640 extra_metadata,
3641 )?;
3642 state.clear_pending();
3643 Ok(())
3644 }
3645 PipelineDslStep::Filter(step) => {
3646 state.current_data = self.compile_filter_operator(
3647 "filter",
3648 step,
3649 &state.current_data,
3650 extra_metadata,
3651 )?;
3652 state.clear_pending();
3653 Ok(())
3654 }
3655 PipelineDslStep::SampleFilter(step) => {
3656 state.current_data = self.compile_filter_operator(
3657 "sample",
3658 step,
3659 &state.current_data,
3660 extra_metadata,
3661 )?;
3662 state.clear_pending();
3663 Ok(())
3664 }
3665 PipelineDslStep::Augmentation(step) => {
3666 state.current_data = self.compile_data_operator_with_extra(
3667 NodeKind::Augmentation,
3668 step,
3669 &state.current_data,
3670 extra_metadata,
3671 )?;
3672 state.clear_pending();
3673 Ok(())
3674 }
3675 PipelineDslStep::FeatureAugmentation(step) => {
3676 state.current_data = self.compile_augmentation_operator_with_extra(
3677 "feature",
3678 step,
3679 &state.current_data,
3680 extra_metadata,
3681 )?;
3682 state.clear_pending();
3683 Ok(())
3684 }
3685 PipelineDslStep::SampleAugmentation(step) => {
3686 state.current_data = self.compile_augmentation_operator_with_extra(
3687 "sample",
3688 step,
3689 &state.current_data,
3690 extra_metadata,
3691 )?;
3692 state.clear_pending();
3693 Ok(())
3694 }
3695 PipelineDslStep::DataGeneration(step) => {
3696 state.current_data = self.compile_data_generation_with_extra(
3697 step,
3698 &state.current_data,
3699 extra_metadata,
3700 )?;
3701 state.clear_pending();
3702 Ok(())
3703 }
3704 PipelineDslStep::ConcatTransform(step) => {
3705 state.current_data = self.compile_concat_transform_with_extra(
3706 step,
3707 &state.current_data,
3708 extra_metadata,
3709 )?;
3710 state.clear_pending();
3711 Ok(())
3712 }
3713 PipelineDslStep::Model(step) => {
3714 state
3715 .pending_predictions
3716 .push(self.compile_model_with_extra(
3717 step,
3718 &state.current_data,
3719 branch_id,
3720 extra_metadata,
3721 )?);
3722 Ok(())
3723 }
3724 PipelineDslStep::Tuner(step) => {
3725 state
3726 .pending_predictions
3727 .push(self.compile_tuner_with_extra(
3728 step,
3729 &state.current_data,
3730 branch_id,
3731 extra_metadata,
3732 )?);
3733 Ok(())
3734 }
3735 PipelineDslStep::Branch(step) => {
3736 let output =
3737 self.compile_branch_with_extra(step, &state.current_data, extra_metadata)?;
3738 state.pending_predictions = output.predictions;
3739 state.pending_branch_data = output.data_sources;
3740 Ok(())
3741 }
3742 PipelineDslStep::Generator(step) => {
3743 state.pending_predictions =
3744 self.compile_generator_with_extra(step, &state.current_data, extra_metadata)?;
3745 state.pending_branch_data.clear();
3746 Ok(())
3747 }
3748 PipelineDslStep::Sequential(step) => {
3749 self.compile_sequence_container(
3750 step,
3751 original_data,
3752 state,
3753 branch_id,
3754 extra_metadata,
3755 )?;
3756 Ok(())
3757 }
3758 PipelineDslStep::Merge(step) => {
3759 match self.compile_merge_with_extra(
3760 step,
3761 &state.pending_predictions,
3762 &state.pending_branch_data,
3763 original_data,
3764 extra_metadata,
3765 )? {
3766 MergeOutputSource::Data(data) => {
3767 state.current_data = data;
3768 state.clear_pending();
3769 }
3770 MergeOutputSource::Prediction(prediction) => {
3771 state.clear_pending();
3772 state.pending_predictions.push(prediction);
3773 }
3774 }
3775 Ok(())
3776 }
3777 PipelineDslStep::MergeModel(step) => {
3778 let prediction = self.compile_merge_model_with_extra(
3779 step,
3780 &state.pending_predictions,
3781 original_data,
3782 extra_metadata,
3783 )?;
3784 state.clear_pending();
3785 state.pending_predictions.push(prediction);
3786 Ok(())
3787 }
3788 PipelineDslStep::Chart(step) => {
3789 state.current_data = self.compile_data_operator_with_extra(
3790 NodeKind::Chart,
3791 step,
3792 &state.current_data,
3793 extra_metadata,
3794 )?;
3795 state.clear_pending();
3796 Ok(())
3797 }
3798 }
3799 }
3800
3801 fn compile_sequence_container(
3802 &mut self,
3803 step: &PipelineDslSequenceStep,
3804 original_data: &DataSource,
3805 state: &mut SequenceCompileState,
3806 branch_id: Option<&str>,
3807 mut extra_metadata: BTreeMap<String, serde_json::Value>,
3808 ) -> Result<()> {
3809 if step.steps.is_empty() {
3810 return Err(DagMlError::GraphValidation(
3811 "pipeline DSL sequential step has no child steps".to_string(),
3812 ));
3813 }
3814 if let Some(sequence_id) = &step.id {
3815 extra_metadata.insert(
3816 "dsl_sequence".to_string(),
3817 serde_json::Value::String(sequence_id.to_string()),
3818 );
3819 }
3820 if !step.metadata.is_empty() {
3821 extra_metadata.insert(
3822 "dsl_sequence_metadata".to_string(),
3823 serde_json::to_value(&step.metadata).map_err(|error| {
3824 DagMlError::GraphValidation(format!(
3825 "failed to serialize pipeline DSL sequential metadata: {error}"
3826 ))
3827 })?,
3828 );
3829 }
3830 for child in &step.steps {
3831 self.compile_sequence_step(
3832 child,
3833 original_data,
3834 state,
3835 branch_id,
3836 extra_metadata.clone(),
3837 )?;
3838 }
3839 Ok(())
3840 }
3841
3842 fn compile_branch_with_extra(
3843 &mut self,
3844 step: &PipelineDslBranchStep,
3845 current_data: &DataSource,
3846 extra_metadata: BTreeMap<String, serde_json::Value>,
3847 ) -> Result<BranchCompileOutput> {
3848 if step.branches.is_empty() {
3849 return Err(DagMlError::GraphValidation(format!(
3850 "pipeline DSL graph `{}` has a branch step without branches",
3851 self.graph_id
3852 )));
3853 }
3854 let mut predictions = Vec::new();
3855 let mut data_sources = Vec::new();
3856 for (index, branch) in step.branches.iter().enumerate() {
3857 validate_branch_id(&branch.id)?;
3858 if branch.steps.is_empty() {
3859 return Err(DagMlError::GraphValidation(format!(
3860 "pipeline DSL branch `{}` has no steps",
3861 branch.id
3862 )));
3863 }
3864 let branch_view_plan = compile_branch_view_plan(step, branch)?;
3865 let mut branch_state = SequenceCompileState::new(current_data.clone());
3866 let mut branch_metadata = branch_context_metadata(step, branch)?;
3867 if let Some(plan) = &branch_view_plan {
3868 branch_metadata.insert(
3869 "dsl_branch_view_plan".to_string(),
3870 serde_json::to_value(plan).map_err(|error| {
3871 DagMlError::GraphValidation(format!(
3872 "failed to serialize branch view plan for `{}`: {error}",
3873 branch.id
3874 ))
3875 })?,
3876 );
3877 }
3878 branch_metadata.extend(extra_metadata.clone());
3879 for branch_step in &branch.steps {
3880 self.compile_sequence_step(
3881 branch_step,
3882 current_data,
3883 &mut branch_state,
3884 Some(&branch.id),
3885 branch_metadata.clone(),
3886 )?;
3887 }
3888 if branch_state.pending_predictions.is_empty()
3889 && branch_state.pending_branch_data.is_empty()
3890 && same_data_source(&branch_state.current_data, current_data)
3891 {
3892 return Err(DagMlError::GraphValidation(format!(
3893 "pipeline DSL branch `{}` must produce at least one model, merge prediction or transformed data output",
3894 branch.id
3895 )));
3896 }
3897 if let Some(plan) = branch_view_plan {
3898 self.collect_branch_view_plan(plan)?;
3899 }
3900 let data_input_name = format!("{}_x", branch_input_prefix(&branch.id, index));
3901 data_sources.push(BranchDataSource {
3902 source: branch_state.current_data,
3903 input_name: data_input_name,
3904 branch_id: Some(branch.id.clone()),
3905 });
3906 data_sources.extend(branch_state.pending_branch_data);
3907 let prediction_count = branch_state.pending_predictions.len();
3908 for (prediction_index, prediction) in
3909 branch_state.pending_predictions.into_iter().enumerate()
3910 {
3911 let input_name = if prediction_count == 1 {
3912 format!("{}_oof", branch_input_prefix(&branch.id, index))
3913 } else {
3914 branch_prediction_input_name(
3915 &branch.id,
3916 index,
3917 prediction_index,
3918 &prediction.node_id,
3919 )
3920 };
3921 predictions.push(PredictionSource {
3922 input_name,
3923 ..prediction
3924 });
3925 }
3926 }
3927 Ok(BranchCompileOutput {
3928 predictions,
3929 data_sources,
3930 })
3931 }
3932
3933 fn compile_generator_with_extra(
3934 &mut self,
3935 step: &PipelineDslGeneratorStep,
3936 current_data: &DataSource,
3937 extra_metadata: BTreeMap<String, serde_json::Value>,
3938 ) -> Result<Vec<PredictionSource>> {
3939 let choices = expand_generator_sequences(step)?;
3940 if choices.is_empty() {
3941 return Err(DagMlError::GraphValidation(format!(
3942 "pipeline DSL generator `{}` produced no choices",
3943 step.id
3944 )));
3945 }
3946 let mut predictions = Vec::new();
3947 for (choice_index, choice) in choices.into_iter().enumerate() {
3948 let choice = namespace_generated_sequence(step, choice, choice_index)?;
3949 validate_branch_id(&choice.id)?;
3950 if choice.steps.is_empty() {
3951 return Err(DagMlError::GraphValidation(format!(
3952 "pipeline DSL generator `{}` choice `{}` has no steps",
3953 step.id, choice.id
3954 )));
3955 }
3956 let mut choice_state = SequenceCompileState::new(current_data.clone());
3957 let mut choice_metadata = generator_choice_metadata(step, &choice)?;
3958 choice_metadata.extend(extra_metadata.clone());
3959 for choice_step in &choice.steps {
3960 self.compile_sequence_step(
3961 choice_step,
3962 current_data,
3963 &mut choice_state,
3964 Some(&choice.id),
3965 choice_metadata.clone(),
3966 )?;
3967 }
3968 if choice_state.pending_predictions.is_empty() {
3969 return Err(DagMlError::GraphValidation(format!(
3970 "pipeline DSL generator `{}` choice `{}` must produce at least one model or merge prediction",
3971 step.id, choice.id
3972 )));
3973 }
3974 let prediction_count = choice_state.pending_predictions.len();
3975 for (prediction_index, prediction) in
3976 choice_state.pending_predictions.into_iter().enumerate()
3977 {
3978 let input_name = if prediction_count == 1 {
3979 format!("{}_oof", branch_input_prefix(&choice.id, choice_index))
3980 } else {
3981 branch_prediction_input_name(
3982 &choice.id,
3983 choice_index,
3984 prediction_index,
3985 &prediction.node_id,
3986 )
3987 };
3988 predictions.push(PredictionSource {
3989 input_name,
3990 ..prediction
3991 });
3992 }
3993 }
3994 Ok(predictions)
3995 }
3996
3997 fn compile_data_operator(
3998 &mut self,
3999 kind: NodeKind,
4000 step: &PipelineDslOperatorStep,
4001 input: &DataSource,
4002 ) -> Result<DataSource> {
4003 self.compile_data_operator_with_extra(kind, step, input, BTreeMap::new())
4004 }
4005
4006 fn compile_filter_operator(
4007 &mut self,
4008 filter_kind: &str,
4009 step: &PipelineDslOperatorStep,
4010 input: &DataSource,
4011 mut extra: BTreeMap<String, serde_json::Value>,
4012 ) -> Result<DataSource> {
4013 extra.insert(
4014 "dsl_filter_kind".to_string(),
4015 serde_json::Value::String(filter_kind.to_string()),
4016 );
4017 self.compile_data_operator_with_extra(NodeKind::Exclude, step, input, extra)
4018 }
4019
4020 fn compile_augmentation_operator_with_extra(
4021 &mut self,
4022 augmentation_kind: &str,
4023 step: &PipelineDslOperatorStep,
4024 input: &DataSource,
4025 mut extra: BTreeMap<String, serde_json::Value>,
4026 ) -> Result<DataSource> {
4027 extra.insert(
4028 "dsl_augmentation_kind".to_string(),
4029 serde_json::Value::String(augmentation_kind.to_string()),
4030 );
4031 self.compile_data_operator_with_extra(NodeKind::Augmentation, step, input, extra)
4032 }
4033
4034 fn compile_data_generation_with_extra(
4035 &mut self,
4036 step: &PipelineDslOperatorStep,
4037 input: &DataSource,
4038 mut extra: BTreeMap<String, serde_json::Value>,
4039 ) -> Result<DataSource> {
4040 if step.shape.is_none() {
4041 return Err(DagMlError::GraphValidation(format!(
4042 "pipeline DSL data_generation `{}` requires a shape plan for leakage-safe runtime generation",
4043 step.id
4044 )));
4045 }
4046 extra.insert(
4047 "dsl_generation_kind".to_string(),
4048 serde_json::Value::String("data".to_string()),
4049 );
4050 self.compile_data_operator_with_extra(NodeKind::Generator, step, input, extra)
4051 }
4052
4053 fn compile_data_operator_with_extra(
4054 &mut self,
4055 kind: NodeKind,
4056 step: &PipelineDslOperatorStep,
4057 input: &DataSource,
4058 extra_metadata: BTreeMap<String, serde_json::Value>,
4059 ) -> Result<DataSource> {
4060 if kind == NodeKind::Augmentation && step.shape.is_none() {
4061 return Err(DagMlError::GraphValidation(format!(
4062 "pipeline DSL augmentation `{}` requires a shape plan for leakage-safe scope validation",
4063 step.id
4064 )));
4065 }
4066 let representation = step
4067 .representation
4068 .clone()
4069 .or_else(|| input.representation.clone())
4070 .or_else(|| self.input_representation.clone());
4071 let mut metadata = operator_runtime_metadata(step, None)?;
4072 metadata.extend(extra_metadata);
4073 let node = NodeSpec {
4074 id: step.id.clone(),
4075 kind,
4076 operator: Some(step.operator.clone()),
4077 params: step.params.clone(),
4078 ports: PortSchema {
4079 inputs: vec![data_port("x", input.representation.clone(), "")],
4080 outputs: vec![data_port("x_out", representation.clone(), "")],
4081 },
4082 metadata,
4083 seed_label: step.seed_label.clone(),
4084 };
4085 self.push_node(node)?;
4086 self.collect_operator_generation(&step.id, &step.variants, &step.param_generators)?;
4087 self.collect_shape_plan(&step.id, step.shape.as_ref())?;
4088 self.connect_data(input, &step.id, "x")?;
4089 Ok(DataSource {
4090 node_id: Some(step.id.clone()),
4091 port_name: "x_out".to_string(),
4092 representation,
4093 })
4094 }
4095
4096 fn compile_y_transform_with_extra(
4097 &mut self,
4098 step: &PipelineDslOperatorStep,
4099 extra_metadata: BTreeMap<String, serde_json::Value>,
4100 ) -> Result<()> {
4101 let mut metadata = operator_runtime_metadata(step, None)?;
4102 metadata.extend(extra_metadata);
4103 let node = NodeSpec {
4104 id: step.id.clone(),
4105 kind: NodeKind::YTransform,
4106 operator: Some(step.operator.clone()),
4107 params: step.params.clone(),
4108 ports: PortSchema {
4109 inputs: vec![target_port("y", "")],
4110 outputs: vec![target_port("y_out", "")],
4111 },
4112 metadata,
4113 seed_label: step.seed_label.clone(),
4114 };
4115 self.push_node(node)?;
4116 self.collect_operator_generation(&step.id, &step.variants, &step.param_generators)?;
4117 self.collect_shape_plan(&step.id, step.shape.as_ref())
4118 }
4119
4120 fn compile_concat_transform_with_extra(
4121 &mut self,
4122 step: &PipelineDslConcatTransformStep,
4123 input: &DataSource,
4124 extra_metadata: BTreeMap<String, serde_json::Value>,
4125 ) -> Result<DataSource> {
4126 if step.branches.is_empty() {
4127 return Err(DagMlError::GraphValidation(format!(
4128 "pipeline DSL concat_transform `{}` has no branches",
4129 step.id
4130 )));
4131 }
4132 let representation = step
4133 .representation
4134 .clone()
4135 .or_else(|| input.representation.clone())
4136 .or_else(|| self.input_representation.clone());
4137 let mut branch_outputs = Vec::with_capacity(step.branches.len());
4138 for (index, branch) in step.branches.iter().enumerate() {
4139 validate_branch_id(&branch.id)?;
4140 let mut branch_data = input.clone();
4141 for branch_step in &branch.steps {
4142 branch_data =
4143 self.compile_data_operator(NodeKind::Transform, branch_step, &branch_data)?;
4144 }
4145 let input_name = format!("{}_x", branch_input_prefix(&branch.id, index));
4146 branch_outputs.push((input_name, branch_data));
4147 }
4148 let node = NodeSpec {
4149 id: step.id.clone(),
4150 kind: NodeKind::FeatureJoin,
4151 operator: None,
4152 params: BTreeMap::new(),
4153 ports: PortSchema {
4154 inputs: branch_outputs
4155 .iter()
4156 .map(|(name, source)| data_port(name, source.representation.clone(), ""))
4157 .collect(),
4158 outputs: vec![data_port("x_out", representation.clone(), "")],
4159 },
4160 metadata: {
4161 let mut metadata = step.metadata.clone();
4162 metadata.extend(extra_metadata);
4163 metadata
4164 },
4165 seed_label: step.seed_label.clone(),
4166 };
4167 self.push_node(node)?;
4168 self.collect_operator_generation(&step.id, &step.variants, &step.param_generators)?;
4169 self.collect_shape_plan(&step.id, step.shape.as_ref())?;
4170 for (input_name, source) in &branch_outputs {
4171 self.connect_data_to_port(source, &step.id, input_name)?;
4172 }
4173 Ok(DataSource {
4174 node_id: Some(step.id.clone()),
4175 port_name: "x_out".to_string(),
4176 representation,
4177 })
4178 }
4179
4180 fn compile_model_with_extra(
4181 &mut self,
4182 step: &PipelineDslOperatorStep,
4183 input: &DataSource,
4184 branch_id: Option<&str>,
4185 extra_metadata: BTreeMap<String, serde_json::Value>,
4186 ) -> Result<PredictionSource> {
4187 self.compile_prediction_operator_with_extra(
4188 NodeKind::Model,
4189 step,
4190 input,
4191 branch_id,
4192 extra_metadata,
4193 )
4194 }
4195
4196 fn compile_tuner_with_extra(
4197 &mut self,
4198 step: &PipelineDslOperatorStep,
4199 input: &DataSource,
4200 branch_id: Option<&str>,
4201 extra_metadata: BTreeMap<String, serde_json::Value>,
4202 ) -> Result<PredictionSource> {
4203 self.compile_prediction_operator_with_extra(
4204 NodeKind::Tuner,
4205 step,
4206 input,
4207 branch_id,
4208 extra_metadata,
4209 )
4210 }
4211
4212 fn compile_prediction_operator_with_extra(
4213 &mut self,
4214 kind: NodeKind,
4215 step: &PipelineDslOperatorStep,
4216 input: &DataSource,
4217 branch_id: Option<&str>,
4218 extra_metadata: BTreeMap<String, serde_json::Value>,
4219 ) -> Result<PredictionSource> {
4220 let mut metadata = operator_runtime_metadata(step, branch_id)?;
4221 metadata.extend(extra_metadata);
4222 let node = NodeSpec {
4223 id: step.id.clone(),
4224 kind,
4225 operator: Some(step.operator.clone()),
4226 params: step.params.clone(),
4227 ports: PortSchema {
4228 inputs: vec![data_port("x", input.representation.clone(), "")],
4229 outputs: vec![prediction_port("oof", "")],
4230 },
4231 metadata,
4232 seed_label: step.seed_label.clone(),
4233 };
4234 self.push_node(node)?;
4235 self.collect_operator_generation(&step.id, &step.variants, &step.param_generators)?;
4236 self.collect_shape_plan(&step.id, step.shape.as_ref())?;
4237 self.connect_data(input, &step.id, "x")?;
4238 Ok(PredictionSource {
4239 node_id: step.id.clone(),
4240 port_name: "oof".to_string(),
4241 input_name: "oof".to_string(),
4242 branch_id: branch_id.map(str::to_string),
4243 })
4244 }
4245
4246 fn compile_merge_with_extra(
4247 &mut self,
4248 step: &PipelineDslMergeStep,
4249 predictions: &[PredictionSource],
4250 branch_data: &[BranchDataSource],
4251 original_data: &DataSource,
4252 extra_metadata: BTreeMap<String, serde_json::Value>,
4253 ) -> Result<MergeOutputSource> {
4254 let consumes_predictions = merge_consumes_predictions(step);
4255 let consumes_branch_data = merge_consumes_branch_data(step);
4256 let prediction_inputs = if consumes_predictions {
4257 predictions
4258 } else {
4259 &[]
4260 };
4261 let branch_data_inputs = if consumes_branch_data {
4262 branch_data
4263 } else {
4264 &[]
4265 };
4266 if prediction_inputs.is_empty()
4267 && branch_data_inputs.is_empty()
4268 && !step.include_original_data
4269 {
4270 return Err(DagMlError::GraphValidation(format!(
4271 "pipeline DSL merge `{}` has no pending predictions, branch data or original data input",
4272 step.id
4273 )));
4274 }
4275 validate_merge_selectors(&step.id, &step.selectors, prediction_inputs)?;
4276 let outputs_prediction = step.output_as == PipelineDslMergeOutput::Predictions;
4277 let representation = step
4278 .representation
4279 .clone()
4280 .or_else(|| original_data.representation.clone())
4281 .or_else(|| self.input_representation.clone());
4282 let mut input_ports =
4283 Vec::with_capacity(prediction_inputs.len() + branch_data_inputs.len() + 1);
4284 for prediction in prediction_inputs {
4285 input_ports.push(prediction_port(&prediction.input_name, ""));
4286 }
4287 for branch_source in branch_data_inputs {
4288 input_ports.push(data_port(
4289 &branch_source.input_name,
4290 branch_source.source.representation.clone(),
4291 "",
4292 ));
4293 }
4294 if step.include_original_data {
4295 input_ports.push(data_port(
4296 "x_original",
4297 original_data.representation.clone(),
4298 "",
4299 ));
4300 }
4301 let mut metadata = step.metadata.clone();
4302 metadata.insert(
4303 "merge_mode".to_string(),
4304 serde_json::Value::String(step.merge_mode.clone()),
4305 );
4306 metadata.insert(
4307 "output_as".to_string(),
4308 serde_json::to_value(step.output_as).map_err(|error| {
4309 DagMlError::GraphValidation(format!(
4310 "failed to serialize pipeline DSL merge `{}` output mode: {error}",
4311 step.id
4312 ))
4313 })?,
4314 );
4315 metadata.insert(
4316 "include_original_data".to_string(),
4317 serde_json::Value::Bool(step.include_original_data),
4318 );
4319 if let Some(on_missing) = &step.on_missing {
4320 metadata.insert(
4321 "on_missing".to_string(),
4322 serde_json::Value::String(on_missing.clone()),
4323 );
4324 }
4325 if !step.selectors.is_empty() {
4326 metadata.insert(
4327 "selectors".to_string(),
4328 serde_json::to_value(&step.selectors).map_err(|error| {
4329 DagMlError::GraphValidation(format!(
4330 "failed to serialize pipeline DSL merge `{}` selectors: {error}",
4331 step.id
4332 ))
4333 })?,
4334 );
4335 }
4336 if !branch_data_inputs.is_empty() {
4337 metadata.insert(
4338 "branch_data_inputs".to_string(),
4339 serde_json::to_value(
4340 branch_data_inputs
4341 .iter()
4342 .map(|source| {
4343 BTreeMap::from([
4344 (
4345 "input_name".to_string(),
4346 serde_json::Value::String(source.input_name.clone()),
4347 ),
4348 (
4349 "branch".to_string(),
4350 source
4351 .branch_id
4352 .as_ref()
4353 .map(|branch| serde_json::Value::String(branch.clone()))
4354 .unwrap_or(serde_json::Value::Null),
4355 ),
4356 ])
4357 })
4358 .collect::<Vec<_>>(),
4359 )
4360 .map_err(|error| {
4361 DagMlError::GraphValidation(format!(
4362 "failed to serialize pipeline DSL merge `{}` branch data inputs: {error}",
4363 step.id
4364 ))
4365 })?,
4366 );
4367 }
4368 let branch_id = branch_id_from_metadata(&extra_metadata);
4369 metadata.extend(extra_metadata);
4370 let node = NodeSpec {
4371 id: step.id.clone(),
4372 kind: merge_node_kind(
4373 step,
4374 !prediction_inputs.is_empty(),
4375 !branch_data_inputs.is_empty(),
4376 ),
4377 operator: None,
4378 params: BTreeMap::new(),
4379 ports: PortSchema {
4380 inputs: input_ports,
4381 outputs: if outputs_prediction {
4382 vec![prediction_port("prediction", "")]
4383 } else {
4384 vec![data_port("x_out", representation.clone(), "")]
4385 },
4386 },
4387 metadata,
4388 seed_label: step.seed_label.clone(),
4389 };
4390 self.push_node(node)?;
4391 self.collect_operator_generation(&step.id, &step.variants, &step.param_generators)?;
4392 self.collect_shape_plan(&step.id, step.shape.as_ref())?;
4393 for prediction in prediction_inputs {
4394 self.edges.push(EdgeSpec {
4395 source: PortRef {
4396 node_id: prediction.node_id.clone(),
4397 port_name: prediction.port_name.clone(),
4398 },
4399 target: PortRef {
4400 node_id: step.id.clone(),
4401 port_name: prediction.input_name.clone(),
4402 },
4403 contract: EdgeContract {
4404 requires_oof: true,
4405 requires_fold_alignment: true,
4406 ..EdgeContract::new(PortKind::Prediction, None)
4407 },
4408 });
4409 }
4410 for branch_source in branch_data_inputs {
4411 self.connect_data_to_port(&branch_source.source, &step.id, &branch_source.input_name)?;
4412 }
4413 if step.include_original_data {
4414 self.connect_data_to_port(original_data, &step.id, "x_original")?;
4415 }
4416 if outputs_prediction {
4417 Ok(MergeOutputSource::Prediction(PredictionSource {
4418 node_id: step.id.clone(),
4419 port_name: "prediction".to_string(),
4420 input_name: "oof".to_string(),
4421 branch_id,
4422 }))
4423 } else {
4424 Ok(MergeOutputSource::Data(DataSource {
4425 node_id: Some(step.id.clone()),
4426 port_name: "x_out".to_string(),
4427 representation,
4428 }))
4429 }
4430 }
4431
4432 fn compile_merge_model_with_extra(
4433 &mut self,
4434 step: &PipelineDslMergeModelStep,
4435 predictions: &[PredictionSource],
4436 external_data: &DataSource,
4437 extra_metadata: BTreeMap<String, serde_json::Value>,
4438 ) -> Result<PredictionSource> {
4439 if predictions.is_empty() {
4440 return Err(DagMlError::GraphValidation(format!(
4441 "pipeline DSL merge_model `{}` has no pending branch predictions",
4442 step.id
4443 )));
4444 }
4445 let mut input_ports = Vec::with_capacity(predictions.len() + 1);
4446 for prediction in predictions {
4447 input_ports.push(prediction_port(&prediction.input_name, ""));
4448 }
4449 if step.include_original_data {
4450 input_ports.push(data_port(
4451 "x_original",
4452 external_data.representation.clone(),
4453 "",
4454 ));
4455 }
4456 let mut metadata = step.metadata.clone();
4457 insert_training_metadata(
4458 &mut metadata,
4459 &step.train_params,
4460 step.tuning.as_ref(),
4461 step.inner_cv.as_ref(),
4462 &step.id,
4463 )?;
4464 metadata.insert(
4465 "merge_mode".to_string(),
4466 serde_json::Value::String(step.merge_mode.clone()),
4467 );
4468 let branch_id = branch_id_from_metadata(&extra_metadata);
4469 metadata.extend(extra_metadata);
4470 let node = NodeSpec {
4471 id: step.id.clone(),
4472 kind: NodeKind::Model,
4473 operator: Some(step.operator.clone()),
4474 params: step.params.clone(),
4475 ports: PortSchema {
4476 inputs: input_ports,
4477 outputs: vec![prediction_port("oof", "")],
4478 },
4479 metadata,
4480 seed_label: step.seed_label.clone(),
4481 };
4482 self.push_node(node)?;
4483 self.collect_operator_generation(&step.id, &step.variants, &step.param_generators)?;
4484 self.collect_shape_plan(&step.id, step.shape.as_ref())?;
4485 for prediction in predictions {
4486 self.edges.push(EdgeSpec {
4487 source: PortRef {
4488 node_id: prediction.node_id.clone(),
4489 port_name: prediction.port_name.clone(),
4490 },
4491 target: PortRef {
4492 node_id: step.id.clone(),
4493 port_name: prediction.input_name.clone(),
4494 },
4495 contract: EdgeContract {
4496 requires_oof: true,
4497 requires_fold_alignment: true,
4498 ..EdgeContract::new(PortKind::Prediction, None)
4499 },
4500 });
4501 }
4502 if step.include_original_data {
4503 self.connect_data_to_port(external_data, &step.id, "x_original")?;
4504 }
4505 Ok(PredictionSource {
4506 node_id: step.id.clone(),
4507 port_name: "oof".to_string(),
4508 input_name: "oof".to_string(),
4509 branch_id,
4510 })
4511 }
4512
4513 fn push_node(&mut self, node: NodeSpec) -> Result<()> {
4514 if self.nodes.iter().any(|existing| existing.id == node.id) {
4515 return Err(DagMlError::GraphValidation(format!(
4516 "pipeline DSL graph `{}` produced duplicate node `{}`",
4517 self.graph_id, node.id
4518 )));
4519 }
4520 self.nodes.push(node);
4521 Ok(())
4522 }
4523
4524 fn collect_operator_generation(
4525 &mut self,
4526 node_id: &NodeId,
4527 choices: &[PipelineDslVariantChoice],
4528 generators: &[PipelineDslParamGenerator],
4529 ) -> Result<()> {
4530 if !choices.is_empty() {
4531 self.generation_dimensions
4532 .push(compile_variant_choice_dimension(node_id, choices)?);
4533 }
4534 for generator in generators {
4535 self.generation_dimensions
4536 .push(compile_param_generator_dimension(node_id, generator)?);
4537 }
4538 Ok(())
4539 }
4540
4541 fn collect_shape_plan(
4542 &mut self,
4543 node_id: &NodeId,
4544 shape: Option<&PipelineDslShapePlan>,
4545 ) -> Result<()> {
4546 let Some(shape) = shape else {
4547 return Ok(());
4548 };
4549 let plan = shape.to_data_model_shape_plan(node_id)?;
4550 if self.shape_plans.insert(node_id.clone(), plan).is_some() {
4551 return Err(DagMlError::GraphValidation(format!(
4552 "pipeline DSL graph `{}` produced duplicate shape plan for `{node_id}`",
4553 self.graph_id
4554 )));
4555 }
4556 Ok(())
4557 }
4558
4559 fn collect_branch_view_plan(&mut self, plan: BranchViewPlan) -> Result<()> {
4560 plan.validate()
4561 .map_err(|error| DagMlError::GraphValidation(error.to_string()))?;
4562 if self
4563 .branch_view_plans
4564 .iter()
4565 .any(|existing| existing.view_id == plan.view_id)
4566 {
4567 return Err(DagMlError::GraphValidation(format!(
4568 "pipeline DSL graph `{}` produced duplicate branch view `{}`",
4569 self.graph_id, plan.view_id
4570 )));
4571 }
4572 self.branch_view_plans.push(plan);
4573 Ok(())
4574 }
4575
4576 fn connect_data(
4577 &mut self,
4578 input: &DataSource,
4579 target_id: &NodeId,
4580 target_port: &str,
4581 ) -> Result<()> {
4582 self.connect_data_to_port(input, target_id, target_port)
4583 }
4584
4585 fn connect_data_to_port(
4586 &mut self,
4587 input: &DataSource,
4588 target_id: &NodeId,
4589 target_port: &str,
4590 ) -> Result<()> {
4591 if let Some(source_id) = &input.node_id {
4592 self.edges.push(EdgeSpec {
4593 source: PortRef {
4594 node_id: source_id.clone(),
4595 port_name: input.port_name.clone(),
4596 },
4597 target: PortRef {
4598 node_id: target_id.clone(),
4599 port_name: target_port.to_string(),
4600 },
4601 contract: EdgeContract {
4602 requires_oof: false,
4603 requires_fold_alignment: true,
4604 ..EdgeContract::new(PortKind::Data, input.representation.clone())
4605 },
4606 });
4607 }
4608 Ok(())
4609 }
4610}
4611
4612impl PipelineDslShapePlan {
4613 fn to_data_model_shape_plan(&self, node_id: &NodeId) -> Result<DataModelShapePlan> {
4614 let plan = DataModelShapePlan {
4615 node_id: node_id.clone(),
4616 input_granularity: self.input_granularity.unwrap_or(Granularity::Sample),
4617 target_granularity: self.target_granularity.unwrap_or(Granularity::Sample),
4618 fit_rows: self.fit_rows.unwrap_or(FitBoundary::FoldTrain),
4619 predict_rows: self.predict_rows.unwrap_or(FitBoundary::FoldValidation),
4620 feature_namespace: self.feature_namespace.clone(),
4621 feature_schema_fingerprint: self.feature_schema_fingerprint.clone(),
4622 target_space: self
4623 .target_space
4624 .clone()
4625 .unwrap_or_else(|| "raw".to_string()),
4626 aggregation_policy: self.aggregation_policy.clone().unwrap_or_default(),
4627 augmentation_policy: self.augmentation_policy.clone().unwrap_or_default(),
4628 selection_policy: self.selection_policy.clone().unwrap_or_default(),
4629 };
4630 plan.validate()?;
4631 Ok(plan)
4632 }
4633}
4634
4635fn validate_shape_plan_targets(
4636 shape_plans: &BTreeMap<NodeId, DataModelShapePlan>,
4637 graph: &GraphSpec,
4638) -> Result<()> {
4639 for (node_id, plan) in shape_plans {
4640 if node_id != &plan.node_id {
4641 return Err(DagMlError::GraphValidation(format!(
4642 "pipeline DSL shape plan key `{node_id}` does not match node_id `{}`",
4643 plan.node_id
4644 )));
4645 }
4646 if !graph.nodes.iter().any(|node| &node.id == node_id) {
4647 return Err(DagMlError::GraphValidation(format!(
4648 "pipeline DSL shape plan references unknown node `{node_id}`"
4649 )));
4650 }
4651 }
4652 Ok(())
4653}
4654
4655fn compile_explicit_generation_dimensions(
4656 dimensions: &[PipelineDslGenerationDimension],
4657 nodes: &[NodeSpec],
4658) -> Result<Vec<GenerationDimension>> {
4659 if dimensions.is_empty() {
4660 return Ok(Vec::new());
4661 }
4662 let node_ids = nodes
4663 .iter()
4664 .map(|node| node.id.clone())
4665 .collect::<BTreeSet<_>>();
4666 dimensions
4667 .iter()
4668 .map(|dimension| compile_explicit_generation_dimension(dimension, &node_ids))
4669 .collect()
4670}
4671
4672fn compile_explicit_generation_dimension(
4673 dimension: &PipelineDslGenerationDimension,
4674 node_ids: &BTreeSet<NodeId>,
4675) -> Result<GenerationDimension> {
4676 let choices = dimension
4677 .choices
4678 .iter()
4679 .map(|choice| compile_explicit_generation_choice(&dimension.name, choice, node_ids))
4680 .collect::<Result<Vec<_>>>()?;
4681 Ok(GenerationDimension {
4682 name: dimension.name.clone(),
4683 choices,
4684 })
4685}
4686
4687fn compile_explicit_generation_choice(
4688 dimension_name: &str,
4689 choice: &PipelineDslGenerationChoice,
4690 node_ids: &BTreeSet<NodeId>,
4691) -> Result<GenerationChoice> {
4692 if choice.param_overrides.is_empty() {
4693 return Err(DagMlError::GraphValidation(format!(
4694 "pipeline DSL generation choice `{}` in dimension `{dimension_name}` has no param_overrides",
4695 choice.label
4696 )));
4697 }
4698 let param_overrides = choice
4699 .param_overrides
4700 .iter()
4701 .map(|override_spec| {
4702 if !node_ids.contains(&override_spec.node_id) {
4703 return Err(DagMlError::GraphValidation(format!(
4704 "pipeline DSL generation choice `{}` in dimension `{dimension_name}` references unknown node `{}`",
4705 choice.label, override_spec.node_id
4706 )));
4707 }
4708 Ok(GenerationParamOverride {
4709 node_id: override_spec.node_id.clone(),
4710 params: override_spec.params.clone(),
4711 })
4712 })
4713 .collect::<Result<Vec<_>>>()?;
4714 let value = match &choice.value {
4715 Some(value) => value.clone(),
4716 None => explicit_generation_choice_value(¶m_overrides)?,
4717 };
4718 Ok(GenerationChoice {
4719 label: choice.label.clone(),
4720 value,
4721 param_overrides,
4722 })
4723}
4724
4725fn explicit_generation_choice_value(
4726 param_overrides: &[GenerationParamOverride],
4727) -> Result<serde_json::Value> {
4728 let mut by_node = serde_json::Map::new();
4729 for override_spec in param_overrides {
4730 let value = serde_json::to_value(&override_spec.params).map_err(|error| {
4731 DagMlError::GraphValidation(format!(
4732 "failed to serialize DSL generation override for node `{}`: {error}",
4733 override_spec.node_id
4734 ))
4735 })?;
4736 by_node.insert(override_spec.node_id.to_string(), value);
4737 }
4738 Ok(serde_json::Value::Object(by_node))
4739}
4740
4741fn build_campaign_template(
4742 spec: &PipelineDslSpec,
4743 generation: &GenerationSpec,
4744 shape_plans: &BTreeMap<NodeId, DataModelShapePlan>,
4745 data_bindings: &BTreeMap<NodeId, Vec<DataBinding>>,
4746 branch_view_plans: &[BranchViewPlan],
4747) -> Result<CampaignSpec> {
4748 let campaign = CampaignSpec {
4749 inner_cv: spec.inner_cv.clone(),
4750 id: spec
4751 .campaign_id
4752 .clone()
4753 .unwrap_or_else(|| format!("campaign:{}", spec.id)),
4754 root_seed: spec.root_seed,
4755 leakage_policy: spec.leakage_policy.clone().unwrap_or_default(),
4756 aggregation_policy: spec.aggregation_policy.clone().unwrap_or_default(),
4757 split_invocation: spec.split_invocation.clone(),
4758 generation: generation.clone(),
4759 shape_plans: shape_plans.clone(),
4760 data_bindings: data_bindings.clone(),
4761 branch_view_plans: branch_view_plans.to_vec(),
4762 metadata: spec.campaign_metadata.clone(),
4763 };
4764 campaign.validate()?;
4765 Ok(campaign)
4766}
4767
4768fn compile_data_bindings(
4769 bindings: &[DataBinding],
4770 graph: &GraphSpec,
4771) -> Result<BTreeMap<NodeId, Vec<DataBinding>>> {
4772 let mut by_node = BTreeMap::<NodeId, Vec<DataBinding>>::new();
4773 for binding in bindings {
4774 validate_dsl_data_binding(binding, graph)?;
4775 by_node
4776 .entry(binding.node_id.clone())
4777 .or_default()
4778 .push(binding.clone());
4779 }
4780 Ok(by_node)
4781}
4782
4783fn validate_dsl_data_binding(binding: &DataBinding, graph: &GraphSpec) -> Result<()> {
4784 binding.validate()?;
4785 let node = graph
4786 .nodes
4787 .iter()
4788 .find(|node| node.id == binding.node_id)
4789 .ok_or_else(|| {
4790 DagMlError::GraphValidation(format!(
4791 "pipeline DSL data binding references unknown node `{}`",
4792 binding.node_id
4793 ))
4794 })?;
4795 let Some(input_port) = node
4796 .ports
4797 .inputs
4798 .iter()
4799 .find(|port| port.name == binding.input_name)
4800 else {
4801 return Err(DagMlError::GraphValidation(format!(
4802 "pipeline DSL data binding `{}` references unknown input port `{}` on node `{}`",
4803 binding.request_id, binding.input_name, binding.node_id
4804 )));
4805 };
4806 if input_port.kind != PortKind::Data {
4807 return Err(DagMlError::GraphValidation(format!(
4808 "pipeline DSL data binding `{}` targets non-data input `{}.{}`",
4809 binding.request_id, binding.node_id, binding.input_name
4810 )));
4811 }
4812 Ok(())
4813}
4814
4815fn compile_variant_choice_dimension(
4816 node_id: &NodeId,
4817 choices: &[PipelineDslVariantChoice],
4818) -> Result<GenerationDimension> {
4819 Ok(GenerationDimension {
4820 name: format!("{node_id}.params"),
4821 choices: choices
4822 .iter()
4823 .map(|choice| {
4824 if choice.params.is_empty() {
4825 return Err(DagMlError::GraphValidation(format!(
4826 "pipeline DSL variant `{}` for node `{node_id}` has no params",
4827 choice.label
4828 )));
4829 }
4830 let value = match &choice.value {
4831 Some(value) => value.clone(),
4832 None => serde_json::to_value(&choice.params).map_err(|error| {
4833 DagMlError::GraphValidation(format!(
4834 "failed to serialize pipeline DSL variant `{}` for node `{node_id}`: {error}",
4835 choice.label
4836 ))
4837 })?,
4838 };
4839 Ok(GenerationChoice {
4840 label: choice.label.clone(),
4841 value,
4842 param_overrides: vec![GenerationParamOverride {
4843 node_id: node_id.clone(),
4844 params: choice.params.clone(),
4845 }],
4846 })
4847 })
4848 .collect::<Result<Vec<_>>>()?,
4849 })
4850}
4851
4852fn compile_param_generator_dimension(
4853 node_id: &NodeId,
4854 generator: &PipelineDslParamGenerator,
4855) -> Result<GenerationDimension> {
4856 match generator {
4857 PipelineDslParamGenerator::Or {
4858 name,
4859 param,
4860 values,
4861 count,
4862 } => compile_or_generator(node_id, name.as_deref(), param, values, *count),
4863 PipelineDslParamGenerator::Range {
4864 name,
4865 param,
4866 start,
4867 stop,
4868 step,
4869 inclusive,
4870 count,
4871 } => compile_range_generator(RangeGeneratorSpec {
4872 node_id,
4873 name: name.as_deref(),
4874 param,
4875 start: *start,
4876 stop: *stop,
4877 step: *step,
4878 inclusive: *inclusive,
4879 count: *count,
4880 }),
4881 PipelineDslParamGenerator::LogRange {
4882 name,
4883 param,
4884 start,
4885 stop,
4886 count,
4887 base,
4888 } => compile_log_range_generator(
4889 node_id,
4890 name.as_deref(),
4891 param,
4892 *start,
4893 *stop,
4894 *count,
4895 *base,
4896 ),
4897 PipelineDslParamGenerator::Grid {
4898 name,
4899 params,
4900 count,
4901 } => compile_grid_generator(node_id, name.as_deref(), params, *count),
4902 PipelineDslParamGenerator::Pick {
4903 name,
4904 param,
4905 values,
4906 sizes,
4907 count,
4908 } => compile_pick_arrange_generator(
4909 node_id,
4910 name.as_deref(),
4911 param,
4912 values,
4913 sizes,
4914 *count,
4915 PickArrangeMode::Pick,
4916 ),
4917 PipelineDslParamGenerator::Arrange {
4918 name,
4919 param,
4920 values,
4921 sizes,
4922 count,
4923 } => compile_pick_arrange_generator(
4924 node_id,
4925 name.as_deref(),
4926 param,
4927 values,
4928 sizes,
4929 *count,
4930 PickArrangeMode::Arrange,
4931 ),
4932 }
4933}
4934
4935fn compile_or_generator(
4936 node_id: &NodeId,
4937 name: Option<&str>,
4938 param: &str,
4939 values: &[PipelineDslGeneratorValue],
4940 count: Option<usize>,
4941) -> Result<GenerationDimension> {
4942 validate_param_name(node_id, param)?;
4943 validate_count(node_id, name, count)?;
4944 if values.is_empty() {
4945 return Err(DagMlError::GraphValidation(format!(
4946 "pipeline DSL generator `{}` for node `{node_id}` has no values",
4947 generator_dimension_name(node_id, name, Some(param), "or")
4948 )));
4949 }
4950 let mut choices = values
4951 .iter()
4952 .enumerate()
4953 .map(|(index, value)| single_param_generation_choice(node_id, param, index, value))
4954 .collect::<Result<Vec<_>>>()?;
4955 apply_choice_count(&mut choices, count);
4956 Ok(GenerationDimension {
4957 name: generator_dimension_name(node_id, name, Some(param), "or"),
4958 choices,
4959 })
4960}
4961
4962struct RangeGeneratorSpec<'a> {
4963 node_id: &'a NodeId,
4964 name: Option<&'a str>,
4965 param: &'a str,
4966 start: f64,
4967 stop: f64,
4968 step: f64,
4969 inclusive: bool,
4970 count: Option<usize>,
4971}
4972
4973fn compile_range_generator(spec: RangeGeneratorSpec<'_>) -> Result<GenerationDimension> {
4974 validate_param_name(spec.node_id, spec.param)?;
4975 validate_count(spec.node_id, spec.name, spec.count)?;
4976 validate_finite(spec.node_id, spec.param, "range start", spec.start)?;
4977 validate_finite(spec.node_id, spec.param, "range stop", spec.stop)?;
4978 validate_finite(spec.node_id, spec.param, "range step", spec.step)?;
4979 if spec.step == 0.0 {
4980 return Err(DagMlError::GraphValidation(format!(
4981 "pipeline DSL range generator for `{}.{}` has zero step",
4982 spec.node_id, spec.param
4983 )));
4984 }
4985 if spec.start < spec.stop && spec.step < 0.0 {
4986 return Err(DagMlError::GraphValidation(format!(
4987 "pipeline DSL range generator for `{}.{}` steps away from stop",
4988 spec.node_id, spec.param
4989 )));
4990 }
4991 if spec.start > spec.stop && spec.step > 0.0 {
4992 return Err(DagMlError::GraphValidation(format!(
4993 "pipeline DSL range generator for `{}.{}` steps away from stop",
4994 spec.node_id, spec.param
4995 )));
4996 }
4997 let mut values = Vec::new();
4998 let mut current = spec.start;
4999 let mut guard = 0usize;
5000 while range_contains(current, spec.stop, spec.step, spec.inclusive) {
5001 values.push(json_number(current, spec.node_id, spec.param)?);
5002 current += spec.step;
5003 guard += 1;
5004 if guard > 10_000 {
5005 return Err(DagMlError::GraphValidation(format!(
5006 "pipeline DSL range generator for `{}.{}` produced more than 10000 values",
5007 spec.node_id, spec.param
5008 )));
5009 }
5010 }
5011 if values.is_empty() {
5012 return Err(DagMlError::GraphValidation(format!(
5013 "pipeline DSL range generator for `{}.{}` produced no values",
5014 spec.node_id, spec.param
5015 )));
5016 }
5017 let wrapped = values
5018 .into_iter()
5019 .map(PipelineDslGeneratorValue::Value)
5020 .collect::<Vec<_>>();
5021 compile_or_generator(spec.node_id, spec.name, spec.param, &wrapped, spec.count).map(
5022 |mut dimension| {
5023 dimension.name =
5024 generator_dimension_name(spec.node_id, spec.name, Some(spec.param), "range");
5025 dimension
5026 },
5027 )
5028}
5029
5030fn compile_log_range_generator(
5031 node_id: &NodeId,
5032 name: Option<&str>,
5033 param: &str,
5034 start: f64,
5035 stop: f64,
5036 count: usize,
5037 base: f64,
5038) -> Result<GenerationDimension> {
5039 validate_param_name(node_id, param)?;
5040 validate_finite(node_id, param, "log_range start", start)?;
5041 validate_finite(node_id, param, "log_range stop", stop)?;
5042 validate_finite(node_id, param, "log_range base", base)?;
5043 if start <= 0.0 || stop <= 0.0 {
5044 return Err(DagMlError::GraphValidation(format!(
5045 "pipeline DSL log_range generator for `{node_id}.{param}` requires positive start and stop"
5046 )));
5047 }
5048 if count == 0 {
5049 return Err(DagMlError::GraphValidation(format!(
5050 "pipeline DSL log_range generator for `{node_id}.{param}` has count=0"
5051 )));
5052 }
5053 if base <= 0.0 || (base - 1.0).abs() < f64::EPSILON {
5054 return Err(DagMlError::GraphValidation(format!(
5055 "pipeline DSL log_range generator for `{node_id}.{param}` requires base > 0 and != 1"
5056 )));
5057 }
5058 let start_log = start.log(base);
5059 let stop_log = stop.log(base);
5060 let values = if count == 1 {
5061 vec![json_number(start, node_id, param)?]
5062 } else {
5063 (0..count)
5064 .map(|index| {
5065 let ratio = index as f64 / (count - 1) as f64;
5066 json_number(
5067 base.powf(start_log + (stop_log - start_log) * ratio),
5068 node_id,
5069 param,
5070 )
5071 })
5072 .collect::<Result<Vec<_>>>()?
5073 };
5074 let wrapped = values
5075 .into_iter()
5076 .map(PipelineDslGeneratorValue::Value)
5077 .collect::<Vec<_>>();
5078 compile_or_generator(node_id, name, param, &wrapped, None).map(|mut dimension| {
5079 dimension.name = generator_dimension_name(node_id, name, Some(param), "log_range");
5080 dimension
5081 })
5082}
5083
5084fn compile_grid_generator(
5085 node_id: &NodeId,
5086 name: Option<&str>,
5087 params: &BTreeMap<String, Vec<PipelineDslGeneratorValue>>,
5088 count: Option<usize>,
5089) -> Result<GenerationDimension> {
5090 validate_count(node_id, name, count)?;
5091 if params.is_empty() {
5092 return Err(DagMlError::GraphValidation(format!(
5093 "pipeline DSL grid generator for node `{node_id}` has no params"
5094 )));
5095 }
5096 for (param, values) in params {
5097 validate_param_name(node_id, param)?;
5098 if values.is_empty() {
5099 return Err(DagMlError::GraphValidation(format!(
5100 "pipeline DSL grid generator for `{node_id}.{param}` has no values"
5101 )));
5102 }
5103 }
5104 let entries = params
5105 .iter()
5106 .map(|(param, values)| (param.as_str(), values.as_slice()))
5107 .collect::<Vec<_>>();
5108 let mut rows = Vec::<BTreeMap<String, PipelineDslGeneratorValue>>::new();
5109 build_grid_rows(&entries, 0, &mut BTreeMap::new(), &mut rows, count);
5110 let choices = rows
5111 .into_iter()
5112 .enumerate()
5113 .map(|(index, row)| multi_param_generation_choice(node_id, index, row))
5114 .collect::<Result<Vec<_>>>()?;
5115 Ok(GenerationDimension {
5116 name: generator_dimension_name(node_id, name, None, "grid"),
5117 choices,
5118 })
5119}
5120
5121#[derive(Clone, Copy, Debug, Eq, PartialEq)]
5122enum PickArrangeMode {
5123 Pick,
5124 Arrange,
5125}
5126
5127fn compile_pick_arrange_generator(
5128 node_id: &NodeId,
5129 name: Option<&str>,
5130 param: &str,
5131 values: &[PipelineDslGeneratorValue],
5132 sizes: &[usize],
5133 count: Option<usize>,
5134 mode: PickArrangeMode,
5135) -> Result<GenerationDimension> {
5136 validate_param_name(node_id, param)?;
5137 validate_count(node_id, name, count)?;
5138 if values.is_empty() {
5139 return Err(DagMlError::GraphValidation(format!(
5140 "pipeline DSL {:?} generator for `{node_id}.{param}` has no values",
5141 mode
5142 )));
5143 }
5144 if sizes.is_empty() {
5145 return Err(DagMlError::GraphValidation(format!(
5146 "pipeline DSL {:?} generator for `{node_id}.{param}` has no sizes",
5147 mode
5148 )));
5149 }
5150 let mut selections = Vec::<Vec<usize>>::new();
5151 for size in sizes {
5152 if *size == 0 || *size > values.len() {
5153 return Err(DagMlError::GraphValidation(format!(
5154 "pipeline DSL {:?} generator for `{node_id}.{param}` has invalid size `{size}`",
5155 mode
5156 )));
5157 }
5158 match mode {
5159 PickArrangeMode::Pick => build_combinations(
5160 values.len(),
5161 *size,
5162 0,
5163 &mut Vec::new(),
5164 &mut selections,
5165 count,
5166 ),
5167 PickArrangeMode::Arrange => build_permutations(
5168 values.len(),
5169 *size,
5170 &mut BTreeSet::new(),
5171 &mut Vec::new(),
5172 &mut selections,
5173 count,
5174 ),
5175 }
5176 if count.is_some_and(|limit| selections.len() >= limit) {
5177 break;
5178 }
5179 }
5180 let mut choices = selections
5181 .into_iter()
5182 .enumerate()
5183 .map(|(index, selection)| {
5184 let selected_values = selection
5185 .iter()
5186 .map(|selected| values[*selected].value().clone())
5187 .collect::<Vec<_>>();
5188 let selected_labels = selection
5189 .iter()
5190 .map(|selected| values[*selected].label_fragment())
5191 .collect::<Vec<_>>();
5192 let mut params = BTreeMap::new();
5193 params.insert(param.to_string(), serde_json::Value::Array(selected_values));
5194 Ok(GenerationChoice {
5195 label: format!(
5196 "{index:04}_{}_{}",
5197 match mode {
5198 PickArrangeMode::Pick => "pick",
5199 PickArrangeMode::Arrange => "arrange",
5200 },
5201 sanitize_generation_label(&selected_labels.join("_"))
5202 ),
5203 value: serde_json::to_value(¶ms).map_err(|error| {
5204 DagMlError::GraphValidation(format!(
5205 "failed to serialize pipeline DSL {:?} generator choice for `{node_id}.{param}`: {error}",
5206 mode
5207 ))
5208 })?,
5209 param_overrides: vec![GenerationParamOverride {
5210 node_id: node_id.clone(),
5211 params,
5212 }],
5213 })
5214 })
5215 .collect::<Result<Vec<_>>>()?;
5216 apply_choice_count(&mut choices, count);
5217 Ok(GenerationDimension {
5218 name: generator_dimension_name(
5219 node_id,
5220 name,
5221 Some(param),
5222 match mode {
5223 PickArrangeMode::Pick => "pick",
5224 PickArrangeMode::Arrange => "arrange",
5225 },
5226 ),
5227 choices,
5228 })
5229}
5230
5231fn single_param_generation_choice(
5232 node_id: &NodeId,
5233 param: &str,
5234 index: usize,
5235 value: &PipelineDslGeneratorValue,
5236) -> Result<GenerationChoice> {
5237 let mut params = BTreeMap::new();
5238 params.insert(param.to_string(), value.value().clone());
5239 Ok(GenerationChoice {
5240 label: format!(
5241 "{index:04}_{}_{}",
5242 sanitize_generation_label(param),
5243 value.label_fragment()
5244 ),
5245 value: serde_json::to_value(¶ms).map_err(|error| {
5246 DagMlError::GraphValidation(format!(
5247 "failed to serialize pipeline DSL generator choice for `{node_id}.{param}`: {error}"
5248 ))
5249 })?,
5250 param_overrides: vec![GenerationParamOverride {
5251 node_id: node_id.clone(),
5252 params,
5253 }],
5254 })
5255}
5256
5257fn multi_param_generation_choice(
5258 node_id: &NodeId,
5259 index: usize,
5260 row: BTreeMap<String, PipelineDslGeneratorValue>,
5261) -> Result<GenerationChoice> {
5262 let mut params = BTreeMap::new();
5263 let mut label_parts = Vec::new();
5264 for (param, value) in row {
5265 label_parts.push(format!(
5266 "{}_{}",
5267 sanitize_generation_label(¶m),
5268 value.label_fragment()
5269 ));
5270 params.insert(param, value.value().clone());
5271 }
5272 Ok(GenerationChoice {
5273 label: format!("{index:04}_{}", label_parts.join("__")),
5274 value: serde_json::to_value(¶ms).map_err(|error| {
5275 DagMlError::GraphValidation(format!(
5276 "failed to serialize pipeline DSL grid generator choice for node `{node_id}`: {error}"
5277 ))
5278 })?,
5279 param_overrides: vec![GenerationParamOverride {
5280 node_id: node_id.clone(),
5281 params,
5282 }],
5283 })
5284}
5285
5286fn build_grid_rows(
5287 entries: &[(&str, &[PipelineDslGeneratorValue])],
5288 entry_index: usize,
5289 current: &mut BTreeMap<String, PipelineDslGeneratorValue>,
5290 rows: &mut Vec<BTreeMap<String, PipelineDslGeneratorValue>>,
5291 count: Option<usize>,
5292) {
5293 if count.is_some_and(|limit| rows.len() >= limit) {
5294 return;
5295 }
5296 if entry_index == entries.len() {
5297 rows.push(current.clone());
5298 return;
5299 }
5300 let (param, values) = entries[entry_index];
5301 for value in values {
5302 current.insert(param.to_string(), value.clone());
5303 build_grid_rows(entries, entry_index + 1, current, rows, count);
5304 current.remove(param);
5305 if count.is_some_and(|limit| rows.len() >= limit) {
5306 break;
5307 }
5308 }
5309}
5310
5311fn build_combinations(
5312 value_count: usize,
5313 size: usize,
5314 start: usize,
5315 current: &mut Vec<usize>,
5316 selections: &mut Vec<Vec<usize>>,
5317 count: Option<usize>,
5318) {
5319 if count.is_some_and(|limit| selections.len() >= limit) {
5320 return;
5321 }
5322 if current.len() == size {
5323 selections.push(current.clone());
5324 return;
5325 }
5326 let remaining = size - current.len();
5327 if value_count < remaining {
5328 return;
5329 }
5330 for index in start..=value_count - remaining {
5331 current.push(index);
5332 build_combinations(value_count, size, index + 1, current, selections, count);
5333 current.pop();
5334 if count.is_some_and(|limit| selections.len() >= limit) {
5335 break;
5336 }
5337 }
5338}
5339
5340fn build_permutations(
5341 value_count: usize,
5342 size: usize,
5343 used: &mut BTreeSet<usize>,
5344 current: &mut Vec<usize>,
5345 selections: &mut Vec<Vec<usize>>,
5346 count: Option<usize>,
5347) {
5348 if count.is_some_and(|limit| selections.len() >= limit) {
5349 return;
5350 }
5351 if current.len() == size {
5352 selections.push(current.clone());
5353 return;
5354 }
5355 for index in 0..value_count {
5356 if used.contains(&index) {
5357 continue;
5358 }
5359 used.insert(index);
5360 current.push(index);
5361 build_permutations(value_count, size, used, current, selections, count);
5362 current.pop();
5363 used.remove(&index);
5364 if count.is_some_and(|limit| selections.len() >= limit) {
5365 break;
5366 }
5367 }
5368}
5369
5370fn apply_choice_count(choices: &mut Vec<GenerationChoice>, count: Option<usize>) {
5371 if let Some(limit) = count {
5372 choices.truncate(limit);
5373 }
5374}
5375
5376fn validate_count(node_id: &NodeId, name: Option<&str>, count: Option<usize>) -> Result<()> {
5377 if count == Some(0) {
5378 return Err(DagMlError::GraphValidation(format!(
5379 "pipeline DSL generator `{}` for node `{node_id}` has count=0",
5380 generator_dimension_name(node_id, name, None, "params")
5381 )));
5382 }
5383 Ok(())
5384}
5385
5386fn validate_param_name(node_id: &NodeId, param: &str) -> Result<()> {
5387 if param.trim().is_empty() {
5388 return Err(DagMlError::GraphValidation(format!(
5389 "pipeline DSL param generator for node `{node_id}` has an empty param name"
5390 )));
5391 }
5392 Ok(())
5393}
5394
5395fn validate_finite(node_id: &NodeId, param: &str, field: &str, value: f64) -> Result<()> {
5396 if !value.is_finite() {
5397 return Err(DagMlError::GraphValidation(format!(
5398 "pipeline DSL {field} for `{node_id}.{param}` must be finite"
5399 )));
5400 }
5401 Ok(())
5402}
5403
5404fn range_contains(current: f64, stop: f64, step: f64, inclusive: bool) -> bool {
5405 let epsilon = step.abs() * 1e-12 + f64::EPSILON;
5406 if step > 0.0 {
5407 if inclusive {
5408 current <= stop + epsilon
5409 } else {
5410 current < stop - epsilon
5411 }
5412 } else if inclusive {
5413 current >= stop - epsilon
5414 } else {
5415 current > stop + epsilon
5416 }
5417}
5418
5419fn json_number(value: f64, node_id: &NodeId, param: &str) -> Result<serde_json::Value> {
5420 let number = serde_json::Number::from_f64(value).ok_or_else(|| {
5421 DagMlError::GraphValidation(format!(
5422 "pipeline DSL numeric generator for `{node_id}.{param}` produced a non-finite value"
5423 ))
5424 })?;
5425 Ok(serde_json::Value::Number(number))
5426}
5427
5428fn generator_dimension_name(
5429 node_id: &NodeId,
5430 name: Option<&str>,
5431 param: Option<&str>,
5432 suffix: &str,
5433) -> String {
5434 if let Some(name) = name {
5435 return name.to_string();
5436 }
5437 match param {
5438 Some(param) => format!("{node_id}.{param}.{suffix}"),
5439 None => format!("{node_id}.{suffix}"),
5440 }
5441}
5442
5443impl PipelineDslGeneratorValue {
5444 fn value(&self) -> &serde_json::Value {
5445 match self {
5446 Self::Labeled { value, .. } | Self::Value(value) => value,
5447 }
5448 }
5449
5450 fn label_fragment(&self) -> String {
5451 match self {
5452 Self::Labeled { label, .. } => sanitize_generation_label(label),
5453 Self::Value(value) => {
5454 let rendered = match value {
5455 serde_json::Value::String(value) => value.clone(),
5456 _ => serde_json::to_string(value).unwrap_or_else(|_| "value".to_string()),
5457 };
5458 sanitize_generation_label(&rendered)
5459 }
5460 }
5461 }
5462}
5463
5464fn sanitize_generation_label(input: &str) -> String {
5465 let sanitized = input
5466 .chars()
5467 .map(|character| {
5468 if character.is_ascii_alphanumeric() || matches!(character, '_' | '-' | '.') {
5469 character
5470 } else {
5471 '_'
5472 }
5473 })
5474 .collect::<String>()
5475 .trim_matches('_')
5476 .to_string();
5477 if sanitized.is_empty() {
5478 "value".to_string()
5479 } else {
5480 sanitized
5481 }
5482}
5483
5484fn build_generation_spec(
5485 requested_strategy: Option<GenerationStrategy>,
5486 max_variants: Option<usize>,
5487 dimensions: Vec<GenerationDimension>,
5488) -> Result<GenerationSpec> {
5489 let strategy = requested_strategy.unwrap_or(if dimensions.is_empty() {
5490 GenerationStrategy::None
5491 } else {
5492 GenerationStrategy::Cartesian
5493 });
5494 let generation = GenerationSpec {
5495 strategy,
5496 dimensions,
5497 max_variants: if strategy == GenerationStrategy::None {
5498 Some(1)
5499 } else {
5500 max_variants
5501 },
5502 };
5503 generation.validate()?;
5504 Ok(generation)
5505}
5506
5507fn operator_runtime_metadata(
5508 step: &PipelineDslOperatorStep,
5509 branch_id: Option<&str>,
5510) -> Result<BTreeMap<String, serde_json::Value>> {
5511 let mut metadata = step.metadata.clone();
5512 if let Some(branch_id) = branch_id {
5513 metadata.insert(
5514 "dsl_branch".to_string(),
5515 serde_json::Value::String(branch_id.to_string()),
5516 );
5517 }
5518 insert_training_metadata(
5519 &mut metadata,
5520 &step.train_params,
5521 step.tuning.as_ref(),
5522 step.inner_cv.as_ref(),
5523 &step.id,
5524 )?;
5525 Ok(metadata)
5526}
5527
5528fn branch_context_metadata(
5529 branch_step: &PipelineDslBranchStep,
5530 branch: &PipelineDslBranch,
5531) -> Result<BTreeMap<String, serde_json::Value>> {
5532 let mut metadata = BTreeMap::new();
5533 metadata.insert(
5534 "dsl_branch".to_string(),
5535 serde_json::Value::String(branch.id.clone()),
5536 );
5537 metadata.insert(
5538 "dsl_branch_mode".to_string(),
5539 serde_json::to_value(branch_step.mode).map_err(|error| {
5540 DagMlError::GraphValidation(format!(
5541 "failed to serialize pipeline DSL branch mode for `{}`: {error}",
5542 branch.id
5543 ))
5544 })?,
5545 );
5546 if let Some(selector) = &branch_step.selector {
5547 metadata.insert("dsl_branch_step_selector".to_string(), selector.clone());
5548 }
5549 if !branch_step.metadata.is_empty() {
5550 metadata.insert(
5551 "dsl_branch_step_metadata".to_string(),
5552 serde_json::to_value(&branch_step.metadata).map_err(|error| {
5553 DagMlError::GraphValidation(format!(
5554 "failed to serialize pipeline DSL branch step metadata for `{}`: {error}",
5555 branch.id
5556 ))
5557 })?,
5558 );
5559 }
5560 if let Some(selector) = &branch.selector {
5561 metadata.insert("dsl_branch_selector".to_string(), selector.clone());
5562 }
5563 if !branch.metadata.is_empty() {
5564 metadata.insert(
5565 "dsl_branch_metadata".to_string(),
5566 serde_json::to_value(&branch.metadata).map_err(|error| {
5567 DagMlError::GraphValidation(format!(
5568 "failed to serialize pipeline DSL branch metadata for `{}`: {error}",
5569 branch.id
5570 ))
5571 })?,
5572 );
5573 }
5574 Ok(metadata)
5575}
5576
5577fn compile_branch_view_plan(
5578 branch_step: &PipelineDslBranchStep,
5579 branch: &PipelineDslBranch,
5580) -> Result<Option<BranchViewPlan>> {
5581 let Some(mode) = branch_view_mode(branch_step.mode) else {
5582 return Ok(None);
5583 };
5584 let selector = branch_view_selector(mode, branch_step.selector.as_ref(), branch)?;
5585 let mut metadata = branch.metadata.clone();
5586 if let Some(step_selector) = &branch_step.selector {
5587 metadata.insert(
5588 "dsl_branch_step_selector".to_string(),
5589 step_selector.clone(),
5590 );
5591 }
5592 if let Some(branch_selector) = &branch.selector {
5593 metadata.insert("dsl_branch_selector".to_string(), branch_selector.clone());
5594 }
5595 if !branch_step.metadata.is_empty() {
5596 metadata.insert(
5597 "dsl_branch_step_metadata".to_string(),
5598 serde_json::to_value(&branch_step.metadata).map_err(|error| {
5599 DagMlError::GraphValidation(format!(
5600 "failed to serialize pipeline DSL branch step metadata for `{}`: {error}",
5601 branch.id
5602 ))
5603 })?,
5604 );
5605 }
5606 let plan = BranchViewPlan {
5607 view_id: format!("branch_view:{}", branch.id),
5608 branch_id: branch.id.clone(),
5609 mode,
5610 selector,
5611 allow_overlap: branch_overlap_allowed(branch_step, branch),
5612 metadata,
5613 };
5614 plan.validate()
5615 .map_err(|error| DagMlError::GraphValidation(error.to_string()))?;
5616 Ok(Some(plan))
5617}
5618
5619fn branch_view_mode(mode: PipelineDslBranchMode) -> Option<BranchViewMode> {
5620 match mode {
5621 PipelineDslBranchMode::Duplication => None,
5622 PipelineDslBranchMode::Separation => Some(BranchViewMode::Separation),
5623 PipelineDslBranchMode::BySource => Some(BranchViewMode::BySource),
5624 PipelineDslBranchMode::ByMetadata => Some(BranchViewMode::ByMetadata),
5625 PipelineDslBranchMode::ByTag => Some(BranchViewMode::ByTag),
5626 PipelineDslBranchMode::ByFilter => Some(BranchViewMode::ByFilter),
5627 }
5628}
5629
5630fn branch_view_selector(
5631 mode: BranchViewMode,
5632 step_selector: Option<&serde_json::Value>,
5633 branch: &PipelineDslBranch,
5634) -> Result<DataViewSelector> {
5635 match mode {
5636 BranchViewMode::BySource => branch_view_selector_by_source(branch),
5637 BranchViewMode::ByMetadata => branch_view_selector_by_metadata(step_selector, branch),
5638 BranchViewMode::ByTag => branch_view_selector_by_tag(branch),
5639 BranchViewMode::ByFilter => branch_view_selector_by_filter(branch),
5640 BranchViewMode::Separation => branch_view_selector_generic(step_selector, branch),
5641 }
5642}
5643
5644fn branch_view_selector_by_source(branch: &PipelineDslBranch) -> Result<DataViewSelector> {
5645 let Some(selector) = &branch.selector else {
5646 return Err(DagMlError::GraphValidation(format!(
5647 "pipeline DSL by_source branch `{}` requires a selector",
5648 branch.id
5649 )));
5650 };
5651 let source_ids = selector_strings(selector, &["source", "source_id"], &["sources", "source_ids"])
5652 .or_else(|| selector.as_str().map(|value| vec![value.to_string()]))
5653 .ok_or_else(|| {
5654 DagMlError::GraphValidation(format!(
5655 "pipeline DSL by_source branch `{}` selector must be a source string or object with source/source_ids",
5656 branch.id
5657 ))
5658 })?;
5659 Ok(DataViewSelector {
5660 source_ids,
5661 ..DataViewSelector::default()
5662 })
5663}
5664
5665fn branch_view_selector_by_metadata(
5666 step_selector: Option<&serde_json::Value>,
5667 branch: &PipelineDslBranch,
5668) -> Result<DataViewSelector> {
5669 let Some(selector) = &branch.selector else {
5670 return Err(DagMlError::GraphValidation(format!(
5671 "pipeline DSL by_metadata branch `{}` requires a selector",
5672 branch.id
5673 )));
5674 };
5675 if let Some(metadata) = selector_metadata_map(selector)? {
5676 return Ok(DataViewSelector {
5677 metadata,
5678 ..DataViewSelector::default()
5679 });
5680 }
5681 let branch_key = selector
5682 .as_object()
5683 .and_then(|_| selector_metadata_key(selector));
5684 let key = branch_key
5685 .or_else(|| step_selector.and_then(selector_metadata_key))
5686 .ok_or_else(|| {
5687 DagMlError::GraphValidation(format!(
5688 "pipeline DSL by_metadata branch `{}` requires a metadata key on the branch or branch step selector",
5689 branch.id
5690 ))
5691 })?;
5692 let value = selector_value(selector).ok_or_else(|| {
5693 DagMlError::GraphValidation(format!(
5694 "pipeline DSL by_metadata branch `{}` requires a metadata value",
5695 branch.id
5696 ))
5697 })?;
5698 Ok(DataViewSelector {
5699 metadata: BTreeMap::from([(key, value)]),
5700 ..DataViewSelector::default()
5701 })
5702}
5703
5704fn branch_view_selector_by_tag(branch: &PipelineDslBranch) -> Result<DataViewSelector> {
5705 let Some(selector) = &branch.selector else {
5706 return Err(DagMlError::GraphValidation(format!(
5707 "pipeline DSL by_tag branch `{}` requires a selector",
5708 branch.id
5709 )));
5710 };
5711 let tags = selector_strings(selector, &["tag"], &["tags"])
5712 .or_else(|| selector.as_str().map(|value| vec![value.to_string()]))
5713 .ok_or_else(|| {
5714 DagMlError::GraphValidation(format!(
5715 "pipeline DSL by_tag branch `{}` selector must be a tag string or object with tag/tags",
5716 branch.id
5717 ))
5718 })?;
5719 Ok(DataViewSelector {
5720 tags,
5721 ..DataViewSelector::default()
5722 })
5723}
5724
5725fn branch_view_selector_by_filter(branch: &PipelineDslBranch) -> Result<DataViewSelector> {
5726 let Some(selector) = &branch.selector else {
5727 return Err(DagMlError::GraphValidation(format!(
5728 "pipeline DSL by_filter branch `{}` requires a selector",
5729 branch.id
5730 )));
5731 };
5732 let filter = selector
5733 .as_object()
5734 .and_then(|object| object.get("filter").cloned())
5735 .unwrap_or_else(|| selector.clone());
5736 Ok(DataViewSelector {
5737 filter: Some(filter),
5738 ..DataViewSelector::default()
5739 })
5740}
5741
5742fn branch_view_selector_generic(
5743 step_selector: Option<&serde_json::Value>,
5744 branch: &PipelineDslBranch,
5745) -> Result<DataViewSelector> {
5746 let Some(selector) = &branch.selector else {
5747 return Err(DagMlError::GraphValidation(format!(
5748 "pipeline DSL separation branch `{}` requires a selector",
5749 branch.id
5750 )));
5751 };
5752 if selector_strings(
5753 selector,
5754 &["source", "source_id"],
5755 &["sources", "source_ids"],
5756 )
5757 .is_some()
5758 || selector
5759 .as_object()
5760 .is_some_and(|object| object.contains_key("source") || object.contains_key("sources"))
5761 {
5762 return branch_view_selector_by_source(branch);
5763 }
5764 if selector_metadata_map(selector)?.is_some()
5765 || selector
5766 .as_object()
5767 .and_then(|_| selector_metadata_key(selector))
5768 .is_some()
5769 || step_selector.and_then(selector_metadata_key).is_some()
5770 {
5771 return branch_view_selector_by_metadata(step_selector, branch);
5772 }
5773 if selector_strings(selector, &["tag"], &["tags"]).is_some() {
5774 return branch_view_selector_by_tag(branch);
5775 }
5776 if selector
5777 .as_object()
5778 .is_some_and(|object| object.contains_key("filter"))
5779 {
5780 return branch_view_selector_by_filter(branch);
5781 }
5782 Err(DagMlError::GraphValidation(format!(
5783 "pipeline DSL separation branch `{}` selector must declare source_ids, metadata, tags or filter",
5784 branch.id
5785 )))
5786}
5787
5788fn selector_strings(
5789 value: &serde_json::Value,
5790 singular_keys: &[&str],
5791 plural_keys: &[&str],
5792) -> Option<Vec<String>> {
5793 let object = value.as_object()?;
5794 for key in singular_keys {
5795 if let Some(text) = object.get(*key).and_then(serde_json::Value::as_str) {
5796 return Some(vec![text.to_string()]);
5797 }
5798 }
5799 for key in plural_keys {
5800 if let Some(values) = object.get(*key).and_then(serde_json::Value::as_array) {
5801 let parsed = values
5802 .iter()
5803 .filter_map(serde_json::Value::as_str)
5804 .map(str::to_string)
5805 .collect::<Vec<_>>();
5806 if parsed.len() == values.len() {
5807 return Some(parsed);
5808 }
5809 }
5810 }
5811 None
5812}
5813
5814fn selector_metadata_map(
5815 value: &serde_json::Value,
5816) -> Result<Option<BTreeMap<String, serde_json::Value>>> {
5817 let Some(object) = value.as_object() else {
5818 return Ok(None);
5819 };
5820 let Some(metadata) = object.get("metadata") else {
5821 return Ok(None);
5822 };
5823 let Some(metadata) = metadata.as_object() else {
5824 return Err(DagMlError::GraphValidation(
5825 "pipeline DSL branch metadata selector must be an object".to_string(),
5826 ));
5827 };
5828 Ok(Some(
5829 metadata
5830 .iter()
5831 .map(|(key, value)| (key.clone(), value.clone()))
5832 .collect(),
5833 ))
5834}
5835
5836fn selector_metadata_key(value: &serde_json::Value) -> Option<String> {
5837 if let Some(text) = value.as_str() {
5838 return Some(text.to_string());
5839 }
5840 let object = value.as_object()?;
5841 ["metadata_key", "column", "key", "by_metadata"]
5842 .into_iter()
5843 .find_map(|key| object.get(key).and_then(serde_json::Value::as_str))
5844 .map(str::to_string)
5845}
5846
5847fn selector_value(value: &serde_json::Value) -> Option<serde_json::Value> {
5848 match value {
5849 serde_json::Value::String(_)
5850 | serde_json::Value::Bool(_)
5851 | serde_json::Value::Number(_) => Some(value.clone()),
5852 serde_json::Value::Object(object) => object
5853 .get("value")
5854 .or_else(|| object.get("equals"))
5855 .cloned(),
5856 _ => None,
5857 }
5858}
5859
5860fn branch_overlap_allowed(branch_step: &PipelineDslBranchStep, branch: &PipelineDslBranch) -> bool {
5861 branch
5862 .metadata
5863 .get("allow_overlap")
5864 .or_else(|| branch_step.metadata.get("allow_overlap"))
5865 .and_then(serde_json::Value::as_bool)
5866 .unwrap_or(false)
5867}
5868
5869fn branch_id_from_metadata(metadata: &BTreeMap<String, serde_json::Value>) -> Option<String> {
5870 metadata
5871 .get("dsl_branch")
5872 .and_then(|value| value.as_str())
5873 .map(str::to_string)
5874}
5875
5876fn expand_generator_sequences(step: &PipelineDslGeneratorStep) -> Result<Vec<GeneratedSequence>> {
5877 if step.count == Some(0) {
5878 return Err(DagMlError::GraphValidation(format!(
5879 "pipeline DSL generator `{}` count cannot be zero",
5880 step.id
5881 )));
5882 }
5883 match step.mode {
5884 PipelineDslGeneratorMode::Or => expand_or_generator_sequences(step),
5885 PipelineDslGeneratorMode::Cartesian => expand_cartesian_generator_sequences(step),
5886 }
5887}
5888
5889fn expand_or_generator_sequences(
5890 step: &PipelineDslGeneratorStep,
5891) -> Result<Vec<GeneratedSequence>> {
5892 if !step.stages.is_empty() {
5893 return Err(DagMlError::GraphValidation(format!(
5894 "pipeline DSL generator `{}` uses mode `or` but declares cartesian stages",
5895 step.id
5896 )));
5897 }
5898 if step.branches.is_empty() {
5899 return Err(DagMlError::GraphValidation(format!(
5900 "pipeline DSL generator `{}` has no branches",
5901 step.id
5902 )));
5903 }
5904 let options = step
5905 .branches
5906 .iter()
5907 .enumerate()
5908 .map(|(index, branch)| {
5909 validate_branch_id(&branch.id)?;
5910 Ok(GeneratedSequence {
5911 id: generator_choice_id(&step.id, index),
5912 labels: vec![branch.id.clone()],
5913 steps: branch.steps.clone(),
5914 metadata: branch.metadata.clone(),
5915 })
5916 })
5917 .collect::<Result<Vec<_>>>()?;
5918
5919 let choices = if let Some(sizes) = selection_sizes(step.pick)? {
5920 generated_pick_sequences(&options, &step.id, "pick", &sizes, step.count)?
5921 } else if let Some(sizes) = selection_sizes(step.arrange)? {
5922 generated_arrange_sequences(&options, &step.id, "arrange", &sizes, step.count)?
5923 } else {
5924 truncate_generated_sequences(options, step.count)
5925 };
5926
5927 let choices = if let Some(sizes) = selection_sizes(step.then_pick)? {
5928 generated_pick_sequences(&choices, &step.id, "then_pick", &sizes, step.count)?
5929 } else if let Some(sizes) = selection_sizes(step.then_arrange)? {
5930 generated_arrange_sequences(&choices, &step.id, "then_arrange", &sizes, step.count)?
5931 } else {
5932 choices
5933 };
5934 Ok(truncate_generated_sequences(choices, step.count))
5935}
5936
5937fn expand_cartesian_generator_sequences(
5938 step: &PipelineDslGeneratorStep,
5939) -> Result<Vec<GeneratedSequence>> {
5940 if !step.branches.is_empty() {
5941 return Err(DagMlError::GraphValidation(format!(
5942 "pipeline DSL generator `{}` uses mode `cartesian` but declares direct branches",
5943 step.id
5944 )));
5945 }
5946 if step.stages.is_empty() {
5947 return Err(DagMlError::GraphValidation(format!(
5948 "pipeline DSL generator `{}` has no cartesian stages",
5949 step.id
5950 )));
5951 }
5952 if step.pick.is_some()
5953 || step.arrange.is_some()
5954 || step.then_pick.is_some()
5955 || step.then_arrange.is_some()
5956 {
5957 return Err(DagMlError::GraphValidation(format!(
5958 "pipeline DSL generator `{}` cannot combine cartesian mode with pick/arrange selectors",
5959 step.id
5960 )));
5961 }
5962
5963 let mut stage_options = Vec::<Vec<GeneratedSequence>>::new();
5964 for (stage_index, stage) in step.stages.iter().enumerate() {
5965 validate_branch_id(&stage.id)?;
5966 if stage.branches.is_empty() {
5967 return Err(DagMlError::GraphValidation(format!(
5968 "pipeline DSL generator `{}` stage `{}` has no branches",
5969 step.id, stage.id
5970 )));
5971 }
5972 let mut options = Vec::new();
5973 for branch in &stage.branches {
5974 validate_branch_id(&branch.id)?;
5975 let mut metadata = branch.metadata.clone();
5976 if let Some(selector) = &stage.selector {
5977 metadata.insert("dsl_generator_stage_selector".to_string(), selector.clone());
5978 }
5979 if !stage.metadata.is_empty() {
5980 metadata.insert(
5981 "dsl_generator_stage_metadata".to_string(),
5982 serde_json::to_value(&stage.metadata).map_err(|error| {
5983 DagMlError::GraphValidation(format!(
5984 "failed to serialize pipeline DSL generator `{}` stage `{}` metadata: {error}",
5985 step.id, stage.id
5986 ))
5987 })?,
5988 );
5989 }
5990 options.push(GeneratedSequence {
5991 id: format!("{stage_index}:{}", branch.id),
5992 labels: vec![format!("{}:{}", stage.id, branch.id)],
5993 steps: branch.steps.clone(),
5994 metadata,
5995 });
5996 }
5997 stage_options.push(options);
5998 }
5999
6000 let mut rows = Vec::<Vec<usize>>::new();
6001 build_cartesian_indices(&stage_options, 0, &mut Vec::new(), &mut rows, step.count);
6002 let mut choices = Vec::with_capacity(rows.len());
6003 for (choice_index, row) in rows.into_iter().enumerate() {
6004 let selected = row
6005 .into_iter()
6006 .enumerate()
6007 .map(|(stage_index, option_index)| stage_options[stage_index][option_index].clone())
6008 .collect::<Vec<_>>();
6009 choices.push(merge_generated_sequence(
6010 generator_choice_id(&step.id, choice_index),
6011 selected,
6012 )?);
6013 }
6014 Ok(choices)
6015}
6016
6017fn generated_pick_sequences(
6018 options: &[GeneratedSequence],
6019 generator_id: &NodeId,
6020 mode: &str,
6021 sizes: &[usize],
6022 count: Option<usize>,
6023) -> Result<Vec<GeneratedSequence>> {
6024 let mut selections = Vec::<Vec<usize>>::new();
6025 for size in sizes {
6026 if *size == 0 || *size > options.len() {
6027 return Err(DagMlError::GraphValidation(format!(
6028 "pipeline DSL generator `{generator_id}` {mode} size {size} is outside 1..={}",
6029 options.len()
6030 )));
6031 }
6032 build_combinations(
6033 options.len(),
6034 *size,
6035 0,
6036 &mut Vec::new(),
6037 &mut selections,
6038 count,
6039 );
6040 }
6041 selections
6042 .into_iter()
6043 .enumerate()
6044 .map(|(index, selection)| {
6045 let selected = selection
6046 .into_iter()
6047 .map(|option_index| options[option_index].clone())
6048 .collect::<Vec<_>>();
6049 merge_generated_sequence(generator_choice_id(generator_id, index), selected)
6050 })
6051 .collect()
6052}
6053
6054fn generated_arrange_sequences(
6055 options: &[GeneratedSequence],
6056 generator_id: &NodeId,
6057 mode: &str,
6058 sizes: &[usize],
6059 count: Option<usize>,
6060) -> Result<Vec<GeneratedSequence>> {
6061 let mut selections = Vec::<Vec<usize>>::new();
6062 for size in sizes {
6063 if *size == 0 || *size > options.len() {
6064 return Err(DagMlError::GraphValidation(format!(
6065 "pipeline DSL generator `{generator_id}` {mode} size {size} is outside 1..={}",
6066 options.len()
6067 )));
6068 }
6069 build_permutations(
6070 options.len(),
6071 *size,
6072 &mut BTreeSet::new(),
6073 &mut Vec::new(),
6074 &mut selections,
6075 count,
6076 );
6077 }
6078 selections
6079 .into_iter()
6080 .enumerate()
6081 .map(|(index, selection)| {
6082 let selected = selection
6083 .into_iter()
6084 .map(|option_index| options[option_index].clone())
6085 .collect::<Vec<_>>();
6086 merge_generated_sequence(generator_choice_id(generator_id, index), selected)
6087 })
6088 .collect()
6089}
6090
6091fn merge_generated_sequence(
6092 id: String,
6093 sequences: Vec<GeneratedSequence>,
6094) -> Result<GeneratedSequence> {
6095 if sequences.is_empty() {
6096 return Err(DagMlError::GraphValidation(format!(
6097 "pipeline DSL generated sequence `{id}` has no selected options"
6098 )));
6099 }
6100 let mut labels = Vec::new();
6101 let mut steps = Vec::new();
6102 let mut metadata = BTreeMap::new();
6103 for sequence in sequences {
6104 labels.extend(sequence.labels);
6105 steps.extend(sequence.steps);
6106 if !sequence.metadata.is_empty() {
6107 metadata.insert(
6108 format!("option:{}", metadata.len()),
6109 serde_json::to_value(sequence.metadata).map_err(|error| {
6110 DagMlError::GraphValidation(format!(
6111 "failed to serialize generated sequence `{id}` metadata: {error}"
6112 ))
6113 })?,
6114 );
6115 }
6116 }
6117 Ok(GeneratedSequence {
6118 id,
6119 labels,
6120 steps,
6121 metadata,
6122 })
6123}
6124
6125fn truncate_generated_sequences(
6126 mut sequences: Vec<GeneratedSequence>,
6127 count: Option<usize>,
6128) -> Vec<GeneratedSequence> {
6129 if let Some(limit) = count {
6130 sequences.truncate(limit);
6131 }
6132 sequences
6133}
6134
6135fn build_cartesian_indices<T>(
6136 stages: &[Vec<T>],
6137 stage_index: usize,
6138 current: &mut Vec<usize>,
6139 rows: &mut Vec<Vec<usize>>,
6140 count: Option<usize>,
6141) {
6142 if count.is_some_and(|limit| rows.len() >= limit) {
6143 return;
6144 }
6145 if stage_index == stages.len() {
6146 rows.push(current.clone());
6147 return;
6148 }
6149 for option_index in 0..stages[stage_index].len() {
6150 current.push(option_index);
6151 build_cartesian_indices(stages, stage_index + 1, current, rows, count);
6152 current.pop();
6153 if count.is_some_and(|limit| rows.len() >= limit) {
6154 break;
6155 }
6156 }
6157}
6158
6159fn selection_sizes(selection: Option<PipelineDslSelectionSpec>) -> Result<Option<Vec<usize>>> {
6160 selection
6161 .map(|selection| match selection {
6162 PipelineDslSelectionSpec::Single(size) => {
6163 if size == 0 {
6164 return Err(DagMlError::GraphValidation(
6165 "pipeline DSL generator selection size cannot be zero".to_string(),
6166 ));
6167 }
6168 Ok(vec![size])
6169 }
6170 PipelineDslSelectionSpec::Range([start, stop]) => {
6171 if start == 0 || stop == 0 || start > stop {
6172 return Err(DagMlError::GraphValidation(format!(
6173 "pipeline DSL generator selection range [{start}, {stop}] is invalid"
6174 )));
6175 }
6176 Ok((start..=stop).collect())
6177 }
6178 })
6179 .transpose()
6180}
6181
6182fn generator_choice_id(generator_id: &NodeId, choice_index: usize) -> String {
6183 format!("{generator_id}:choice{choice_index}")
6184}
6185
6186fn generator_choice_metadata(
6187 step: &PipelineDslGeneratorStep,
6188 choice: &GeneratedSequence,
6189) -> Result<BTreeMap<String, serde_json::Value>> {
6190 let mut metadata = step.metadata.clone();
6191 metadata.insert(
6192 "dsl_generator".to_string(),
6193 serde_json::Value::String(step.id.to_string()),
6194 );
6195 metadata.insert(
6196 "dsl_generator_mode".to_string(),
6197 serde_json::to_value(step.mode).map_err(|error| {
6198 DagMlError::GraphValidation(format!(
6199 "failed to serialize pipeline DSL generator `{}` mode: {error}",
6200 step.id
6201 ))
6202 })?,
6203 );
6204 metadata.insert(
6205 "dsl_generator_choice".to_string(),
6206 serde_json::Value::String(choice.id.clone()),
6207 );
6208 metadata.insert(
6209 "dsl_generator_labels".to_string(),
6210 serde_json::to_value(&choice.labels).map_err(|error| {
6211 DagMlError::GraphValidation(format!(
6212 "failed to serialize pipeline DSL generator `{}` choice labels: {error}",
6213 step.id
6214 ))
6215 })?,
6216 );
6217 if !choice.metadata.is_empty() {
6218 metadata.insert(
6219 "dsl_generator_choice_metadata".to_string(),
6220 serde_json::to_value(&choice.metadata).map_err(|error| {
6221 DagMlError::GraphValidation(format!(
6222 "failed to serialize pipeline DSL generator `{}` choice metadata: {error}",
6223 step.id
6224 ))
6225 })?,
6226 );
6227 }
6228 Ok(metadata)
6229}
6230
6231fn namespace_generated_sequence(
6232 generator: &PipelineDslGeneratorStep,
6233 mut choice: GeneratedSequence,
6234 choice_index: usize,
6235) -> Result<GeneratedSequence> {
6236 let mut node_map = BTreeMap::<NodeId, NodeId>::new();
6237 let mut counter = 0usize;
6238 for step in &mut choice.steps {
6239 namespace_step_ids(generator, choice_index, step, &mut counter, &mut node_map)?;
6240 }
6241 for step in &mut choice.steps {
6242 rewrite_step_node_refs(step, &node_map);
6243 }
6244 Ok(choice)
6245}
6246
6247fn namespace_step_ids(
6248 generator: &PipelineDslGeneratorStep,
6249 choice_index: usize,
6250 step: &mut PipelineDslStep,
6251 counter: &mut usize,
6252 node_map: &mut BTreeMap<NodeId, NodeId>,
6253) -> Result<()> {
6254 match step {
6255 PipelineDslStep::Transform(step)
6256 | PipelineDslStep::YTransform(step)
6257 | PipelineDslStep::Tag(step)
6258 | PipelineDslStep::Exclude(step)
6259 | PipelineDslStep::Filter(step)
6260 | PipelineDslStep::SampleFilter(step)
6261 | PipelineDslStep::Augmentation(step)
6262 | PipelineDslStep::FeatureAugmentation(step)
6263 | PipelineDslStep::SampleAugmentation(step)
6264 | PipelineDslStep::DataGeneration(step)
6265 | PipelineDslStep::Model(step)
6266 | PipelineDslStep::Tuner(step)
6267 | PipelineDslStep::Chart(step) => {
6268 namespace_operator_step_id(generator, choice_index, step, counter, node_map)?;
6269 }
6270 PipelineDslStep::ConcatTransform(step) => {
6271 namespace_node_id_field(generator, choice_index, &mut step.id, counter, node_map)?;
6272 for branch in &mut step.branches {
6273 for branch_step in &mut branch.steps {
6274 namespace_operator_step_id(
6275 generator,
6276 choice_index,
6277 branch_step,
6278 counter,
6279 node_map,
6280 )?;
6281 }
6282 }
6283 }
6284 PipelineDslStep::Branch(step) => {
6285 for branch in &mut step.branches {
6286 for branch_step in &mut branch.steps {
6287 namespace_step_ids(generator, choice_index, branch_step, counter, node_map)?;
6288 }
6289 }
6290 }
6291 PipelineDslStep::Generator(step) => {
6292 namespace_node_id_field(generator, choice_index, &mut step.id, counter, node_map)?;
6293 for branch in &mut step.branches {
6294 for branch_step in &mut branch.steps {
6295 namespace_step_ids(generator, choice_index, branch_step, counter, node_map)?;
6296 }
6297 }
6298 for stage in &mut step.stages {
6299 for branch in &mut stage.branches {
6300 for branch_step in &mut branch.steps {
6301 namespace_step_ids(
6302 generator,
6303 choice_index,
6304 branch_step,
6305 counter,
6306 node_map,
6307 )?;
6308 }
6309 }
6310 }
6311 }
6312 PipelineDslStep::Sequential(step) => {
6313 if let Some(id) = &mut step.id {
6314 namespace_node_id_field(generator, choice_index, id, counter, node_map)?;
6315 }
6316 for child in &mut step.steps {
6317 namespace_step_ids(generator, choice_index, child, counter, node_map)?;
6318 }
6319 }
6320 PipelineDslStep::Merge(step) => {
6321 namespace_node_id_field(generator, choice_index, &mut step.id, counter, node_map)?;
6322 }
6323 PipelineDslStep::MergeModel(step) => {
6324 namespace_node_id_field(generator, choice_index, &mut step.id, counter, node_map)?;
6325 }
6326 }
6327 Ok(())
6328}
6329
6330fn namespace_operator_step_id(
6331 generator: &PipelineDslGeneratorStep,
6332 choice_index: usize,
6333 step: &mut PipelineDslOperatorStep,
6334 counter: &mut usize,
6335 node_map: &mut BTreeMap<NodeId, NodeId>,
6336) -> Result<()> {
6337 namespace_node_id_field(generator, choice_index, &mut step.id, counter, node_map)
6338}
6339
6340fn namespace_node_id_field(
6341 generator: &PipelineDslGeneratorStep,
6342 choice_index: usize,
6343 node_id: &mut NodeId,
6344 counter: &mut usize,
6345 node_map: &mut BTreeMap<NodeId, NodeId>,
6346) -> Result<()> {
6347 let original = node_id.clone();
6348 if node_map.contains_key(&original) {
6349 return Err(DagMlError::GraphValidation(format!(
6350 "pipeline DSL generator `{}` choice `{}` reuses node id `{original}`; generated choices require unique node ids inside each expanded sequence",
6351 generator.id, choice_index
6352 )));
6353 }
6354 let next = namespaced_generated_node_id(&generator.id, choice_index, *counter, &original)?;
6355 *counter += 1;
6356 *node_id = next.clone();
6357 node_map.insert(original, next);
6358 Ok(())
6359}
6360
6361fn namespaced_generated_node_id(
6362 generator_id: &NodeId,
6363 choice_index: usize,
6364 node_index: usize,
6365 original: &NodeId,
6366) -> Result<NodeId> {
6367 let generator = sanitized_id_fragment(generator_id.as_str(), 32);
6368 let suffix = sanitized_id_fragment(original.as_str(), 28);
6369 NodeId::new(format!(
6370 "gen:{generator}:c{choice_index}:n{node_index}.{suffix}"
6371 ))
6372}
6373
6374fn sanitized_id_fragment(input: &str, max_len: usize) -> String {
6375 let sanitized = sanitize_generation_label(input);
6376 let mut fragment = sanitized.chars().take(max_len).collect::<String>();
6377 if fragment.is_empty() {
6378 fragment = "x".to_string();
6379 }
6380 fragment
6381}
6382
6383fn rewrite_step_node_refs(step: &mut PipelineDslStep, node_map: &BTreeMap<NodeId, NodeId>) {
6384 match step {
6385 PipelineDslStep::Transform(_)
6386 | PipelineDslStep::YTransform(_)
6387 | PipelineDslStep::Tag(_)
6388 | PipelineDslStep::Exclude(_)
6389 | PipelineDslStep::Filter(_)
6390 | PipelineDslStep::SampleFilter(_)
6391 | PipelineDslStep::Augmentation(_)
6392 | PipelineDslStep::FeatureAugmentation(_)
6393 | PipelineDslStep::SampleAugmentation(_)
6394 | PipelineDslStep::DataGeneration(_)
6395 | PipelineDslStep::Model(_)
6396 | PipelineDslStep::Tuner(_)
6397 | PipelineDslStep::Chart(_) => {}
6398 PipelineDslStep::ConcatTransform(step) => {
6399 for branch in &mut step.branches {
6400 for branch_step in &mut branch.steps {
6401 rewrite_operator_step_refs(branch_step, node_map);
6402 }
6403 }
6404 }
6405 PipelineDslStep::Branch(step) => {
6406 for branch in &mut step.branches {
6407 for branch_step in &mut branch.steps {
6408 rewrite_step_node_refs(branch_step, node_map);
6409 }
6410 }
6411 }
6412 PipelineDslStep::Generator(step) => {
6413 for branch in &mut step.branches {
6414 for branch_step in &mut branch.steps {
6415 rewrite_step_node_refs(branch_step, node_map);
6416 }
6417 }
6418 for stage in &mut step.stages {
6419 for branch in &mut stage.branches {
6420 for branch_step in &mut branch.steps {
6421 rewrite_step_node_refs(branch_step, node_map);
6422 }
6423 }
6424 }
6425 }
6426 PipelineDslStep::Sequential(step) => {
6427 for child in &mut step.steps {
6428 rewrite_step_node_refs(child, node_map);
6429 }
6430 }
6431 PipelineDslStep::Merge(step) => {
6432 rewrite_merge_selectors(&mut step.selectors, node_map);
6433 }
6434 PipelineDslStep::MergeModel(_) => {}
6435 }
6436}
6437
6438fn rewrite_operator_step_refs(
6439 _step: &mut PipelineDslOperatorStep,
6440 _node_map: &BTreeMap<NodeId, NodeId>,
6441) {
6442}
6443
6444fn rewrite_merge_selectors(
6445 selectors: &mut [PipelineDslMergeSelector],
6446 node_map: &BTreeMap<NodeId, NodeId>,
6447) {
6448 for selector in selectors {
6449 if let Some(model) = &selector.model {
6450 if let Some(rewritten) = node_map.get(model) {
6451 selector.model = Some(rewritten.clone());
6452 }
6453 }
6454 }
6455}
6456
6457fn validate_merge_selectors(
6458 merge_id: &NodeId,
6459 selectors: &[PipelineDslMergeSelector],
6460 predictions: &[PredictionSource],
6461) -> Result<()> {
6462 if selectors.is_empty() {
6463 return Ok(());
6464 }
6465 if predictions.is_empty() {
6466 return Err(DagMlError::GraphValidation(format!(
6467 "pipeline DSL merge `{merge_id}` declares selectors but has no prediction inputs"
6468 )));
6469 }
6470 for (selector_index, selector) in selectors.iter().enumerate() {
6471 let mut matched = predictions.iter().collect::<Vec<_>>();
6472 if let Some(input_name) = &selector.input_name {
6473 if input_name.trim().is_empty() {
6474 return Err(DagMlError::GraphValidation(format!(
6475 "pipeline DSL merge `{merge_id}` selector {selector_index} has an empty input_name"
6476 )));
6477 }
6478 matched.retain(|prediction| prediction.input_name == *input_name);
6479 }
6480 if let Some(branch) = &selector.branch {
6481 if branch.trim().is_empty() {
6482 return Err(DagMlError::GraphValidation(format!(
6483 "pipeline DSL merge `{merge_id}` selector {selector_index} has an empty branch"
6484 )));
6485 }
6486 matched.retain(|prediction| prediction.branch_id.as_deref() == Some(branch.as_str()));
6487 }
6488 if let Some(model) = &selector.model {
6489 matched.retain(|prediction| prediction.node_id == *model);
6490 }
6491 if matched.is_empty() {
6492 return Err(DagMlError::GraphValidation(format!(
6493 "pipeline DSL merge `{merge_id}` selector {selector_index} does not match any pending prediction input"
6494 )));
6495 }
6496 validate_merge_selector_select(merge_id, selector_index, selector, matched.len())?;
6497 }
6498 Ok(())
6499}
6500
6501fn validate_merge_selector_select(
6502 merge_id: &NodeId,
6503 selector_index: usize,
6504 selector: &PipelineDslMergeSelector,
6505 matched_count: usize,
6506) -> Result<()> {
6507 let Some(select) = &selector.select else {
6508 return Ok(());
6509 };
6510 if let Some(mode) = select.as_str() {
6511 match mode {
6512 "all" => return Ok(()),
6513 "best" => {
6514 require_selector_metric(merge_id, selector_index, selector, mode)?;
6515 return Ok(());
6516 }
6517 _ => {
6518 return Err(DagMlError::GraphValidation(format!(
6519 "pipeline DSL merge `{merge_id}` selector {selector_index} has unsupported select mode `{mode}`"
6520 )));
6521 }
6522 }
6523 }
6524 let Some(object) = select.as_object() else {
6525 return Err(DagMlError::GraphValidation(format!(
6526 "pipeline DSL merge `{merge_id}` selector {selector_index} select must be `all`, `best` or an object with `top_k`"
6527 )));
6528 };
6529 if object.len() != 1 || !object.contains_key("top_k") {
6530 return Err(DagMlError::GraphValidation(format!(
6531 "pipeline DSL merge `{merge_id}` selector {selector_index} object select currently supports only `top_k`"
6532 )));
6533 }
6534 let Some(top_k) = object.get("top_k").and_then(|value| value.as_u64()) else {
6535 return Err(DagMlError::GraphValidation(format!(
6536 "pipeline DSL merge `{merge_id}` selector {selector_index} top_k must be a positive integer"
6537 )));
6538 };
6539 if top_k == 0 {
6540 return Err(DagMlError::GraphValidation(format!(
6541 "pipeline DSL merge `{merge_id}` selector {selector_index} top_k must be positive"
6542 )));
6543 }
6544 if top_k as usize > matched_count {
6545 return Err(DagMlError::GraphValidation(format!(
6546 "pipeline DSL merge `{merge_id}` selector {selector_index} top_k={top_k} exceeds {matched_count} matched prediction inputs"
6547 )));
6548 }
6549 require_selector_metric(merge_id, selector_index, selector, "top_k")
6550}
6551
6552fn require_selector_metric(
6553 merge_id: &NodeId,
6554 selector_index: usize,
6555 selector: &PipelineDslMergeSelector,
6556 select_mode: &str,
6557) -> Result<()> {
6558 if selector
6559 .metric
6560 .as_ref()
6561 .is_some_and(|metric| !metric.trim().is_empty())
6562 {
6563 return Ok(());
6564 }
6565 Err(DagMlError::GraphValidation(format!(
6566 "pipeline DSL merge `{merge_id}` selector {selector_index} select `{select_mode}` requires a non-empty metric"
6567 )))
6568}
6569
6570fn insert_training_metadata(
6571 metadata: &mut BTreeMap<String, serde_json::Value>,
6572 train_params: &BTreeMap<String, serde_json::Value>,
6573 tuning: Option<&PipelineDslTuningSpec>,
6574 inner_cv: Option<&NestedCvSpec>,
6575 node_id: &NodeId,
6576) -> Result<()> {
6577 if let Some(inner_cv) = inner_cv {
6578 metadata.insert(
6581 "dsl_inner_cv".to_string(),
6582 serde_json::to_value(inner_cv).map_err(|error| {
6583 DagMlError::GraphValidation(format!(
6584 "failed to serialize pipeline DSL inner_cv for node `{node_id}`: {error}"
6585 ))
6586 })?,
6587 );
6588 }
6589 if !train_params.is_empty() {
6590 metadata.insert(
6591 "dsl_train_params".to_string(),
6592 serde_json::to_value(train_params).map_err(|error| {
6593 DagMlError::GraphValidation(format!(
6594 "failed to serialize pipeline DSL train params for node `{node_id}`: {error}"
6595 ))
6596 })?,
6597 );
6598 }
6599 if let Some(tuning) = tuning {
6600 metadata.insert(
6601 "dsl_tuning".to_string(),
6602 serde_json::to_value(tuning).map_err(|error| {
6603 DagMlError::GraphValidation(format!(
6604 "failed to serialize pipeline DSL tuning for node `{node_id}`: {error}"
6605 ))
6606 })?,
6607 );
6608 }
6609 Ok(())
6610}
6611
6612fn same_data_source(left: &DataSource, right: &DataSource) -> bool {
6613 left.node_id == right.node_id
6614 && left.port_name == right.port_name
6615 && left.representation == right.representation
6616}
6617
6618fn merge_consumes_predictions(step: &PipelineDslMergeStep) -> bool {
6619 match step.output_as {
6620 PipelineDslMergeOutput::Predictions => true,
6621 PipelineDslMergeOutput::Sources => false,
6622 PipelineDslMergeOutput::Features => {
6623 matches!(
6624 step.merge_mode.as_str(),
6625 "predictions" | "prediction" | "all" | "mixed" | "predictions_plus_original"
6626 ) || !step.selectors.is_empty()
6627 }
6628 }
6629}
6630
6631fn merge_consumes_branch_data(step: &PipelineDslMergeStep) -> bool {
6632 match step.output_as {
6633 PipelineDslMergeOutput::Predictions => false,
6634 PipelineDslMergeOutput::Sources => true,
6635 PipelineDslMergeOutput::Features => matches!(
6636 step.merge_mode.as_str(),
6637 "features" | "feature" | "concat" | "all" | "mixed" | "sources" | "source"
6638 ),
6639 }
6640}
6641
6642fn merge_node_kind(
6643 step: &PipelineDslMergeStep,
6644 has_predictions: bool,
6645 has_branch_data: bool,
6646) -> NodeKind {
6647 match step.output_as {
6648 PipelineDslMergeOutput::Predictions => NodeKind::PredictionJoin,
6649 PipelineDslMergeOutput::Sources => NodeKind::SourceJoin,
6650 PipelineDslMergeOutput::Features => {
6651 if has_predictions && (step.include_original_data || has_branch_data) {
6652 NodeKind::MixedJoin
6653 } else if has_predictions {
6654 NodeKind::PredictionJoin
6655 } else {
6656 NodeKind::FeatureJoin
6657 }
6658 }
6659 }
6660}
6661
6662fn data_port(name: &str, representation: Option<String>, description: &str) -> PortSpec {
6663 PortSpec {
6664 name: name.to_string(),
6665 kind: PortKind::Data,
6666 representation,
6667 cardinality: PortCardinality::One,
6668 unit_level: None,
6669 alignment_key: None,
6670 target_level: None,
6671 description: description.to_string(),
6672 }
6673}
6674
6675fn apply_data_unit_contract(port: &mut PortSpec, contract: &PipelineDslDataPort) {
6676 port.unit_level = contract.unit_level;
6677 port.alignment_key = contract.alignment_key.clone();
6678 port.target_level = contract.target_level;
6679}
6680
6681fn target_port(name: &str, description: &str) -> PortSpec {
6682 PortSpec {
6683 name: name.to_string(),
6684 kind: PortKind::Target,
6685 representation: None,
6686 cardinality: PortCardinality::One,
6687 unit_level: None,
6688 alignment_key: None,
6689 target_level: None,
6690 description: description.to_string(),
6691 }
6692}
6693
6694fn prediction_port(name: &str, description: &str) -> PortSpec {
6695 PortSpec {
6696 name: name.to_string(),
6697 kind: PortKind::Prediction,
6698 representation: None,
6699 cardinality: PortCardinality::One,
6700 unit_level: None,
6701 alignment_key: None,
6702 target_level: None,
6703 description: description.to_string(),
6704 }
6705}
6706
6707fn apply_prediction_unit_contract(port: &mut PortSpec, contract: &PipelineDslPredictionPort) {
6708 port.representation = contract.representation.clone();
6709 port.unit_level = contract.unit_level;
6710 port.alignment_key = contract.alignment_key.clone();
6711 port.target_level = contract.target_level;
6712}
6713
6714fn validate_branch_id(branch_id: &str) -> Result<()> {
6715 if branch_id.trim().is_empty() {
6716 return Err(DagMlError::GraphValidation(
6717 "pipeline DSL branch id must not be empty".to_string(),
6718 ));
6719 }
6720 if !branch_id
6721 .bytes()
6722 .all(|byte| byte.is_ascii_alphanumeric() || matches!(byte, b'_' | b'-' | b'.' | b':'))
6723 {
6724 return Err(DagMlError::GraphValidation(format!(
6725 "pipeline DSL branch id `{branch_id}` contains unsupported characters"
6726 )));
6727 }
6728 Ok(())
6729}
6730
6731fn branch_input_prefix(branch_id: &str, index: usize) -> String {
6732 let sanitized = branch_id
6733 .chars()
6734 .map(|character| {
6735 if character.is_ascii_alphanumeric() || character == '_' {
6736 character
6737 } else {
6738 '_'
6739 }
6740 })
6741 .collect::<String>()
6742 .trim_matches('_')
6743 .to_string();
6744 if sanitized.is_empty() {
6745 format!("branch{index}")
6746 } else {
6747 sanitized
6748 }
6749}
6750
6751fn branch_prediction_input_name(
6752 branch_id: &str,
6753 branch_index: usize,
6754 prediction_index: usize,
6755 node_id: &NodeId,
6756) -> String {
6757 let branch = branch_input_prefix(branch_id, branch_index);
6758 let model = node_id
6759 .as_str()
6760 .chars()
6761 .map(|character| {
6762 if character.is_ascii_alphanumeric() || character == '_' {
6763 character
6764 } else {
6765 '_'
6766 }
6767 })
6768 .collect::<String>()
6769 .trim_matches('_')
6770 .to_string();
6771 if model.is_empty() {
6772 format!("{branch}_model{prediction_index}_oof")
6773 } else {
6774 format!("{branch}_{model}_oof")
6775 }
6776}
6777
6778fn default_input_name() -> String {
6779 "x".to_string()
6780}
6781
6782fn default_output_name() -> String {
6783 "prediction".to_string()
6784}
6785
6786fn default_data_representation() -> String {
6787 "tabular_numeric".to_string()
6788}
6789
6790fn default_true() -> bool {
6791 true
6792}
6793
6794fn default_log_base() -> f64 {
6795 10.0
6796}
6797
6798fn default_merge_mode() -> String {
6799 "predictions_plus_original".to_string()
6800}
6801
6802#[cfg(test)]
6803mod tests {
6804 use super::*;
6805 use crate::controller::{
6806 ArtifactPolicy, ControllerCapability, ControllerFitScope, ControllerManifest,
6807 OperatorSelector, RngPolicy,
6808 };
6809 use crate::phase::Phase;
6810
6811 fn registry_manifest(id: &str, kind: NodeKind, aliases: &[&str]) -> ControllerManifest {
6812 ControllerManifest {
6813 controller_id: crate::ids::ControllerId::new(id).unwrap(),
6814 controller_version: "0.1.0".to_string(),
6815 operator_kind: kind,
6816 priority: 0,
6817 supported_phases: BTreeSet::from([Phase::FitCv]),
6818 input_ports: Vec::new(),
6819 output_ports: Vec::new(),
6820 data_requirements: None,
6821 capabilities: BTreeSet::from([ControllerCapability::Deterministic]),
6822 operator_selectors: vec![OperatorSelector {
6823 aliases: aliases.iter().map(|alias| (*alias).to_string()).collect(),
6824 ..OperatorSelector::default()
6825 }],
6826 fit_scope: ControllerFitScope::FoldTrain,
6827 rng_policy: RngPolicy::UsesCoreSeed,
6828 artifact_policy: ArtifactPolicy::Serializable,
6829 }
6830 }
6831
6832 #[test]
6833 fn compiles_linear_pipeline_dsl_to_valid_graph() {
6834 let spec: PipelineDslSpec = serde_json::from_str(
6835 r#"{
6836 "id": "dsl-linear-smoke",
6837 "steps": [
6838 {
6839 "kind": "transform",
6840 "id": "transform:snv",
6841 "operator": {"type": "StandardNormalVariate"},
6842 "seed_label": "snv"
6843 },
6844 {
6845 "kind": "model",
6846 "id": "model:base",
6847 "operator": {"type": "RandomForestRegressor"},
6848 "params": {"n_estimators": 100},
6849 "seed_label": "base"
6850 }
6851 ]
6852}"#,
6853 )
6854 .unwrap();
6855
6856 let graph = compile_pipeline_dsl(&spec).unwrap();
6857
6858 assert_eq!(graph.id, "dsl-linear-smoke");
6859 assert_eq!(graph.nodes.len(), 2);
6860 assert_eq!(graph.edges.len(), 1);
6861 assert_eq!(graph.nodes[0].kind, NodeKind::Transform);
6862 assert_eq!(graph.nodes[1].kind, NodeKind::Model);
6863 assert_eq!(graph.edges[0].source.node_id.as_str(), "transform:snv");
6864 assert_eq!(graph.edges[0].target.node_id.as_str(), "model:base");
6865 assert_eq!(graph.edges[0].contract.kind, PortKind::Data);
6866 graph.validate().unwrap();
6867 }
6868
6869 #[test]
6870 fn compiles_pipeline_dsl_unit_contracts_to_graph_interface() {
6871 let spec: PipelineDslSpec = serde_json::from_str(
6872 r#"{
6873 "id": "dsl-unit-contract-smoke",
6874 "input": {
6875 "name": "spectra",
6876 "representation": "tabular",
6877 "unit_level": "observation",
6878 "alignment_key": "sample_id",
6879 "target_level": "physical_sample"
6880 },
6881 "output": {
6882 "name": "prediction",
6883 "representation": "regression",
6884 "unit_level": "physical_sample",
6885 "alignment_key": "sample_id",
6886 "target_level": "physical_sample"
6887 },
6888 "steps": [
6889 {
6890 "kind": "model",
6891 "id": "model:base",
6892 "operator": {"type": "RandomForestRegressor"}
6893 }
6894 ]
6895}"#,
6896 )
6897 .unwrap();
6898
6899 let graph = compile_pipeline_dsl(&spec).unwrap();
6900
6901 assert_eq!(
6902 graph.interface.inputs[0].unit_level,
6903 Some(EntityUnitLevel::Observation)
6904 );
6905 assert_eq!(
6906 graph.interface.inputs[0].alignment_key.as_deref(),
6907 Some("sample_id")
6908 );
6909 assert_eq!(
6910 graph.interface.outputs[0].unit_level,
6911 Some(EntityUnitLevel::PhysicalSample)
6912 );
6913 assert_eq!(
6914 graph.interface.outputs[0].representation.as_deref(),
6915 Some("regression")
6916 );
6917 }
6918
6919 #[test]
6920 fn compiles_branch_merge_predictions_plus_original_dsl() {
6921 let spec: PipelineDslSpec = serde_json::from_str(
6922 r#"{
6923 "id": "dsl-branch-merge-smoke",
6924 "steps": [
6925 {
6926 "kind": "branch",
6927 "branches": [
6928 {
6929 "id": "b0",
6930 "steps": [
6931 {
6932 "kind": "model",
6933 "id": "branch:b0.model:ridge",
6934 "operator": {"type": "Ridge"},
6935 "params": {"alpha": 0.3},
6936 "seed_label": "branch:b0"
6937 }
6938 ]
6939 },
6940 {
6941 "id": "b1",
6942 "steps": [
6943 {
6944 "kind": "augmentation",
6945 "id": "branch:b1.augment:noise",
6946 "operator": {"type": "GaussianNoise"},
6947 "params": {"scope": "train_only"},
6948 "seed_label": "branch:b1.augment",
6949 "shape": {
6950 "fit_rows": "fold_train",
6951 "predict_rows": "fold_validation",
6952 "augmentation_policy": {
6953 "sample_scope": "train_only",
6954 "feature_scope": "none",
6955 "require_origin_id": true,
6956 "inherit_group": true,
6957 "inherit_target": true
6958 }
6959 }
6960 },
6961 {
6962 "kind": "model",
6963 "id": "branch:b1.model:rf",
6964 "operator": {"type": "RandomForestRegressor"},
6965 "params": {"n_estimators": 64},
6966 "seed_label": "branch:b1"
6967 }
6968 ]
6969 }
6970 ]
6971 },
6972 {
6973 "kind": "merge_model",
6974 "id": "merge:stack.pred_plus_original.meta:ridge",
6975 "operator": {"type": "RidgeMetaStacker"},
6976 "params": {"alpha": 0.2},
6977 "seed_label": "merge:stack"
6978 }
6979 ]
6980}"#,
6981 )
6982 .unwrap();
6983
6984 let graph = compile_pipeline_dsl(&spec).unwrap();
6985
6986 assert_eq!(graph.nodes.len(), 4);
6987 assert_eq!(graph.edges.len(), 3);
6988 let merge = graph
6989 .nodes
6990 .iter()
6991 .find(|node| node.id.as_str() == "merge:stack.pred_plus_original.meta:ridge")
6992 .unwrap();
6993 assert_eq!(merge.ports.inputs.len(), 3);
6994 assert_eq!(merge.ports.inputs[0].name, "b0_oof");
6995 assert_eq!(merge.ports.inputs[1].name, "b1_oof");
6996 assert_eq!(merge.ports.inputs[2].name, "x_original");
6997 let prediction_edges = graph
6998 .edges
6999 .iter()
7000 .filter(|edge| edge.contract.kind == PortKind::Prediction)
7001 .collect::<Vec<_>>();
7002 assert_eq!(prediction_edges.len(), 2);
7003 assert!(prediction_edges
7004 .iter()
7005 .all(|edge| edge.contract.requires_oof));
7006 assert!(prediction_edges
7007 .iter()
7008 .all(|edge| edge.contract.requires_fold_alignment));
7009 assert!(graph.edges.iter().any(|edge| edge.source.node_id.as_str()
7010 == "branch:b1.augment:noise"
7011 && edge.target.node_id.as_str() == "branch:b1.model:rf"));
7012 graph.validate().unwrap();
7013 }
7014
7015 #[test]
7016 fn compiles_separation_branch_view_plans() {
7017 let spec: PipelineDslSpec = serde_json::from_str(
7018 r#"{
7019 "id": "dsl-separation-branch-views",
7020 "steps": [
7021 {
7022 "kind": "branch",
7023 "mode": "by_metadata",
7024 "selector": {"metadata_key": "site"},
7025 "branches": [
7026 {
7027 "id": "site_a",
7028 "selector": "A",
7029 "steps": [
7030 {"kind": "model", "id": "model:site.a", "operator": {"type": "PLSRegression"}}
7031 ]
7032 },
7033 {
7034 "id": "site_b",
7035 "selector": {"value": "B"},
7036 "steps": [
7037 {"kind": "model", "id": "model:site.b", "operator": {"type": "Ridge"}}
7038 ]
7039 }
7040 ]
7041 },
7042 {
7043 "kind": "merge_model",
7044 "id": "model:site.meta",
7045 "operator": {"type": "Ridge"},
7046 "include_original_data": false
7047 }
7048 ]
7049}"#,
7050 )
7051 .unwrap();
7052
7053 let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
7054
7055 assert_eq!(compiled.branch_view_plans.len(), 2);
7056 assert_eq!(
7057 compiled.campaign_template.branch_view_plans,
7058 compiled.branch_view_plans
7059 );
7060 assert_eq!(
7061 compiled.branch_view_plans[0].mode,
7062 BranchViewMode::ByMetadata
7063 );
7064 assert_eq!(compiled.branch_view_plans[0].selector.metadata["site"], "A");
7065 assert_eq!(compiled.branch_view_plans[1].selector.metadata["site"], "B");
7066 let site_model = compiled
7067 .graph
7068 .nodes
7069 .iter()
7070 .find(|node| node.id.as_str() == "model:site.a")
7071 .unwrap();
7072 assert_eq!(
7073 site_model.metadata["dsl_branch_view_plan"]["selector"]["metadata"]["site"],
7074 "A"
7075 );
7076 }
7077
7078 #[test]
7079 fn refuses_separation_branch_without_selector() {
7080 let spec: PipelineDslSpec = serde_json::from_str(
7081 r#"{
7082 "id": "dsl-bad-separation-branch",
7083 "steps": [
7084 {
7085 "kind": "branch",
7086 "mode": "by_source",
7087 "branches": [
7088 {
7089 "id": "nir",
7090 "steps": [
7091 {"kind": "model", "id": "model:nir", "operator": {"type": "Ridge"}}
7092 ]
7093 }
7094 ]
7095 }
7096 ]
7097}"#,
7098 )
7099 .unwrap();
7100
7101 let error = compile_pipeline_dsl_with_generation(&spec)
7102 .unwrap_err()
7103 .to_string();
7104
7105 assert!(error.contains("by_source branch `nir` requires a selector"));
7106 }
7107
7108 #[test]
7109 fn compiles_branch_feature_merge_into_downstream_model() {
7110 let spec: PipelineDslSpec = serde_json::from_str(
7111 r#"{
7112 "id": "dsl-branch-feature-merge",
7113 "steps": [
7114 {
7115 "kind": "branch",
7116 "branches": [
7117 {
7118 "id": "snv",
7119 "steps": [
7120 {
7121 "kind": "transform",
7122 "id": "branch:snv.transform",
7123 "operator": {"type": "SNV"}
7124 }
7125 ]
7126 },
7127 {
7128 "id": "msc",
7129 "steps": [
7130 {
7131 "kind": "transform",
7132 "id": "branch:msc.transform",
7133 "operator": {"type": "MSC"}
7134 }
7135 ]
7136 }
7137 ]
7138 },
7139 {
7140 "kind": "merge",
7141 "id": "merge:features",
7142 "merge_mode": "features",
7143 "output_as": "features",
7144 "include_original_data": false
7145 },
7146 {
7147 "kind": "model",
7148 "id": "model:pls",
7149 "operator": {"type": "PLSRegression"}
7150 }
7151 ]
7152}"#,
7153 )
7154 .unwrap();
7155
7156 let graph = compile_pipeline_dsl(&spec).unwrap();
7157 graph.validate().unwrap();
7158 let merge = graph
7159 .nodes
7160 .iter()
7161 .find(|node| node.id.as_str() == "merge:features")
7162 .unwrap();
7163 assert_eq!(merge.kind, NodeKind::FeatureJoin);
7164 assert_eq!(merge.ports.inputs.len(), 2);
7165 assert!(merge.ports.inputs.iter().any(|port| port.name == "snv_x"));
7166 assert!(merge.ports.inputs.iter().any(|port| port.name == "msc_x"));
7167 assert!(graph.edges.iter().any(|edge| {
7168 edge.source.node_id.as_str() == "branch:snv.transform"
7169 && edge.target.node_id.as_str() == "merge:features"
7170 && edge.target.port_name == "snv_x"
7171 && edge.contract.kind == PortKind::Data
7172 }));
7173 assert!(graph.edges.iter().any(|edge| {
7174 edge.source.node_id.as_str() == "merge:features"
7175 && edge.target.node_id.as_str() == "model:pls"
7176 && edge.contract.kind == PortKind::Data
7177 }));
7178 assert!(!graph
7179 .edges
7180 .iter()
7181 .any(|edge| edge.contract.kind == PortKind::Prediction));
7182 }
7183
7184 #[test]
7185 fn compiles_nirs4all_style_multi_model_branch_and_separate_merge() {
7186 let spec: PipelineDslSpec = serde_json::from_str(
7187 r#"{
7188 "id": "dsl-nirs4all-branch-parity",
7189 "steps": [
7190 {
7191 "kind": "branch",
7192 "mode": "duplication",
7193 "selector": {"scope": "all_samples"},
7194 "branches": [
7195 {
7196 "id": "pls_path",
7197 "steps": [
7198 {
7199 "kind": "model",
7200 "id": "branch:pls.model:pls5",
7201 "operator": {"class": "sklearn.cross_decomposition.PLSRegression"},
7202 "params": {"n_components": 5}
7203 },
7204 {
7205 "kind": "model",
7206 "id": "branch:pls.model:pls10",
7207 "operator": {"class": "sklearn.cross_decomposition.PLSRegression"},
7208 "params": {"n_components": 10}
7209 }
7210 ]
7211 },
7212 {
7213 "id": "rf_path",
7214 "selector": {"source": "nir"},
7215 "steps": [
7216 {
7217 "kind": "transform",
7218 "id": "branch:rf.transform:snv",
7219 "operator": {"class": "nirs4all.operators.transforms.StandardNormalVariate"}
7220 },
7221 {
7222 "kind": "model",
7223 "id": "branch:rf.model:rf",
7224 "operator": {"class": "sklearn.ensemble.RandomForestRegressor"},
7225 "params": {"n_estimators": 64}
7226 },
7227 {
7228 "kind": "model",
7229 "id": "branch:rf.model:gbr",
7230 "operator": {"class": "sklearn.ensemble.GradientBoostingRegressor"},
7231 "params": {"n_estimators": 32}
7232 }
7233 ]
7234 }
7235 ]
7236 },
7237 {
7238 "kind": "merge",
7239 "id": "merge:stack.predictions_plus_original",
7240 "merge_mode": "predictions_plus_original",
7241 "output_as": "features",
7242 "include_original_data": true,
7243 "selectors": [
7244 {"branch": "pls_path", "select": "best", "metric": "rmse"},
7245 {"branch": "rf_path", "select": {"top_k": 2}, "metric": "r2"}
7246 ],
7247 "metadata": {"on_missing": "warn"}
7248 },
7249 {
7250 "kind": "model",
7251 "id": "model:meta.ridge",
7252 "operator": {"class": "sklearn.linear_model.Ridge"},
7253 "variants": [
7254 {"label": "low", "params": {"alpha": 0.1}},
7255 {"label": "mid", "params": {"alpha": 0.5}}
7256 ]
7257 },
7258 {
7259 "kind": "model",
7260 "id": "model:meta.rf",
7261 "operator": {"class": "sklearn.ensemble.RandomForestRegressor"},
7262 "params": {"n_estimators": 30}
7263 }
7264 ]
7265}"#,
7266 )
7267 .unwrap();
7268
7269 let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
7270 let graph = compiled.graph;
7271 let merge = graph
7272 .nodes
7273 .iter()
7274 .find(|node| node.id.as_str() == "merge:stack.predictions_plus_original")
7275 .unwrap();
7276
7277 assert_eq!(merge.kind, NodeKind::MixedJoin);
7278 assert_eq!(merge.ports.inputs.len(), 5);
7279 assert_eq!(merge.ports.outputs[0].kind, PortKind::Data);
7280 assert_eq!(merge.metadata["merge_mode"], "predictions_plus_original");
7281 assert_eq!(merge.metadata["selectors"][0]["branch"], "pls_path");
7282 let rf_model = graph
7283 .nodes
7284 .iter()
7285 .find(|node| node.id.as_str() == "branch:rf.model:rf")
7286 .unwrap();
7287 assert_eq!(rf_model.metadata["dsl_branch"], "rf_path");
7288 assert_eq!(rf_model.metadata["dsl_branch_mode"], "duplication");
7289 assert_eq!(
7290 rf_model.metadata["dsl_branch_step_selector"]["scope"],
7291 "all_samples"
7292 );
7293 assert_eq!(rf_model.metadata["dsl_branch_selector"]["source"], "nir");
7294 assert_eq!(
7295 graph
7296 .edges
7297 .iter()
7298 .filter(|edge| edge.target.node_id == merge.id
7299 && edge.contract.kind == PortKind::Prediction
7300 && edge.contract.requires_oof)
7301 .count(),
7302 4
7303 );
7304 assert!(graph
7305 .edges
7306 .iter()
7307 .any(|edge| edge.source.node_id == merge.id
7308 && edge.target.node_id.as_str() == "model:meta.ridge"
7309 && edge.contract.kind == PortKind::Data));
7310 assert!(graph
7311 .edges
7312 .iter()
7313 .any(|edge| edge.source.node_id == merge.id
7314 && edge.target.node_id.as_str() == "model:meta.rf"
7315 && edge.contract.kind == PortKind::Data));
7316 assert_eq!(compiled.generation.dimensions.len(), 1);
7317 assert_eq!(
7318 compiled.generation.dimensions[0].name,
7319 "model:meta.ridge.params"
7320 );
7321 graph.validate().unwrap();
7322 }
7323
7324 #[test]
7325 fn merge_selectors_reject_unknown_branch_and_missing_metric() {
7326 let unknown_branch: PipelineDslSpec = serde_json::from_str(
7327 r#"{
7328 "id": "dsl-bad-merge-selector-branch",
7329 "steps": [
7330 {
7331 "kind": "branch",
7332 "branches": [
7333 {
7334 "id": "known",
7335 "steps": [
7336 {
7337 "kind": "model",
7338 "id": "branch:known.model:ridge",
7339 "operator": {"type": "Ridge"}
7340 }
7341 ]
7342 }
7343 ]
7344 },
7345 {
7346 "kind": "merge",
7347 "id": "merge:bad.selector",
7348 "selectors": [
7349 {"branch": "missing", "select": "all"}
7350 ]
7351 }
7352 ]
7353}"#,
7354 )
7355 .unwrap();
7356 let error = compile_pipeline_dsl_with_generation(&unknown_branch).unwrap_err();
7357 assert!(format!("{error}").contains("does not match any pending prediction input"));
7358
7359 let missing_metric: PipelineDslSpec = serde_json::from_str(
7360 r#"{
7361 "id": "dsl-bad-merge-selector-metric",
7362 "steps": [
7363 {
7364 "kind": "branch",
7365 "branches": [
7366 {
7367 "id": "known",
7368 "steps": [
7369 {
7370 "kind": "model",
7371 "id": "branch:known.model:ridge",
7372 "operator": {"type": "Ridge"}
7373 }
7374 ]
7375 }
7376 ]
7377 },
7378 {
7379 "kind": "merge",
7380 "id": "merge:bad.metric",
7381 "selectors": [
7382 {"branch": "known", "select": "best"}
7383 ]
7384 }
7385 ]
7386}"#,
7387 )
7388 .unwrap();
7389 let error = compile_pipeline_dsl_with_generation(&missing_metric).unwrap_err();
7390 assert!(format!("{error}").contains("requires a non-empty metric"));
7391 }
7392
7393 #[test]
7394 fn merge_selectors_reject_top_k_above_scope() {
7395 let spec: PipelineDslSpec = serde_json::from_str(
7396 r#"{
7397 "id": "dsl-bad-merge-selector-top-k",
7398 "steps": [
7399 {
7400 "kind": "branch",
7401 "branches": [
7402 {
7403 "id": "known",
7404 "steps": [
7405 {
7406 "kind": "model",
7407 "id": "branch:known.model:ridge",
7408 "operator": {"type": "Ridge"}
7409 }
7410 ]
7411 }
7412 ]
7413 },
7414 {
7415 "kind": "merge",
7416 "id": "merge:bad.topk",
7417 "selectors": [
7418 {"branch": "known", "select": {"top_k": 2}, "metric": "rmse"}
7419 ]
7420 }
7421 ]
7422}"#,
7423 )
7424 .unwrap();
7425
7426 let error = compile_pipeline_dsl_with_generation(&spec).unwrap_err();
7427 assert!(format!("{error}").contains("top_k=2 exceeds 1 matched prediction inputs"));
7428 }
7429
7430 #[test]
7431 fn compiles_nirs4all_shape_changing_and_tuning_surface() {
7432 let spec: PipelineDslSpec = serde_json::from_str(
7433 r#"{
7434 "id": "dsl-nirs4all-shape-parity",
7435 "steps": [
7436 {
7437 "kind": "y_transform",
7438 "id": "target:scale",
7439 "operator": {"class": "sklearn.preprocessing.StandardScaler"}
7440 },
7441 {
7442 "kind": "tag",
7443 "id": "tag:y_outliers",
7444 "operator": {"class": "nirs4all.filters.YOutlierFilter"},
7445 "params": {"method": "iqr"}
7446 },
7447 {
7448 "kind": "exclude",
7449 "id": "exclude:train_outliers",
7450 "operator": {"class": "nirs4all.filters.YOutlierFilter"},
7451 "params": {"mode": "any"}
7452 },
7453 {
7454 "kind": "sample_augmentation",
7455 "id": "augment:sample.noise",
7456 "operator": {"class": "nirs4all.operators.transforms.GaussianAdditiveNoise"},
7457 "params": {"count": 3, "selection": "random"},
7458 "shape": {
7459 "fit_rows": "fold_train",
7460 "predict_rows": "fold_validation",
7461 "augmentation_policy": {
7462 "sample_scope": "train_only",
7463 "feature_scope": "none",
7464 "require_origin_id": true,
7465 "inherit_group": true,
7466 "inherit_target": true
7467 }
7468 }
7469 },
7470 {
7471 "kind": "feature_augmentation",
7472 "id": "augment:feature.views",
7473 "operator": {"class": "nirs4all.operators.transforms.FeatureAugmentation"},
7474 "params": {"action": "extend"},
7475 "shape": {
7476 "fit_rows": "fold_train",
7477 "predict_rows": "fold_validation",
7478 "feature_namespace": "augmented_views",
7479 "augmentation_policy": {
7480 "sample_scope": "none",
7481 "feature_scope": "train_only",
7482 "require_origin_id": false
7483 }
7484 }
7485 },
7486 {
7487 "kind": "concat_transform",
7488 "id": "join:concat.multi_view",
7489 "branches": [
7490 {
7491 "id": "pca",
7492 "steps": [
7493 {
7494 "id": "concat:pca",
7495 "operator": {"class": "sklearn.decomposition.PCA"},
7496 "params": {"n_components": 20}
7497 }
7498 ]
7499 },
7500 {
7501 "id": "derivative_pca",
7502 "steps": [
7503 {
7504 "id": "concat:derivative",
7505 "operator": {"class": "nirs4all.operators.transforms.FirstDerivative"}
7506 },
7507 {
7508 "id": "concat:derivative.pca",
7509 "operator": {"class": "sklearn.decomposition.PCA"},
7510 "params": {"n_components": 10}
7511 }
7512 ]
7513 }
7514 ],
7515 "shape": {
7516 "fit_rows": "fold_train",
7517 "feature_namespace": "concat.multi_view",
7518 "selection_policy": {
7519 "scope": "unsupervised"
7520 }
7521 }
7522 },
7523 {
7524 "kind": "model",
7525 "id": "model:tuned",
7526 "operator": {"class": "sklearn.ensemble.RandomForestRegressor"},
7527 "finetune_params": {
7528 "n_trials": 10,
7529 "approach": "single",
7530 "eval_mode": "mean",
7531 "sampler": "random",
7532 "metric": "rmse",
7533 "model_params": {
7534 "max_depth": [3, 5, 7]
7535 }
7536 },
7537 "train_params": {
7538 "sample_weight": "balanced"
7539 }
7540 }
7541 ]
7542}"#,
7543 )
7544 .unwrap();
7545
7546 let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
7547 let graph = compiled.graph;
7548 let kinds = graph
7549 .nodes
7550 .iter()
7551 .map(|node| node.kind.clone())
7552 .collect::<Vec<_>>();
7553 assert!(kinds.contains(&NodeKind::YTransform));
7554 assert!(kinds.contains(&NodeKind::Tag));
7555 assert!(kinds.contains(&NodeKind::Exclude));
7556 assert!(kinds.contains(&NodeKind::Augmentation));
7557 assert!(kinds.contains(&NodeKind::FeatureJoin));
7558 assert_eq!(compiled.shape_plans.len(), 3);
7559
7560 let sample_aug = graph
7561 .nodes
7562 .iter()
7563 .find(|node| node.id.as_str() == "augment:sample.noise")
7564 .unwrap();
7565 assert_eq!(sample_aug.metadata["dsl_augmentation_kind"], "sample");
7566 let feature_aug = graph
7567 .nodes
7568 .iter()
7569 .find(|node| node.id.as_str() == "augment:feature.views")
7570 .unwrap();
7571 assert_eq!(feature_aug.metadata["dsl_augmentation_kind"], "feature");
7572 let model = graph
7573 .nodes
7574 .iter()
7575 .find(|node| node.id.as_str() == "model:tuned")
7576 .unwrap();
7577 assert_eq!(model.metadata["dsl_tuning"]["n_trials"], 10);
7578 assert_eq!(
7579 model.metadata["dsl_train_params"]["sample_weight"],
7580 "balanced"
7581 );
7582 graph.validate().unwrap();
7583 }
7584
7585 #[test]
7586 fn extracts_node_param_variants_into_generation_spec() {
7587 let spec: PipelineDslSpec = serde_json::from_str(
7588 r#"{
7589 "id": "dsl-generation-smoke",
7590 "max_variants": 4,
7591 "steps": [
7592 {
7593 "kind": "transform",
7594 "id": "transform:preprocess",
7595 "operator": {"type": "Preprocess"},
7596 "variants": [
7597 {
7598 "label": "snv",
7599 "params": {"method": "snv"}
7600 },
7601 {
7602 "label": "msc",
7603 "params": {"method": "msc"}
7604 }
7605 ]
7606 },
7607 {
7608 "kind": "model",
7609 "id": "model:base",
7610 "operator": {"type": "Ridge"},
7611 "variants": [
7612 {
7613 "label": "low",
7614 "params": {"alpha": 0.1}
7615 },
7616 {
7617 "label": "high",
7618 "params": {"alpha": 1.0}
7619 }
7620 ]
7621 }
7622 ]
7623}"#,
7624 )
7625 .unwrap();
7626
7627 let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
7628
7629 assert_eq!(compiled.generation.strategy, GenerationStrategy::Cartesian);
7630 assert_eq!(compiled.generation.max_variants, Some(4));
7631 assert_eq!(compiled.generation.dimensions.len(), 2);
7632 assert_eq!(
7633 compiled.generation.dimensions[0].name,
7634 "transform:preprocess.params"
7635 );
7636 assert_eq!(compiled.generation.dimensions[0].choices[0].label, "snv");
7637 assert_eq!(
7638 compiled.generation.dimensions[0].choices[0].param_overrides[0].node_id,
7639 NodeId::new("transform:preprocess").unwrap()
7640 );
7641 assert_eq!(
7642 compiled.generation.dimensions[1].choices[1].param_overrides[0].params["alpha"],
7643 1.0
7644 );
7645 assert!(compiled.generation_fingerprint.is_some());
7646 assert_eq!(
7647 compiled.graph.search_space_fingerprint,
7648 compiled.generation_fingerprint
7649 );
7650 compiled.graph.validate().unwrap();
7651 }
7652
7653 #[test]
7654 fn expands_compact_param_generators_into_generation_dimensions() {
7655 let spec: PipelineDslSpec = serde_json::from_str(
7656 r#"{
7657 "id": "dsl-compact-generation",
7658 "steps": [
7659 {
7660 "kind": "model",
7661 "id": "model:tuned",
7662 "operator": {"type": "TunedModel"},
7663 "generators": [
7664 {
7665 "kind": "or",
7666 "name": "model_family",
7667 "param": "family",
7668 "values": [
7669 {"label": "ridge", "value": "ridge"},
7670 {"label": "rf", "value": "random_forest"}
7671 ]
7672 },
7673 {
7674 "kind": "range",
7675 "param": "alpha",
7676 "start": 0.1,
7677 "stop": 0.9,
7678 "step": 0.4
7679 },
7680 {
7681 "kind": "log_range",
7682 "param": "lambda",
7683 "start": 0.01,
7684 "stop": 1.0,
7685 "count": 3
7686 },
7687 {
7688 "kind": "grid",
7689 "name": "tree_grid",
7690 "params": {
7691 "max_depth": [3, 5],
7692 "n_estimators": [50, 100]
7693 },
7694 "count": 3
7695 },
7696 {
7697 "kind": "pick",
7698 "param": "views",
7699 "values": ["snv", "msc", "derivative"],
7700 "sizes": [1, 2],
7701 "count": 4
7702 },
7703 {
7704 "kind": "arrange",
7705 "param": "chain",
7706 "values": ["snv", "pca", "pls"],
7707 "sizes": [2],
7708 "count": 3
7709 }
7710 ]
7711 }
7712 ]
7713}"#,
7714 )
7715 .unwrap();
7716
7717 let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
7718
7719 assert_eq!(compiled.generation.strategy, GenerationStrategy::Cartesian);
7720 assert_eq!(compiled.generation.dimensions.len(), 6);
7721 assert_eq!(compiled.generation.dimensions[0].name, "model_family");
7722 assert_eq!(compiled.generation.dimensions[0].choices.len(), 2);
7723 assert_eq!(
7724 compiled.generation.dimensions[1].name,
7725 "model:tuned.alpha.range"
7726 );
7727 assert_eq!(compiled.generation.dimensions[1].choices.len(), 3);
7728 assert_eq!(
7729 compiled.generation.dimensions[1].choices[1].param_overrides[0].params["alpha"],
7730 0.5
7731 );
7732 assert_eq!(
7733 compiled.generation.dimensions[2].name,
7734 "model:tuned.lambda.log_range"
7735 );
7736 assert_eq!(compiled.generation.dimensions[2].choices.len(), 3);
7737 assert_eq!(compiled.generation.dimensions[3].name, "tree_grid");
7738 assert_eq!(compiled.generation.dimensions[3].choices.len(), 3);
7739 assert_eq!(
7740 compiled.generation.dimensions[3].choices[2].param_overrides[0].params["n_estimators"],
7741 50
7742 );
7743 assert_eq!(
7744 compiled.generation.dimensions[4].choices[3].param_overrides[0].params["views"],
7745 serde_json::json!(["snv", "msc"])
7746 );
7747 assert_eq!(
7748 compiled.generation.dimensions[5].choices[2].param_overrides[0].params["chain"],
7749 serde_json::json!(["pca", "snv"])
7750 );
7751 assert!(compiled.generation_fingerprint.is_some());
7752 }
7753
7754 #[test]
7755 fn compact_param_generators_reject_invalid_counts() {
7756 let spec: PipelineDslSpec = serde_json::from_str(
7757 r#"{
7758 "id": "dsl-bad-compact-generation",
7759 "steps": [
7760 {
7761 "kind": "model",
7762 "id": "model:bad",
7763 "operator": {"type": "Ridge"},
7764 "generators": [
7765 {
7766 "kind": "or",
7767 "param": "alpha",
7768 "values": [0.1, 1.0],
7769 "count": 0
7770 }
7771 ]
7772 }
7773 ]
7774}"#,
7775 )
7776 .unwrap();
7777
7778 let error = compile_pipeline_dsl_with_generation(&spec).unwrap_err();
7779 assert!(format!("{error}").contains("count=0"));
7780 }
7781
7782 #[test]
7783 fn compiles_coordinated_generation_dimensions() {
7784 let spec: PipelineDslSpec = serde_json::from_str(
7785 r#"{
7786 "id": "dsl-coordinated-generation",
7787 "max_variants": 2,
7788 "generation_dimensions": [
7789 {
7790 "name": "stack_profile",
7791 "choices": [
7792 {
7793 "label": "linear_stack",
7794 "param_overrides": [
7795 {"node_id": "branch:b0.model:ridge", "params": {"alpha": 0.1}},
7796 {"node_id": "branch:b1.model:rf", "params": {"max_depth": 4}},
7797 {"node_id": "merge:stack.pred_plus_original.meta:ridge", "params": {"alpha": 0.05}}
7798 ]
7799 },
7800 {
7801 "label": "robust_stack",
7802 "param_overrides": [
7803 {"node_id": "branch:b0.model:ridge", "params": {"alpha": 1.0}},
7804 {"node_id": "branch:b1.model:rf", "params": {"max_depth": 8}},
7805 {"node_id": "merge:stack.pred_plus_original.meta:ridge", "params": {"alpha": 0.5}}
7806 ]
7807 }
7808 ]
7809 }
7810 ],
7811 "steps": [
7812 {
7813 "kind": "branch",
7814 "branches": [
7815 {
7816 "id": "b0",
7817 "steps": [
7818 {
7819 "kind": "model",
7820 "id": "branch:b0.model:ridge",
7821 "operator": {"type": "Ridge"}
7822 }
7823 ]
7824 },
7825 {
7826 "id": "b1",
7827 "steps": [
7828 {
7829 "kind": "model",
7830 "id": "branch:b1.model:rf",
7831 "operator": {"type": "RandomForestRegressor"}
7832 }
7833 ]
7834 }
7835 ]
7836 },
7837 {
7838 "kind": "merge_model",
7839 "id": "merge:stack.pred_plus_original.meta:ridge",
7840 "operator": {"type": "RidgeMetaStacker"}
7841 }
7842 ]
7843}"#,
7844 )
7845 .unwrap();
7846
7847 let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
7848
7849 assert_eq!(compiled.generation.strategy, GenerationStrategy::Cartesian);
7850 assert_eq!(compiled.generation.max_variants, Some(2));
7851 assert_eq!(compiled.generation.dimensions.len(), 1);
7852 assert_eq!(compiled.generation.dimensions[0].name, "stack_profile");
7853 assert_eq!(
7854 compiled.generation.dimensions[0].choices[0]
7855 .param_overrides
7856 .len(),
7857 3
7858 );
7859 assert_eq!(
7860 compiled.generation.dimensions[0].choices[1].param_overrides[2].node_id,
7861 NodeId::new("merge:stack.pred_plus_original.meta:ridge").unwrap()
7862 );
7863 assert_eq!(
7864 compiled.generation.dimensions[0].choices[1].value
7865 ["merge:stack.pred_plus_original.meta:ridge"]["alpha"],
7866 0.5
7867 );
7868 assert_eq!(
7869 compiled.graph.search_space_fingerprint,
7870 compiled.generation_fingerprint
7871 );
7872 compiled.graph.validate().unwrap();
7873 }
7874
7875 #[test]
7876 fn refuses_coordinated_generation_for_unknown_node() {
7877 let spec: PipelineDslSpec = serde_json::from_str(
7878 r#"{
7879 "id": "dsl-bad-generation-target",
7880 "generation_dimensions": [
7881 {
7882 "name": "bad_target",
7883 "choices": [
7884 {
7885 "label": "bad",
7886 "param_overrides": [
7887 {"node_id": "model:missing", "params": {"alpha": 0.1}}
7888 ]
7889 }
7890 ]
7891 }
7892 ],
7893 "steps": [
7894 {
7895 "kind": "model",
7896 "id": "model:base",
7897 "operator": {"type": "Ridge"}
7898 }
7899 ]
7900}"#,
7901 )
7902 .unwrap();
7903
7904 let error = compile_pipeline_dsl_with_generation(&spec).unwrap_err();
7905 assert!(format!("{error}").contains("references unknown node `model:missing`"));
7906 }
7907
7908 #[test]
7909 fn artifact_contains_campaign_template_without_split_graph_nodes() {
7910 let spec: PipelineDslSpec = serde_json::from_str(
7911 r#"{
7912 "id": "dsl-campaign-template",
7913 "campaign_id": "campaign:dsl.template",
7914 "root_seed": 123,
7915 "leakage_policy": {
7916 "split_unit": "group",
7917 "require_group_ids": true
7918 },
7919 "split_invocation": {
7920 "id": "split:group-kfold",
7921 "leakage_policy": {
7922 "split_unit": "group",
7923 "require_group_ids": true
7924 },
7925 "params": {
7926 "n_splits": 3
7927 }
7928 },
7929 "generation_dimensions": [
7930 {
7931 "name": "model_family",
7932 "choices": [
7933 {
7934 "label": "ridge_low",
7935 "param_overrides": [
7936 {"node_id": "model:base", "params": {"alpha": 0.1}}
7937 ]
7938 },
7939 {
7940 "label": "ridge_high",
7941 "param_overrides": [
7942 {"node_id": "model:base", "params": {"alpha": 1.0}}
7943 ]
7944 }
7945 ]
7946 }
7947 ],
7948 "data_bindings": [
7949 {
7950 "node_id": "model:base",
7951 "input_name": "x",
7952 "request_id": "data:model.base.x",
7953 "schema_fingerprint": "f97b37872fa22134b508f98fd8e207e5b776b52594fb8f6f5c3e15bee212246b",
7954 "plan_fingerprint": "7c5431d85574b3f337022fa5d25971d5b5cf445b90331b49938f573ff6901e4d",
7955 "relation_fingerprint": "a3a7e329df35db9f2883a17b8611b7fae6dcaa031875e3ec2c9be1b9e29cbe10",
7956 "output_representation": "tabular_numeric",
7957 "feature_set_id": "x",
7958 "source_ids": ["nir"],
7959 "require_relations": true
7960 }
7961 ],
7962 "steps": [
7963 {
7964 "kind": "model",
7965 "id": "model:base",
7966 "operator": {"type": "Ridge"}
7967 }
7968 ],
7969 "campaign_metadata": {
7970 "owner": "dsl-test"
7971 }
7972}"#,
7973 )
7974 .unwrap();
7975
7976 let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
7977
7978 assert_eq!(compiled.campaign_template.id, "campaign:dsl.template");
7979 assert_eq!(compiled.campaign_template.root_seed, Some(123));
7980 assert_eq!(
7981 compiled
7982 .campaign_template
7983 .split_invocation
7984 .as_ref()
7985 .unwrap()
7986 .id,
7987 "split:group-kfold"
7988 );
7989 assert_eq!(compiled.campaign_template.generation, compiled.generation);
7990 assert_eq!(
7991 compiled.data_bindings[&NodeId::new("model:base").unwrap()][0].request_id,
7992 "data:model.base.x"
7993 );
7994 assert_eq!(
7995 compiled.campaign_template.data_bindings,
7996 compiled.data_bindings
7997 );
7998 assert_eq!(compiled.graph.nodes.len(), 1);
7999 assert!(compiled
8000 .graph
8001 .nodes
8002 .iter()
8003 .all(|node| !node.id.as_str().starts_with("split:")));
8004 }
8005
8006 #[test]
8007 fn refuses_data_binding_for_unknown_or_non_data_port() {
8008 let unknown_input_spec: PipelineDslSpec = serde_json::from_str(
8009 r#"{
8010 "id": "dsl-bad-data-binding",
8011 "data_bindings": [
8012 {
8013 "node_id": "model:base",
8014 "input_name": "missing",
8015 "request_id": "data:bad",
8016 "schema_fingerprint": "f97b37872fa22134b508f98fd8e207e5b776b52594fb8f6f5c3e15bee212246b",
8017 "plan_fingerprint": "7c5431d85574b3f337022fa5d25971d5b5cf445b90331b49938f573ff6901e4d",
8018 "output_representation": "tabular_numeric"
8019 }
8020 ],
8021 "steps": [
8022 {
8023 "kind": "model",
8024 "id": "model:base",
8025 "operator": {"type": "Ridge"}
8026 }
8027 ]
8028}"#,
8029 )
8030 .unwrap();
8031 let error = compile_pipeline_dsl_with_generation(&unknown_input_spec).unwrap_err();
8032 assert!(format!("{error}").contains("unknown input port `missing`"));
8033
8034 let prediction_input_spec: PipelineDslSpec = serde_json::from_str(
8035 r#"{
8036 "id": "dsl-prediction-port-data-binding",
8037 "data_bindings": [
8038 {
8039 "node_id": "merge:stack.pred_plus_original.meta:ridge",
8040 "input_name": "b0_oof",
8041 "request_id": "data:bad.prediction-port",
8042 "schema_fingerprint": "f97b37872fa22134b508f98fd8e207e5b776b52594fb8f6f5c3e15bee212246b",
8043 "plan_fingerprint": "7c5431d85574b3f337022fa5d25971d5b5cf445b90331b49938f573ff6901e4d",
8044 "output_representation": "tabular_numeric"
8045 }
8046 ],
8047 "steps": [
8048 {
8049 "kind": "branch",
8050 "branches": [
8051 {
8052 "id": "b0",
8053 "steps": [
8054 {
8055 "kind": "model",
8056 "id": "branch:b0.model:ridge",
8057 "operator": {"type": "Ridge"}
8058 }
8059 ]
8060 }
8061 ]
8062 },
8063 {
8064 "kind": "merge_model",
8065 "id": "merge:stack.pred_plus_original.meta:ridge",
8066 "operator": {"type": "RidgeMetaStacker"}
8067 }
8068 ]
8069}"#,
8070 )
8071 .unwrap();
8072 let error = compile_pipeline_dsl_with_generation(&prediction_input_spec).unwrap_err();
8073 assert!(format!("{error}").contains("targets non-data input"));
8074 }
8075
8076 #[test]
8077 fn extracts_shape_plans_into_compiled_artifact() {
8078 let spec: PipelineDslSpec = serde_json::from_str(
8079 r#"{
8080 "id": "dsl-shape-plan-smoke",
8081 "steps": [
8082 {
8083 "kind": "augmentation",
8084 "id": "augment:synthetic",
8085 "operator": {"type": "SampleAugmenter"},
8086 "shape": {
8087 "input_granularity": "sample",
8088 "target_granularity": "sample",
8089 "fit_rows": "fold_train",
8090 "predict_rows": "fold_validation",
8091 "feature_namespace": "aug.synthetic",
8092 "augmentation_policy": {
8093 "sample_scope": "train_only",
8094 "feature_scope": "none",
8095 "require_origin_id": true,
8096 "inherit_group": true,
8097 "inherit_target": true
8098 }
8099 }
8100 },
8101 {
8102 "kind": "transform",
8103 "id": "transform:select",
8104 "operator": {"type": "SupervisedFeatureSelector"},
8105 "shape": {
8106 "fit_rows": "fold_train",
8107 "feature_namespace": "selected",
8108 "selection_policy": {
8109 "scope": "supervised_fold_train",
8110 "store_masks": true
8111 }
8112 }
8113 },
8114 {
8115 "kind": "model",
8116 "id": "model:base",
8117 "operator": {"type": "Ridge"}
8118 }
8119 ]
8120}"#,
8121 )
8122 .unwrap();
8123
8124 let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
8125
8126 assert_eq!(compiled.shape_plans.len(), 2);
8127 let augment_plan = compiled
8128 .shape_plans
8129 .get(&NodeId::new("augment:synthetic").unwrap())
8130 .unwrap();
8131 assert_eq!(
8132 augment_plan.feature_namespace.as_deref(),
8133 Some("aug.synthetic")
8134 );
8135 assert_eq!(
8136 augment_plan.augmentation_policy.sample_scope,
8137 crate::policy::AugmentationScope::TrainOnly
8138 );
8139 let select_plan = compiled
8140 .shape_plans
8141 .get(&NodeId::new("transform:select").unwrap())
8142 .unwrap();
8143 assert_eq!(
8144 select_plan.selection_policy.scope,
8145 crate::policy::FeatureSelectionScope::SupervisedFoldTrain
8146 );
8147 assert_eq!(compiled.generation.strategy, GenerationStrategy::None);
8148 compiled.graph.validate().unwrap();
8149 }
8150
8151 #[test]
8152 fn compiles_sequential_filter_and_or_generator_surface() {
8153 let spec: PipelineDslSpec = serde_json::from_str(
8154 r#"{
8155 "id": "dsl-generator-or-parity",
8156 "steps": [
8157 {
8158 "kind": "sequential",
8159 "id": "seq:pre",
8160 "steps": [
8161 {
8162 "kind": "sample_filter",
8163 "id": "filter:y_outlier",
8164 "operator": {"class": "nirs4all.operators.filters.YOutlierFilter"},
8165 "params": {"mode": "any"}
8166 },
8167 {
8168 "kind": "transform",
8169 "id": "transform:scale",
8170 "operator": {"class": "sklearn.preprocessing.StandardScaler"}
8171 }
8172 ]
8173 },
8174 {
8175 "kind": "generator",
8176 "id": "generator:model_choices",
8177 "mode": "or",
8178 "pick": 1,
8179 "branches": [
8180 {
8181 "id": "pls",
8182 "steps": [
8183 {
8184 "kind": "model",
8185 "id": "model:pls",
8186 "operator": {"class": "sklearn.cross_decomposition.PLSRegression"},
8187 "params": {"n_components": 8}
8188 }
8189 ]
8190 },
8191 {
8192 "id": "rf",
8193 "steps": [
8194 {
8195 "kind": "model",
8196 "id": "model:rf",
8197 "operator": {"class": "sklearn.ensemble.RandomForestRegressor"},
8198 "params": {"n_estimators": 64}
8199 }
8200 ]
8201 }
8202 ]
8203 },
8204 {
8205 "kind": "merge",
8206 "id": "merge:generated",
8207 "output_as": "features",
8208 "include_original_data": false,
8209 "selectors": [
8210 {"branch": "generator:model_choices:choice0", "select": "all"}
8211 ]
8212 }
8213 ]
8214}"#,
8215 )
8216 .unwrap();
8217
8218 let graph = compile_pipeline_dsl(&spec).unwrap();
8219 graph.validate().unwrap();
8220 let filter = graph
8221 .nodes
8222 .iter()
8223 .find(|node| node.id.as_str() == "filter:y_outlier")
8224 .unwrap();
8225 assert_eq!(filter.kind, NodeKind::Exclude);
8226 assert_eq!(filter.metadata["dsl_filter_kind"], "sample");
8227
8228 let generated_models = graph
8229 .nodes
8230 .iter()
8231 .filter(|node| node.kind == NodeKind::Model)
8232 .collect::<Vec<_>>();
8233 assert_eq!(generated_models.len(), 2);
8234 assert!(generated_models
8235 .iter()
8236 .all(|node| node.id.as_str().starts_with("gen:generator_model_choices")));
8237 assert!(generated_models.iter().all(|node| {
8238 node.metadata
8239 .get("dsl_generator")
8240 .and_then(|value| value.as_str())
8241 == Some("generator:model_choices")
8242 }));
8243
8244 let merge_inputs = graph
8245 .nodes
8246 .iter()
8247 .find(|node| node.id.as_str() == "merge:generated")
8248 .unwrap()
8249 .ports
8250 .inputs
8251 .iter()
8252 .map(|port| port.name.as_str())
8253 .collect::<Vec<_>>();
8254 assert_eq!(
8255 merge_inputs,
8256 vec![
8257 "generator_model_choices_choice0_oof",
8258 "generator_model_choices_choice1_oof"
8259 ]
8260 );
8261 }
8262
8263 #[test]
8264 fn compiles_cartesian_generator_as_explicit_prediction_choices() {
8265 let spec: PipelineDslSpec = serde_json::from_str(
8266 r#"{
8267 "id": "dsl-generator-cartesian-parity",
8268 "steps": [
8269 {
8270 "kind": "generator",
8271 "id": "generator:cartesian",
8272 "mode": "cartesian",
8273 "stages": [
8274 {
8275 "id": "preproc",
8276 "branches": [
8277 {
8278 "id": "snv",
8279 "steps": [
8280 {
8281 "kind": "transform",
8282 "id": "transform:snv",
8283 "operator": {"class": "nirs4all.operators.transforms.StandardNormalVariate"}
8284 }
8285 ]
8286 },
8287 {
8288 "id": "msc",
8289 "steps": [
8290 {
8291 "kind": "transform",
8292 "id": "transform:msc",
8293 "operator": {"class": "nirs4all.operators.transforms.MultiplicativeScatterCorrection"}
8294 }
8295 ]
8296 }
8297 ]
8298 },
8299 {
8300 "id": "model",
8301 "branches": [
8302 {
8303 "id": "ridge",
8304 "steps": [
8305 {
8306 "kind": "model",
8307 "id": "model:ridge",
8308 "operator": {"class": "sklearn.linear_model.Ridge"}
8309 }
8310 ]
8311 },
8312 {
8313 "id": "lasso",
8314 "steps": [
8315 {
8316 "kind": "model",
8317 "id": "model:lasso",
8318 "operator": {"class": "sklearn.linear_model.Lasso"}
8319 }
8320 ]
8321 }
8322 ]
8323 }
8324 ]
8325 },
8326 {
8327 "kind": "merge",
8328 "id": "merge:cartesian",
8329 "output_as": "features",
8330 "include_original_data": false
8331 }
8332 ]
8333}"#,
8334 )
8335 .unwrap();
8336
8337 let graph = compile_pipeline_dsl(&spec).unwrap();
8338 graph.validate().unwrap();
8339 let models = graph
8340 .nodes
8341 .iter()
8342 .filter(|node| node.kind == NodeKind::Model)
8343 .collect::<Vec<_>>();
8344 assert_eq!(models.len(), 4);
8345 assert!(models.iter().all(|node| {
8346 node.metadata
8347 .get("dsl_generator_mode")
8348 .and_then(|value| value.as_str())
8349 == Some("cartesian")
8350 }));
8351 let merge = graph
8352 .nodes
8353 .iter()
8354 .find(|node| node.id.as_str() == "merge:cartesian")
8355 .unwrap();
8356 assert_eq!(merge.ports.inputs.len(), 4);
8357 assert_eq!(
8358 graph
8359 .edges
8360 .iter()
8361 .filter(|edge| edge.target.node_id.as_str() == "merge:cartesian")
8362 .count(),
8363 4
8364 );
8365 }
8366
8367 #[test]
8368 fn refuses_generator_choice_without_prediction_output() {
8369 let spec: PipelineDslSpec = serde_json::from_str(
8370 r#"{
8371 "id": "dsl-generator-bad-choice",
8372 "steps": [
8373 {
8374 "kind": "generator",
8375 "id": "generator:bad",
8376 "branches": [
8377 {
8378 "id": "transform_only",
8379 "steps": [
8380 {
8381 "kind": "transform",
8382 "id": "transform:only",
8383 "operator": {"class": "sklearn.preprocessing.StandardScaler"}
8384 }
8385 ]
8386 }
8387 ]
8388 }
8389 ]
8390}"#,
8391 )
8392 .unwrap();
8393
8394 let error = compile_pipeline_dsl(&spec).unwrap_err();
8395 assert!(format!("{error}").contains("must produce at least one model or merge prediction"));
8396 }
8397
8398 #[test]
8399 fn parses_nirs4all_compat_pipeline_and_fuses_data_generators() {
8400 let spec = parse_pipeline_dsl_json(
8401 br#"{
8402 "id": "dsl-nirs4all-compat-fused",
8403 "pipeline": [
8404 {"sources": ["nir"]},
8405 {"_cartesian_": [
8406 {"_or_": ["SNV", "MSC", null]},
8407 {"_or_": [null, {"preprocessing": "SavitzkyGolay", "params": {"window": 11, "deriv": 1}}]}
8408 ]},
8409 {"split": {"type": "GroupKFold", "n_splits": 3}},
8410 {"_chain_": [
8411 {"_grid_": {"model": ["PLSRegression"], "n_components": [5, 10]}},
8412 {"_grid_": {"model": ["Ridge"], "alpha": [0.1, 1.0]}},
8413 {"_sample_": {"model": "SVR", "distribution": "log_uniform", "from": 0.001, "to": 1.0, "num": 2, "tune": ["C", "gamma"], "kernel": "rbf"}}
8414 ]},
8415 {"merge": "all"},
8416 {"model": "Ridge", "id": "model:meta", "params": {"alpha": 0.5}}
8417 ]
8418}"#,
8419 )
8420 .unwrap();
8421
8422 assert_eq!(spec.steps.len(), 2);
8423 assert_eq!(
8424 spec.split_invocation
8425 .as_ref()
8426 .unwrap()
8427 .params
8428 .get("type")
8429 .unwrap(),
8430 "GroupKFold"
8431 );
8432
8433 let graph = compile_pipeline_dsl(&spec).unwrap();
8434 graph.validate().unwrap();
8435 let meta = graph
8436 .nodes
8437 .iter()
8438 .find(|node| node.id.as_str() == "model:meta")
8439 .unwrap();
8440 assert_eq!(meta.kind, NodeKind::Model);
8441 assert!(meta
8442 .ports
8443 .inputs
8444 .iter()
8445 .any(|port| port.name == "x_original"));
8446 assert!(graph.edges.iter().any(|edge| {
8447 edge.target.node_id.as_str() == "model:meta"
8448 && edge.contract.kind == PortKind::Prediction
8449 && edge.contract.requires_oof
8450 }));
8451 assert!(graph.nodes.iter().any(|node| {
8452 node.metadata
8453 .get("dsl_compat_keyword")
8454 .and_then(serde_json::Value::as_str)
8455 == Some("preprocessing")
8456 }));
8457 assert!(graph.nodes.iter().any(|node| {
8458 node.kind == NodeKind::Model
8459 && node.params.contains_key("C")
8460 && node.params.contains_key("gamma")
8461 }));
8462 }
8463
8464 #[test]
8465 fn parses_nirs4all_range_attached_to_following_model() {
8466 let spec = parse_pipeline_dsl_json(
8467 br#"{
8468 "id": "dsl-nirs4all-compat-range",
8469 "pipeline": [
8470 {"_range_": [5, 15, 5]},
8471 {"model": "PLSRegression", "id": "model:pls"}
8472 ]
8473}"#,
8474 )
8475 .unwrap();
8476
8477 let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
8478 assert_eq!(compiled.generation.dimensions.len(), 1);
8479 assert_eq!(compiled.generation.dimensions[0].choices.len(), 3);
8480 assert_eq!(
8481 compiled.generation.dimensions[0].choices[0].param_overrides[0].params["n_components"],
8482 5.0
8483 );
8484 }
8485
8486 #[test]
8487 fn parses_nirs4all_minimal_aliases_plain_classes_and_split_chain() {
8488 let spec = parse_pipeline_dsl_json(
8489 br#"{
8490 "id": "dsl-nirs4all-compat-minimal-aliases",
8491 "pipeline": [
8492 "chart_2d",
8493 {"class": "sklearn.preprocessing.MinMaxScaler", "params": {"feature_range": [0, 1]}},
8494 {"class": "nirs4all.operators.splitters.SPXYGFold", "params": {"n_splits": 1, "test_size": 0.2}, "group": "Sample_ID"},
8495 {"class": "sklearn.model_selection.KFold", "params": {"n_splits": 3, "shuffle": true, "random_state": 42}},
8496 "SNV",
8497 "PLSRegression"
8498 ]
8499}"#,
8500 )
8501 .unwrap();
8502
8503 let split = spec.split_invocation.as_ref().unwrap();
8504 assert_eq!(split.id, "split:compat.chain");
8505 let chain = split.params["compat_split_chain"].as_array().unwrap();
8506 assert_eq!(chain.len(), 2);
8507 assert_eq!(
8508 chain[0]["params"]["class"],
8509 "nirs4all.operators.splitters.SPXYGFold"
8510 );
8511 assert_eq!(chain[0]["params"]["group"], "Sample_ID");
8512 assert_eq!(chain[1]["params"]["class"], "sklearn.model_selection.KFold");
8513
8514 let graph = compile_pipeline_dsl(&spec).unwrap();
8515 graph.validate().unwrap();
8516 assert!(graph.nodes.iter().any(|node| node.kind == NodeKind::Chart));
8517 assert!(graph.nodes.iter().any(|node| {
8518 node.kind == NodeKind::Transform
8519 && node.operator.as_ref().unwrap()["class"] == "sklearn.preprocessing.MinMaxScaler"
8520 }));
8521 assert!(graph.nodes.iter().any(|node| {
8522 node.kind == NodeKind::Transform
8523 && node.operator.as_ref().unwrap().as_str() == Some("SNV")
8524 }));
8525 assert!(graph.nodes.iter().any(|node| {
8526 node.kind == NodeKind::Model
8527 && node.operator.as_ref().unwrap().as_str() == Some("PLSRegression")
8528 }));
8529 }
8530
8531 #[test]
8532 fn registry_reclassifies_non_heuristic_minimal_aliases_before_compile() {
8533 let spec = parse_pipeline_dsl_json(
8534 br#"{
8535 "id": "dsl-registry-minimal-aliases",
8536 "pipeline": [
8537 "SNV",
8538 "ElasticSpectra"
8539 ]
8540}"#,
8541 )
8542 .unwrap();
8543 let mut registry = ControllerRegistry::new();
8544 registry
8545 .register(registry_manifest(
8546 "controller:transformer.mixin",
8547 NodeKind::Transform,
8548 &["SNV"],
8549 ))
8550 .unwrap();
8551 registry
8552 .register(registry_manifest(
8553 "controller:elastic.spectra",
8554 NodeKind::Model,
8555 &["ElasticSpectra"],
8556 ))
8557 .unwrap();
8558
8559 let compiled =
8560 compile_pipeline_dsl_with_generation_and_controller_registry(&spec, ®istry).unwrap();
8561 let model = compiled
8562 .graph
8563 .nodes
8564 .iter()
8565 .find(|node| {
8566 node.operator.as_ref().and_then(serde_json::Value::as_str) == Some("ElasticSpectra")
8567 })
8568 .unwrap();
8569
8570 assert_eq!(model.kind, NodeKind::Model);
8571 assert_eq!(model.metadata[DSL_REGISTRY_INFERRED_KIND], "model");
8572 assert_eq!(model.metadata[DSL_COMPAT_ORIGINAL_KEYWORD], "preprocessing");
8573 assert!(compiled.graph.nodes.iter().any(|node| {
8574 node.kind == NodeKind::Transform
8575 && node.operator.as_ref().and_then(serde_json::Value::as_str) == Some("SNV")
8576 }));
8577 }
8578
8579 #[test]
8580 fn parses_nirs4all_named_step_wrapper_and_plain_class_model() {
8581 let spec = parse_pipeline_dsl_json(
8582 br#"{
8583 "id": "dsl-nirs4all-compat-named-step",
8584 "pipeline": [
8585 {"name": "scaled", "step": {"class": "sklearn.preprocessing.StandardScaler"}},
8586 {"class": "sklearn.ensemble.RandomForestRegressor", "params": {"n_estimators": 10, "random_state": 42}}
8587 ]
8588}"#,
8589 )
8590 .unwrap();
8591
8592 let graph = compile_pipeline_dsl(&spec).unwrap();
8593 graph.validate().unwrap();
8594 let scaled = graph
8595 .nodes
8596 .iter()
8597 .find(|node| node.kind == NodeKind::Transform)
8598 .unwrap();
8599 assert_eq!(scaled.metadata["dsl_name"], "scaled");
8600 let model = graph
8601 .nodes
8602 .iter()
8603 .find(|node| node.kind == NodeKind::Model)
8604 .unwrap();
8605 assert_eq!(
8606 model.operator.as_ref().unwrap()["class"],
8607 "sklearn.ensemble.RandomForestRegressor"
8608 );
8609 assert_eq!(model.params["n_estimators"], 10);
8610 }
8611
8612 #[test]
8613 fn compiles_tuner_as_external_prediction_node() {
8614 let spec: PipelineDslSpec = serde_json::from_str(
8615 r#"{
8616 "id": "dsl-tuner",
8617 "steps": [
8618 {
8619 "kind": "tuner",
8620 "id": "tuner:optuna",
8621 "operator": "OptunaTuner",
8622 "params": {"sampler": "tpe"},
8623 "tuning": {"n_trials": 4, "metric": "rmse"}
8624 },
8625 {
8626 "kind": "merge_model",
8627 "id": "model:meta",
8628 "operator": "Ridge"
8629 }
8630 ]
8631}"#,
8632 )
8633 .unwrap();
8634
8635 let graph = compile_pipeline_dsl(&spec).unwrap();
8636 graph.validate().unwrap();
8637 let tuner = graph
8638 .nodes
8639 .iter()
8640 .find(|node| node.id.as_str() == "tuner:optuna")
8641 .unwrap();
8642 assert_eq!(tuner.kind, NodeKind::Tuner);
8643 assert_eq!(
8644 tuner.operator.as_ref().unwrap().as_str(),
8645 Some("OptunaTuner")
8646 );
8647 assert_eq!(tuner.metadata["dsl_tuning"]["n_trials"], 4);
8648 assert!(graph.edges.iter().any(|edge| {
8649 edge.source.node_id.as_str() == "tuner:optuna"
8650 && edge.source.port_name == "oof"
8651 && edge.target.node_id.as_str() == "model:meta"
8652 && edge.contract.kind == PortKind::Prediction
8653 && edge.contract.requires_oof
8654 && edge.contract.requires_fold_alignment
8655 }));
8656 }
8657
8658 #[test]
8659 fn parses_compat_tuner_minimal_alias_and_wrappers() {
8660 let spec = parse_pipeline_dsl_json(
8661 br#"{
8662 "id": "dsl-compat-tuner",
8663 "pipeline": [
8664 "SNV",
8665 {"tuner": "OptunaTuner", "id": "tuner:compat", "n_trials": 3, "metric": "rmse"},
8666 {"merge": "all"},
8667 {"model": "Ridge"}
8668 ]
8669}"#,
8670 )
8671 .unwrap();
8672
8673 let graph = compile_pipeline_dsl(&spec).unwrap();
8674 graph.validate().unwrap();
8675 let transform = graph
8676 .nodes
8677 .iter()
8678 .find(|node| node.kind == NodeKind::Transform)
8679 .unwrap();
8680 assert_eq!(transform.operator.as_ref().unwrap().as_str(), Some("SNV"));
8681 let tuner = graph
8682 .nodes
8683 .iter()
8684 .find(|node| node.id.as_str() == "tuner:compat")
8685 .unwrap();
8686 assert_eq!(tuner.kind, NodeKind::Tuner);
8687 assert_eq!(tuner.params["n_trials"], 3);
8688 assert_eq!(tuner.metadata["dsl_compat_keyword"], "tuner");
8689 }
8690
8691 #[test]
8692 fn parses_bare_tuner_alias_as_tuner_node() {
8693 let spec = parse_pipeline_dsl_json(
8694 br#"{
8695 "id": "dsl-bare-tuner-alias",
8696 "pipeline": ["SNV", "OptunaTuner"]
8697}"#,
8698 )
8699 .unwrap();
8700
8701 let graph = compile_pipeline_dsl(&spec).unwrap();
8702 graph.validate().unwrap();
8703 assert!(graph.nodes.iter().any(|node| {
8704 node.kind == NodeKind::Transform
8705 && node.operator.as_ref().unwrap().as_str() == Some("SNV")
8706 }));
8707 assert!(graph.nodes.iter().any(|node| {
8708 node.kind == NodeKind::Tuner
8709 && node.operator.as_ref().unwrap().as_str() == Some("OptunaTuner")
8710 }));
8711 }
8712
8713 #[test]
8714 fn compiles_runtime_data_generation_as_external_generator_node() {
8715 let spec: PipelineDslSpec = serde_json::from_str(
8716 r#"{
8717 "id": "dsl-runtime-data-generation",
8718 "steps": [
8719 {
8720 "kind": "generation",
8721 "id": "generator:synthetic.train",
8722 "operator": "SMOTE",
8723 "params": {"ratio": 0.5},
8724 "shape": {
8725 "fit_rows": "fold_train",
8726 "predict_rows": "fold_validation",
8727 "augmentation_policy": {
8728 "sample_scope": "train_only",
8729 "feature_scope": "none",
8730 "require_origin_id": true,
8731 "inherit_group": true,
8732 "inherit_target": true
8733 }
8734 }
8735 },
8736 {
8737 "kind": "model",
8738 "id": "model:ridge",
8739 "operator": "Ridge"
8740 }
8741 ]
8742}"#,
8743 )
8744 .unwrap();
8745
8746 let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
8747 compiled.graph.validate().unwrap();
8748 let generator = compiled
8749 .graph
8750 .nodes
8751 .iter()
8752 .find(|node| node.id.as_str() == "generator:synthetic.train")
8753 .unwrap();
8754 assert_eq!(generator.kind, NodeKind::Generator);
8755 assert_eq!(generator.operator.as_ref().unwrap().as_str(), Some("SMOTE"));
8756 assert_eq!(generator.metadata["dsl_generation_kind"], "data");
8757 assert!(compiled
8758 .shape_plans
8759 .contains_key(&NodeId::new("generator:synthetic.train").unwrap()));
8760 assert!(compiled.graph.edges.iter().any(|edge| {
8761 edge.source.node_id.as_str() == "generator:synthetic.train"
8762 && edge.source.port_name == "x_out"
8763 && edge.target.node_id.as_str() == "model:ridge"
8764 && edge.target.port_name == "x"
8765 && edge.contract.kind == PortKind::Data
8766 }));
8767 }
8768
8769 #[test]
8770 fn parses_compat_runtime_generation_step() {
8771 let spec = parse_pipeline_dsl_json(
8772 br#"{
8773 "id": "dsl-compat-runtime-generation",
8774 "pipeline": [
8775 {
8776 "generation": "SMOTE",
8777 "id": "generator:compat.synthetic",
8778 "generation_params": {"ratio": 0.25},
8779 "shape": {
8780 "fit_rows": "fold_train",
8781 "predict_rows": "fold_validation",
8782 "augmentation_policy": {
8783 "sample_scope": "train_only",
8784 "feature_scope": "none",
8785 "require_origin_id": true,
8786 "inherit_group": true,
8787 "inherit_target": true
8788 }
8789 }
8790 },
8791 "Ridge"
8792 ]
8793}"#,
8794 )
8795 .unwrap();
8796
8797 let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
8798 let generator = compiled
8799 .graph
8800 .nodes
8801 .iter()
8802 .find(|node| node.id.as_str() == "generator:compat.synthetic")
8803 .unwrap();
8804 assert_eq!(generator.kind, NodeKind::Generator);
8805 assert_eq!(generator.params["ratio"], 0.25);
8806 assert_eq!(generator.metadata["dsl_compat_keyword"], "data_generation");
8807 }
8808
8809 #[test]
8810 fn parses_nirs4all_compat_feature_branch_merge_dict() {
8811 let spec = parse_pipeline_dsl_json(
8812 br#"{
8813 "id": "dsl-nirs4all-compat-feature-merge",
8814 "pipeline": [
8815 {
8816 "branch": {
8817 "snv": ["SNV"],
8818 "msc": ["MSC"]
8819 }
8820 },
8821 {
8822 "merge": {
8823 "features": "all",
8824 "output_as": "features",
8825 "on_missing": "error"
8826 }
8827 },
8828 "PLSRegression"
8829 ]
8830}"#,
8831 )
8832 .unwrap();
8833
8834 let graph = compile_pipeline_dsl(&spec).unwrap();
8835 graph.validate().unwrap();
8836 let merge = graph
8837 .nodes
8838 .iter()
8839 .find(|node| node.kind == NodeKind::FeatureJoin)
8840 .unwrap();
8841 assert_eq!(merge.metadata["merge_mode"], "features");
8842 assert_eq!(merge.metadata["on_missing"], "error");
8843 assert!(merge.metadata.contains_key("dsl_compat_merge"));
8844 assert!(merge.ports.inputs.iter().any(|port| port.name == "snv_x"));
8845 assert!(merge.ports.inputs.iter().any(|port| port.name == "msc_x"));
8846 assert!(graph.nodes.iter().any(|node| node.kind == NodeKind::Model
8847 && node.operator.as_ref().unwrap().as_str() == Some("PLSRegression")));
8848 }
8849
8850 #[test]
8851 fn published_pipeline_dsl_schema_declares_current_contract() {
8852 let schema: serde_json::Value = serde_json::from_str(include_str!(
8853 "../../../docs/contracts/pipeline_dsl.schema.json"
8854 ))
8855 .unwrap();
8856
8857 assert_eq!(schema["$id"], PIPELINE_DSL_SCHEMA_ID);
8858 assert!(schema["oneOf"].is_array());
8859 assert!(schema["$defs"]["canonical_step_kind"]["enum"]
8860 .as_array()
8861 .unwrap()
8862 .iter()
8863 .any(|value| value.as_str() == Some("generator")));
8864 assert!(schema["$defs"]["canonical_step_kind"]["enum"]
8865 .as_array()
8866 .unwrap()
8867 .iter()
8868 .any(|value| value.as_str() == Some("data_generation")));
8869 assert!(schema["$defs"]["canonical_step_kind"]["enum"]
8870 .as_array()
8871 .unwrap()
8872 .iter()
8873 .any(|value| value.as_str() == Some("tuner")));
8874 assert!(schema["$defs"]["compat_generator_key"]["enum"]
8875 .as_array()
8876 .unwrap()
8877 .iter()
8878 .any(|value| value.as_str() == Some("_cartesian_")));
8879 assert!(schema["$defs"]["compat_step_object"]["properties"]
8880 .as_object()
8881 .unwrap()
8882 .contains_key("class"));
8883 assert!(schema["$defs"]["compat_step_object"]["properties"]
8884 .as_object()
8885 .unwrap()
8886 .contains_key("step"));
8887 assert!(schema["$defs"]["pipeline_unit_contract"]["properties"]
8888 .as_object()
8889 .unwrap()
8890 .contains_key("unit_level"));
8891 assert!(schema["$defs"]["entity_unit_level"]["enum"]
8892 .as_array()
8893 .unwrap()
8894 .iter()
8895 .any(|value| value.as_str() == Some("observation")));
8896 }
8897
8898 #[test]
8899 fn refuses_unsafe_shape_plan_from_dsl() {
8900 let spec: PipelineDslSpec = serde_json::from_str(
8901 r#"{
8902 "id": "dsl-unsafe-shape-plan",
8903 "steps": [
8904 {
8905 "kind": "augmentation",
8906 "id": "augment:bad",
8907 "operator": {"type": "LeakyAugmenter"},
8908 "shape": {
8909 "augmentation_policy": {
8910 "sample_scope": "all_partitions"
8911 }
8912 }
8913 }
8914 ]
8915}"#,
8916 )
8917 .unwrap();
8918
8919 let error = compile_pipeline_dsl_with_generation(&spec).unwrap_err();
8920 assert!(format!("{error}").contains("sample augmentation over all partitions"));
8921 }
8922
8923 #[test]
8924 fn refuses_augmentation_without_shape_plan() {
8925 let spec: PipelineDslSpec = serde_json::from_str(
8926 r#"{
8927 "id": "dsl-augmentation-without-shape",
8928 "steps": [
8929 {
8930 "kind": "augmentation",
8931 "id": "augment:missing-shape",
8932 "operator": {"type": "GaussianNoise"}
8933 }
8934 ]
8935}"#,
8936 )
8937 .unwrap();
8938
8939 let error = compile_pipeline_dsl_with_generation(&spec).unwrap_err();
8940 assert!(format!("{error}").contains("requires a shape plan"));
8941 }
8942
8943 #[test]
8944 fn refuses_data_generation_without_shape_plan() {
8945 let spec: PipelineDslSpec = serde_json::from_str(
8946 r#"{
8947 "id": "dsl-generation-without-shape",
8948 "steps": [
8949 {
8950 "kind": "data_generation",
8951 "id": "generator:missing-shape",
8952 "operator": {"type": "SMOTE"}
8953 }
8954 ]
8955}"#,
8956 )
8957 .unwrap();
8958
8959 let error = compile_pipeline_dsl_with_generation(&spec).unwrap_err();
8960 assert!(format!("{error}").contains("requires a shape plan"));
8961 }
8962
8963 #[test]
8964 fn refuses_branch_without_prediction_or_data_output() {
8965 let spec: PipelineDslSpec = serde_json::from_str(
8966 r#"{
8967 "id": "dsl-bad-branch",
8968 "steps": [
8969 {
8970 "kind": "branch",
8971 "branches": [
8972 {
8973 "id": "b0",
8974 "steps": [
8975 {
8976 "kind": "y_transform",
8977 "id": "target:only",
8978 "operator": {"type": "StandardScaler"}
8979 }
8980 ]
8981 }
8982 ]
8983 }
8984 ]
8985}"#,
8986 )
8987 .unwrap();
8988
8989 let error = compile_pipeline_dsl(&spec).unwrap_err();
8990 assert!(format!("{error}")
8991 .contains("must produce at least one model, merge prediction or transformed data"));
8992 }
8993
8994 #[test]
8995 fn dsl_top_level_inner_cv_maps_to_campaign_template() {
8996 let spec: PipelineDslSpec = serde_json::from_str(
8997 r#"{
8998 "id": "dsl-inner-cv-campaign",
8999 "inner_cv": {"kind": "kfold", "n_splits": 4, "shuffle": true, "seed": 7},
9000 "steps": [
9001 {"kind": "model", "id": "model:base", "operator": {"type": "Ridge"}, "params": {"alpha": 0.5}}
9002 ]
9003}"#,
9004 )
9005 .unwrap();
9006
9007 let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
9008 match compiled.campaign_template.inner_cv {
9009 Some(crate::fold::NestedCvSpec::KFold(ref k)) => {
9010 assert_eq!(k.n_splits, 4);
9011 assert!(k.shuffle);
9012 assert_eq!(k.seed, Some(7));
9013 }
9014 ref other => panic!("expected campaign-level KFold inner_cv, got {other:?}"),
9015 }
9016 }
9017
9018 #[test]
9019 fn dsl_model_step_inner_cv_maps_to_node_metadata() {
9020 let spec: PipelineDslSpec = serde_json::from_str(
9021 r#"{
9022 "id": "dsl-inner-cv-node",
9023 "steps": [
9024 {
9025 "kind": "model",
9026 "id": "model:meta",
9027 "operator": {"type": "Ridge"},
9028 "inner_cv": {"kind": "group_kfold", "n_splits": 3}
9029 }
9030 ]
9031}"#,
9032 )
9033 .unwrap();
9034
9035 let graph = compile_pipeline_dsl(&spec).unwrap();
9036 let node = graph
9037 .nodes
9038 .iter()
9039 .find(|node| node.id.as_str() == "model:meta")
9040 .expect("compiled model node exists");
9041 let value = node
9042 .metadata
9043 .get("dsl_inner_cv")
9044 .expect("node carries dsl_inner_cv metadata");
9045 let inner: crate::fold::NestedCvSpec = serde_json::from_value(value.clone()).unwrap();
9046 match inner {
9047 crate::fold::NestedCvSpec::GroupKFold(ref g) => assert_eq!(g.n_splits, 3),
9048 other => panic!("expected node-local GroupKFold inner_cv, got {other:?}"),
9049 }
9050 }
9051
9052 #[test]
9053 fn dsl_absent_inner_cv_leaves_campaign_and_nodes_unset() {
9054 let spec: PipelineDslSpec = serde_json::from_str(
9055 r#"{
9056 "id": "dsl-no-inner-cv",
9057 "steps": [
9058 {"kind": "model", "id": "model:base", "operator": {"type": "Ridge"}}
9059 ]
9060}"#,
9061 )
9062 .unwrap();
9063
9064 let compiled = compile_pipeline_dsl_with_generation(&spec).unwrap();
9065 assert!(compiled.campaign_template.inner_cv.is_none());
9066 for node in &compiled.graph.nodes {
9067 assert!(!node.metadata.contains_key("dsl_inner_cv"));
9068 }
9069 }
9070
9071 #[test]
9072 fn compat_pipeline_preserves_campaign_and_model_inner_cv() {
9073 let spec = parse_pipeline_dsl_json(
9076 br#"{
9077 "id": "dsl-compat-inner-cv",
9078 "inner_cv": {"kind": "kfold", "n_splits": 5, "shuffle": false, "seed": 3},
9079 "pipeline": [
9080 {"split": {"type": "KFold", "n_splits": 4}},
9081 {"model": "Ridge", "id": "model:base", "inner_cv": {"kind": "group_kfold", "n_splits": 3}}
9082 ]
9083}"#,
9084 )
9085 .unwrap();
9086
9087 match spec.inner_cv {
9088 Some(crate::fold::NestedCvSpec::KFold(ref k)) => assert_eq!(k.n_splits, 5),
9089 ref other => panic!("expected compat campaign-global KFold inner_cv, got {other:?}"),
9090 }
9091
9092 let graph = compile_pipeline_dsl(&spec).unwrap();
9093 let node = graph
9094 .nodes
9095 .iter()
9096 .find(|node| node.id.as_str() == "model:base")
9097 .expect("compat model node exists");
9098 let inner: crate::fold::NestedCvSpec =
9099 serde_json::from_value(node.metadata.get("dsl_inner_cv").cloned().unwrap()).unwrap();
9100 match inner {
9101 crate::fold::NestedCvSpec::GroupKFold(ref g) => assert_eq!(g.n_splits, 3),
9102 other => panic!("expected compat node-local GroupKFold inner_cv, got {other:?}"),
9103 }
9104 }
9105
9106 #[test]
9107 fn compat_merge_model_collapse_preserves_inner_cv() {
9108 let spec = parse_pipeline_dsl_json(
9111 br#"{
9112 "id": "dsl-compat-merge-inner-cv",
9113 "pipeline": [
9114 {"_chain_": [
9115 {"_grid_": {"model": ["PLSRegression"], "n_components": [5, 10]}},
9116 {"_grid_": {"model": ["Ridge"], "alpha": [0.1, 1.0]}}
9117 ]},
9118 {"merge": "predictions"},
9119 {"model": "Ridge", "id": "model:meta", "params": {"alpha": 0.5}, "inner_cv": {"kind": "kfold", "n_splits": 4, "shuffle": false, "seed": null}}
9120 ]
9121}"#,
9122 )
9123 .unwrap();
9124
9125 let graph = compile_pipeline_dsl(&spec).unwrap();
9126 let node = graph
9127 .nodes
9128 .iter()
9129 .find(|node| node.id.as_str() == "model:meta")
9130 .expect("compat merge-model node exists");
9131 let inner: crate::fold::NestedCvSpec =
9132 serde_json::from_value(node.metadata.get("dsl_inner_cv").cloned().unwrap()).unwrap();
9133 match inner {
9134 crate::fold::NestedCvSpec::KFold(ref k) => assert_eq!(k.n_splits, 4),
9135 other => panic!("expected merge-model KFold inner_cv, got {other:?}"),
9136 }
9137 }
9138}