Skip to main content

lgp/core/
experiment_config.rs

1//! Experiment configuration types for TOML-based experiment definitions.
2//!
3//! This module provides the configuration structures for defining and running
4//! LGP experiments in a reproducible, versioned manner.
5
6use serde::{Deserialize, Serialize};
7use std::error::Error;
8use std::fs;
9use std::path::Path;
10
11/// Serde helper module for serializing Option<u64> as a string.
12/// This is necessary because TOML only supports signed 64-bit integers,
13/// and u64 values larger than i64::MAX would cause serialization to fail.
14mod optional_u64_as_string {
15    use serde::{self, Deserialize, Deserializer, Serializer};
16
17    pub fn serialize<S>(value: &Option<u64>, serializer: S) -> Result<S::Ok, S::Error>
18    where
19        S: Serializer,
20    {
21        match value {
22            Some(v) => serializer.serialize_str(&v.to_string()),
23            None => serializer.serialize_none(),
24        }
25    }
26
27    pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<u64>, D::Error>
28    where
29        D: Deserializer<'de>,
30    {
31        use serde::de::Error;
32
33        // We need to handle both string and integer formats for backwards compatibility
34        #[derive(Deserialize)]
35        #[serde(untagged)]
36        enum StringOrInt {
37            String(String),
38            Int(u64),
39        }
40
41        let opt: Option<StringOrInt> = Option::deserialize(deserializer)?;
42        match opt {
43            Some(StringOrInt::String(s)) => s.parse().map(Some).map_err(D::Error::custom),
44            Some(StringOrInt::Int(n)) => Ok(Some(n)),
45            None => Ok(None),
46        }
47    }
48}
49
50/// Complete experiment configuration loaded from a TOML file.
51#[derive(Debug, Clone, Deserialize, Serialize)]
52pub struct ExperimentConfig {
53    pub name: String,
54    pub environment: String,
55    pub metadata: Metadata,
56    pub problem: Problem,
57    pub hyperparameters: HyperParams,
58    #[serde(default)]
59    pub operations: Vec<Operation>,
60}
61
62/// Metadata about the experiment.
63#[derive(Debug, Clone, Deserialize, Serialize)]
64pub struct Metadata {
65    pub version: String,
66    #[serde(default)]
67    pub description: Option<String>,
68    #[serde(default, skip_serializing_if = "Option::is_none")]
69    pub run_timestamp: Option<String>,
70    #[serde(default)]
71    pub title: Option<String>,
72    #[serde(default, skip_serializing_if = "Option::is_none")]
73    pub x_label: Option<String>,
74    #[serde(default, skip_serializing_if = "Option::is_none")]
75    pub y_label: Option<String>,
76    #[serde(default, skip_serializing_if = "Vec::is_empty")]
77    pub tags: Vec<String>,
78}
79
80/// Problem-specific configuration.
81#[derive(Debug, Clone, Deserialize, Serialize)]
82pub struct Problem {
83    pub n_inputs: usize,
84    pub n_actions: usize,
85}
86
87/// Hyperparameters for the genetic algorithm.
88#[derive(Debug, Clone, Deserialize, Serialize)]
89pub struct HyperParams {
90    pub population_size: usize,
91    pub n_generations: usize,
92    #[serde(default = "default_n_trials")]
93    pub n_trials: usize,
94    #[serde(default = "default_gap")]
95    pub gap: f64,
96    #[serde(default)]
97    pub default_fitness: f64,
98    /// Random seed. If None, a random seed will be generated.
99    /// Serialized as a string to support values > i64::MAX in TOML format.
100    #[serde(default, with = "optional_u64_as_string")]
101    pub seed: Option<u64>,
102    /// Number of threads for parallel evaluation. If None, uses all available cores.
103    #[serde(default)]
104    pub n_threads: Option<usize>,
105    pub program: ProgramConfig,
106}
107
108fn default_n_trials() -> usize {
109    1
110}
111
112fn default_gap() -> f64 {
113    0.5
114}
115
116/// Program generation parameters.
117#[derive(Debug, Clone, Deserialize, Serialize)]
118pub struct ProgramConfig {
119    pub max_instructions: usize,
120    #[serde(default = "default_n_extras")]
121    pub n_extras: usize,
122    #[serde(default = "default_external_factor")]
123    pub external_factor: f64,
124}
125
126fn default_n_extras() -> usize {
127    1
128}
129
130fn default_external_factor() -> f64 {
131    10.0
132}
133
134/// An operation that can be applied to the evolutionary process.
135#[derive(Debug, Clone, Deserialize, Serialize)]
136#[serde(tag = "name", rename_all = "snake_case")]
137pub enum Operation {
138    Mutation { parameters: MutationParams },
139    Crossover { parameters: CrossoverParams },
140    QLearning { parameters: QLearningParams },
141}
142
143/// Parameters for the mutation operation.
144#[derive(Debug, Clone, Deserialize, Serialize)]
145pub struct MutationParams {
146    pub percent: f64,
147}
148
149/// Parameters for the crossover operation.
150#[derive(Debug, Clone, Deserialize, Serialize)]
151pub struct CrossoverParams {
152    pub percent: f64,
153}
154
155/// Q-Learning specific parameters (for reinforcement learning with Q-Learning).
156#[derive(Debug, Clone, Deserialize, Serialize)]
157pub struct QLearningParams {
158    #[serde(default = "default_alpha")]
159    pub alpha: f64,
160    #[serde(default = "default_gamma")]
161    pub gamma: f64,
162    #[serde(default = "default_epsilon")]
163    pub epsilon: f64,
164    #[serde(default = "default_alpha_decay")]
165    pub alpha_decay: f64,
166    #[serde(default = "default_epsilon_decay")]
167    pub epsilon_decay: f64,
168}
169
170fn default_alpha() -> f64 {
171    0.1
172}
173
174fn default_gamma() -> f64 {
175    0.9
176}
177
178fn default_epsilon() -> f64 {
179    0.05
180}
181
182fn default_alpha_decay() -> f64 {
183    0.01
184}
185
186fn default_epsilon_decay() -> f64 {
187    0.001
188}
189
190impl Default for QLearningParams {
191    fn default() -> Self {
192        Self {
193            alpha: default_alpha(),
194            gamma: default_gamma(),
195            epsilon: default_epsilon(),
196            alpha_decay: default_alpha_decay(),
197            epsilon_decay: default_epsilon_decay(),
198        }
199    }
200}
201
202impl ExperimentConfig {
203    /// Load an experiment configuration from a TOML file.
204    pub fn load(path: impl AsRef<Path>) -> Result<Self, Box<dyn Error>> {
205        let content = fs::read_to_string(path)?;
206        let config: ExperimentConfig = toml::from_str(&content)?;
207        Ok(config)
208    }
209
210    /// Save the experiment configuration to a TOML file.
211    pub fn save(&self, path: impl AsRef<Path>) -> Result<(), Box<dyn Error>> {
212        let content = toml::to_string_pretty(self)?;
213        fs::write(path, content)?;
214        Ok(())
215    }
216
217    /// Create a copy of this config with resolved runtime values.
218    pub fn with_runtime_values(&self, seed: u64, timestamp: &str) -> Self {
219        let mut config = self.clone();
220        config.metadata.run_timestamp = Some(timestamp.to_string());
221        config.hyperparameters.seed = Some(seed);
222        config
223    }
224
225    /// Extract mutation percent from operations, defaults to 0.0 if not found.
226    pub fn mutation_percent(&self) -> f64 {
227        self.operations
228            .iter()
229            .find_map(|op| match op {
230                Operation::Mutation { parameters } => Some(parameters.percent),
231                _ => None,
232            })
233            .unwrap_or(0.0)
234    }
235
236    /// Extract crossover percent from operations, defaults to 0.0 if not found.
237    pub fn crossover_percent(&self) -> f64 {
238        self.operations
239            .iter()
240            .find_map(|op| match op {
241                Operation::Crossover { parameters } => Some(parameters.percent),
242                _ => None,
243            })
244            .unwrap_or(0.0)
245    }
246
247    /// Extract Q-Learning parameters from operations if present.
248    pub fn q_learning_params(&self) -> Option<QLearningParams> {
249        self.operations.iter().find_map(|op| match op {
250            Operation::QLearning { parameters } => Some(parameters.clone()),
251            _ => None,
252        })
253    }
254
255    /// Check if Q-Learning is enabled.
256    pub fn has_q_learning(&self) -> bool {
257        self.q_learning_params().is_some()
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[test]
266    fn test_parse_baseline_config() {
267        let toml_str = r#"
268name = "iris_baseline"
269environment = "Iris"
270
271[metadata]
272version = "1.0.0"
273description = "Iris baseline - no genetic operators"
274
275[problem]
276n_inputs = 4
277n_actions = 3
278
279[hyperparameters]
280population_size = 100
281n_generations = 200
282n_trials = 1
283gap = 0.5
284default_fitness = 0.0
285
286[hyperparameters.program]
287max_instructions = 100
288n_extras = 1
289external_factor = 10.0
290"#;
291        let config: ExperimentConfig = toml::from_str(toml_str).unwrap();
292        assert_eq!(config.name, "iris_baseline");
293        assert_eq!(config.environment, "Iris");
294        assert_eq!(config.problem.n_inputs, 4);
295        assert_eq!(config.problem.n_actions, 3);
296        assert_eq!(config.hyperparameters.population_size, 100);
297        assert_eq!(config.operations.len(), 0);
298        assert_eq!(config.mutation_percent(), 0.0);
299        assert_eq!(config.crossover_percent(), 0.0);
300    }
301
302    #[test]
303    fn test_parse_mutation_only_config() {
304        let toml_str = r#"
305name = "iris_mutation"
306environment = "Iris"
307
308[metadata]
309version = "1.0.0"
310
311[problem]
312n_inputs = 4
313n_actions = 3
314
315[hyperparameters]
316population_size = 100
317n_generations = 200
318
319[hyperparameters.program]
320max_instructions = 100
321
322[[operations]]
323name = "mutation"
324parameters = { percent = 1.0 }
325"#;
326        let config: ExperimentConfig = toml::from_str(toml_str).unwrap();
327        assert_eq!(config.name, "iris_mutation");
328        assert_eq!(config.operations.len(), 1);
329        assert_eq!(config.mutation_percent(), 1.0);
330        assert_eq!(config.crossover_percent(), 0.0);
331        assert!(!config.has_q_learning());
332    }
333
334    #[test]
335    fn test_parse_full_lgp_config() {
336        let toml_str = r#"
337name = "cart_pole_lgp"
338environment = "CartPole"
339
340[metadata]
341version = "1.0.0"
342description = "CartPole with mutation and crossover"
343
344[problem]
345n_inputs = 4
346n_actions = 2
347
348[hyperparameters]
349population_size = 100
350n_generations = 100
351n_trials = 100
352gap = 0.5
353default_fitness = 500.0
354
355[hyperparameters.program]
356max_instructions = 50
357n_extras = 1
358external_factor = 10.0
359
360[[operations]]
361name = "mutation"
362parameters = { percent = 0.5 }
363
364[[operations]]
365name = "crossover"
366parameters = { percent = 0.5 }
367"#;
368        let config: ExperimentConfig = toml::from_str(toml_str).unwrap();
369        assert_eq!(config.name, "cart_pole_lgp");
370        assert_eq!(config.environment, "CartPole");
371        assert_eq!(config.operations.len(), 2);
372        assert_eq!(config.mutation_percent(), 0.5);
373        assert_eq!(config.crossover_percent(), 0.5);
374        assert!(!config.has_q_learning());
375    }
376
377    #[test]
378    fn test_parse_with_q_learning_config() {
379        let toml_str = r#"
380name = "cart_pole_with_q"
381environment = "CartPole"
382
383[metadata]
384version = "1.0.0"
385description = "CartPole with mutation, crossover, and Q-learning"
386
387[problem]
388n_inputs = 4
389n_actions = 2
390
391[hyperparameters]
392population_size = 100
393n_generations = 100
394n_trials = 100
395gap = 0.5
396default_fitness = 500.0
397
398[hyperparameters.program]
399max_instructions = 50
400n_extras = 1
401external_factor = 10.0
402
403[[operations]]
404name = "mutation"
405parameters = { percent = 0.5 }
406
407[[operations]]
408name = "crossover"
409parameters = { percent = 0.5 }
410
411[[operations]]
412name = "q_learning"
413parameters = { alpha = 0.1, gamma = 0.9, epsilon = 0.05, alpha_decay = 0.01, epsilon_decay = 0.001 }
414"#;
415        let config: ExperimentConfig = toml::from_str(toml_str).unwrap();
416        assert_eq!(config.name, "cart_pole_with_q");
417        assert!(config.has_q_learning());
418        let q_params = config.q_learning_params().unwrap();
419        assert_eq!(q_params.alpha, 0.1);
420        assert_eq!(q_params.gamma, 0.9);
421        assert_eq!(q_params.epsilon, 0.05);
422    }
423
424    #[test]
425    fn test_large_seed_serialization() {
426        // Test that seeds larger than i64::MAX can be serialized and deserialized
427        let large_seed: u64 = 16412768254277122702; // > i64::MAX (9223372036854775807)
428        assert!(large_seed > i64::MAX as u64);
429
430        let toml_str = r#"
431name = "test_large_seed"
432environment = "Test"
433
434[metadata]
435version = "1.0.0"
436
437[problem]
438n_inputs = 4
439n_actions = 3
440
441[hyperparameters]
442population_size = 100
443n_generations = 200
444seed = "16412768254277122702"
445
446[hyperparameters.program]
447max_instructions = 100
448"#;
449        // Test deserialization from string format
450        let config: ExperimentConfig = toml::from_str(toml_str).unwrap();
451        assert_eq!(config.hyperparameters.seed, Some(large_seed));
452
453        // Test round-trip serialization
454        let serialized = toml::to_string_pretty(&config).unwrap();
455        assert!(serialized.contains("seed = \"16412768254277122702\""));
456
457        // Test deserialization of the serialized config
458        let deserialized: ExperimentConfig = toml::from_str(&serialized).unwrap();
459        assert_eq!(deserialized.hyperparameters.seed, Some(large_seed));
460    }
461
462    #[test]
463    fn test_seed_backwards_compatibility() {
464        // Test that integer seeds (within i64 range) still work
465        let toml_str = r#"
466name = "test_int_seed"
467environment = "Test"
468
469[metadata]
470version = "1.0.0"
471
472[problem]
473n_inputs = 4
474n_actions = 3
475
476[hyperparameters]
477population_size = 100
478n_generations = 200
479seed = 12345
480
481[hyperparameters.program]
482max_instructions = 100
483"#;
484        let config: ExperimentConfig = toml::from_str(toml_str).unwrap();
485        assert_eq!(config.hyperparameters.seed, Some(12345));
486    }
487
488    #[test]
489    fn test_no_seed_serialization() {
490        // Test that configs without a seed work correctly
491        let toml_str = r#"
492name = "test_no_seed"
493environment = "Test"
494
495[metadata]
496version = "1.0.0"
497
498[problem]
499n_inputs = 4
500n_actions = 3
501
502[hyperparameters]
503population_size = 100
504n_generations = 200
505
506[hyperparameters.program]
507max_instructions = 100
508"#;
509        let config: ExperimentConfig = toml::from_str(toml_str).unwrap();
510        assert_eq!(config.hyperparameters.seed, None);
511    }
512}