1use serde::{Deserialize, Serialize};
7use std::error::Error;
8use std::fs;
9use std::path::Path;
10
11mod optional_u64_as_string {
15 use serde::{self, Deserialize, Deserializer, Serializer};
16
17 pub fn serialize<S>(value: &Option<u64>, serializer: S) -> Result<S::Ok, S::Error>
18 where
19 S: Serializer,
20 {
21 match value {
22 Some(v) => serializer.serialize_str(&v.to_string()),
23 None => serializer.serialize_none(),
24 }
25 }
26
27 pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<u64>, D::Error>
28 where
29 D: Deserializer<'de>,
30 {
31 use serde::de::Error;
32
33 #[derive(Deserialize)]
35 #[serde(untagged)]
36 enum StringOrInt {
37 String(String),
38 Int(u64),
39 }
40
41 let opt: Option<StringOrInt> = Option::deserialize(deserializer)?;
42 match opt {
43 Some(StringOrInt::String(s)) => s.parse().map(Some).map_err(D::Error::custom),
44 Some(StringOrInt::Int(n)) => Ok(Some(n)),
45 None => Ok(None),
46 }
47 }
48}
49
50#[derive(Debug, Clone, Deserialize, Serialize)]
52pub struct ExperimentConfig {
53 pub name: String,
54 pub environment: String,
55 pub metadata: Metadata,
56 pub problem: Problem,
57 pub hyperparameters: HyperParams,
58 #[serde(default)]
59 pub operations: Vec<Operation>,
60}
61
62#[derive(Debug, Clone, Deserialize, Serialize)]
64pub struct Metadata {
65 pub version: String,
66 #[serde(default)]
67 pub description: Option<String>,
68 #[serde(default, skip_serializing_if = "Option::is_none")]
69 pub run_timestamp: Option<String>,
70 #[serde(default)]
71 pub title: Option<String>,
72 #[serde(default, skip_serializing_if = "Option::is_none")]
73 pub x_label: Option<String>,
74 #[serde(default, skip_serializing_if = "Option::is_none")]
75 pub y_label: Option<String>,
76 #[serde(default, skip_serializing_if = "Vec::is_empty")]
77 pub tags: Vec<String>,
78}
79
80#[derive(Debug, Clone, Deserialize, Serialize)]
82pub struct Problem {
83 pub n_inputs: usize,
84 pub n_actions: usize,
85}
86
87#[derive(Debug, Clone, Deserialize, Serialize)]
89pub struct HyperParams {
90 pub population_size: usize,
91 pub n_generations: usize,
92 #[serde(default = "default_n_trials")]
93 pub n_trials: usize,
94 #[serde(default = "default_gap")]
95 pub gap: f64,
96 #[serde(default)]
97 pub default_fitness: f64,
98 #[serde(default, with = "optional_u64_as_string")]
101 pub seed: Option<u64>,
102 #[serde(default)]
104 pub n_threads: Option<usize>,
105 pub program: ProgramConfig,
106}
107
108fn default_n_trials() -> usize {
109 1
110}
111
112fn default_gap() -> f64 {
113 0.5
114}
115
116#[derive(Debug, Clone, Deserialize, Serialize)]
118pub struct ProgramConfig {
119 pub max_instructions: usize,
120 #[serde(default = "default_n_extras")]
121 pub n_extras: usize,
122 #[serde(default = "default_external_factor")]
123 pub external_factor: f64,
124}
125
126fn default_n_extras() -> usize {
127 1
128}
129
130fn default_external_factor() -> f64 {
131 10.0
132}
133
134#[derive(Debug, Clone, Deserialize, Serialize)]
136#[serde(tag = "name", rename_all = "snake_case")]
137pub enum Operation {
138 Mutation { parameters: MutationParams },
139 Crossover { parameters: CrossoverParams },
140 QLearning { parameters: QLearningParams },
141}
142
143#[derive(Debug, Clone, Deserialize, Serialize)]
145pub struct MutationParams {
146 pub percent: f64,
147}
148
149#[derive(Debug, Clone, Deserialize, Serialize)]
151pub struct CrossoverParams {
152 pub percent: f64,
153}
154
155#[derive(Debug, Clone, Deserialize, Serialize)]
157pub struct QLearningParams {
158 #[serde(default = "default_alpha")]
159 pub alpha: f64,
160 #[serde(default = "default_gamma")]
161 pub gamma: f64,
162 #[serde(default = "default_epsilon")]
163 pub epsilon: f64,
164 #[serde(default = "default_alpha_decay")]
165 pub alpha_decay: f64,
166 #[serde(default = "default_epsilon_decay")]
167 pub epsilon_decay: f64,
168}
169
170fn default_alpha() -> f64 {
171 0.1
172}
173
174fn default_gamma() -> f64 {
175 0.9
176}
177
178fn default_epsilon() -> f64 {
179 0.05
180}
181
182fn default_alpha_decay() -> f64 {
183 0.01
184}
185
186fn default_epsilon_decay() -> f64 {
187 0.001
188}
189
190impl Default for QLearningParams {
191 fn default() -> Self {
192 Self {
193 alpha: default_alpha(),
194 gamma: default_gamma(),
195 epsilon: default_epsilon(),
196 alpha_decay: default_alpha_decay(),
197 epsilon_decay: default_epsilon_decay(),
198 }
199 }
200}
201
202impl ExperimentConfig {
203 pub fn load(path: impl AsRef<Path>) -> Result<Self, Box<dyn Error>> {
205 let content = fs::read_to_string(path)?;
206 let config: ExperimentConfig = toml::from_str(&content)?;
207 Ok(config)
208 }
209
210 pub fn save(&self, path: impl AsRef<Path>) -> Result<(), Box<dyn Error>> {
212 let content = toml::to_string_pretty(self)?;
213 fs::write(path, content)?;
214 Ok(())
215 }
216
217 pub fn with_runtime_values(&self, seed: u64, timestamp: &str) -> Self {
219 let mut config = self.clone();
220 config.metadata.run_timestamp = Some(timestamp.to_string());
221 config.hyperparameters.seed = Some(seed);
222 config
223 }
224
225 pub fn mutation_percent(&self) -> f64 {
227 self.operations
228 .iter()
229 .find_map(|op| match op {
230 Operation::Mutation { parameters } => Some(parameters.percent),
231 _ => None,
232 })
233 .unwrap_or(0.0)
234 }
235
236 pub fn crossover_percent(&self) -> f64 {
238 self.operations
239 .iter()
240 .find_map(|op| match op {
241 Operation::Crossover { parameters } => Some(parameters.percent),
242 _ => None,
243 })
244 .unwrap_or(0.0)
245 }
246
247 pub fn q_learning_params(&self) -> Option<QLearningParams> {
249 self.operations.iter().find_map(|op| match op {
250 Operation::QLearning { parameters } => Some(parameters.clone()),
251 _ => None,
252 })
253 }
254
255 pub fn has_q_learning(&self) -> bool {
257 self.q_learning_params().is_some()
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn test_parse_baseline_config() {
267 let toml_str = r#"
268name = "iris_baseline"
269environment = "Iris"
270
271[metadata]
272version = "1.0.0"
273description = "Iris baseline - no genetic operators"
274
275[problem]
276n_inputs = 4
277n_actions = 3
278
279[hyperparameters]
280population_size = 100
281n_generations = 200
282n_trials = 1
283gap = 0.5
284default_fitness = 0.0
285
286[hyperparameters.program]
287max_instructions = 100
288n_extras = 1
289external_factor = 10.0
290"#;
291 let config: ExperimentConfig = toml::from_str(toml_str).unwrap();
292 assert_eq!(config.name, "iris_baseline");
293 assert_eq!(config.environment, "Iris");
294 assert_eq!(config.problem.n_inputs, 4);
295 assert_eq!(config.problem.n_actions, 3);
296 assert_eq!(config.hyperparameters.population_size, 100);
297 assert_eq!(config.operations.len(), 0);
298 assert_eq!(config.mutation_percent(), 0.0);
299 assert_eq!(config.crossover_percent(), 0.0);
300 }
301
302 #[test]
303 fn test_parse_mutation_only_config() {
304 let toml_str = r#"
305name = "iris_mutation"
306environment = "Iris"
307
308[metadata]
309version = "1.0.0"
310
311[problem]
312n_inputs = 4
313n_actions = 3
314
315[hyperparameters]
316population_size = 100
317n_generations = 200
318
319[hyperparameters.program]
320max_instructions = 100
321
322[[operations]]
323name = "mutation"
324parameters = { percent = 1.0 }
325"#;
326 let config: ExperimentConfig = toml::from_str(toml_str).unwrap();
327 assert_eq!(config.name, "iris_mutation");
328 assert_eq!(config.operations.len(), 1);
329 assert_eq!(config.mutation_percent(), 1.0);
330 assert_eq!(config.crossover_percent(), 0.0);
331 assert!(!config.has_q_learning());
332 }
333
334 #[test]
335 fn test_parse_full_lgp_config() {
336 let toml_str = r#"
337name = "cart_pole_lgp"
338environment = "CartPole"
339
340[metadata]
341version = "1.0.0"
342description = "CartPole with mutation and crossover"
343
344[problem]
345n_inputs = 4
346n_actions = 2
347
348[hyperparameters]
349population_size = 100
350n_generations = 100
351n_trials = 100
352gap = 0.5
353default_fitness = 500.0
354
355[hyperparameters.program]
356max_instructions = 50
357n_extras = 1
358external_factor = 10.0
359
360[[operations]]
361name = "mutation"
362parameters = { percent = 0.5 }
363
364[[operations]]
365name = "crossover"
366parameters = { percent = 0.5 }
367"#;
368 let config: ExperimentConfig = toml::from_str(toml_str).unwrap();
369 assert_eq!(config.name, "cart_pole_lgp");
370 assert_eq!(config.environment, "CartPole");
371 assert_eq!(config.operations.len(), 2);
372 assert_eq!(config.mutation_percent(), 0.5);
373 assert_eq!(config.crossover_percent(), 0.5);
374 assert!(!config.has_q_learning());
375 }
376
377 #[test]
378 fn test_parse_with_q_learning_config() {
379 let toml_str = r#"
380name = "cart_pole_with_q"
381environment = "CartPole"
382
383[metadata]
384version = "1.0.0"
385description = "CartPole with mutation, crossover, and Q-learning"
386
387[problem]
388n_inputs = 4
389n_actions = 2
390
391[hyperparameters]
392population_size = 100
393n_generations = 100
394n_trials = 100
395gap = 0.5
396default_fitness = 500.0
397
398[hyperparameters.program]
399max_instructions = 50
400n_extras = 1
401external_factor = 10.0
402
403[[operations]]
404name = "mutation"
405parameters = { percent = 0.5 }
406
407[[operations]]
408name = "crossover"
409parameters = { percent = 0.5 }
410
411[[operations]]
412name = "q_learning"
413parameters = { alpha = 0.1, gamma = 0.9, epsilon = 0.05, alpha_decay = 0.01, epsilon_decay = 0.001 }
414"#;
415 let config: ExperimentConfig = toml::from_str(toml_str).unwrap();
416 assert_eq!(config.name, "cart_pole_with_q");
417 assert!(config.has_q_learning());
418 let q_params = config.q_learning_params().unwrap();
419 assert_eq!(q_params.alpha, 0.1);
420 assert_eq!(q_params.gamma, 0.9);
421 assert_eq!(q_params.epsilon, 0.05);
422 }
423
424 #[test]
425 fn test_large_seed_serialization() {
426 let large_seed: u64 = 16412768254277122702; assert!(large_seed > i64::MAX as u64);
429
430 let toml_str = r#"
431name = "test_large_seed"
432environment = "Test"
433
434[metadata]
435version = "1.0.0"
436
437[problem]
438n_inputs = 4
439n_actions = 3
440
441[hyperparameters]
442population_size = 100
443n_generations = 200
444seed = "16412768254277122702"
445
446[hyperparameters.program]
447max_instructions = 100
448"#;
449 let config: ExperimentConfig = toml::from_str(toml_str).unwrap();
451 assert_eq!(config.hyperparameters.seed, Some(large_seed));
452
453 let serialized = toml::to_string_pretty(&config).unwrap();
455 assert!(serialized.contains("seed = \"16412768254277122702\""));
456
457 let deserialized: ExperimentConfig = toml::from_str(&serialized).unwrap();
459 assert_eq!(deserialized.hyperparameters.seed, Some(large_seed));
460 }
461
462 #[test]
463 fn test_seed_backwards_compatibility() {
464 let toml_str = r#"
466name = "test_int_seed"
467environment = "Test"
468
469[metadata]
470version = "1.0.0"
471
472[problem]
473n_inputs = 4
474n_actions = 3
475
476[hyperparameters]
477population_size = 100
478n_generations = 200
479seed = 12345
480
481[hyperparameters.program]
482max_instructions = 100
483"#;
484 let config: ExperimentConfig = toml::from_str(toml_str).unwrap();
485 assert_eq!(config.hyperparameters.seed, Some(12345));
486 }
487
488 #[test]
489 fn test_no_seed_serialization() {
490 let toml_str = r#"
492name = "test_no_seed"
493environment = "Test"
494
495[metadata]
496version = "1.0.0"
497
498[problem]
499n_inputs = 4
500n_actions = 3
501
502[hyperparameters]
503population_size = 100
504n_generations = 200
505
506[hyperparameters.program]
507max_instructions = 100
508"#;
509 let config: ExperimentConfig = toml::from_str(toml_str).unwrap();
510 assert_eq!(config.hyperparameters.seed, None);
511 }
512}