Skip to main content

dag_ml_core/
generation.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::campaign::stable_json_fingerprint;
6use crate::error::{DagMlError, Result};
7use crate::ids::{NodeId, VariantId};
8use crate::rng::SeedContext;
9
10#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
11#[serde(rename_all = "snake_case")]
12pub enum GenerationStrategy {
13    #[default]
14    None,
15    Cartesian,
16    Zip,
17}
18
19#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
20pub struct GenerationChoice {
21    pub label: String,
22    pub value: serde_json::Value,
23    #[serde(default, skip_serializing_if = "Vec::is_empty")]
24    pub param_overrides: Vec<GenerationParamOverride>,
25}
26
27#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
28pub struct GenerationParamOverride {
29    pub node_id: NodeId,
30    #[serde(default)]
31    pub params: BTreeMap<String, serde_json::Value>,
32}
33
34impl GenerationChoice {
35    fn validate(&self, dimension_name: &str) -> Result<()> {
36        if self.label.trim().is_empty() {
37            return Err(DagMlError::CampaignValidation(format!(
38                "generation dimension `{dimension_name}` has an empty choice label"
39            )));
40        }
41        for override_spec in &self.param_overrides {
42            override_spec.validate(dimension_name, &self.label)?;
43        }
44        Ok(())
45    }
46}
47
48impl GenerationParamOverride {
49    fn validate(&self, dimension_name: &str, choice_label: &str) -> Result<()> {
50        if self.params.is_empty() {
51            return Err(DagMlError::CampaignValidation(format!(
52                "generation choice `{choice_label}` in dimension `{dimension_name}` has an empty param override for node `{}`",
53                self.node_id
54            )));
55        }
56        for key in self.params.keys() {
57            if key.trim().is_empty() {
58                return Err(DagMlError::CampaignValidation(format!(
59                    "generation choice `{choice_label}` in dimension `{dimension_name}` has an empty param override key for node `{}`",
60                    self.node_id
61                )));
62            }
63        }
64        Ok(())
65    }
66}
67
68#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
69pub struct GenerationDimension {
70    pub name: String,
71    #[serde(default)]
72    pub choices: Vec<GenerationChoice>,
73}
74
75impl GenerationDimension {
76    fn validate(&self) -> Result<()> {
77        if self.name.trim().is_empty() {
78            return Err(DagMlError::CampaignValidation(
79                "generation dimension name is empty".to_string(),
80            ));
81        }
82        if self.choices.is_empty() {
83            return Err(DagMlError::CampaignValidation(format!(
84                "generation dimension `{}` has no choices",
85                self.name
86            )));
87        }
88        let mut labels = BTreeSet::new();
89        for choice in &self.choices {
90            choice.validate(&self.name)?;
91            if !labels.insert(choice.label.as_str()) {
92                return Err(DagMlError::CampaignValidation(format!(
93                    "generation dimension `{}` has duplicate choice `{}`",
94                    self.name, choice.label
95                )));
96            }
97        }
98        Ok(())
99    }
100}
101
102#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
103pub struct GenerationSpec {
104    #[serde(default)]
105    pub strategy: GenerationStrategy,
106    #[serde(default)]
107    pub dimensions: Vec<GenerationDimension>,
108    #[serde(default)]
109    pub max_variants: Option<usize>,
110}
111
112impl Default for GenerationSpec {
113    fn default() -> Self {
114        Self {
115            strategy: GenerationStrategy::None,
116            dimensions: Vec::new(),
117            max_variants: Some(1),
118        }
119    }
120}
121
122impl GenerationSpec {
123    pub fn validate(&self) -> Result<()> {
124        if self.max_variants == Some(0) {
125            return Err(DagMlError::CampaignValidation(
126                "generation max_variants cannot be zero".to_string(),
127            ));
128        }
129        if self.strategy == GenerationStrategy::None {
130            if !self.dimensions.is_empty() {
131                return Err(DagMlError::CampaignValidation(
132                    "generation dimensions require cartesian or zip strategy".to_string(),
133                ));
134            }
135            return Ok(());
136        }
137
138        if self.dimensions.is_empty() {
139            return Err(DagMlError::CampaignValidation(
140                "generation strategy requires at least one dimension".to_string(),
141            ));
142        }
143        let mut names = BTreeSet::new();
144        for dimension in &self.dimensions {
145            dimension.validate()?;
146            if !names.insert(dimension.name.as_str()) {
147                return Err(DagMlError::CampaignValidation(format!(
148                    "duplicate generation dimension `{}`",
149                    dimension.name
150                )));
151            }
152        }
153        if self.strategy == GenerationStrategy::Zip {
154            let expected = self.dimensions[0].choices.len();
155            if self
156                .dimensions
157                .iter()
158                .any(|dimension| dimension.choices.len() != expected)
159            {
160                return Err(DagMlError::CampaignValidation(
161                    "zip generation requires every dimension to have the same number of choices"
162                        .to_string(),
163                ));
164            }
165        }
166        Ok(())
167    }
168}
169
170#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
171pub struct VariantPlan {
172    pub variant_id: VariantId,
173    #[serde(default)]
174    pub choices: BTreeMap<String, GenerationChoice>,
175    pub fingerprint: String,
176    pub seed: Option<u64>,
177}
178
179impl VariantPlan {
180    pub fn validate(&self) -> Result<()> {
181        if self.fingerprint.trim().is_empty() {
182            return Err(DagMlError::Planning(format!(
183                "variant `{}` has an empty fingerprint",
184                self.variant_id
185            )));
186        }
187        for (dimension_name, choice) in &self.choices {
188            choice.validate(dimension_name)?;
189        }
190        self.param_overrides_by_node()?;
191        Ok(())
192    }
193
194    pub fn effective_params_for_node(
195        &self,
196        node_id: &NodeId,
197        base_params: &BTreeMap<String, serde_json::Value>,
198    ) -> Result<BTreeMap<String, serde_json::Value>> {
199        let overrides_by_node = self.param_overrides_by_node()?;
200        let Some(overrides) = overrides_by_node.get(node_id) else {
201            return Ok(base_params.clone());
202        };
203        let mut params = base_params.clone();
204        params.extend(overrides.clone());
205        Ok(params)
206    }
207
208    pub fn param_override_targets(&self) -> Result<BTreeSet<NodeId>> {
209        Ok(self.param_overrides_by_node()?.into_keys().collect())
210    }
211
212    fn param_overrides_by_node(
213        &self,
214    ) -> Result<BTreeMap<NodeId, BTreeMap<String, serde_json::Value>>> {
215        let mut overrides = BTreeMap::<NodeId, BTreeMap<String, serde_json::Value>>::new();
216        let mut owners = BTreeMap::<(NodeId, String), String>::new();
217        for (dimension_name, choice) in &self.choices {
218            for override_spec in &choice.param_overrides {
219                for (param_key, value) in &override_spec.params {
220                    let owner_key = (override_spec.node_id.clone(), param_key.clone());
221                    if let Some(previous) =
222                        owners.insert(owner_key, format!("{dimension_name}:{}", choice.label))
223                    {
224                        return Err(DagMlError::CampaignValidation(format!(
225                            "variant `{}` has conflicting generation overrides for `{}.{}` from `{previous}` and `{}:{}`",
226                            self.variant_id,
227                            override_spec.node_id,
228                            param_key,
229                            dimension_name,
230                            choice.label
231                        )));
232                    }
233                    overrides
234                        .entry(override_spec.node_id.clone())
235                        .or_default()
236                        .insert(param_key.clone(), value.clone());
237                }
238            }
239        }
240        Ok(overrides)
241    }
242}
243
244pub fn enumerate_variants(
245    spec: &GenerationSpec,
246    root_seed: Option<u64>,
247) -> Result<Vec<VariantPlan>> {
248    spec.validate()?;
249    let mut variants = match spec.strategy {
250        GenerationStrategy::None => vec![BTreeMap::new()],
251        GenerationStrategy::Cartesian => cartesian_choices(&spec.dimensions),
252        GenerationStrategy::Zip => zip_choices(&spec.dimensions),
253    };
254    if let Some(max_variants) = spec.max_variants {
255        if variants.len() > max_variants {
256            return Err(DagMlError::CampaignValidation(format!(
257                "generation produced {} variants, above max_variants={max_variants}",
258                variants.len()
259            )));
260        }
261    }
262
263    variants
264        .drain(..)
265        .map(|choices| variant_from_choices(choices, root_seed))
266        .collect()
267}
268
269pub fn generation_spec_fingerprint(spec: &GenerationSpec) -> Result<String> {
270    spec.validate()?;
271    stable_json_fingerprint(spec)
272}
273
274fn cartesian_choices(
275    dimensions: &[GenerationDimension],
276) -> Vec<BTreeMap<String, GenerationChoice>> {
277    let mut variants = vec![BTreeMap::new()];
278    for dimension in dimensions {
279        let mut next = Vec::with_capacity(variants.len() * dimension.choices.len());
280        for existing in &variants {
281            for choice in &dimension.choices {
282                let mut merged = existing.clone();
283                merged.insert(dimension.name.clone(), choice.clone());
284                next.push(merged);
285            }
286        }
287        variants = next;
288    }
289    variants
290}
291
292fn zip_choices(dimensions: &[GenerationDimension]) -> Vec<BTreeMap<String, GenerationChoice>> {
293    let len = dimensions
294        .first()
295        .map_or(0, |dimension| dimension.choices.len());
296    (0..len)
297        .map(|idx| {
298            dimensions
299                .iter()
300                .map(|dimension| (dimension.name.clone(), dimension.choices[idx].clone()))
301                .collect::<BTreeMap<_, _>>()
302        })
303        .collect()
304}
305
306fn variant_from_choices(
307    choices: BTreeMap<String, GenerationChoice>,
308    root_seed: Option<u64>,
309) -> Result<VariantPlan> {
310    let fingerprint = stable_json_fingerprint(&choices)?;
311    let suffix = if choices.is_empty() {
312        "base".to_string()
313    } else {
314        fingerprint[..16].to_string()
315    };
316    let variant_id = VariantId::new(format!("variant:{suffix}"))?;
317    let seed = root_seed.map(|seed| {
318        SeedContext::root(seed)
319            .child(format!("variant:{variant_id}"))
320            .derive_u64("variant")
321    });
322    let variant = VariantPlan {
323        variant_id,
324        choices,
325        fingerprint,
326        seed,
327    };
328    variant.validate()?;
329    Ok(variant)
330}
331
332#[cfg(test)]
333mod tests {
334    use serde_json::json;
335
336    use super::*;
337
338    fn choice(label: &str, value: serde_json::Value) -> GenerationChoice {
339        GenerationChoice {
340            label: label.to_string(),
341            value,
342            param_overrides: Vec::new(),
343        }
344    }
345
346    fn override_choice(
347        label: &str,
348        node_id: &str,
349        params: BTreeMap<String, serde_json::Value>,
350    ) -> GenerationChoice {
351        GenerationChoice {
352            label: label.to_string(),
353            value: json!(label),
354            param_overrides: vec![GenerationParamOverride {
355                node_id: NodeId::new(node_id).unwrap(),
356                params,
357            }],
358        }
359    }
360
361    #[test]
362    fn default_generation_produces_base_variant() {
363        let variants = enumerate_variants(&GenerationSpec::default(), Some(7)).unwrap();
364
365        assert_eq!(variants.len(), 1);
366        assert_eq!(variants[0].variant_id.as_str(), "variant:base");
367        assert!(variants[0].choices.is_empty());
368        assert!(variants[0].seed.is_some());
369    }
370
371    #[test]
372    fn cartesian_generation_is_deterministic_and_fingerprinted() {
373        let spec = GenerationSpec {
374            strategy: GenerationStrategy::Cartesian,
375            dimensions: vec![
376                GenerationDimension {
377                    name: "model".to_string(),
378                    choices: vec![choice("pls", json!("pls")), choice("rf", json!("rf"))],
379                },
380                GenerationDimension {
381                    name: "window".to_string(),
382                    choices: vec![choice("short", json!(7)), choice("long", json!(21))],
383                },
384            ],
385            max_variants: Some(4),
386        };
387
388        let left = enumerate_variants(&spec, Some(11)).unwrap();
389        let right = enumerate_variants(&spec, Some(11)).unwrap();
390
391        assert_eq!(left.len(), 4);
392        assert_eq!(left, right);
393        let fingerprint = generation_spec_fingerprint(&spec).unwrap();
394        let mut changed_spec = spec.clone();
395        changed_spec.dimensions[0].choices[0].value = json!("changed");
396        assert_eq!(fingerprint, generation_spec_fingerprint(&spec).unwrap());
397        assert_ne!(
398            fingerprint,
399            generation_spec_fingerprint(&changed_spec).unwrap()
400        );
401        assert_ne!(left[0].variant_id, left[1].variant_id);
402        assert_eq!(left[0].choices["model"].label, "pls");
403        assert_eq!(left[0].choices["window"].label, "short");
404    }
405
406    #[test]
407    fn zip_generation_requires_same_choice_count() {
408        let spec = GenerationSpec {
409            strategy: GenerationStrategy::Zip,
410            dimensions: vec![
411                GenerationDimension {
412                    name: "a".to_string(),
413                    choices: vec![choice("a1", json!(1))],
414                },
415                GenerationDimension {
416                    name: "b".to_string(),
417                    choices: vec![choice("b1", json!(1)), choice("b2", json!(2))],
418                },
419            ],
420            max_variants: None,
421        };
422
423        assert!(spec.validate().is_err());
424    }
425
426    #[test]
427    fn generation_respects_variant_limit() {
428        let spec = GenerationSpec {
429            strategy: GenerationStrategy::Cartesian,
430            dimensions: vec![GenerationDimension {
431                name: "x".to_string(),
432                choices: vec![choice("a", json!(1)), choice("b", json!(2))],
433            }],
434            max_variants: Some(1),
435        };
436
437        assert!(enumerate_variants(&spec, None).is_err());
438    }
439
440    #[test]
441    fn variant_applies_node_param_overrides() {
442        let spec = GenerationSpec {
443            strategy: GenerationStrategy::Cartesian,
444            dimensions: vec![GenerationDimension {
445                name: "model_family".to_string(),
446                choices: vec![override_choice(
447                    "pls",
448                    "model:base",
449                    BTreeMap::from([("n_components".to_string(), json!(8))]),
450                )],
451            }],
452            max_variants: Some(1),
453        };
454        let variants = enumerate_variants(&spec, Some(7)).unwrap();
455        let base = BTreeMap::from([("scale".to_string(), json!(true))]);
456
457        let params = variants[0]
458            .effective_params_for_node(&NodeId::new("model:base").unwrap(), &base)
459            .unwrap();
460
461        assert_eq!(params["scale"], json!(true));
462        assert_eq!(params["n_components"], json!(8));
463    }
464
465    #[test]
466    fn variant_rejects_conflicting_param_overrides() {
467        let spec = GenerationSpec {
468            strategy: GenerationStrategy::Cartesian,
469            dimensions: vec![
470                GenerationDimension {
471                    name: "family".to_string(),
472                    choices: vec![override_choice(
473                        "pls",
474                        "model:base",
475                        BTreeMap::from([("alpha".to_string(), json!(1))]),
476                    )],
477                },
478                GenerationDimension {
479                    name: "regularization".to_string(),
480                    choices: vec![override_choice(
481                        "ridge",
482                        "model:base",
483                        BTreeMap::from([("alpha".to_string(), json!(2))]),
484                    )],
485                },
486            ],
487            max_variants: Some(1),
488        };
489
490        let error = enumerate_variants(&spec, None).unwrap_err().to_string();
491
492        assert!(error.contains("conflicting generation overrides"));
493    }
494}