1use crate::{AlgorithmType, Config, Result};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::{Path, PathBuf};
10use rand::RngExt;
11use tracing::{info, warn};
12use rayon::prelude::*;
13use std::sync::{Arc, Mutex};
14use crate::models::NetworkConfig;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct HyperparameterSpace {
19 pub learning_rate: (f64, f64),
21 pub batch_size: Vec<usize>,
22 pub gamma: (f64, f64),
23 pub epsilon_decay: (f64, f64),
24 pub priority_alpha: (f64, f64),
25 pub priority_beta: (f64, f64),
26
27 pub hidden_layer_sizes: Vec<Vec<usize>>, pub value_hidden: Vec<usize>,
30 pub advantage_hidden: Vec<usize>,
31 pub use_layer_norm: Vec<bool>,
32 pub dropout: (f32, f32),
33}
34
35impl Default for HyperparameterSpace {
36 fn default() -> Self {
37 Self {
38 learning_rate: (1e-5, 3e-3),
41 batch_size: vec![256, 512, 1024, 2048, 4096, 6144, 8192],
42 gamma: (0.85, 0.99),
43 epsilon_decay: (0.985, 0.999),
44 priority_alpha: (0.35, 0.8),
45 priority_beta: (0.3, 0.7),
46
47 hidden_layer_sizes: vec![
49 vec![256, 128], vec![512, 256, 128], vec![1024, 512, 256], vec![512, 512, 256, 128], ],
54 value_hidden: vec![32, 64, 128, 192],
55 advantage_hidden: vec![32, 64, 128, 192],
56 use_layer_norm: vec![true, false],
57 dropout: (0.0, 0.01),
58 }
59 }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct Hyperparameters {
66 pub learning_rate: f64,
68 pub batch_size: usize,
69 pub gamma: f64,
70 pub epsilon_decay: f64,
71 pub priority_alpha: f64,
72 pub priority_beta: f64,
73
74 pub network_config: NetworkConfig,
76
77 pub timestamp: String,
78 pub quality_score: f64,
79}
80
81impl Hyperparameters {
82 pub fn apply_to_config(&self, config: &mut Config) {
84 config.learning_rate = self.learning_rate;
85 config.batch_size = self.batch_size;
86 config.gamma = self.gamma;
87 config.epsilon_decay = self.epsilon_decay;
88 config.priority_alpha = self.priority_alpha;
89 config.priority_beta = self.priority_beta;
90
91 config.state_dim = self.network_config.state_dim;
93 config.num_discrete_actions = self.network_config.num_actions;
94 config.num_continuous_params = self.network_config.num_params;
95 }
96
97 pub fn save_with_algorithm(&self, base_path: &Path, algorithm: AlgorithmType) -> Result<()> {
99 let filename = format!("best_hyperparams_{}.json", algorithm.to_string().to_lowercase());
100 let path = base_path.parent()
101 .unwrap_or(base_path)
102 .join(filename);
103
104 let json = serde_json::to_string_pretty(self)?;
105 std::fs::write(&path, json)?;
106
107 info!("✓ Saved {} hyperparameters to: {}", algorithm, path.display());
108 Ok(())
109 }
110
111 pub fn load_for_algorithm(base_dir: &Path, algorithm: AlgorithmType) -> Result<Self> {
113 let filename = format!("best_hyperparams_{}.json", algorithm.to_string().to_lowercase());
114 let path = base_dir.join(&filename);
115
116 if !path.exists() {
117 return Err(crate::ExtractionError::ParseError(
118 format!("Hyperparameters file not found: {}", path.display())
119 ));
120 }
121
122 let json = std::fs::read_to_string(&path)?;
123 let params:Hyperparameters = serde_json::from_str(&json)?;
124
125 info!("✓ Loaded {} hyperparameters from: {}", algorithm, path.display());
126 info!(" Settings:");
127 info!(" learning_rate: {:.6}", params.learning_rate);
128 info!(" batch_size: {}", params.batch_size);
129 info!(" gamma: {:.3}", params.gamma);
130 info!(" epsilon_decay: {:.6}", params.epsilon_decay);
131 info!(" priority_alpha: {:.3}", params.priority_alpha);
132 info!(" priority_beta: {:.3}", params.priority_beta);
133
134 Ok(params)
135 }
136
137 pub fn save(&self, path: &Path) -> Result<()> {
139 let json = serde_json::to_string_pretty(self)?;
140 std::fs::write(path, json)?;
141 info!("Saved hyperparameters to: {}", path.display());
142 Ok(())
143 }
144
145 pub fn load(path: &Path) -> Result<Self> {
147 let json = std::fs::read_to_string(path)?;
148 let params = serde_json::from_str(&json)?;
149 info!("Loaded hyperparameters from: {}", path.display());
150 Ok(params)
151 }
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct TrialResult {
157 pub trial_number: usize,
158 pub hyperparameters: Hyperparameters,
159 pub quality_score: f64,
160 pub avg_reward: f64,
161 pub duration_seconds: f64,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct OptimizerState {
167 pub trials: Vec<TrialResult>,
168 pub n_startup_trials: usize,
169 pub space: HyperparameterSpace,
170 pub best_trial: Option<usize>,
171 pub timestamp: String,
172}
173
174impl OptimizerState {
175 pub fn save(&self, path: &Path) -> Result<()> {
177 let json = serde_json::to_string_pretty(self)?;
178 std::fs::write(path, json)?;
179 info!("Saved optimizer state to: {}", path.display());
180 Ok(())
181 }
182
183 pub fn load(path: &Path) -> Result<Self> {
185 let json = std::fs::read_to_string(path)?;
186 let state = serde_json::from_str(&json)?;
187 info!("Loaded optimizer state from: {}", path.display());
188 Ok(state)
189 }
190}
191
192pub struct TPEOptimizer {
194 space: HyperparameterSpace,
195 trials: Vec<TrialResult>,
196 n_startup_trials: usize,
197 state_path: Option<PathBuf>,
198}
199
200impl TPEOptimizer {
201 pub fn new(space: HyperparameterSpace) -> Self {
203 Self {
204 space,
205 trials: Vec::new(),
206 n_startup_trials: 5, state_path: None,
208 }
209 }
210
211 pub fn optimize_parallel(
213 &mut self,
214 n_trials: usize,
215 episodes_per_trial: usize,
216 html_samples: Vec<(String, String)>,
217 base_config: &Config,
218 n_workers: usize,
219 ) -> Result<()> {
220 info!("Starting parallel TPE optimization with {} workers", n_workers);
221
222 let pool = rayon::ThreadPoolBuilder::new()
224 .num_threads(n_workers)
225 .build()
226 .map_err(|e| crate::ExtractionError::RuntimeError(e.to_string()))?;
227
228 let mut all_trial_params = Vec::new();
230 let mut rng = rand::rng();
231 for trial_num in 0..n_trials {
232 let params = self.random_suggest(&mut rng);
234 all_trial_params.push((trial_num, params));
235 }
236
237 let results = Arc::new(Mutex::new(Vec::new()));
239 let completed_trials = Arc::new(Mutex::new(0usize));
240
241 pool.install(|| {
243 all_trial_params.par_iter().for_each(|(trial_num, params)| {
244 info!("Worker starting trial {}", trial_num);
245
246 let mut trial_config = base_config.clone();
248 params.apply_to_config(&mut trial_config);
249 trial_config.num_episodes = episodes_per_trial;
250
251 trial_config.use_cpu_for_tuning = true;
253
254 let trial_start = std::time::Instant::now();
255
256 let trial_samples: Vec<crate::TrainingSample> =
258 html_samples.clone().into_iter().map(Into::into).collect();
259 let result = crate::training::train_standard(&trial_config, trial_samples);
260
261 match result {
262 Ok((_agent, metrics)) => {
263 let duration = trial_start.elapsed();
264
265 let window = metrics.episode_qualities.len().min(50);
267 let quality = if metrics.episode_qualities.len() >= window {
268 metrics.episode_qualities[metrics.episode_qualities.len() - window..]
269 .iter()
270 .sum::<f32>() / window as f32
271 } else if !metrics.episode_qualities.is_empty() {
272 metrics.episode_qualities.iter().sum::<f32>() /
273 metrics.episode_qualities.len() as f32
274 } else {
275 0.0
276 };
277
278 let avg_reward = if !metrics.episode_rewards.is_empty() {
279 let window = metrics.episode_rewards.len().min(50);
280 if metrics.episode_rewards.len() >= window {
281 metrics.episode_rewards[metrics.episode_rewards.len() - window..]
282 .iter()
283 .sum::<f32>() / window as f32
284 } else {
285 metrics.episode_rewards.iter().sum::<f32>() /
286 metrics.episode_rewards.len() as f32
287 }
288 } else {
289 0.0
290 };
291
292 let trial_result = TrialResult {
294 trial_number: *trial_num,
295 hyperparameters: Hyperparameters {
296 quality_score: quality as f64,
297 ..params.clone()
298 },
299 quality_score: quality as f64,
300 avg_reward: avg_reward as f64,
301 duration_seconds: duration.as_secs_f64(),
302 };
303
304 {
306 let mut res = results.lock().unwrap();
307 res.push(trial_result);
308 }
309
310 {
311 let mut completed = completed_trials.lock().unwrap();
312 *completed += 1;
313 info!("Trial {} completed ({}/{}): quality={:.4}",
314 trial_num, *completed, n_trials, quality);
315 }
316 }
317 Err(e) => {
318 warn!("Trial {} failed: {}", trial_num, e);
319 }
320 }
321 });
322 });
323
324 let trial_results = results.lock().unwrap();
326 for trial_result in trial_results.iter() {
327 self.tell(trial_result.clone());
328 }
329
330 info!("Parallel optimization complete");
331 Ok(())
332 }
333
334 pub fn with_resume(space: HyperparameterSpace, state_path: PathBuf) -> Result<Self> {
336 let mut optimizer = Self {
337 space: space.clone(),
338 trials: Vec::new(),
339 n_startup_trials: 5,
340 state_path: Some(state_path.clone()),
341 };
342
343 if state_path.exists() {
345 info!("Resuming from saved state: {}", state_path.display());
346 let state = OptimizerState::load(&state_path)?;
347 optimizer.trials = state.trials;
348 optimizer.space = state.space;
349 optimizer.n_startup_trials = state.n_startup_trials;
350 info!("Resumed with {} existing trials", optimizer.trials.len());
351 }
352
353 Ok(optimizer)
354 }
355
356 pub fn random_suggest(&self, rng: &mut impl RngExt) -> Hyperparameters {
358 let hidden_layers = self.space.hidden_layer_sizes
360 .get(rng.random_range(0..self.space.hidden_layer_sizes.len()))
361 .unwrap()
362 .clone();
363
364 let value_hidden = *self.space.value_hidden
365 .get(rng.random_range(0..self.space.value_hidden.len()))
366 .unwrap();
367
368 let advantage_hidden = *self.space.advantage_hidden
369 .get(rng.random_range(0..self.space.advantage_hidden.len()))
370 .unwrap();
371
372 let use_layer_norm = *self.space.use_layer_norm
373 .get(rng.random_range(0..self.space.use_layer_norm.len()))
374 .unwrap();
375
376 let dropout = rng.random_range(self.space.dropout.0..self.space.dropout.1);
377
378 let (lr_lo, lr_hi) = self.space.learning_rate;
383 let learning_rate = (lr_lo.ln() + rng.random::<f64>() * (lr_hi.ln() - lr_lo.ln())).exp();
384
385 Hyperparameters {
386 learning_rate,
387 batch_size: *self.space.batch_size
388 .get(rng.random_range(0..self.space.batch_size.len()))
389 .unwrap(),
390 gamma: rng.random_range(self.space.gamma.0..self.space.gamma.1),
391 epsilon_decay: rng.random_range(self.space.epsilon_decay.0..self.space.epsilon_decay.1),
392 priority_alpha: rng.random_range(self.space.priority_alpha.0..self.space.priority_alpha.1),
393 priority_beta: rng.random_range(self.space.priority_beta.0..self.space.priority_beta.1),
394 network_config: NetworkConfig {
395 state_dim: 300,
396 num_actions: 16,
397 num_params: 6,
398 hidden_layers,
399 use_layer_norm,
400 dropout,
401 value_hidden,
402 advantage_hidden,
403 },
404 timestamp: chrono::Utc::now().to_rfc3339(),
405 quality_score: 0.0,
406 }
407 }
408
409 #[allow(dead_code)]
411 fn sample_tpe_categorical<T: Clone>(
412 &self,
413 good_values: Vec<&T>,
414 _bad_values: Vec<&T>,
415 choices: &[T],
416 rng: &mut impl RngExt,
417 ) -> T {
418 if good_values.is_empty() {
419 return choices[rng.random_range(0..choices.len())].clone();
420 }
421
422 let mut counts: HashMap<usize, usize> = HashMap::new();
424 for _good_val in &good_values {
425 for (i, _choice) in choices.iter().enumerate() {
426 counts.entry(i).or_insert(0);
428 }
429 }
430
431 if counts.is_empty() {
433 choices[rng.random_range(0..choices.len())].clone()
434 } else {
435 let total: usize = counts.values().sum();
436 let r: f64 = rng.random::<f64>() * total as f64;
437 let mut cumsum = 0.0;
438
439 for (idx, count) in counts.iter() {
440 cumsum += *count as f64;
441 if r <= cumsum {
442 return choices[*idx].clone();
443 }
444 }
445 choices[0].clone()
446 }
447 }
448
449 #[allow(dead_code)]
451 fn sample_tpe_boolean(
452 &self,
453 good_values: Vec<bool>,
454 _bad_values: Vec<bool>,
455 rng: &mut impl RngExt,
456 ) -> bool {
457 if good_values.is_empty() {
458 return rng.random();
459 }
460
461 let true_count = good_values.iter().filter(|&&x| x).count();
462 let probability = true_count as f64 / good_values.len() as f64;
463
464 rng.random::<f64>() < probability
465 }
466
467 #[allow(dead_code)]
468 fn good_trials(&self) -> Vec<TrialResult> {
469 let quantile = 0.25;
470 let mut sorted = self.trials.clone();
471 sorted.sort_by(|a, b| b.quality_score.partial_cmp(&a.quality_score).unwrap());
472 let n_good = (sorted.len() as f64 * quantile).ceil() as usize;
473 sorted[..n_good].to_vec()
474 }
475
476 #[allow(dead_code)]
477 fn bad_trials(&self) -> Vec<TrialResult> {
478 let quantile = 0.25;
479 let mut sorted = self.trials.clone();
480 sorted.sort_by(|a, b| b.quality_score.partial_cmp(&a.quality_score).unwrap());
481 let n_good = (sorted.len() as f64 * quantile).ceil() as usize;
482 sorted[n_good..].to_vec()
483 }
484
485 #[allow(dead_code)]
487 fn sample_tpe_continuous(
488 &self,
489 good_values: Vec<f64>,
490 _bad_values: Vec<f64>,
491 bounds: (f64, f64),
492 rng: &mut impl RngExt,
493 ) -> f64 {
494 if good_values.is_empty() {
495 return rng.random_range(bounds.0..bounds.1);
496 }
497
498 let good_mean = good_values.iter().sum::<f64>() / good_values.len() as f64;
500 let good_std = if good_values.len() > 1 {
501 let variance = good_values.iter()
502 .map(|x| (x - good_mean).powi(2))
503 .sum::<f64>() / (good_values.len() - 1) as f64;
504 variance.sqrt()
505 } else {
506 (bounds.1 - bounds.0) * 0.1
507 };
508
509 let value = self.sample_truncated_normal(good_mean, good_std, bounds, rng);
511 value.clamp(bounds.0, bounds.1)
512 }
513
514 #[allow(dead_code)]
516 fn sample_tpe_discrete(
517 &self,
518 good_values: Vec<usize>,
519 _bad_values: Vec<usize>,
520 choices: &[usize],
521 rng: &mut impl RngExt,
522 ) -> usize {
523 if good_values.is_empty() {
524 return *choices.get(rng.random_range(0..choices.len())).unwrap();
525 }
526
527 let mut counts: HashMap<usize, usize> = HashMap::new();
529 for &val in &good_values {
530 *counts.entry(val).or_insert(0) += 1;
531 }
532
533 let total: usize = counts.values().sum();
535 if total == 0 {
536 return *choices.get(rng.random_range(0..choices.len())).unwrap();
537 }
538
539 let r: f64 = rng.random::<f64>() * total as f64;
540 let mut cumsum = 0.0;
541
542 for (&val, &count) in counts.iter() {
543 cumsum += count as f64;
544 if r <= cumsum {
545 return val;
546 }
547 }
548
549 *good_values.last().unwrap()
551 }
552
553 #[allow(dead_code)]
555 fn sample_truncated_normal(
556 &self,
557 mean: f64,
558 std: f64,
559 bounds: (f64, f64),
560 rng: &mut impl RngExt,
561 ) -> f64 {
562 use rand_distr::{Normal, Distribution};
563
564 let normal = Normal::new(mean, std).unwrap_or_else(|_| Normal::new(mean, 0.1).unwrap());
565
566 for _ in 0..100 {
568 let value = normal.sample(rng);
569 if value >= bounds.0 && value <= bounds.1 {
570 return value;
571 }
572 }
573
574 mean.clamp(bounds.0, bounds.1)
576 }
577
578 pub fn tell(&mut self, trial: TrialResult) {
580 info!(
581 "Trial {}: quality={:.4}, lr={:.6}, batch={}, gamma={:.3}",
582 trial.trial_number,
583 trial.quality_score,
584 trial.hyperparameters.learning_rate,
585 trial.hyperparameters.batch_size,
586 trial.hyperparameters.gamma
587 );
588
589 self.trials.push(trial);
590
591 if let Some(ref path) = self.state_path {
593 let state = OptimizerState {
594 trials: self.trials.clone(),
595 n_startup_trials: self.n_startup_trials,
596 space: self.space.clone(),
597 best_trial: self.get_best_trial_idx(),
598 timestamp: chrono::Utc::now().to_rfc3339(),
599 };
600
601 if let Err(e) = state.save(path) {
602 warn!("Failed to save optimizer state: {}", e);
603 }
604 }
605 }
606
607 pub fn get_best(&self) -> Option<&Hyperparameters> {
609 self.trials.iter()
610 .max_by(|a, b| a.quality_score.partial_cmp(&b.quality_score).unwrap())
611 .map(|t| &t.hyperparameters)
612 }
613
614 fn get_best_trial_idx(&self) -> Option<usize> {
616 self.trials.iter()
617 .enumerate()
618 .max_by(|(_, a), (_, b)| a.quality_score.partial_cmp(&b.quality_score).unwrap())
619 .map(|(idx, _)| idx)
620 }
621
622 pub fn num_trials(&self) -> usize {
624 self.trials.len()
625 }
626
627 pub fn save_results_for_algorithm(&self, output_dir: &Path, algorithm: AlgorithmType) -> Result<()> {
629 let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S");
630 let filename = format!("tuning_results_{}_{}.json",
631 algorithm.to_string().to_lowercase(),
632 timestamp
633 );
634 let path = output_dir.join(filename);
635
636 let best_trial = self.get_best_trial_idx();
637
638 let results = serde_json::json!({
639 "algorithm": algorithm.to_string(),
640 "n_trials": self.trials.len(),
641 "best_quality": self.get_best().map(|h| h.quality_score).unwrap_or(0.0),
642 "best_trial_number": best_trial.map(|i| self.trials[i].trial_number),
643 "best_hyperparameters": self.get_best(),
644 "all_trials": self.trials,
645 "search_space": self.space,
646 });
647
648 let json = serde_json::to_string_pretty(&results)?;
649 std::fs::write(&path, json)?;
650
651 info!("✓ Saved {} tuning results to: {}", algorithm, path.display());
652 Ok(())
653 }
654
655 pub fn save_results(&self, path: &Path) -> Result<()> {
657 let best_trial = self.get_best_trial_idx();
658
659 let results = serde_json::json!({
660 "n_trials": self.trials.len(),
661 "best_quality": self.get_best().map(|h| h.quality_score).unwrap_or(0.0),
662 "best_trial_number": best_trial.map(|i| self.trials[i].trial_number),
663 "best_hyperparameters": self.get_best(),
664 "all_trials": self.trials,
665 "search_space": self.space,
666 });
667
668 let json = serde_json::to_string_pretty(&results)?;
669 std::fs::write(path, json)?;
670 info!("Saved optimization results to: {}", path.display());
671 Ok(())
672 }
673}
674
675#[cfg(test)]
676mod tests {
677 use super::*;
678 use tempfile::TempDir;
679
680 #[test]
681 fn test_tpe_optimizer() {
682 let space = HyperparameterSpace::default();
683 let mut optimizer = TPEOptimizer::new(space);
684 let mut rng = rand::rng();
685 for i in 0..15 {
687 let params = optimizer.random_suggest(&mut rng);
688 let quality = 0.5 + i as f64 * 0.02; let trial = TrialResult {
691 trial_number: i,
692 hyperparameters: Hyperparameters {
693 quality_score: quality,
694 ..params
695 },
696 quality_score: quality,
697 avg_reward: quality * 2.0 - 1.0,
698 duration_seconds: 100.0,
699 };
700
701 optimizer.tell(trial);
702 }
703
704 let best = optimizer.get_best().unwrap();
705 assert!(best.quality_score > 0.7);
706 }
707
708 #[test]
709 fn test_optimizer_resume() {
710 let temp_dir = TempDir::new().unwrap();
711 let state_path = temp_dir.path().join("optimizer_state.json");
712
713 let space = HyperparameterSpace::default();
714
715 {
717 let mut optimizer = TPEOptimizer::with_resume(space.clone(), state_path.clone()).unwrap();
718 let mut rng = rand::rng();
719 for i in 0..5 {
720 let params = optimizer.random_suggest(&mut rng);
721 let trial = TrialResult {
722 trial_number: i,
723 hyperparameters: Hyperparameters {
724 quality_score: 0.5 + i as f64 * 0.1,
725 ..params
726 },
727 quality_score: 0.5 + i as f64 * 0.1,
728 avg_reward: 0.0,
729 duration_seconds: 100.0,
730 };
731 optimizer.tell(trial);
732 }
733
734 assert_eq!(optimizer.num_trials(), 5);
735 }
736
737 {
739 let mut optimizer = TPEOptimizer::with_resume(space, state_path).unwrap();
740 assert_eq!(optimizer.num_trials(), 5);
741 let mut rng = rand::rng();
742 for i in 5..10 {
744 let params = optimizer.random_suggest(&mut rng);
745 let trial = TrialResult {
746 trial_number: i,
747 hyperparameters: Hyperparameters {
748 quality_score: 0.5 + i as f64 * 0.1,
749 ..params
750 },
751 quality_score: 0.5 + i as f64 * 0.1,
752 avg_reward: 0.0,
753 duration_seconds: 100.0,
754 };
755 optimizer.tell(trial);
756 }
757
758 assert_eq!(optimizer.num_trials(), 10);
759 }
760 }
761}