1use crate::{AlgorithmType, Config, Result};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::{Path, PathBuf};
10use rand::RngExt;
11use tracing::{info, warn};
12use rayon::prelude::*;
13use std::sync::{Arc, Mutex};
14use crate::models::NetworkConfig;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct HyperparameterSpace {
19 pub learning_rate: (f64, f64),
21 pub batch_size: Vec<usize>,
22 pub gamma: (f64, f64),
23 pub epsilon_decay: (f64, f64),
24 pub priority_alpha: (f64, f64),
25 pub priority_beta: (f64, f64),
26
27 pub hidden_layer_sizes: Vec<Vec<usize>>, pub value_hidden: Vec<usize>,
30 pub advantage_hidden: Vec<usize>,
31 pub use_layer_norm: Vec<bool>,
32 pub dropout: (f32, f32),
33}
34
35impl Default for HyperparameterSpace {
36 fn default() -> Self {
37 Self {
38 learning_rate: (1e-5, 1e-2),
39 batch_size: vec![256, 512, 1024, 2048, 4096, 6144, 8192],
40 gamma: (0.85, 0.99),
41 epsilon_decay: (0.985, 0.999),
42 priority_alpha: (0.35, 0.8),
43 priority_beta: (0.3, 0.7),
44
45 hidden_layer_sizes: vec![
47 vec![256, 128], vec![512, 256, 128], vec![1024, 512, 256], vec![512, 512, 256, 128], ],
52 value_hidden: vec![32, 64, 128, 192],
53 advantage_hidden: vec![32, 64, 128, 192],
54 use_layer_norm: vec![true, false],
55 dropout: (0.0, 0.01),
56 }
57 }
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct Hyperparameters {
64 pub learning_rate: f64,
66 pub batch_size: usize,
67 pub gamma: f64,
68 pub epsilon_decay: f64,
69 pub priority_alpha: f64,
70 pub priority_beta: f64,
71
72 pub network_config: NetworkConfig,
74
75 pub timestamp: String,
76 pub quality_score: f64,
77}
78
79impl Hyperparameters {
80 pub fn apply_to_config(&self, config: &mut Config) {
82 config.learning_rate = self.learning_rate;
83 config.batch_size = self.batch_size;
84 config.gamma = self.gamma;
85 config.epsilon_decay = self.epsilon_decay;
86 config.priority_alpha = self.priority_alpha;
87 config.priority_beta = self.priority_beta;
88
89 config.state_dim = self.network_config.state_dim;
91 config.num_discrete_actions = self.network_config.num_actions;
92 config.num_continuous_params = self.network_config.num_params;
93 }
94
95 pub fn save_with_algorithm(&self, base_path: &Path, algorithm: AlgorithmType) -> Result<()> {
97 let filename = format!("best_hyperparams_{}.json", algorithm.to_string().to_lowercase());
98 let path = base_path.parent()
99 .unwrap_or(base_path)
100 .join(filename);
101
102 let json = serde_json::to_string_pretty(self)?;
103 std::fs::write(&path, json)?;
104
105 info!("✓ Saved {} hyperparameters to: {}", algorithm, path.display());
106 Ok(())
107 }
108
109 pub fn load_for_algorithm(base_dir: &Path, algorithm: AlgorithmType) -> Result<Self> {
111 let filename = format!("best_hyperparams_{}.json", algorithm.to_string().to_lowercase());
112 let path = base_dir.join(&filename);
113
114 if !path.exists() {
115 return Err(crate::ExtractionError::ParseError(
116 format!("Hyperparameters file not found: {}", path.display())
117 ));
118 }
119
120 let json = std::fs::read_to_string(&path)?;
121 let params:Hyperparameters = serde_json::from_str(&json)?;
122
123 info!("✓ Loaded {} hyperparameters from: {}", algorithm, path.display());
124 info!(" Settings:");
125 info!(" learning_rate: {:.6}", params.learning_rate);
126 info!(" batch_size: {}", params.batch_size);
127 info!(" gamma: {:.3}", params.gamma);
128 info!(" epsilon_decay: {:.6}", params.epsilon_decay);
129 info!(" priority_alpha: {:.3}", params.priority_alpha);
130 info!(" priority_beta: {:.3}", params.priority_beta);
131
132 Ok(params)
133 }
134
135 pub fn save(&self, path: &Path) -> Result<()> {
137 let json = serde_json::to_string_pretty(self)?;
138 std::fs::write(path, json)?;
139 info!("Saved hyperparameters to: {}", path.display());
140 Ok(())
141 }
142
143 pub fn load(path: &Path) -> Result<Self> {
145 let json = std::fs::read_to_string(path)?;
146 let params = serde_json::from_str(&json)?;
147 info!("Loaded hyperparameters from: {}", path.display());
148 Ok(params)
149 }
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct TrialResult {
155 pub trial_number: usize,
156 pub hyperparameters: Hyperparameters,
157 pub quality_score: f64,
158 pub avg_reward: f64,
159 pub duration_seconds: f64,
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct OptimizerState {
165 pub trials: Vec<TrialResult>,
166 pub n_startup_trials: usize,
167 pub space: HyperparameterSpace,
168 pub best_trial: Option<usize>,
169 pub timestamp: String,
170}
171
172impl OptimizerState {
173 pub fn save(&self, path: &Path) -> Result<()> {
175 let json = serde_json::to_string_pretty(self)?;
176 std::fs::write(path, json)?;
177 info!("Saved optimizer state to: {}", path.display());
178 Ok(())
179 }
180
181 pub fn load(path: &Path) -> Result<Self> {
183 let json = std::fs::read_to_string(path)?;
184 let state = serde_json::from_str(&json)?;
185 info!("Loaded optimizer state from: {}", path.display());
186 Ok(state)
187 }
188}
189
190pub struct TPEOptimizer {
192 space: HyperparameterSpace,
193 trials: Vec<TrialResult>,
194 n_startup_trials: usize,
195 state_path: Option<PathBuf>,
196}
197
198impl TPEOptimizer {
199 pub fn new(space: HyperparameterSpace) -> Self {
201 Self {
202 space,
203 trials: Vec::new(),
204 n_startup_trials: 5, state_path: None,
206 }
207 }
208
209 pub fn optimize_parallel(
211 &mut self,
212 n_trials: usize,
213 episodes_per_trial: usize,
214 html_samples: Vec<(String, String)>,
215 base_config: &Config,
216 n_workers: usize,
217 ) -> Result<()> {
218 info!("Starting parallel TPE optimization with {} workers", n_workers);
219
220 let pool = rayon::ThreadPoolBuilder::new()
222 .num_threads(n_workers)
223 .build()
224 .map_err(|e| crate::ExtractionError::RuntimeError(e.to_string()))?;
225
226 let mut all_trial_params = Vec::new();
228 let mut rng = rand::rng();
229 for trial_num in 0..n_trials {
230 let params = self.random_suggest(&mut rng);
232 all_trial_params.push((trial_num, params));
233 }
234
235 let results = Arc::new(Mutex::new(Vec::new()));
237 let completed_trials = Arc::new(Mutex::new(0usize));
238
239 pool.install(|| {
241 all_trial_params.par_iter().for_each(|(trial_num, params)| {
242 info!("Worker starting trial {}", trial_num);
243
244 let mut trial_config = base_config.clone();
246 params.apply_to_config(&mut trial_config);
247 trial_config.num_episodes = episodes_per_trial;
248
249 trial_config.use_cpu_for_tuning = true;
251
252 let trial_start = std::time::Instant::now();
253
254 let result = crate::training::train_standard(&trial_config, html_samples.clone());
256
257 match result {
258 Ok((_agent, metrics)) => {
259 let duration = trial_start.elapsed();
260
261 let window = metrics.episode_qualities.len().min(50);
263 let quality = if metrics.episode_qualities.len() >= window {
264 metrics.episode_qualities[metrics.episode_qualities.len() - window..]
265 .iter()
266 .sum::<f32>() / window as f32
267 } else if !metrics.episode_qualities.is_empty() {
268 metrics.episode_qualities.iter().sum::<f32>() /
269 metrics.episode_qualities.len() as f32
270 } else {
271 0.0
272 };
273
274 let avg_reward = if !metrics.episode_rewards.is_empty() {
275 let window = metrics.episode_rewards.len().min(50);
276 if metrics.episode_rewards.len() >= window {
277 metrics.episode_rewards[metrics.episode_rewards.len() - window..]
278 .iter()
279 .sum::<f32>() / window as f32
280 } else {
281 metrics.episode_rewards.iter().sum::<f32>() /
282 metrics.episode_rewards.len() as f32
283 }
284 } else {
285 0.0
286 };
287
288 let trial_result = TrialResult {
290 trial_number: *trial_num,
291 hyperparameters: Hyperparameters {
292 quality_score: quality as f64,
293 ..params.clone()
294 },
295 quality_score: quality as f64,
296 avg_reward: avg_reward as f64,
297 duration_seconds: duration.as_secs_f64(),
298 };
299
300 {
302 let mut res = results.lock().unwrap();
303 res.push(trial_result);
304 }
305
306 {
307 let mut completed = completed_trials.lock().unwrap();
308 *completed += 1;
309 info!("Trial {} completed ({}/{}): quality={:.4}",
310 trial_num, *completed, n_trials, quality);
311 }
312 }
313 Err(e) => {
314 warn!("Trial {} failed: {}", trial_num, e);
315 }
316 }
317 });
318 });
319
320 let trial_results = results.lock().unwrap();
322 for trial_result in trial_results.iter() {
323 self.tell(trial_result.clone());
324 }
325
326 info!("Parallel optimization complete");
327 Ok(())
328 }
329
330 pub fn with_resume(space: HyperparameterSpace, state_path: PathBuf) -> Result<Self> {
332 let mut optimizer = Self {
333 space: space.clone(),
334 trials: Vec::new(),
335 n_startup_trials: 5,
336 state_path: Some(state_path.clone()),
337 };
338
339 if state_path.exists() {
341 info!("Resuming from saved state: {}", state_path.display());
342 let state = OptimizerState::load(&state_path)?;
343 optimizer.trials = state.trials;
344 optimizer.space = state.space;
345 optimizer.n_startup_trials = state.n_startup_trials;
346 info!("Resumed with {} existing trials", optimizer.trials.len());
347 }
348
349 Ok(optimizer)
350 }
351
352 pub fn random_suggest(&self, rng: &mut impl RngExt) -> Hyperparameters {
354 let hidden_layers = self.space.hidden_layer_sizes
356 .get(rng.random_range(0..self.space.hidden_layer_sizes.len()))
357 .unwrap()
358 .clone();
359
360 let value_hidden = *self.space.value_hidden
361 .get(rng.random_range(0..self.space.value_hidden.len()))
362 .unwrap();
363
364 let advantage_hidden = *self.space.advantage_hidden
365 .get(rng.random_range(0..self.space.advantage_hidden.len()))
366 .unwrap();
367
368 let use_layer_norm = *self.space.use_layer_norm
369 .get(rng.random_range(0..self.space.use_layer_norm.len()))
370 .unwrap();
371
372 let dropout = rng.random_range(self.space.dropout.0..self.space.dropout.1);
373
374 Hyperparameters {
375 learning_rate: rng.random_range(self.space.learning_rate.0..self.space.learning_rate.1),
376 batch_size: *self.space.batch_size
377 .get(rng.random_range(0..self.space.batch_size.len()))
378 .unwrap(),
379 gamma: rng.random_range(self.space.gamma.0..self.space.gamma.1),
380 epsilon_decay: rng.random_range(self.space.epsilon_decay.0..self.space.epsilon_decay.1),
381 priority_alpha: rng.random_range(self.space.priority_alpha.0..self.space.priority_alpha.1),
382 priority_beta: rng.random_range(self.space.priority_beta.0..self.space.priority_beta.1),
383 network_config: NetworkConfig {
384 state_dim: 300,
385 num_actions: 16,
386 num_params: 6,
387 hidden_layers,
388 use_layer_norm,
389 dropout,
390 value_hidden,
391 advantage_hidden,
392 },
393 timestamp: chrono::Utc::now().to_rfc3339(),
394 quality_score: 0.0,
395 }
396 }
397
398 #[allow(dead_code)]
400 fn sample_tpe_categorical<T: Clone>(
401 &self,
402 good_values: Vec<&T>,
403 _bad_values: Vec<&T>,
404 choices: &[T],
405 rng: &mut impl RngExt,
406 ) -> T {
407 if good_values.is_empty() {
408 return choices[rng.random_range(0..choices.len())].clone();
409 }
410
411 let mut counts: HashMap<usize, usize> = HashMap::new();
413 for _good_val in &good_values {
414 for (i, _choice) in choices.iter().enumerate() {
415 counts.entry(i).or_insert(0);
417 }
418 }
419
420 if counts.is_empty() {
422 choices[rng.random_range(0..choices.len())].clone()
423 } else {
424 let total: usize = counts.values().sum();
425 let r: f64 = rng.random::<f64>() * total as f64;
426 let mut cumsum = 0.0;
427
428 for (idx, count) in counts.iter() {
429 cumsum += *count as f64;
430 if r <= cumsum {
431 return choices[*idx].clone();
432 }
433 }
434 choices[0].clone()
435 }
436 }
437
438 #[allow(dead_code)]
440 fn sample_tpe_boolean(
441 &self,
442 good_values: Vec<bool>,
443 _bad_values: Vec<bool>,
444 rng: &mut impl RngExt,
445 ) -> bool {
446 if good_values.is_empty() {
447 return rng.random();
448 }
449
450 let true_count = good_values.iter().filter(|&&x| x).count();
451 let probability = true_count as f64 / good_values.len() as f64;
452
453 rng.random::<f64>() < probability
454 }
455
456 #[allow(dead_code)]
457 fn good_trials(&self) -> Vec<TrialResult> {
458 let quantile = 0.25;
459 let mut sorted = self.trials.clone();
460 sorted.sort_by(|a, b| b.quality_score.partial_cmp(&a.quality_score).unwrap());
461 let n_good = (sorted.len() as f64 * quantile).ceil() as usize;
462 sorted[..n_good].to_vec()
463 }
464
465 #[allow(dead_code)]
466 fn bad_trials(&self) -> Vec<TrialResult> {
467 let quantile = 0.25;
468 let mut sorted = self.trials.clone();
469 sorted.sort_by(|a, b| b.quality_score.partial_cmp(&a.quality_score).unwrap());
470 let n_good = (sorted.len() as f64 * quantile).ceil() as usize;
471 sorted[n_good..].to_vec()
472 }
473
474 #[allow(dead_code)]
476 fn sample_tpe_continuous(
477 &self,
478 good_values: Vec<f64>,
479 _bad_values: Vec<f64>,
480 bounds: (f64, f64),
481 rng: &mut impl RngExt,
482 ) -> f64 {
483 if good_values.is_empty() {
484 return rng.random_range(bounds.0..bounds.1);
485 }
486
487 let good_mean = good_values.iter().sum::<f64>() / good_values.len() as f64;
489 let good_std = if good_values.len() > 1 {
490 let variance = good_values.iter()
491 .map(|x| (x - good_mean).powi(2))
492 .sum::<f64>() / (good_values.len() - 1) as f64;
493 variance.sqrt()
494 } else {
495 (bounds.1 - bounds.0) * 0.1
496 };
497
498 let value = self.sample_truncated_normal(good_mean, good_std, bounds, rng);
500 value.clamp(bounds.0, bounds.1)
501 }
502
503 #[allow(dead_code)]
505 fn sample_tpe_discrete(
506 &self,
507 good_values: Vec<usize>,
508 _bad_values: Vec<usize>,
509 choices: &[usize],
510 rng: &mut impl RngExt,
511 ) -> usize {
512 if good_values.is_empty() {
513 return *choices.get(rng.random_range(0..choices.len())).unwrap();
514 }
515
516 let mut counts: HashMap<usize, usize> = HashMap::new();
518 for &val in &good_values {
519 *counts.entry(val).or_insert(0) += 1;
520 }
521
522 let total: usize = counts.values().sum();
524 if total == 0 {
525 return *choices.get(rng.random_range(0..choices.len())).unwrap();
526 }
527
528 let r: f64 = rng.random::<f64>() * total as f64;
529 let mut cumsum = 0.0;
530
531 for (&val, &count) in counts.iter() {
532 cumsum += count as f64;
533 if r <= cumsum {
534 return val;
535 }
536 }
537
538 *good_values.last().unwrap()
540 }
541
542 #[allow(dead_code)]
544 fn sample_truncated_normal(
545 &self,
546 mean: f64,
547 std: f64,
548 bounds: (f64, f64),
549 rng: &mut impl RngExt,
550 ) -> f64 {
551 use rand_distr::{Normal, Distribution};
552
553 let normal = Normal::new(mean, std).unwrap_or_else(|_| Normal::new(mean, 0.1).unwrap());
554
555 for _ in 0..100 {
557 let value = normal.sample(rng);
558 if value >= bounds.0 && value <= bounds.1 {
559 return value;
560 }
561 }
562
563 mean.clamp(bounds.0, bounds.1)
565 }
566
567 pub fn tell(&mut self, trial: TrialResult) {
569 info!(
570 "Trial {}: quality={:.4}, lr={:.6}, batch={}, gamma={:.3}",
571 trial.trial_number,
572 trial.quality_score,
573 trial.hyperparameters.learning_rate,
574 trial.hyperparameters.batch_size,
575 trial.hyperparameters.gamma
576 );
577
578 self.trials.push(trial);
579
580 if let Some(ref path) = self.state_path {
582 let state = OptimizerState {
583 trials: self.trials.clone(),
584 n_startup_trials: self.n_startup_trials,
585 space: self.space.clone(),
586 best_trial: self.get_best_trial_idx(),
587 timestamp: chrono::Utc::now().to_rfc3339(),
588 };
589
590 if let Err(e) = state.save(path) {
591 warn!("Failed to save optimizer state: {}", e);
592 }
593 }
594 }
595
596 pub fn get_best(&self) -> Option<&Hyperparameters> {
598 self.trials.iter()
599 .max_by(|a, b| a.quality_score.partial_cmp(&b.quality_score).unwrap())
600 .map(|t| &t.hyperparameters)
601 }
602
603 fn get_best_trial_idx(&self) -> Option<usize> {
605 self.trials.iter()
606 .enumerate()
607 .max_by(|(_, a), (_, b)| a.quality_score.partial_cmp(&b.quality_score).unwrap())
608 .map(|(idx, _)| idx)
609 }
610
611 pub fn num_trials(&self) -> usize {
613 self.trials.len()
614 }
615
616 pub fn save_results_for_algorithm(&self, output_dir: &Path, algorithm: AlgorithmType) -> Result<()> {
618 let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S");
619 let filename = format!("tuning_results_{}_{}.json",
620 algorithm.to_string().to_lowercase(),
621 timestamp
622 );
623 let path = output_dir.join(filename);
624
625 let best_trial = self.get_best_trial_idx();
626
627 let results = serde_json::json!({
628 "algorithm": algorithm.to_string(),
629 "n_trials": self.trials.len(),
630 "best_quality": self.get_best().map(|h| h.quality_score).unwrap_or(0.0),
631 "best_trial_number": best_trial.map(|i| self.trials[i].trial_number),
632 "best_hyperparameters": self.get_best(),
633 "all_trials": self.trials,
634 "search_space": self.space,
635 });
636
637 let json = serde_json::to_string_pretty(&results)?;
638 std::fs::write(&path, json)?;
639
640 info!("✓ Saved {} tuning results to: {}", algorithm, path.display());
641 Ok(())
642 }
643
644 pub fn save_results(&self, path: &Path) -> Result<()> {
646 let best_trial = self.get_best_trial_idx();
647
648 let results = serde_json::json!({
649 "n_trials": self.trials.len(),
650 "best_quality": self.get_best().map(|h| h.quality_score).unwrap_or(0.0),
651 "best_trial_number": best_trial.map(|i| self.trials[i].trial_number),
652 "best_hyperparameters": self.get_best(),
653 "all_trials": self.trials,
654 "search_space": self.space,
655 });
656
657 let json = serde_json::to_string_pretty(&results)?;
658 std::fs::write(path, json)?;
659 info!("Saved optimization results to: {}", path.display());
660 Ok(())
661 }
662}
663
664#[cfg(test)]
665mod tests {
666 use super::*;
667 use tempfile::TempDir;
668
669 #[test]
670 fn test_tpe_optimizer() {
671 let space = HyperparameterSpace::default();
672 let mut optimizer = TPEOptimizer::new(space);
673 let mut rng = rand::rng();
674 for i in 0..15 {
676 let params = optimizer.random_suggest(&mut rng);
677 let quality = 0.5 + i as f64 * 0.02; let trial = TrialResult {
680 trial_number: i,
681 hyperparameters: Hyperparameters {
682 quality_score: quality,
683 ..params
684 },
685 quality_score: quality,
686 avg_reward: quality * 2.0 - 1.0,
687 duration_seconds: 100.0,
688 };
689
690 optimizer.tell(trial);
691 }
692
693 let best = optimizer.get_best().unwrap();
694 assert!(best.quality_score > 0.7);
695 }
696
697 #[test]
698 fn test_optimizer_resume() {
699 let temp_dir = TempDir::new().unwrap();
700 let state_path = temp_dir.path().join("optimizer_state.json");
701
702 let space = HyperparameterSpace::default();
703
704 {
706 let mut optimizer = TPEOptimizer::with_resume(space.clone(), state_path.clone()).unwrap();
707 let mut rng = rand::rng();
708 for i in 0..5 {
709 let params = optimizer.random_suggest(&mut rng);
710 let trial = TrialResult {
711 trial_number: i,
712 hyperparameters: Hyperparameters {
713 quality_score: 0.5 + i as f64 * 0.1,
714 ..params
715 },
716 quality_score: 0.5 + i as f64 * 0.1,
717 avg_reward: 0.0,
718 duration_seconds: 100.0,
719 };
720 optimizer.tell(trial);
721 }
722
723 assert_eq!(optimizer.num_trials(), 5);
724 }
725
726 {
728 let mut optimizer = TPEOptimizer::with_resume(space, state_path).unwrap();
729 assert_eq!(optimizer.num_trials(), 5);
730 let mut rng = rand::rng();
731 for i in 5..10 {
733 let params = optimizer.random_suggest(&mut rng);
734 let trial = TrialResult {
735 trial_number: i,
736 hyperparameters: Hyperparameters {
737 quality_score: 0.5 + i as f64 * 0.1,
738 ..params
739 },
740 quality_score: 0.5 + i as f64 * 0.1,
741 avg_reward: 0.0,
742 duration_seconds: 100.0,
743 };
744 optimizer.tell(trial);
745 }
746
747 assert_eq!(optimizer.num_trials(), 10);
748 }
749 }
750}