Skip to main content

simular/edd/
experiment.rs

1//! Experiment specification for EDD - YAML-driven declarative experiments.
2//!
3//! Every experiment in the EDD framework must be declaratively specified with:
4//! - Explicit seed for reproducibility
5//! - Falsification criteria for scientific validity
6//! - Governing equation reference
7//! - Hypotheses to test
8//!
9//! # Example YAML Specification
10//!
11//! ```yaml
12//! experiment:
13//!   name: "Little's Law Validation"
14//!   seed: 42
15//!   emc: "littles_law_v1.0"
16//!   hypothesis:
17//!     null: "L ≠ λW"
18//!     alternative: "L = λW holds under stochastic conditions"
19//!   falsification:
20//!     - criterion: "relative_error > 0.05"
21//!       action: reject_model
22//! ```
23
24use super::model_card::EquationModelCard;
25use serde::{Deserialize, Serialize};
26
27/// A hypothesis to test in the experiment.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ExperimentHypothesis {
30    /// The null hypothesis (what we try to disprove)
31    pub null: String,
32    /// The alternative hypothesis
33    pub alternative: String,
34    /// Significance level (α)
35    #[serde(default = "default_alpha")]
36    pub alpha: f64,
37}
38
39fn default_alpha() -> f64 {
40    0.05
41}
42
43impl ExperimentHypothesis {
44    /// Create a new hypothesis.
45    #[must_use]
46    pub fn new(null: &str, alternative: &str) -> Self {
47        Self {
48            null: null.to_string(),
49            alternative: alternative.to_string(),
50            alpha: 0.05,
51        }
52    }
53
54    /// Set the significance level.
55    #[must_use]
56    pub fn with_alpha(mut self, alpha: f64) -> Self {
57        self.alpha = alpha;
58        self
59    }
60}
61
62/// Action to take when falsification criterion is met.
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
64#[serde(rename_all = "snake_case")]
65pub enum FalsificationAction {
66    /// Log warning but continue
67    Warn,
68    /// Stop the experiment
69    Stop,
70    /// Reject the model
71    RejectModel,
72    /// Flag for manual review
73    FlagReview,
74}
75
76/// A criterion that would falsify the model.
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct FalsificationCriterion {
79    /// Name/description of the criterion
80    pub name: String,
81    /// Mathematical expression defining the criterion
82    pub criterion: String,
83    /// Action to take if criterion is met
84    pub action: FalsificationAction,
85    /// Additional context or explanation
86    #[serde(default)]
87    pub context: String,
88}
89
90impl FalsificationCriterion {
91    /// Create a new falsification criterion.
92    #[must_use]
93    pub fn new(name: &str, criterion: &str, action: FalsificationAction) -> Self {
94        Self {
95            name: name.to_string(),
96            criterion: criterion.to_string(),
97            action,
98            context: String::new(),
99        }
100    }
101
102    /// Add context.
103    #[must_use]
104    pub fn with_context(mut self, context: &str) -> Self {
105        self.context = context.to_string();
106        self
107    }
108}
109
110/// Experiment specification following EDD principles.
111///
112/// Every experiment must have:
113/// - Explicit seed (Pillar 3: Seed It)
114/// - Falsification criteria (Pillar 4: Falsify It)
115/// - Reference to EMC (Pillar 1: Prove It)
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct ExperimentSpec {
118    /// Unique name for this experiment
119    name: String,
120    /// Random seed for reproducibility
121    seed: u64,
122    /// Reference to the Equation Model Card
123    #[serde(default)]
124    emc_reference: Option<String>,
125    /// Hypothesis to test
126    #[serde(default)]
127    hypothesis: Option<ExperimentHypothesis>,
128    /// Falsification criteria
129    #[serde(default)]
130    falsification_criteria: Vec<FalsificationCriterion>,
131    /// Number of replications
132    #[serde(default = "default_replications")]
133    replications: u32,
134    /// Warmup period (time units)
135    #[serde(default)]
136    warmup: f64,
137    /// Run length (time units)
138    #[serde(default = "default_run_length")]
139    run_length: f64,
140    /// Description
141    #[serde(default)]
142    description: String,
143}
144
145fn default_replications() -> u32 {
146    30
147}
148
149fn default_run_length() -> f64 {
150    1000.0
151}
152
153impl ExperimentSpec {
154    /// Create a new experiment spec builder.
155    #[must_use]
156    pub fn builder() -> ExperimentSpecBuilder {
157        ExperimentSpecBuilder::new()
158    }
159
160    /// Get the experiment name.
161    #[must_use]
162    pub fn name(&self) -> &str {
163        &self.name
164    }
165
166    /// Get the seed.
167    #[must_use]
168    pub fn seed(&self) -> u64 {
169        self.seed
170    }
171
172    /// Get the EMC reference.
173    #[must_use]
174    pub fn emc_reference(&self) -> Option<&str> {
175        self.emc_reference.as_deref()
176    }
177
178    /// Get the hypothesis.
179    #[must_use]
180    pub fn hypothesis(&self) -> Option<&ExperimentHypothesis> {
181        self.hypothesis.as_ref()
182    }
183
184    /// Get falsification criteria.
185    #[must_use]
186    pub fn falsification_criteria(&self) -> &[FalsificationCriterion] {
187        &self.falsification_criteria
188    }
189
190    /// Get number of replications.
191    #[must_use]
192    pub fn replications(&self) -> u32 {
193        self.replications
194    }
195
196    /// Get warmup period.
197    #[must_use]
198    pub fn warmup(&self) -> f64 {
199        self.warmup
200    }
201
202    /// Get run length.
203    #[must_use]
204    pub fn run_length(&self) -> f64 {
205        self.run_length
206    }
207
208    /// Get description.
209    #[must_use]
210    pub fn description(&self) -> &str {
211        &self.description
212    }
213
214    /// Parse from YAML string.
215    ///
216    /// # Errors
217    /// Returns error if YAML is invalid or missing required fields.
218    pub fn from_yaml(yaml: &str) -> Result<Self, String> {
219        serde_yaml::from_str(yaml).map_err(|e| format!("Failed to parse experiment YAML: {e}"))
220    }
221
222    /// Serialize to YAML string.
223    ///
224    /// # Errors
225    /// Returns error if serialization fails.
226    pub fn to_yaml(&self) -> Result<String, String> {
227        serde_yaml::to_string(self).map_err(|e| format!("Failed to serialize experiment: {e}"))
228    }
229
230    /// Validate the experiment specification.
231    ///
232    /// # Errors
233    /// Returns error if validation fails.
234    pub fn validate(&self) -> Result<(), Vec<String>> {
235        let mut errors = Vec::new();
236
237        if self.name.is_empty() {
238            errors.push("Experiment must have a name".to_string());
239        }
240
241        if self.replications == 0 {
242            errors.push("Experiment must have at least 1 replication".to_string());
243        }
244
245        if self.run_length <= 0.0 {
246            errors.push("Run length must be positive".to_string());
247        }
248
249        if self.warmup < 0.0 {
250            errors.push("Warmup cannot be negative".to_string());
251        }
252
253        if errors.is_empty() {
254            Ok(())
255        } else {
256            Err(errors)
257        }
258    }
259}
260
261/// Builder for `ExperimentSpec`.
262#[derive(Debug, Default)]
263pub struct ExperimentSpecBuilder {
264    name: Option<String>,
265    seed: Option<u64>,
266    emc_reference: Option<String>,
267    hypothesis: Option<ExperimentHypothesis>,
268    falsification_criteria: Vec<FalsificationCriterion>,
269    replications: u32,
270    warmup: f64,
271    run_length: f64,
272    description: String,
273}
274
275impl ExperimentSpecBuilder {
276    /// Create a new builder.
277    #[must_use]
278    pub fn new() -> Self {
279        Self {
280            replications: 30,
281            run_length: 1000.0,
282            ..Default::default()
283        }
284    }
285
286    /// Set the experiment name.
287    #[must_use]
288    pub fn name(mut self, name: &str) -> Self {
289        self.name = Some(name.to_string());
290        self
291    }
292
293    /// Set the seed (required).
294    #[must_use]
295    pub fn seed(mut self, seed: u64) -> Self {
296        self.seed = Some(seed);
297        self
298    }
299
300    /// Set the EMC reference.
301    #[must_use]
302    pub fn emc_reference(mut self, reference: &str) -> Self {
303        self.emc_reference = Some(reference.to_string());
304        self
305    }
306
307    /// Set the EMC directly (extracts reference).
308    #[must_use]
309    pub fn emc(mut self, emc: &EquationModelCard) -> Self {
310        self.emc_reference = Some(format!("{}@{}", emc.name, emc.version));
311        self
312    }
313
314    /// Set the hypothesis.
315    #[must_use]
316    pub fn hypothesis(mut self, hypothesis: ExperimentHypothesis) -> Self {
317        self.hypothesis = Some(hypothesis);
318        self
319    }
320
321    /// Add a falsification criterion.
322    #[must_use]
323    pub fn add_falsification_criterion(mut self, criterion: FalsificationCriterion) -> Self {
324        self.falsification_criteria.push(criterion);
325        self
326    }
327
328    /// Set number of replications.
329    #[must_use]
330    pub fn replications(mut self, n: u32) -> Self {
331        self.replications = n;
332        self
333    }
334
335    /// Set warmup period.
336    #[must_use]
337    pub fn warmup(mut self, warmup: f64) -> Self {
338        self.warmup = warmup;
339        self
340    }
341
342    /// Set run length.
343    #[must_use]
344    pub fn run_length(mut self, length: f64) -> Self {
345        self.run_length = length;
346        self
347    }
348
349    /// Set description.
350    #[must_use]
351    pub fn description(mut self, description: &str) -> Self {
352        self.description = description.to_string();
353        self
354    }
355
356    /// Build the experiment spec.
357    ///
358    /// # Errors
359    /// Returns error if required fields are missing.
360    pub fn build(self) -> Result<ExperimentSpec, String> {
361        let name = self.name.ok_or("Experiment must have a name")?;
362        let seed = self
363            .seed
364            .ok_or("Experiment must have an explicit seed (Pillar 3: Seed It)")?;
365
366        Ok(ExperimentSpec {
367            name,
368            seed,
369            emc_reference: self.emc_reference,
370            hypothesis: self.hypothesis,
371            falsification_criteria: self.falsification_criteria,
372            replications: self.replications,
373            warmup: self.warmup,
374            run_length: self.run_length,
375            description: self.description,
376        })
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    #[test]
385    fn test_experiment_requires_seed() {
386        let result = ExperimentSpec::builder().name("Test Experiment").build();
387
388        assert!(result.is_err());
389        assert!(result.err().map(|e| e.contains("seed")).unwrap_or(false));
390    }
391
392    #[test]
393    fn test_experiment_requires_name() {
394        let result = ExperimentSpec::builder().seed(42).build();
395
396        assert!(result.is_err());
397        assert!(result.err().map(|e| e.contains("name")).unwrap_or(false));
398    }
399
400    #[test]
401    fn test_experiment_builds_with_required_fields() {
402        let result = ExperimentSpec::builder()
403            .name("Test Experiment")
404            .seed(42)
405            .build();
406
407        assert!(result.is_ok());
408        let spec = result.ok();
409        assert!(spec.is_some());
410        let spec = spec.unwrap();
411        assert_eq!(spec.seed(), 42);
412        assert_eq!(spec.name(), "Test Experiment");
413    }
414
415    #[test]
416    fn test_experiment_with_hypothesis() {
417        let spec = ExperimentSpec::builder()
418            .name("Little's Law Test")
419            .seed(42)
420            .hypothesis(ExperimentHypothesis::new(
421                "L ≠ λW",
422                "L = λW holds under stochastic conditions",
423            ))
424            .build()
425            .ok();
426
427        assert!(spec.is_some());
428        let spec = spec.unwrap();
429        assert!(spec.hypothesis().is_some());
430        assert_eq!(spec.hypothesis().unwrap().null, "L ≠ λW");
431    }
432
433    #[test]
434    fn test_experiment_with_falsification() {
435        let spec = ExperimentSpec::builder()
436            .name("Test")
437            .seed(42)
438            .add_falsification_criterion(FalsificationCriterion::new(
439                "Error too high",
440                "relative_error > 0.05",
441                FalsificationAction::RejectModel,
442            ))
443            .build()
444            .ok();
445
446        assert!(spec.is_some());
447        let spec = spec.unwrap();
448        assert_eq!(spec.falsification_criteria().len(), 1);
449    }
450
451    #[test]
452    fn test_experiment_yaml_roundtrip() {
453        let spec = ExperimentSpec::builder()
454            .name("YAML Test")
455            .seed(12345)
456            .replications(50)
457            .warmup(100.0)
458            .run_length(5000.0)
459            .description("Test experiment for YAML serialization")
460            .build()
461            .ok();
462
463        assert!(spec.is_some());
464        let spec = spec.unwrap();
465        let yaml = spec.to_yaml();
466        assert!(yaml.is_ok());
467        let yaml = yaml.ok().unwrap();
468        assert!(yaml.contains("name: YAML Test"));
469        assert!(yaml.contains("seed: 12345"));
470    }
471
472    #[test]
473    fn test_experiment_validation() {
474        let spec = ExperimentSpec::builder()
475            .name("Valid Experiment")
476            .seed(42)
477            .replications(30)
478            .run_length(1000.0)
479            .build()
480            .ok();
481
482        assert!(spec.is_some());
483        let result = spec.unwrap().validate();
484        assert!(result.is_ok());
485    }
486
487    #[test]
488    fn test_hypothesis_alpha() {
489        let hypothesis = ExperimentHypothesis::new("H0", "H1").with_alpha(0.01);
490
491        assert!((hypothesis.alpha - 0.01).abs() < f64::EPSILON);
492    }
493
494    #[test]
495    fn test_falsification_action_serialization() {
496        let criterion =
497            FalsificationCriterion::new("test", "x > 0", FalsificationAction::RejectModel);
498
499        let yaml = serde_yaml::to_string(&criterion).ok();
500        assert!(yaml.is_some());
501        let yaml = yaml.unwrap();
502        assert!(yaml.contains("reject_model"));
503    }
504
505    #[test]
506    fn test_experiment_spec_from_yaml() {
507        let yaml = r#"
508name: "Test"
509seed: 42
510replications: 10
511warmup: 50.0
512run_length: 500.0
513description: "A test experiment"
514"#;
515        let spec = ExperimentSpec::from_yaml(yaml);
516        assert!(spec.is_ok());
517        let spec = spec.ok().unwrap();
518        assert_eq!(spec.name(), "Test");
519        assert_eq!(spec.seed(), 42);
520        assert_eq!(spec.replications(), 10);
521        assert!((spec.warmup() - 50.0).abs() < f64::EPSILON);
522        assert!((spec.run_length() - 500.0).abs() < f64::EPSILON);
523        assert_eq!(spec.description(), "A test experiment");
524    }
525
526    #[test]
527    fn test_experiment_spec_from_yaml_invalid() {
528        let yaml = "invalid: [yaml";
529        let result = ExperimentSpec::from_yaml(yaml);
530        assert!(result.is_err());
531        let err = result.err().unwrap();
532        assert!(err.contains("Failed to parse"));
533    }
534
535    #[test]
536    fn test_experiment_validation_fails_empty_name() {
537        let spec = ExperimentSpec {
538            name: String::new(),
539            seed: 42,
540            emc_reference: None,
541            hypothesis: None,
542            falsification_criteria: Vec::new(),
543            replications: 30,
544            warmup: 0.0,
545            run_length: 1000.0,
546            description: String::new(),
547        };
548        let result = spec.validate();
549        assert!(result.is_err());
550        let errors = result.err().unwrap();
551        assert!(errors.iter().any(|e| e.contains("name")));
552    }
553
554    #[test]
555    fn test_experiment_validation_fails_zero_replications() {
556        let spec = ExperimentSpec {
557            name: "Test".to_string(),
558            seed: 42,
559            emc_reference: None,
560            hypothesis: None,
561            falsification_criteria: Vec::new(),
562            replications: 0,
563            warmup: 0.0,
564            run_length: 1000.0,
565            description: String::new(),
566        };
567        let result = spec.validate();
568        assert!(result.is_err());
569        let errors = result.err().unwrap();
570        assert!(errors.iter().any(|e| e.contains("replication")));
571    }
572
573    #[test]
574    fn test_experiment_validation_fails_negative_run_length() {
575        let spec = ExperimentSpec {
576            name: "Test".to_string(),
577            seed: 42,
578            emc_reference: None,
579            hypothesis: None,
580            falsification_criteria: Vec::new(),
581            replications: 30,
582            warmup: 0.0,
583            run_length: -100.0,
584            description: String::new(),
585        };
586        let result = spec.validate();
587        assert!(result.is_err());
588        let errors = result.err().unwrap();
589        assert!(errors.iter().any(|e| e.contains("Run length")));
590    }
591
592    #[test]
593    fn test_experiment_validation_fails_negative_warmup() {
594        let spec = ExperimentSpec {
595            name: "Test".to_string(),
596            seed: 42,
597            emc_reference: None,
598            hypothesis: None,
599            falsification_criteria: Vec::new(),
600            replications: 30,
601            warmup: -10.0,
602            run_length: 1000.0,
603            description: String::new(),
604        };
605        let result = spec.validate();
606        assert!(result.is_err());
607        let errors = result.err().unwrap();
608        assert!(errors.iter().any(|e| e.contains("Warmup")));
609    }
610
611    #[test]
612    fn test_experiment_spec_emc_reference_getter() {
613        let spec = ExperimentSpec::builder()
614            .name("Test")
615            .seed(42)
616            .emc_reference("test_emc@1.0")
617            .build()
618            .ok()
619            .unwrap();
620
621        assert_eq!(spec.emc_reference(), Some("test_emc@1.0"));
622    }
623
624    #[test]
625    fn test_experiment_spec_builder_defaults() {
626        let spec = ExperimentSpec::builder()
627            .name("Test")
628            .seed(42)
629            .build()
630            .ok()
631            .unwrap();
632
633        // Check defaults
634        assert_eq!(spec.replications(), 30);
635        assert!((spec.run_length() - 1000.0).abs() < f64::EPSILON);
636        assert!((spec.warmup() - 0.0).abs() < f64::EPSILON);
637    }
638
639    #[test]
640    fn test_experiment_spec_builder_emc() {
641        use crate::edd::equation::Citation;
642        use crate::edd::model_card::{EmcBuilder, VerificationTest};
643
644        let emc = EmcBuilder::new()
645            .name("TestEMC")
646            .version("2.0.0")
647            .equation("y = x")
648            .citation(Citation::new(&["Author"], "Journal", 2024))
649            .add_verification_test_full(
650                VerificationTest::new("test", 1.0, 0.1).with_input("x", 1.0),
651            )
652            .build()
653            .ok()
654            .unwrap();
655
656        let spec = ExperimentSpec::builder()
657            .name("Test")
658            .seed(42)
659            .emc(&emc)
660            .build()
661            .ok()
662            .unwrap();
663
664        assert!(spec.emc_reference().is_some());
665        let emc_ref = spec.emc_reference().unwrap();
666        assert!(emc_ref.contains("TestEMC"));
667        assert!(emc_ref.contains("2.0.0"));
668    }
669
670    #[test]
671    fn test_experiment_spec_builder_description() {
672        let spec = ExperimentSpec::builder()
673            .name("Test")
674            .seed(42)
675            .description("A detailed description")
676            .build()
677            .ok()
678            .unwrap();
679
680        assert_eq!(spec.description(), "A detailed description");
681    }
682
683    #[test]
684    fn test_experiment_spec_builder_warmup() {
685        let spec = ExperimentSpec::builder()
686            .name("Test")
687            .seed(42)
688            .warmup(200.0)
689            .build()
690            .ok()
691            .unwrap();
692
693        assert!((spec.warmup() - 200.0).abs() < f64::EPSILON);
694    }
695
696    #[test]
697    fn test_experiment_spec_builder_run_length() {
698        let spec = ExperimentSpec::builder()
699            .name("Test")
700            .seed(42)
701            .run_length(5000.0)
702            .build()
703            .ok()
704            .unwrap();
705
706        assert!((spec.run_length() - 5000.0).abs() < f64::EPSILON);
707    }
708
709    #[test]
710    fn test_falsification_criterion_with_context() {
711        let criterion = FalsificationCriterion::new("Test", "x > 0", FalsificationAction::Warn)
712            .with_context("Additional context here");
713
714        assert_eq!(criterion.context, "Additional context here");
715    }
716
717    #[test]
718    fn test_falsification_action_variants() {
719        assert_ne!(FalsificationAction::Warn, FalsificationAction::Stop);
720        assert_ne!(FalsificationAction::Stop, FalsificationAction::RejectModel);
721        assert_ne!(
722            FalsificationAction::RejectModel,
723            FalsificationAction::FlagReview
724        );
725        assert_ne!(FalsificationAction::FlagReview, FalsificationAction::Warn);
726    }
727
728    #[test]
729    fn test_default_alpha() {
730        assert!((default_alpha() - 0.05).abs() < f64::EPSILON);
731    }
732
733    #[test]
734    fn test_default_replications() {
735        assert_eq!(default_replications(), 30);
736    }
737
738    #[test]
739    fn test_default_run_length() {
740        assert!((default_run_length() - 1000.0).abs() < f64::EPSILON);
741    }
742}