Skip to main content

content_extractor_rl/
hyperparameter_tuner.rs

1//! Hyperparameter tuning using TPE (Tree-structured Parzen Estimator) with resume capability
2// ============================================================================
3// FILE: crates/content-extractor-rl/src/hyperparameter_tuner.rs
4// ============================================================================
5
6use crate::{AlgorithmType, Config, Result};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::{Path, PathBuf};
10use rand::RngExt;
11use tracing::{info, warn};
12use rayon::prelude::*;
13use std::sync::{Arc, Mutex};
14use crate::models::NetworkConfig;
15
16/// hyperparameter search space with network architecture
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct HyperparameterSpace {
19    // Training hyperparameters
20    pub learning_rate: (f64, f64),
21    pub batch_size: Vec<usize>,
22    pub gamma: (f64, f64),
23    pub epsilon_decay: (f64, f64),
24    pub priority_alpha: (f64, f64),
25    pub priority_beta: (f64, f64),
26
27    // Network architecture hyperparameters
28    pub hidden_layer_sizes: Vec<Vec<usize>>,  // Different architectures to try
29    pub value_hidden: Vec<usize>,
30    pub advantage_hidden: Vec<usize>,
31    pub use_layer_norm: Vec<bool>,
32    pub dropout: (f32, f32),
33}
34
35impl Default for HyperparameterSpace {
36    fn default() -> Self {
37        Self {
38            learning_rate: (1e-5, 1e-2),
39            batch_size: vec![256, 512, 1024, 2048, 4096, 6144, 8192],
40            gamma: (0.85, 0.99),
41            epsilon_decay: (0.985, 0.999),
42            priority_alpha: (0.35, 0.8),
43            priority_beta: (0.3, 0.7),
44
45            // Network architectures to try
46            hidden_layer_sizes: vec![
47                vec![256, 128],           // Small
48                vec![512, 256, 128],      // Default
49                vec![1024, 512, 256],     // Large
50                vec![512, 512, 256, 128], // Deep
51            ],
52            value_hidden: vec![32, 64, 128, 192],
53            advantage_hidden: vec![32, 64, 128, 192],
54            use_layer_norm: vec![true, false],
55            dropout: (0.0, 0.01),
56        }
57    }
58}
59
60/// Hyperparameter configuration
61/// Enhanced hyperparameters including network config
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct Hyperparameters {
64    // Training hyperparameters
65    pub learning_rate: f64,
66    pub batch_size: usize,
67    pub gamma: f64,
68    pub epsilon_decay: f64,
69    pub priority_alpha: f64,
70    pub priority_beta: f64,
71
72    // Network architecture
73    pub network_config: NetworkConfig,
74
75    pub timestamp: String,
76    pub quality_score: f64,
77}
78
79impl Hyperparameters {
80    /// Apply hyperparameters to config
81    pub fn apply_to_config(&self, config: &mut Config) {
82        config.learning_rate = self.learning_rate;
83        config.batch_size = self.batch_size;
84        config.gamma = self.gamma;
85        config.epsilon_decay = self.epsilon_decay;
86        config.priority_alpha = self.priority_alpha;
87        config.priority_beta = self.priority_beta;
88
89        // Apply network config
90        config.state_dim = self.network_config.state_dim;
91        config.num_discrete_actions = self.network_config.num_actions;
92        config.num_continuous_params = self.network_config.num_params;
93    }
94
95    /// Save to algorithm-specific JSON file
96    pub fn save_with_algorithm(&self, base_path: &Path, algorithm: AlgorithmType) -> Result<()> {
97        let filename = format!("best_hyperparams_{}.json", algorithm.to_string().to_lowercase());
98        let path = base_path.parent()
99            .unwrap_or(base_path)
100            .join(filename);
101
102        let json = serde_json::to_string_pretty(self)?;
103        std::fs::write(&path, json)?;
104
105        info!("✓ Saved {} hyperparameters to: {}", algorithm, path.display());
106        Ok(())
107    }
108
109    /// Load from algorithm-specific file
110    pub fn load_for_algorithm(base_dir: &Path, algorithm: AlgorithmType) -> Result<Self> {
111        let filename = format!("best_hyperparams_{}.json", algorithm.to_string().to_lowercase());
112        let path = base_dir.join(&filename);
113
114        if !path.exists() {
115            return Err(crate::ExtractionError::ParseError(
116                format!("Hyperparameters file not found: {}", path.display())
117            ));
118        }
119
120        let json = std::fs::read_to_string(&path)?;
121        let params:Hyperparameters = serde_json::from_str(&json)?;
122
123        info!("✓ Loaded {} hyperparameters from: {}", algorithm, path.display());
124        info!("  Settings:");
125        info!("    learning_rate: {:.6}", params.learning_rate);
126        info!("    batch_size: {}", params.batch_size);
127        info!("    gamma: {:.3}", params.gamma);
128        info!("    epsilon_decay: {:.6}", params.epsilon_decay);
129        info!("    priority_alpha: {:.3}", params.priority_alpha);
130        info!("    priority_beta: {:.3}", params.priority_beta);
131
132        Ok(params)
133    }
134
135    /// Save to JSON file
136    pub fn save(&self, path: &Path) -> Result<()> {
137        let json = serde_json::to_string_pretty(self)?;
138        std::fs::write(path, json)?;
139        info!("Saved hyperparameters to: {}", path.display());
140        Ok(())
141    }
142
143    /// Load from JSON file
144    pub fn load(path: &Path) -> Result<Self> {
145        let json = std::fs::read_to_string(path)?;
146        let params = serde_json::from_str(&json)?;
147        info!("Loaded hyperparameters from: {}", path.display());
148        Ok(params)
149    }
150}
151
152/// Trial result from hyperparameter optimization
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct TrialResult {
155    pub trial_number: usize,
156    pub hyperparameters: Hyperparameters,
157    pub quality_score: f64,
158    pub avg_reward: f64,
159    pub duration_seconds: f64,
160}
161
162/// Optimizer state for resuming
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct OptimizerState {
165    pub trials: Vec<TrialResult>,
166    pub n_startup_trials: usize,
167    pub space: HyperparameterSpace,
168    pub best_trial: Option<usize>,
169    pub timestamp: String,
170}
171
172impl OptimizerState {
173    /// Save state to file
174    pub fn save(&self, path: &Path) -> Result<()> {
175        let json = serde_json::to_string_pretty(self)?;
176        std::fs::write(path, json)?;
177        info!("Saved optimizer state to: {}", path.display());
178        Ok(())
179    }
180
181    /// Load state from file
182    pub fn load(path: &Path) -> Result<Self> {
183        let json = std::fs::read_to_string(path)?;
184        let state = serde_json::from_str(&json)?;
185        info!("Loaded optimizer state from: {}", path.display());
186        Ok(state)
187    }
188}
189
190/// TPE-based hyperparameter optimizer with resume capability
191pub struct TPEOptimizer {
192    space: HyperparameterSpace,
193    trials: Vec<TrialResult>,
194    n_startup_trials: usize,
195    state_path: Option<PathBuf>,
196}
197
198impl TPEOptimizer {
199    /// Create new TPE optimizer
200    pub fn new(space: HyperparameterSpace) -> Self {
201        Self {
202            space,
203            trials: Vec::new(),
204            n_startup_trials: 5, // As per requirement
205            state_path: None,
206        }
207    }
208
209    /// Run parallel hyperparameter optimization
210    pub fn optimize_parallel(
211        &mut self,
212        n_trials: usize,
213        episodes_per_trial: usize,
214        html_samples: Vec<(String, String)>,
215        base_config: &Config,
216        n_workers: usize,
217    ) -> Result<()> {
218        info!("Starting parallel TPE optimization with {} workers", n_workers);
219
220        // Configure rayon thread pool
221        let pool = rayon::ThreadPoolBuilder::new()
222            .num_threads(n_workers)
223            .build()
224            .map_err(|e| crate::ExtractionError::RuntimeError(e.to_string()))?;
225
226        // Generate all trial parameters upfront (sequential, uses TPE logic)
227        let mut all_trial_params = Vec::new();
228        let mut rng = rand::rng();
229        for trial_num in 0..n_trials {
230            // Use random sampling for all trials in parallel mode
231            let params = self.random_suggest(&mut rng);
232            all_trial_params.push((trial_num, params));
233        }
234
235        // Shared state for collecting results
236        let results = Arc::new(Mutex::new(Vec::new()));
237        let completed_trials = Arc::new(Mutex::new(0usize));
238
239        // Run trials in parallel
240        pool.install(|| {
241            all_trial_params.par_iter().for_each(|(trial_num, params)| {
242                info!("Worker starting trial {}", trial_num);
243
244                // Each worker gets its own config and data
245                let mut trial_config = base_config.clone();
246                params.apply_to_config(&mut trial_config);
247                trial_config.num_episodes = episodes_per_trial;
248
249                // Use CPU for parallel trials to avoid GPU contention
250                trial_config.use_cpu_for_tuning = true;
251
252                let trial_start = std::time::Instant::now();
253
254                // Run training
255                let result = crate::training::train_standard(&trial_config, html_samples.clone());
256
257                match result {
258                    Ok((_agent, metrics)) => {
259                        let duration = trial_start.elapsed();
260
261                        // Calculate quality
262                        let window = metrics.episode_qualities.len().min(50);
263                        let quality = if metrics.episode_qualities.len() >= window {
264                            metrics.episode_qualities[metrics.episode_qualities.len() - window..]
265                                .iter()
266                                .sum::<f32>() / window as f32
267                        } else if !metrics.episode_qualities.is_empty() {
268                            metrics.episode_qualities.iter().sum::<f32>() /
269                                metrics.episode_qualities.len() as f32
270                        } else {
271                            0.0
272                        };
273
274                        let avg_reward = if !metrics.episode_rewards.is_empty() {
275                            let window = metrics.episode_rewards.len().min(50);
276                            if metrics.episode_rewards.len() >= window {
277                                metrics.episode_rewards[metrics.episode_rewards.len() - window..]
278                                    .iter()
279                                    .sum::<f32>() / window as f32
280                            } else {
281                                metrics.episode_rewards.iter().sum::<f32>() /
282                                    metrics.episode_rewards.len() as f32
283                            }
284                        } else {
285                            0.0
286                        };
287
288                        // Record result
289                        let trial_result = TrialResult {
290                            trial_number: *trial_num,
291                            hyperparameters: Hyperparameters {
292                                quality_score: quality as f64,
293                                ..params.clone()
294                            },
295                            quality_score: quality as f64,
296                            avg_reward: avg_reward as f64,
297                            duration_seconds: duration.as_secs_f64(),
298                        };
299
300                        // Store result
301                        {
302                            let mut res = results.lock().unwrap();
303                            res.push(trial_result);
304                        }
305
306                        {
307                            let mut completed = completed_trials.lock().unwrap();
308                            *completed += 1;
309                            info!("Trial {} completed ({}/{}): quality={:.4}",
310                          trial_num, *completed, n_trials, quality);
311                        }
312                    }
313                    Err(e) => {
314                        warn!("Trial {} failed: {}", trial_num, e);
315                    }
316                }
317            });
318        });
319
320        // After all trials complete, update self with results
321        let trial_results = results.lock().unwrap();
322        for trial_result in trial_results.iter() {
323            self.tell(trial_result.clone());
324        }
325
326        info!("Parallel optimization complete");
327        Ok(())
328    }
329
330    /// Create optimizer with resume capability
331    pub fn with_resume(space: HyperparameterSpace, state_path: PathBuf) -> Result<Self> {
332        let mut optimizer = Self {
333            space: space.clone(),
334            trials: Vec::new(),
335            n_startup_trials: 5,
336            state_path: Some(state_path.clone()),
337        };
338
339        // Try to load existing state
340        if state_path.exists() {
341            info!("Resuming from saved state: {}", state_path.display());
342            let state = OptimizerState::load(&state_path)?;
343            optimizer.trials = state.trials;
344            optimizer.space = state.space;
345            optimizer.n_startup_trials = state.n_startup_trials;
346            info!("Resumed with {} existing trials", optimizer.trials.len());
347        }
348
349        Ok(optimizer)
350    }
351
352    /// Random hyperparameter suggestion for initial trials
353    pub fn random_suggest(&self, rng: &mut impl RngExt) -> Hyperparameters {
354        // Sample random network architecture
355        let hidden_layers = self.space.hidden_layer_sizes
356            .get(rng.random_range(0..self.space.hidden_layer_sizes.len()))
357            .unwrap()
358            .clone();
359
360        let value_hidden = *self.space.value_hidden
361            .get(rng.random_range(0..self.space.value_hidden.len()))
362            .unwrap();
363
364        let advantage_hidden = *self.space.advantage_hidden
365            .get(rng.random_range(0..self.space.advantage_hidden.len()))
366            .unwrap();
367
368        let use_layer_norm = *self.space.use_layer_norm
369            .get(rng.random_range(0..self.space.use_layer_norm.len()))
370            .unwrap();
371
372        let dropout = rng.random_range(self.space.dropout.0..self.space.dropout.1);
373
374        Hyperparameters {
375            learning_rate: rng.random_range(self.space.learning_rate.0..self.space.learning_rate.1),
376            batch_size: *self.space.batch_size
377                .get(rng.random_range(0..self.space.batch_size.len()))
378                .unwrap(),
379            gamma: rng.random_range(self.space.gamma.0..self.space.gamma.1),
380            epsilon_decay: rng.random_range(self.space.epsilon_decay.0..self.space.epsilon_decay.1),
381            priority_alpha: rng.random_range(self.space.priority_alpha.0..self.space.priority_alpha.1),
382            priority_beta: rng.random_range(self.space.priority_beta.0..self.space.priority_beta.1),
383            network_config: NetworkConfig {
384                state_dim: 300,
385                num_actions: 16,
386                num_params: 6,
387                hidden_layers,
388                use_layer_norm,
389                dropout,
390                value_hidden,
391                advantage_hidden,
392            },
393            timestamp: chrono::Utc::now().to_rfc3339(),
394            quality_score: 0.0,
395        }
396    }
397
398    /// Sample categorical choice (e.g., network architecture)
399    #[allow(dead_code)]
400    fn sample_tpe_categorical<T: Clone>(
401        &self,
402        good_values: Vec<&T>,
403        _bad_values: Vec<&T>,
404        choices: &[T],
405        rng: &mut impl RngExt,
406    ) -> T {
407        if good_values.is_empty() {
408            return choices[rng.random_range(0..choices.len())].clone();
409        }
410
411        // Count frequency in good trials
412        let mut counts: HashMap<usize, usize> = HashMap::new();
413        for _good_val in &good_values {
414            for (i, _choice) in choices.iter().enumerate() {
415                // This is a simplified comparison - in real code you'd need proper equality
416                counts.entry(i).or_insert(0);
417            }
418        }
419
420        // Weighted sampling based on good trial frequencies
421        if counts.is_empty() {
422            choices[rng.random_range(0..choices.len())].clone()
423        } else {
424            let total: usize = counts.values().sum();
425            let r: f64 = rng.random::<f64>() * total as f64;
426            let mut cumsum = 0.0;
427
428            for (idx, count) in counts.iter() {
429                cumsum += *count as f64;
430                if r <= cumsum {
431                    return choices[*idx].clone();
432                }
433            }
434            choices[0].clone()
435        }
436    }
437
438    /// Sample boolean parameter
439    #[allow(dead_code)]
440    fn sample_tpe_boolean(
441        &self,
442        good_values: Vec<bool>,
443        _bad_values: Vec<bool>,
444        rng: &mut impl RngExt,
445    ) -> bool {
446        if good_values.is_empty() {
447            return rng.random();
448        }
449
450        let true_count = good_values.iter().filter(|&&x| x).count();
451        let probability = true_count as f64 / good_values.len() as f64;
452
453        rng.random::<f64>() < probability
454    }
455
456    #[allow(dead_code)]
457    fn good_trials(&self) -> Vec<TrialResult> {
458        let quantile = 0.25;
459        let mut sorted = self.trials.clone();
460        sorted.sort_by(|a, b| b.quality_score.partial_cmp(&a.quality_score).unwrap());
461        let n_good = (sorted.len() as f64 * quantile).ceil() as usize;
462        sorted[..n_good].to_vec()
463    }
464
465    #[allow(dead_code)]
466    fn bad_trials(&self) -> Vec<TrialResult> {
467        let quantile = 0.25;
468        let mut sorted = self.trials.clone();
469        sorted.sort_by(|a, b| b.quality_score.partial_cmp(&a.quality_score).unwrap());
470        let n_good = (sorted.len() as f64 * quantile).ceil() as usize;
471        sorted[n_good..].to_vec()
472    }
473
474    /// Sample continuous parameter using TPE
475    #[allow(dead_code)]
476    fn sample_tpe_continuous(
477        &self,
478        good_values: Vec<f64>,
479        _bad_values: Vec<f64>,
480        bounds: (f64, f64),
481        rng: &mut impl RngExt,
482    ) -> f64 {
483        if good_values.is_empty() {
484            return rng.random_range(bounds.0..bounds.1);
485        }
486
487        // Calculate mean and std for good and bad distributions
488        let good_mean = good_values.iter().sum::<f64>() / good_values.len() as f64;
489        let good_std = if good_values.len() > 1 {
490            let variance = good_values.iter()
491                .map(|x| (x - good_mean).powi(2))
492                .sum::<f64>() / (good_values.len() - 1) as f64;
493            variance.sqrt()
494        } else {
495            (bounds.1 - bounds.0) * 0.1
496        };
497
498        // Sample from good distribution (truncated normal)
499        let value = self.sample_truncated_normal(good_mean, good_std, bounds, rng);
500        value.clamp(bounds.0, bounds.1)
501    }
502
503    /// Sample discrete parameter using TPE
504    #[allow(dead_code)]
505    fn sample_tpe_discrete(
506        &self,
507        good_values: Vec<usize>,
508        _bad_values: Vec<usize>,
509        choices: &[usize],
510        rng: &mut impl RngExt,
511    ) -> usize {
512        if good_values.is_empty() {
513            return *choices.get(rng.random_range(0..choices.len())).unwrap();
514        }
515
516        // Count frequency in good trials
517        let mut counts: HashMap<usize, usize> = HashMap::new();
518        for &val in &good_values {
519            *counts.entry(val).or_insert(0) += 1;
520        }
521
522        // Choose based on frequency (weighted sampling)
523        let total: usize = counts.values().sum();
524        if total == 0 {
525            return *choices.get(rng.random_range(0..choices.len())).unwrap();
526        }
527
528        let r: f64 = rng.random::<f64>() * total as f64;
529        let mut cumsum = 0.0;
530
531        for (&val, &count) in counts.iter() {
532            cumsum += count as f64;
533            if r <= cumsum {
534                return val;
535            }
536        }
537
538        // Fallback
539        *good_values.last().unwrap()
540    }
541
542    /// Sample from truncated normal distribution
543    #[allow(dead_code)]
544    fn sample_truncated_normal(
545        &self,
546        mean: f64,
547        std: f64,
548        bounds: (f64, f64),
549        rng: &mut impl RngExt,
550    ) -> f64 {
551        use rand_distr::{Normal, Distribution};
552
553        let normal = Normal::new(mean, std).unwrap_or_else(|_| Normal::new(mean, 0.1).unwrap());
554
555        // Sample with rejection (max 100 attempts)
556        for _ in 0..100 {
557            let value = normal.sample(rng);
558            if value >= bounds.0 && value <= bounds.1 {
559                return value;
560            }
561        }
562
563        // Fallback to clamped value
564        mean.clamp(bounds.0, bounds.1)
565    }
566
567    /// Record trial result and save state
568    pub fn tell(&mut self, trial: TrialResult) {
569        info!(
570            "Trial {}: quality={:.4}, lr={:.6}, batch={}, gamma={:.3}",
571            trial.trial_number,
572            trial.quality_score,
573            trial.hyperparameters.learning_rate,
574            trial.hyperparameters.batch_size,
575            trial.hyperparameters.gamma
576        );
577
578        self.trials.push(trial);
579
580        // Save state if path is configured
581        if let Some(ref path) = self.state_path {
582            let state = OptimizerState {
583                trials: self.trials.clone(),
584                n_startup_trials: self.n_startup_trials,
585                space: self.space.clone(),
586                best_trial: self.get_best_trial_idx(),
587                timestamp: chrono::Utc::now().to_rfc3339(),
588            };
589
590            if let Err(e) = state.save(path) {
591                warn!("Failed to save optimizer state: {}", e);
592            }
593        }
594    }
595
596    /// Get best hyperparameters
597    pub fn get_best(&self) -> Option<&Hyperparameters> {
598        self.trials.iter()
599            .max_by(|a, b| a.quality_score.partial_cmp(&b.quality_score).unwrap())
600            .map(|t| &t.hyperparameters)
601    }
602
603    /// Get best trial index
604    fn get_best_trial_idx(&self) -> Option<usize> {
605        self.trials.iter()
606            .enumerate()
607            .max_by(|(_, a), (_, b)| a.quality_score.partial_cmp(&b.quality_score).unwrap())
608            .map(|(idx, _)| idx)
609    }
610
611    /// Get number of trials completed
612    pub fn num_trials(&self) -> usize {
613        self.trials.len()
614    }
615
616    /// Save results with algorithm-specific filename
617    pub fn save_results_for_algorithm(&self, output_dir: &Path, algorithm: AlgorithmType) -> Result<()> {
618        let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S");
619        let filename = format!("tuning_results_{}_{}.json",
620                               algorithm.to_string().to_lowercase(),
621                               timestamp
622        );
623        let path = output_dir.join(filename);
624
625        let best_trial = self.get_best_trial_idx();
626
627        let results = serde_json::json!({
628            "algorithm": algorithm.to_string(),
629            "n_trials": self.trials.len(),
630            "best_quality": self.get_best().map(|h| h.quality_score).unwrap_or(0.0),
631            "best_trial_number": best_trial.map(|i| self.trials[i].trial_number),
632            "best_hyperparameters": self.get_best(),
633            "all_trials": self.trials,
634            "search_space": self.space,
635        });
636
637        let json = serde_json::to_string_pretty(&results)?;
638        std::fs::write(&path, json)?;
639
640        info!("✓ Saved {} tuning results to: {}", algorithm, path.display());
641        Ok(())
642    }
643
644    /// Save optimization results
645    pub fn save_results(&self, path: &Path) -> Result<()> {
646        let best_trial = self.get_best_trial_idx();
647
648        let results = serde_json::json!({
649            "n_trials": self.trials.len(),
650            "best_quality": self.get_best().map(|h| h.quality_score).unwrap_or(0.0),
651            "best_trial_number": best_trial.map(|i| self.trials[i].trial_number),
652            "best_hyperparameters": self.get_best(),
653            "all_trials": self.trials,
654            "search_space": self.space,
655        });
656
657        let json = serde_json::to_string_pretty(&results)?;
658        std::fs::write(path, json)?;
659        info!("Saved optimization results to: {}", path.display());
660        Ok(())
661    }
662}
663
664#[cfg(test)]
665mod tests {
666    use super::*;
667    use tempfile::TempDir;
668
669    #[test]
670    fn test_tpe_optimizer() {
671        let space = HyperparameterSpace::default();
672        let mut optimizer = TPEOptimizer::new(space);
673        let mut rng = rand::rng();
674        // Simulate some trials
675        for i in 0..15 {
676            let params = optimizer.random_suggest(&mut rng);
677            let quality = 0.5 + i as f64 * 0.02; // Simulate improving quality
678
679            let trial = TrialResult {
680                trial_number: i,
681                hyperparameters: Hyperparameters {
682                    quality_score: quality,
683                    ..params
684                },
685                quality_score: quality,
686                avg_reward: quality * 2.0 - 1.0,
687                duration_seconds: 100.0,
688            };
689
690            optimizer.tell(trial);
691        }
692
693        let best = optimizer.get_best().unwrap();
694        assert!(best.quality_score > 0.7);
695    }
696
697    #[test]
698    fn test_optimizer_resume() {
699        let temp_dir = TempDir::new().unwrap();
700        let state_path = temp_dir.path().join("optimizer_state.json");
701
702        let space = HyperparameterSpace::default();
703
704        // First session
705        {
706            let mut optimizer = TPEOptimizer::with_resume(space.clone(), state_path.clone()).unwrap();
707            let mut rng = rand::rng();
708            for i in 0..5 {
709                let params = optimizer.random_suggest(&mut rng);
710                let trial = TrialResult {
711                    trial_number: i,
712                    hyperparameters: Hyperparameters {
713                        quality_score: 0.5 + i as f64 * 0.1,
714                        ..params
715                    },
716                    quality_score: 0.5 + i as f64 * 0.1,
717                    avg_reward: 0.0,
718                    duration_seconds: 100.0,
719                };
720                optimizer.tell(trial);
721            }
722
723            assert_eq!(optimizer.num_trials(), 5);
724        }
725
726        // Resume session
727        {
728            let mut optimizer = TPEOptimizer::with_resume(space, state_path).unwrap();
729            assert_eq!(optimizer.num_trials(), 5);
730            let mut rng = rand::rng();
731            // Continue with more trials
732            for i in 5..10 {
733                let params = optimizer.random_suggest(&mut rng);
734                let trial = TrialResult {
735                    trial_number: i,
736                    hyperparameters: Hyperparameters {
737                        quality_score: 0.5 + i as f64 * 0.1,
738                        ..params
739                    },
740                    quality_score: 0.5 + i as f64 * 0.1,
741                    avg_reward: 0.0,
742                    duration_seconds: 100.0,
743                };
744                optimizer.tell(trial);
745            }
746
747            assert_eq!(optimizer.num_trials(), 10);
748        }
749    }
750}