1use std::collections::{BTreeMap, BTreeSet};
2
3use serde::{Deserialize, Serialize};
4
5use crate::campaign::stable_json_fingerprint;
6use crate::error::{DagMlError, Result};
7use crate::ids::{NodeId, VariantId};
8use crate::rng::SeedContext;
9
10#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
11#[serde(rename_all = "snake_case")]
12pub enum GenerationStrategy {
13 #[default]
14 None,
15 Cartesian,
16 Zip,
17}
18
19#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
20pub struct GenerationChoice {
21 pub label: String,
22 pub value: serde_json::Value,
23 #[serde(default, skip_serializing_if = "Vec::is_empty")]
24 pub param_overrides: Vec<GenerationParamOverride>,
25}
26
27#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
28pub struct GenerationParamOverride {
29 pub node_id: NodeId,
30 #[serde(default)]
31 pub params: BTreeMap<String, serde_json::Value>,
32}
33
34impl GenerationChoice {
35 fn validate(&self, dimension_name: &str) -> Result<()> {
36 if self.label.trim().is_empty() {
37 return Err(DagMlError::CampaignValidation(format!(
38 "generation dimension `{dimension_name}` has an empty choice label"
39 )));
40 }
41 for override_spec in &self.param_overrides {
42 override_spec.validate(dimension_name, &self.label)?;
43 }
44 Ok(())
45 }
46}
47
48impl GenerationParamOverride {
49 fn validate(&self, dimension_name: &str, choice_label: &str) -> Result<()> {
50 if self.params.is_empty() {
51 return Err(DagMlError::CampaignValidation(format!(
52 "generation choice `{choice_label}` in dimension `{dimension_name}` has an empty param override for node `{}`",
53 self.node_id
54 )));
55 }
56 for key in self.params.keys() {
57 if key.trim().is_empty() {
58 return Err(DagMlError::CampaignValidation(format!(
59 "generation choice `{choice_label}` in dimension `{dimension_name}` has an empty param override key for node `{}`",
60 self.node_id
61 )));
62 }
63 }
64 Ok(())
65 }
66}
67
68#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
69pub struct GenerationDimension {
70 pub name: String,
71 #[serde(default)]
72 pub choices: Vec<GenerationChoice>,
73}
74
75impl GenerationDimension {
76 fn validate(&self) -> Result<()> {
77 if self.name.trim().is_empty() {
78 return Err(DagMlError::CampaignValidation(
79 "generation dimension name is empty".to_string(),
80 ));
81 }
82 if self.choices.is_empty() {
83 return Err(DagMlError::CampaignValidation(format!(
84 "generation dimension `{}` has no choices",
85 self.name
86 )));
87 }
88 let mut labels = BTreeSet::new();
89 for choice in &self.choices {
90 choice.validate(&self.name)?;
91 if !labels.insert(choice.label.as_str()) {
92 return Err(DagMlError::CampaignValidation(format!(
93 "generation dimension `{}` has duplicate choice `{}`",
94 self.name, choice.label
95 )));
96 }
97 }
98 Ok(())
99 }
100}
101
102#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
103pub struct GenerationSpec {
104 #[serde(default)]
105 pub strategy: GenerationStrategy,
106 #[serde(default)]
107 pub dimensions: Vec<GenerationDimension>,
108 #[serde(default)]
109 pub max_variants: Option<usize>,
110}
111
112impl Default for GenerationSpec {
113 fn default() -> Self {
114 Self {
115 strategy: GenerationStrategy::None,
116 dimensions: Vec::new(),
117 max_variants: Some(1),
118 }
119 }
120}
121
122impl GenerationSpec {
123 pub fn validate(&self) -> Result<()> {
124 if self.max_variants == Some(0) {
125 return Err(DagMlError::CampaignValidation(
126 "generation max_variants cannot be zero".to_string(),
127 ));
128 }
129 if self.strategy == GenerationStrategy::None {
130 if !self.dimensions.is_empty() {
131 return Err(DagMlError::CampaignValidation(
132 "generation dimensions require cartesian or zip strategy".to_string(),
133 ));
134 }
135 return Ok(());
136 }
137
138 if self.dimensions.is_empty() {
139 return Err(DagMlError::CampaignValidation(
140 "generation strategy requires at least one dimension".to_string(),
141 ));
142 }
143 let mut names = BTreeSet::new();
144 for dimension in &self.dimensions {
145 dimension.validate()?;
146 if !names.insert(dimension.name.as_str()) {
147 return Err(DagMlError::CampaignValidation(format!(
148 "duplicate generation dimension `{}`",
149 dimension.name
150 )));
151 }
152 }
153 if self.strategy == GenerationStrategy::Zip {
154 let expected = self.dimensions[0].choices.len();
155 if self
156 .dimensions
157 .iter()
158 .any(|dimension| dimension.choices.len() != expected)
159 {
160 return Err(DagMlError::CampaignValidation(
161 "zip generation requires every dimension to have the same number of choices"
162 .to_string(),
163 ));
164 }
165 }
166 Ok(())
167 }
168}
169
170#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
171pub struct VariantPlan {
172 pub variant_id: VariantId,
173 #[serde(default)]
174 pub choices: BTreeMap<String, GenerationChoice>,
175 pub fingerprint: String,
176 pub seed: Option<u64>,
177}
178
179impl VariantPlan {
180 pub fn validate(&self) -> Result<()> {
181 if self.fingerprint.trim().is_empty() {
182 return Err(DagMlError::Planning(format!(
183 "variant `{}` has an empty fingerprint",
184 self.variant_id
185 )));
186 }
187 for (dimension_name, choice) in &self.choices {
188 choice.validate(dimension_name)?;
189 }
190 self.param_overrides_by_node()?;
191 Ok(())
192 }
193
194 pub fn effective_params_for_node(
195 &self,
196 node_id: &NodeId,
197 base_params: &BTreeMap<String, serde_json::Value>,
198 ) -> Result<BTreeMap<String, serde_json::Value>> {
199 let overrides_by_node = self.param_overrides_by_node()?;
200 let Some(overrides) = overrides_by_node.get(node_id) else {
201 return Ok(base_params.clone());
202 };
203 let mut params = base_params.clone();
204 params.extend(overrides.clone());
205 Ok(params)
206 }
207
208 pub fn param_override_targets(&self) -> Result<BTreeSet<NodeId>> {
209 Ok(self.param_overrides_by_node()?.into_keys().collect())
210 }
211
212 fn param_overrides_by_node(
213 &self,
214 ) -> Result<BTreeMap<NodeId, BTreeMap<String, serde_json::Value>>> {
215 let mut overrides = BTreeMap::<NodeId, BTreeMap<String, serde_json::Value>>::new();
216 let mut owners = BTreeMap::<(NodeId, String), String>::new();
217 for (dimension_name, choice) in &self.choices {
218 for override_spec in &choice.param_overrides {
219 for (param_key, value) in &override_spec.params {
220 let owner_key = (override_spec.node_id.clone(), param_key.clone());
221 if let Some(previous) =
222 owners.insert(owner_key, format!("{dimension_name}:{}", choice.label))
223 {
224 return Err(DagMlError::CampaignValidation(format!(
225 "variant `{}` has conflicting generation overrides for `{}.{}` from `{previous}` and `{}:{}`",
226 self.variant_id,
227 override_spec.node_id,
228 param_key,
229 dimension_name,
230 choice.label
231 )));
232 }
233 overrides
234 .entry(override_spec.node_id.clone())
235 .or_default()
236 .insert(param_key.clone(), value.clone());
237 }
238 }
239 }
240 Ok(overrides)
241 }
242}
243
244pub fn enumerate_variants(
245 spec: &GenerationSpec,
246 root_seed: Option<u64>,
247) -> Result<Vec<VariantPlan>> {
248 spec.validate()?;
249 let mut variants = match spec.strategy {
250 GenerationStrategy::None => vec![BTreeMap::new()],
251 GenerationStrategy::Cartesian => cartesian_choices(&spec.dimensions),
252 GenerationStrategy::Zip => zip_choices(&spec.dimensions),
253 };
254 if let Some(max_variants) = spec.max_variants {
255 if variants.len() > max_variants {
256 return Err(DagMlError::CampaignValidation(format!(
257 "generation produced {} variants, above max_variants={max_variants}",
258 variants.len()
259 )));
260 }
261 }
262
263 variants
264 .drain(..)
265 .map(|choices| variant_from_choices(choices, root_seed))
266 .collect()
267}
268
269pub fn generation_spec_fingerprint(spec: &GenerationSpec) -> Result<String> {
270 spec.validate()?;
271 stable_json_fingerprint(spec)
272}
273
274fn cartesian_choices(
275 dimensions: &[GenerationDimension],
276) -> Vec<BTreeMap<String, GenerationChoice>> {
277 let mut variants = vec![BTreeMap::new()];
278 for dimension in dimensions {
279 let mut next = Vec::with_capacity(variants.len() * dimension.choices.len());
280 for existing in &variants {
281 for choice in &dimension.choices {
282 let mut merged = existing.clone();
283 merged.insert(dimension.name.clone(), choice.clone());
284 next.push(merged);
285 }
286 }
287 variants = next;
288 }
289 variants
290}
291
292fn zip_choices(dimensions: &[GenerationDimension]) -> Vec<BTreeMap<String, GenerationChoice>> {
293 let len = dimensions
294 .first()
295 .map_or(0, |dimension| dimension.choices.len());
296 (0..len)
297 .map(|idx| {
298 dimensions
299 .iter()
300 .map(|dimension| (dimension.name.clone(), dimension.choices[idx].clone()))
301 .collect::<BTreeMap<_, _>>()
302 })
303 .collect()
304}
305
306fn variant_from_choices(
307 choices: BTreeMap<String, GenerationChoice>,
308 root_seed: Option<u64>,
309) -> Result<VariantPlan> {
310 let fingerprint = stable_json_fingerprint(&choices)?;
311 let suffix = if choices.is_empty() {
312 "base".to_string()
313 } else {
314 fingerprint[..16].to_string()
315 };
316 let variant_id = VariantId::new(format!("variant:{suffix}"))?;
317 let seed = root_seed.map(|seed| {
318 SeedContext::root(seed)
319 .child(format!("variant:{variant_id}"))
320 .derive_u64("variant")
321 });
322 let variant = VariantPlan {
323 variant_id,
324 choices,
325 fingerprint,
326 seed,
327 };
328 variant.validate()?;
329 Ok(variant)
330}
331
332#[cfg(test)]
333mod tests {
334 use serde_json::json;
335
336 use super::*;
337
338 fn choice(label: &str, value: serde_json::Value) -> GenerationChoice {
339 GenerationChoice {
340 label: label.to_string(),
341 value,
342 param_overrides: Vec::new(),
343 }
344 }
345
346 fn override_choice(
347 label: &str,
348 node_id: &str,
349 params: BTreeMap<String, serde_json::Value>,
350 ) -> GenerationChoice {
351 GenerationChoice {
352 label: label.to_string(),
353 value: json!(label),
354 param_overrides: vec![GenerationParamOverride {
355 node_id: NodeId::new(node_id).unwrap(),
356 params,
357 }],
358 }
359 }
360
361 #[test]
362 fn default_generation_produces_base_variant() {
363 let variants = enumerate_variants(&GenerationSpec::default(), Some(7)).unwrap();
364
365 assert_eq!(variants.len(), 1);
366 assert_eq!(variants[0].variant_id.as_str(), "variant:base");
367 assert!(variants[0].choices.is_empty());
368 assert!(variants[0].seed.is_some());
369 }
370
371 #[test]
372 fn cartesian_generation_is_deterministic_and_fingerprinted() {
373 let spec = GenerationSpec {
374 strategy: GenerationStrategy::Cartesian,
375 dimensions: vec![
376 GenerationDimension {
377 name: "model".to_string(),
378 choices: vec![choice("pls", json!("pls")), choice("rf", json!("rf"))],
379 },
380 GenerationDimension {
381 name: "window".to_string(),
382 choices: vec![choice("short", json!(7)), choice("long", json!(21))],
383 },
384 ],
385 max_variants: Some(4),
386 };
387
388 let left = enumerate_variants(&spec, Some(11)).unwrap();
389 let right = enumerate_variants(&spec, Some(11)).unwrap();
390
391 assert_eq!(left.len(), 4);
392 assert_eq!(left, right);
393 let fingerprint = generation_spec_fingerprint(&spec).unwrap();
394 let mut changed_spec = spec.clone();
395 changed_spec.dimensions[0].choices[0].value = json!("changed");
396 assert_eq!(fingerprint, generation_spec_fingerprint(&spec).unwrap());
397 assert_ne!(
398 fingerprint,
399 generation_spec_fingerprint(&changed_spec).unwrap()
400 );
401 assert_ne!(left[0].variant_id, left[1].variant_id);
402 assert_eq!(left[0].choices["model"].label, "pls");
403 assert_eq!(left[0].choices["window"].label, "short");
404 }
405
406 #[test]
407 fn zip_generation_requires_same_choice_count() {
408 let spec = GenerationSpec {
409 strategy: GenerationStrategy::Zip,
410 dimensions: vec![
411 GenerationDimension {
412 name: "a".to_string(),
413 choices: vec![choice("a1", json!(1))],
414 },
415 GenerationDimension {
416 name: "b".to_string(),
417 choices: vec![choice("b1", json!(1)), choice("b2", json!(2))],
418 },
419 ],
420 max_variants: None,
421 };
422
423 assert!(spec.validate().is_err());
424 }
425
426 #[test]
427 fn generation_respects_variant_limit() {
428 let spec = GenerationSpec {
429 strategy: GenerationStrategy::Cartesian,
430 dimensions: vec![GenerationDimension {
431 name: "x".to_string(),
432 choices: vec![choice("a", json!(1)), choice("b", json!(2))],
433 }],
434 max_variants: Some(1),
435 };
436
437 assert!(enumerate_variants(&spec, None).is_err());
438 }
439
440 #[test]
441 fn variant_applies_node_param_overrides() {
442 let spec = GenerationSpec {
443 strategy: GenerationStrategy::Cartesian,
444 dimensions: vec![GenerationDimension {
445 name: "model_family".to_string(),
446 choices: vec![override_choice(
447 "pls",
448 "model:base",
449 BTreeMap::from([("n_components".to_string(), json!(8))]),
450 )],
451 }],
452 max_variants: Some(1),
453 };
454 let variants = enumerate_variants(&spec, Some(7)).unwrap();
455 let base = BTreeMap::from([("scale".to_string(), json!(true))]);
456
457 let params = variants[0]
458 .effective_params_for_node(&NodeId::new("model:base").unwrap(), &base)
459 .unwrap();
460
461 assert_eq!(params["scale"], json!(true));
462 assert_eq!(params["n_components"], json!(8));
463 }
464
465 #[test]
466 fn variant_rejects_conflicting_param_overrides() {
467 let spec = GenerationSpec {
468 strategy: GenerationStrategy::Cartesian,
469 dimensions: vec![
470 GenerationDimension {
471 name: "family".to_string(),
472 choices: vec![override_choice(
473 "pls",
474 "model:base",
475 BTreeMap::from([("alpha".to_string(), json!(1))]),
476 )],
477 },
478 GenerationDimension {
479 name: "regularization".to_string(),
480 choices: vec![override_choice(
481 "ridge",
482 "model:base",
483 BTreeMap::from([("alpha".to_string(), json!(2))]),
484 )],
485 },
486 ],
487 max_variants: Some(1),
488 };
489
490 let error = enumerate_variants(&spec, None).unwrap_err().to_string();
491
492 assert!(error.contains("conflicting generation overrides"));
493 }
494}