Skip to main content

content_extractor_rl/
hyperparameter_tuner.rs

1//! Hyperparameter tuning using TPE (Tree-structured Parzen Estimator) with resume capability
2// ============================================================================
3// FILE: crates/content-extractor-rl/src/hyperparameter_tuner.rs
4// ============================================================================
5
6use crate::{AlgorithmType, Config, Result};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::{Path, PathBuf};
10use rand::RngExt;
11use tracing::{info, warn};
12use rayon::prelude::*;
13use std::sync::{Arc, Mutex};
14use crate::models::NetworkConfig;
15
16/// hyperparameter search space with network architecture
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct HyperparameterSpace {
19    // Training hyperparameters
20    pub learning_rate: (f64, f64),
21    pub batch_size: Vec<usize>,
22    pub gamma: (f64, f64),
23    pub epsilon_decay: (f64, f64),
24    pub priority_alpha: (f64, f64),
25    pub priority_beta: (f64, f64),
26
27    // Network architecture hyperparameters
28    pub hidden_layer_sizes: Vec<Vec<usize>>,  // Different architectures to try
29    pub value_hidden: Vec<usize>,
30    pub advantage_hidden: Vec<usize>,
31    pub use_layer_norm: Vec<bool>,
32    pub dropout: (f32, f32),
33}
34
35impl Default for HyperparameterSpace {
36    fn default() -> Self {
37        Self {
38            // Capped at 3e-3: rates above this destabilise SAC/DQN even with
39            // gradient clipping. Sampled log-uniformly (see random_suggest).
40            learning_rate: (1e-5, 3e-3),
41            batch_size: vec![256, 512, 1024, 2048, 4096, 6144, 8192],
42            gamma: (0.85, 0.99),
43            epsilon_decay: (0.985, 0.999),
44            priority_alpha: (0.35, 0.8),
45            priority_beta: (0.3, 0.7),
46
47            // Network architectures to try
48            hidden_layer_sizes: vec![
49                vec![256, 128],           // Small
50                vec![512, 256, 128],      // Default
51                vec![1024, 512, 256],     // Large
52                vec![512, 512, 256, 128], // Deep
53            ],
54            value_hidden: vec![32, 64, 128, 192],
55            advantage_hidden: vec![32, 64, 128, 192],
56            use_layer_norm: vec![true, false],
57            dropout: (0.0, 0.01),
58        }
59    }
60}
61
62/// Hyperparameter configuration
63/// Enhanced hyperparameters including network config
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct Hyperparameters {
66    // Training hyperparameters
67    pub learning_rate: f64,
68    pub batch_size: usize,
69    pub gamma: f64,
70    pub epsilon_decay: f64,
71    pub priority_alpha: f64,
72    pub priority_beta: f64,
73
74    // Network architecture
75    pub network_config: NetworkConfig,
76
77    pub timestamp: String,
78    pub quality_score: f64,
79}
80
81impl Hyperparameters {
82    /// Apply hyperparameters to config
83    pub fn apply_to_config(&self, config: &mut Config) {
84        config.learning_rate = self.learning_rate;
85        config.batch_size = self.batch_size;
86        config.gamma = self.gamma;
87        config.epsilon_decay = self.epsilon_decay;
88        config.priority_alpha = self.priority_alpha;
89        config.priority_beta = self.priority_beta;
90
91        // Apply network config
92        config.state_dim = self.network_config.state_dim;
93        config.num_discrete_actions = self.network_config.num_actions;
94        config.num_continuous_params = self.network_config.num_params;
95    }
96
97    /// Save to algorithm-specific JSON file
98    pub fn save_with_algorithm(&self, base_path: &Path, algorithm: AlgorithmType) -> Result<()> {
99        let filename = format!("best_hyperparams_{}.json", algorithm.to_string().to_lowercase());
100        let path = base_path.parent()
101            .unwrap_or(base_path)
102            .join(filename);
103
104        let json = serde_json::to_string_pretty(self)?;
105        std::fs::write(&path, json)?;
106
107        info!("✓ Saved {} hyperparameters to: {}", algorithm, path.display());
108        Ok(())
109    }
110
111    /// Load from algorithm-specific file
112    pub fn load_for_algorithm(base_dir: &Path, algorithm: AlgorithmType) -> Result<Self> {
113        let filename = format!("best_hyperparams_{}.json", algorithm.to_string().to_lowercase());
114        let path = base_dir.join(&filename);
115
116        if !path.exists() {
117            return Err(crate::ExtractionError::ParseError(
118                format!("Hyperparameters file not found: {}", path.display())
119            ));
120        }
121
122        let json = std::fs::read_to_string(&path)?;
123        let params:Hyperparameters = serde_json::from_str(&json)?;
124
125        info!("✓ Loaded {} hyperparameters from: {}", algorithm, path.display());
126        info!("  Settings:");
127        info!("    learning_rate: {:.6}", params.learning_rate);
128        info!("    batch_size: {}", params.batch_size);
129        info!("    gamma: {:.3}", params.gamma);
130        info!("    epsilon_decay: {:.6}", params.epsilon_decay);
131        info!("    priority_alpha: {:.3}", params.priority_alpha);
132        info!("    priority_beta: {:.3}", params.priority_beta);
133
134        Ok(params)
135    }
136
137    /// Save to JSON file
138    pub fn save(&self, path: &Path) -> Result<()> {
139        let json = serde_json::to_string_pretty(self)?;
140        std::fs::write(path, json)?;
141        info!("Saved hyperparameters to: {}", path.display());
142        Ok(())
143    }
144
145    /// Load from JSON file
146    pub fn load(path: &Path) -> Result<Self> {
147        let json = std::fs::read_to_string(path)?;
148        let params = serde_json::from_str(&json)?;
149        info!("Loaded hyperparameters from: {}", path.display());
150        Ok(params)
151    }
152}
153
154/// Trial result from hyperparameter optimization
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct TrialResult {
157    pub trial_number: usize,
158    pub hyperparameters: Hyperparameters,
159    pub quality_score: f64,
160    pub avg_reward: f64,
161    pub duration_seconds: f64,
162}
163
164/// Optimizer state for resuming
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct OptimizerState {
167    pub trials: Vec<TrialResult>,
168    pub n_startup_trials: usize,
169    pub space: HyperparameterSpace,
170    pub best_trial: Option<usize>,
171    pub timestamp: String,
172}
173
174impl OptimizerState {
175    /// Save state to file
176    pub fn save(&self, path: &Path) -> Result<()> {
177        let json = serde_json::to_string_pretty(self)?;
178        std::fs::write(path, json)?;
179        info!("Saved optimizer state to: {}", path.display());
180        Ok(())
181    }
182
183    /// Load state from file
184    pub fn load(path: &Path) -> Result<Self> {
185        let json = std::fs::read_to_string(path)?;
186        let state = serde_json::from_str(&json)?;
187        info!("Loaded optimizer state from: {}", path.display());
188        Ok(state)
189    }
190}
191
192/// TPE-based hyperparameter optimizer with resume capability
193pub struct TPEOptimizer {
194    space: HyperparameterSpace,
195    trials: Vec<TrialResult>,
196    n_startup_trials: usize,
197    state_path: Option<PathBuf>,
198}
199
200impl TPEOptimizer {
201    /// Create new TPE optimizer
202    pub fn new(space: HyperparameterSpace) -> Self {
203        Self {
204            space,
205            trials: Vec::new(),
206            n_startup_trials: 5, // As per requirement
207            state_path: None,
208        }
209    }
210
211    /// Run parallel hyperparameter optimization
212    pub fn optimize_parallel(
213        &mut self,
214        n_trials: usize,
215        episodes_per_trial: usize,
216        html_samples: Vec<(String, String)>,
217        base_config: &Config,
218        n_workers: usize,
219    ) -> Result<()> {
220        info!("Starting parallel TPE optimization with {} workers", n_workers);
221
222        // Configure rayon thread pool
223        let pool = rayon::ThreadPoolBuilder::new()
224            .num_threads(n_workers)
225            .build()
226            .map_err(|e| crate::ExtractionError::RuntimeError(e.to_string()))?;
227
228        // Generate all trial parameters upfront (sequential, uses TPE logic)
229        let mut all_trial_params = Vec::new();
230        let mut rng = rand::rng();
231        for trial_num in 0..n_trials {
232            // Use random sampling for all trials in parallel mode
233            let params = self.random_suggest(&mut rng);
234            all_trial_params.push((trial_num, params));
235        }
236
237        // Shared state for collecting results
238        let results = Arc::new(Mutex::new(Vec::new()));
239        let completed_trials = Arc::new(Mutex::new(0usize));
240
241        // Run trials in parallel
242        pool.install(|| {
243            all_trial_params.par_iter().for_each(|(trial_num, params)| {
244                info!("Worker starting trial {}", trial_num);
245
246                // Each worker gets its own config and data
247                let mut trial_config = base_config.clone();
248                params.apply_to_config(&mut trial_config);
249                trial_config.num_episodes = episodes_per_trial;
250
251                // Use CPU for parallel trials to avoid GPU contention
252                trial_config.use_cpu_for_tuning = true;
253
254                let trial_start = std::time::Instant::now();
255
256                // Run training (tuning optimizes the text-quality proxy reward)
257                let trial_samples: Vec<crate::TrainingSample> =
258                    html_samples.clone().into_iter().map(Into::into).collect();
259                let result = crate::training::train_standard(&trial_config, trial_samples);
260
261                match result {
262                    Ok((_agent, metrics)) => {
263                        let duration = trial_start.elapsed();
264
265                        // Calculate quality
266                        let window = metrics.episode_qualities.len().min(50);
267                        let quality = if metrics.episode_qualities.len() >= window {
268                            metrics.episode_qualities[metrics.episode_qualities.len() - window..]
269                                .iter()
270                                .sum::<f32>() / window as f32
271                        } else if !metrics.episode_qualities.is_empty() {
272                            metrics.episode_qualities.iter().sum::<f32>() /
273                                metrics.episode_qualities.len() as f32
274                        } else {
275                            0.0
276                        };
277
278                        let avg_reward = if !metrics.episode_rewards.is_empty() {
279                            let window = metrics.episode_rewards.len().min(50);
280                            if metrics.episode_rewards.len() >= window {
281                                metrics.episode_rewards[metrics.episode_rewards.len() - window..]
282                                    .iter()
283                                    .sum::<f32>() / window as f32
284                            } else {
285                                metrics.episode_rewards.iter().sum::<f32>() /
286                                    metrics.episode_rewards.len() as f32
287                            }
288                        } else {
289                            0.0
290                        };
291
292                        // Record result
293                        let trial_result = TrialResult {
294                            trial_number: *trial_num,
295                            hyperparameters: Hyperparameters {
296                                quality_score: quality as f64,
297                                ..params.clone()
298                            },
299                            quality_score: quality as f64,
300                            avg_reward: avg_reward as f64,
301                            duration_seconds: duration.as_secs_f64(),
302                        };
303
304                        // Store result
305                        {
306                            let mut res = results.lock().unwrap();
307                            res.push(trial_result);
308                        }
309
310                        {
311                            let mut completed = completed_trials.lock().unwrap();
312                            *completed += 1;
313                            info!("Trial {} completed ({}/{}): quality={:.4}",
314                          trial_num, *completed, n_trials, quality);
315                        }
316                    }
317                    Err(e) => {
318                        warn!("Trial {} failed: {}", trial_num, e);
319                    }
320                }
321            });
322        });
323
324        // After all trials complete, update self with results
325        let trial_results = results.lock().unwrap();
326        for trial_result in trial_results.iter() {
327            self.tell(trial_result.clone());
328        }
329
330        info!("Parallel optimization complete");
331        Ok(())
332    }
333
334    /// Create optimizer with resume capability
335    pub fn with_resume(space: HyperparameterSpace, state_path: PathBuf) -> Result<Self> {
336        let mut optimizer = Self {
337            space: space.clone(),
338            trials: Vec::new(),
339            n_startup_trials: 5,
340            state_path: Some(state_path.clone()),
341        };
342
343        // Try to load existing state
344        if state_path.exists() {
345            info!("Resuming from saved state: {}", state_path.display());
346            let state = OptimizerState::load(&state_path)?;
347            optimizer.trials = state.trials;
348            optimizer.space = state.space;
349            optimizer.n_startup_trials = state.n_startup_trials;
350            info!("Resumed with {} existing trials", optimizer.trials.len());
351        }
352
353        Ok(optimizer)
354    }
355
356    /// Random hyperparameter suggestion for initial trials
357    pub fn random_suggest(&self, rng: &mut impl RngExt) -> Hyperparameters {
358        // Sample random network architecture
359        let hidden_layers = self.space.hidden_layer_sizes
360            .get(rng.random_range(0..self.space.hidden_layer_sizes.len()))
361            .unwrap()
362            .clone();
363
364        let value_hidden = *self.space.value_hidden
365            .get(rng.random_range(0..self.space.value_hidden.len()))
366            .unwrap();
367
368        let advantage_hidden = *self.space.advantage_hidden
369            .get(rng.random_range(0..self.space.advantage_hidden.len()))
370            .unwrap();
371
372        let use_layer_norm = *self.space.use_layer_norm
373            .get(rng.random_range(0..self.space.use_layer_norm.len()))
374            .unwrap();
375
376        let dropout = rng.random_range(self.space.dropout.0..self.space.dropout.1);
377
378        // Learning rate is sampled log-uniformly: a plain uniform draw over
379        // [1e-5, 1e-2] puts almost all mass on large, unstable rates (this is
380        // what produced the ~5.9e-3 rate that made SAC diverge). Log sampling
381        // explores small rates properly.
382        let (lr_lo, lr_hi) = self.space.learning_rate;
383        let learning_rate = (lr_lo.ln() + rng.random::<f64>() * (lr_hi.ln() - lr_lo.ln())).exp();
384
385        Hyperparameters {
386            learning_rate,
387            batch_size: *self.space.batch_size
388                .get(rng.random_range(0..self.space.batch_size.len()))
389                .unwrap(),
390            gamma: rng.random_range(self.space.gamma.0..self.space.gamma.1),
391            epsilon_decay: rng.random_range(self.space.epsilon_decay.0..self.space.epsilon_decay.1),
392            priority_alpha: rng.random_range(self.space.priority_alpha.0..self.space.priority_alpha.1),
393            priority_beta: rng.random_range(self.space.priority_beta.0..self.space.priority_beta.1),
394            network_config: NetworkConfig {
395                state_dim: 300,
396                num_actions: 16,
397                num_params: 6,
398                hidden_layers,
399                use_layer_norm,
400                dropout,
401                value_hidden,
402                advantage_hidden,
403            },
404            timestamp: chrono::Utc::now().to_rfc3339(),
405            quality_score: 0.0,
406        }
407    }
408
409    /// Sample categorical choice (e.g., network architecture)
410    #[allow(dead_code)]
411    fn sample_tpe_categorical<T: Clone>(
412        &self,
413        good_values: Vec<&T>,
414        _bad_values: Vec<&T>,
415        choices: &[T],
416        rng: &mut impl RngExt,
417    ) -> T {
418        if good_values.is_empty() {
419            return choices[rng.random_range(0..choices.len())].clone();
420        }
421
422        // Count frequency in good trials
423        let mut counts: HashMap<usize, usize> = HashMap::new();
424        for _good_val in &good_values {
425            for (i, _choice) in choices.iter().enumerate() {
426                // This is a simplified comparison - in real code you'd need proper equality
427                counts.entry(i).or_insert(0);
428            }
429        }
430
431        // Weighted sampling based on good trial frequencies
432        if counts.is_empty() {
433            choices[rng.random_range(0..choices.len())].clone()
434        } else {
435            let total: usize = counts.values().sum();
436            let r: f64 = rng.random::<f64>() * total as f64;
437            let mut cumsum = 0.0;
438
439            for (idx, count) in counts.iter() {
440                cumsum += *count as f64;
441                if r <= cumsum {
442                    return choices[*idx].clone();
443                }
444            }
445            choices[0].clone()
446        }
447    }
448
449    /// Sample boolean parameter
450    #[allow(dead_code)]
451    fn sample_tpe_boolean(
452        &self,
453        good_values: Vec<bool>,
454        _bad_values: Vec<bool>,
455        rng: &mut impl RngExt,
456    ) -> bool {
457        if good_values.is_empty() {
458            return rng.random();
459        }
460
461        let true_count = good_values.iter().filter(|&&x| x).count();
462        let probability = true_count as f64 / good_values.len() as f64;
463
464        rng.random::<f64>() < probability
465    }
466
467    #[allow(dead_code)]
468    fn good_trials(&self) -> Vec<TrialResult> {
469        let quantile = 0.25;
470        let mut sorted = self.trials.clone();
471        sorted.sort_by(|a, b| b.quality_score.partial_cmp(&a.quality_score).unwrap());
472        let n_good = (sorted.len() as f64 * quantile).ceil() as usize;
473        sorted[..n_good].to_vec()
474    }
475
476    #[allow(dead_code)]
477    fn bad_trials(&self) -> Vec<TrialResult> {
478        let quantile = 0.25;
479        let mut sorted = self.trials.clone();
480        sorted.sort_by(|a, b| b.quality_score.partial_cmp(&a.quality_score).unwrap());
481        let n_good = (sorted.len() as f64 * quantile).ceil() as usize;
482        sorted[n_good..].to_vec()
483    }
484
485    /// Sample continuous parameter using TPE
486    #[allow(dead_code)]
487    fn sample_tpe_continuous(
488        &self,
489        good_values: Vec<f64>,
490        _bad_values: Vec<f64>,
491        bounds: (f64, f64),
492        rng: &mut impl RngExt,
493    ) -> f64 {
494        if good_values.is_empty() {
495            return rng.random_range(bounds.0..bounds.1);
496        }
497
498        // Calculate mean and std for good and bad distributions
499        let good_mean = good_values.iter().sum::<f64>() / good_values.len() as f64;
500        let good_std = if good_values.len() > 1 {
501            let variance = good_values.iter()
502                .map(|x| (x - good_mean).powi(2))
503                .sum::<f64>() / (good_values.len() - 1) as f64;
504            variance.sqrt()
505        } else {
506            (bounds.1 - bounds.0) * 0.1
507        };
508
509        // Sample from good distribution (truncated normal)
510        let value = self.sample_truncated_normal(good_mean, good_std, bounds, rng);
511        value.clamp(bounds.0, bounds.1)
512    }
513
514    /// Sample discrete parameter using TPE
515    #[allow(dead_code)]
516    fn sample_tpe_discrete(
517        &self,
518        good_values: Vec<usize>,
519        _bad_values: Vec<usize>,
520        choices: &[usize],
521        rng: &mut impl RngExt,
522    ) -> usize {
523        if good_values.is_empty() {
524            return *choices.get(rng.random_range(0..choices.len())).unwrap();
525        }
526
527        // Count frequency in good trials
528        let mut counts: HashMap<usize, usize> = HashMap::new();
529        for &val in &good_values {
530            *counts.entry(val).or_insert(0) += 1;
531        }
532
533        // Choose based on frequency (weighted sampling)
534        let total: usize = counts.values().sum();
535        if total == 0 {
536            return *choices.get(rng.random_range(0..choices.len())).unwrap();
537        }
538
539        let r: f64 = rng.random::<f64>() * total as f64;
540        let mut cumsum = 0.0;
541
542        for (&val, &count) in counts.iter() {
543            cumsum += count as f64;
544            if r <= cumsum {
545                return val;
546            }
547        }
548
549        // Fallback
550        *good_values.last().unwrap()
551    }
552
553    /// Sample from truncated normal distribution
554    #[allow(dead_code)]
555    fn sample_truncated_normal(
556        &self,
557        mean: f64,
558        std: f64,
559        bounds: (f64, f64),
560        rng: &mut impl RngExt,
561    ) -> f64 {
562        use rand_distr::{Normal, Distribution};
563
564        let normal = Normal::new(mean, std).unwrap_or_else(|_| Normal::new(mean, 0.1).unwrap());
565
566        // Sample with rejection (max 100 attempts)
567        for _ in 0..100 {
568            let value = normal.sample(rng);
569            if value >= bounds.0 && value <= bounds.1 {
570                return value;
571            }
572        }
573
574        // Fallback to clamped value
575        mean.clamp(bounds.0, bounds.1)
576    }
577
578    /// Record trial result and save state
579    pub fn tell(&mut self, trial: TrialResult) {
580        info!(
581            "Trial {}: quality={:.4}, lr={:.6}, batch={}, gamma={:.3}",
582            trial.trial_number,
583            trial.quality_score,
584            trial.hyperparameters.learning_rate,
585            trial.hyperparameters.batch_size,
586            trial.hyperparameters.gamma
587        );
588
589        self.trials.push(trial);
590
591        // Save state if path is configured
592        if let Some(ref path) = self.state_path {
593            let state = OptimizerState {
594                trials: self.trials.clone(),
595                n_startup_trials: self.n_startup_trials,
596                space: self.space.clone(),
597                best_trial: self.get_best_trial_idx(),
598                timestamp: chrono::Utc::now().to_rfc3339(),
599            };
600
601            if let Err(e) = state.save(path) {
602                warn!("Failed to save optimizer state: {}", e);
603            }
604        }
605    }
606
607    /// Get best hyperparameters
608    pub fn get_best(&self) -> Option<&Hyperparameters> {
609        self.trials.iter()
610            .max_by(|a, b| a.quality_score.partial_cmp(&b.quality_score).unwrap())
611            .map(|t| &t.hyperparameters)
612    }
613
614    /// Get best trial index
615    fn get_best_trial_idx(&self) -> Option<usize> {
616        self.trials.iter()
617            .enumerate()
618            .max_by(|(_, a), (_, b)| a.quality_score.partial_cmp(&b.quality_score).unwrap())
619            .map(|(idx, _)| idx)
620    }
621
622    /// Get number of trials completed
623    pub fn num_trials(&self) -> usize {
624        self.trials.len()
625    }
626
627    /// Save results with algorithm-specific filename
628    pub fn save_results_for_algorithm(&self, output_dir: &Path, algorithm: AlgorithmType) -> Result<()> {
629        let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S");
630        let filename = format!("tuning_results_{}_{}.json",
631                               algorithm.to_string().to_lowercase(),
632                               timestamp
633        );
634        let path = output_dir.join(filename);
635
636        let best_trial = self.get_best_trial_idx();
637
638        let results = serde_json::json!({
639            "algorithm": algorithm.to_string(),
640            "n_trials": self.trials.len(),
641            "best_quality": self.get_best().map(|h| h.quality_score).unwrap_or(0.0),
642            "best_trial_number": best_trial.map(|i| self.trials[i].trial_number),
643            "best_hyperparameters": self.get_best(),
644            "all_trials": self.trials,
645            "search_space": self.space,
646        });
647
648        let json = serde_json::to_string_pretty(&results)?;
649        std::fs::write(&path, json)?;
650
651        info!("✓ Saved {} tuning results to: {}", algorithm, path.display());
652        Ok(())
653    }
654
655    /// Save optimization results
656    pub fn save_results(&self, path: &Path) -> Result<()> {
657        let best_trial = self.get_best_trial_idx();
658
659        let results = serde_json::json!({
660            "n_trials": self.trials.len(),
661            "best_quality": self.get_best().map(|h| h.quality_score).unwrap_or(0.0),
662            "best_trial_number": best_trial.map(|i| self.trials[i].trial_number),
663            "best_hyperparameters": self.get_best(),
664            "all_trials": self.trials,
665            "search_space": self.space,
666        });
667
668        let json = serde_json::to_string_pretty(&results)?;
669        std::fs::write(path, json)?;
670        info!("Saved optimization results to: {}", path.display());
671        Ok(())
672    }
673}
674
675#[cfg(test)]
676mod tests {
677    use super::*;
678    use tempfile::TempDir;
679
680    #[test]
681    fn test_tpe_optimizer() {
682        let space = HyperparameterSpace::default();
683        let mut optimizer = TPEOptimizer::new(space);
684        let mut rng = rand::rng();
685        // Simulate some trials
686        for i in 0..15 {
687            let params = optimizer.random_suggest(&mut rng);
688            let quality = 0.5 + i as f64 * 0.02; // Simulate improving quality
689
690            let trial = TrialResult {
691                trial_number: i,
692                hyperparameters: Hyperparameters {
693                    quality_score: quality,
694                    ..params
695                },
696                quality_score: quality,
697                avg_reward: quality * 2.0 - 1.0,
698                duration_seconds: 100.0,
699            };
700
701            optimizer.tell(trial);
702        }
703
704        let best = optimizer.get_best().unwrap();
705        assert!(best.quality_score > 0.7);
706    }
707
708    #[test]
709    fn test_optimizer_resume() {
710        let temp_dir = TempDir::new().unwrap();
711        let state_path = temp_dir.path().join("optimizer_state.json");
712
713        let space = HyperparameterSpace::default();
714
715        // First session
716        {
717            let mut optimizer = TPEOptimizer::with_resume(space.clone(), state_path.clone()).unwrap();
718            let mut rng = rand::rng();
719            for i in 0..5 {
720                let params = optimizer.random_suggest(&mut rng);
721                let trial = TrialResult {
722                    trial_number: i,
723                    hyperparameters: Hyperparameters {
724                        quality_score: 0.5 + i as f64 * 0.1,
725                        ..params
726                    },
727                    quality_score: 0.5 + i as f64 * 0.1,
728                    avg_reward: 0.0,
729                    duration_seconds: 100.0,
730                };
731                optimizer.tell(trial);
732            }
733
734            assert_eq!(optimizer.num_trials(), 5);
735        }
736
737        // Resume session
738        {
739            let mut optimizer = TPEOptimizer::with_resume(space, state_path).unwrap();
740            assert_eq!(optimizer.num_trials(), 5);
741            let mut rng = rand::rng();
742            // Continue with more trials
743            for i in 5..10 {
744                let params = optimizer.random_suggest(&mut rng);
745                let trial = TrialResult {
746                    trial_number: i,
747                    hyperparameters: Hyperparameters {
748                        quality_score: 0.5 + i as f64 * 0.1,
749                        ..params
750                    },
751                    quality_score: 0.5 + i as f64 * 0.1,
752                    avg_reward: 0.0,
753                    duration_seconds: 100.0,
754                };
755                optimizer.tell(trial);
756            }
757
758            assert_eq!(optimizer.num_trials(), 10);
759        }
760    }
761}