Skip to main content

entrenar_bench/
sweep.rs

1//! Hyperparameter sweep executor (Kaizen principle).
2
3use entrenar_common::Result;
4
5/// Sweep configuration.
6#[derive(Debug, Clone)]
7pub struct SweepConfig {
8    /// Parameter to sweep
9    pub parameter: SweepParameter,
10    /// Number of runs per configuration
11    pub runs_per_point: usize,
12    /// Whether to use early stopping
13    pub early_stop: bool,
14    /// Random seed for reproducibility
15    pub seed: Option<u64>,
16}
17
18impl SweepConfig {
19    /// Create a temperature sweep.
20    pub fn temperature(range: std::ops::Range<f32>, step: f32) -> Self {
21        Self {
22            parameter: SweepParameter::Temperature {
23                start: range.start,
24                end: range.end,
25                step,
26            },
27            runs_per_point: 1,
28            early_stop: false,
29            seed: Some(42),
30        }
31    }
32
33    /// Create an alpha sweep.
34    pub fn alpha(range: std::ops::Range<f32>, step: f32) -> Self {
35        Self {
36            parameter: SweepParameter::Alpha {
37                start: range.start,
38                end: range.end,
39                step,
40            },
41            runs_per_point: 1,
42            early_stop: false,
43            seed: Some(42),
44        }
45    }
46
47    /// Set number of runs per point.
48    pub fn with_runs(mut self, runs: usize) -> Self {
49        self.runs_per_point = runs;
50        self
51    }
52
53    /// Enable early stopping.
54    pub fn with_early_stop(mut self) -> Self {
55        self.early_stop = true;
56        self
57    }
58
59    /// Set random seed.
60    pub fn with_seed(mut self, seed: u64) -> Self {
61        self.seed = Some(seed);
62        self
63    }
64}
65
66/// Parameter being swept.
67#[derive(Debug, Clone)]
68pub enum SweepParameter {
69    /// Temperature parameter
70    Temperature { start: f32, end: f32, step: f32 },
71    /// Alpha parameter
72    Alpha { start: f32, end: f32, step: f32 },
73    /// LoRA rank
74    Rank { values: Vec<u32> },
75    /// Learning rate
76    LearningRate { values: Vec<f64> },
77}
78
79impl SweepParameter {
80    /// Get the values to sweep over.
81    pub fn values(&self) -> Vec<f64> {
82        match self {
83            Self::Temperature { start, end, step } | Self::Alpha { start, end, step } => {
84                let mut values = Vec::new();
85                let mut v = *start;
86                while v <= *end {
87                    values.push(f64::from(v));
88                    v += step;
89                }
90                values
91            }
92            Self::Rank { values } => values.iter().map(|&v| f64::from(v)).collect(),
93            Self::LearningRate { values } => values.clone(),
94        }
95    }
96
97    /// Get parameter name.
98    pub fn name(&self) -> &'static str {
99        match self {
100            Self::Temperature { .. } => "temperature",
101            Self::Alpha { .. } => "alpha",
102            Self::Rank { .. } => "rank",
103            Self::LearningRate { .. } => "learning_rate",
104        }
105    }
106}
107
108/// Sweep executor.
109pub struct Sweeper {
110    config: SweepConfig,
111}
112
113impl Sweeper {
114    /// Create a new sweeper.
115    pub fn new(config: SweepConfig) -> Self {
116        Self { config }
117    }
118
119    /// Run the sweep.
120    pub fn run(&self) -> Result<SweepResult> {
121        let values = self.config.parameter.values();
122        let mut data_points = Vec::new();
123
124        for value in &values {
125            let mut metrics = Vec::new();
126
127            for run in 0..self.config.runs_per_point {
128                // Simulate training with this configuration
129                let result = self.simulate_training(*value, run);
130                metrics.push(result);
131            }
132
133            // Aggregate metrics across runs
134            let mean_loss = metrics.iter().map(|m| m.loss).sum::<f64>() / metrics.len() as f64;
135            let mean_accuracy =
136                metrics.iter().map(|m| m.accuracy).sum::<f64>() / metrics.len() as f64;
137            let std_loss = self.calculate_std(&metrics.iter().map(|m| m.loss).collect::<Vec<_>>());
138            let std_accuracy =
139                self.calculate_std(&metrics.iter().map(|m| m.accuracy).collect::<Vec<_>>());
140
141            data_points.push(DataPoint {
142                parameter_value: *value,
143                mean_loss,
144                std_loss,
145                mean_accuracy,
146                std_accuracy,
147                runs: metrics.len(),
148            });
149        }
150
151        // Find optimal
152        let optimal = data_points
153            .iter()
154            .min_by(|a, b| {
155                a.mean_loss
156                    .partial_cmp(&b.mean_loss)
157                    .unwrap_or(std::cmp::Ordering::Equal)
158            })
159            .cloned();
160
161        Ok(SweepResult {
162            parameter_name: self.config.parameter.name().to_string(),
163            data_points,
164            optimal,
165            config: self.config.clone(),
166        })
167    }
168
169    fn simulate_training(&self, param_value: f64, run: usize) -> TrainingMetrics {
170        // Simulated training - in real implementation would run actual training
171        // Using a simple model where:
172        // - Temperature ~4.0 is optimal
173        // - Alpha ~0.7 is optimal
174
175        let seed_offset = self.config.seed.unwrap_or(0) + run as u64;
176        let noise = (seed_offset as f64 * 0.1).sin() * 0.05; // Deterministic "randomness"
177
178        let param_name = self.config.parameter.name();
179
180        let (loss, accuracy) = match param_name {
181            "temperature" => {
182                // Optimal around 4.0
183                let deviation = (param_value - 4.0).abs();
184                let loss = 0.65 + deviation * 0.1 + noise;
185                let accuracy = 0.83 - deviation * 0.02 + noise * 0.5;
186                (loss, accuracy.clamp(0.0, 1.0))
187            }
188            "alpha" => {
189                // Optimal around 0.7
190                let deviation = (param_value - 0.7).abs();
191                let loss = 0.65 + deviation * 0.2 + noise;
192                let accuracy = 0.83 - deviation * 0.05 + noise * 0.5;
193                (loss, accuracy.clamp(0.0, 1.0))
194            }
195            _ => (0.8 + noise, 0.75 + noise * 0.5),
196        };
197
198        TrainingMetrics {
199            loss,
200            accuracy,
201            throughput: 1200.0 + noise * 100.0,
202            duration_secs: 3600.0 + noise * 600.0,
203        }
204    }
205
206    fn calculate_std(&self, values: &[f64]) -> f64 {
207        if values.len() < 2 {
208            return 0.0;
209        }
210        let mean = values.iter().sum::<f64>() / values.len() as f64;
211        let variance =
212            values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
213        variance.sqrt()
214    }
215}
216
217/// Training metrics from a single run.
218#[derive(Debug, Clone)]
219pub struct TrainingMetrics {
220    /// Final loss
221    pub loss: f64,
222    /// Final accuracy
223    pub accuracy: f64,
224    /// Training throughput (samples/sec)
225    pub throughput: f64,
226    /// Training duration in seconds
227    pub duration_secs: f64,
228}
229
230/// A single data point in the sweep.
231#[derive(Debug, Clone)]
232pub struct DataPoint {
233    /// Parameter value
234    pub parameter_value: f64,
235    /// Mean loss across runs
236    pub mean_loss: f64,
237    /// Standard deviation of loss
238    pub std_loss: f64,
239    /// Mean accuracy across runs
240    pub mean_accuracy: f64,
241    /// Standard deviation of accuracy
242    pub std_accuracy: f64,
243    /// Number of runs
244    pub runs: usize,
245}
246
247/// Result of a sweep.
248#[derive(Debug, Clone)]
249pub struct SweepResult {
250    /// Parameter name
251    pub parameter_name: String,
252    /// Data points
253    pub data_points: Vec<DataPoint>,
254    /// Optimal configuration
255    pub optimal: Option<DataPoint>,
256    /// Original configuration
257    pub config: SweepConfig,
258}
259
260impl SweepResult {
261    /// Format as ASCII table.
262    pub fn to_table(&self) -> String {
263        let mut output = format!("{} Sweep Results\n", self.parameter_name);
264        output.push_str("┌─────────────┬────────────┬────────────┬────────────┐\n");
265        output.push_str("│ Value       │ Loss       │ Accuracy   │ Runs       │\n");
266        output.push_str("├─────────────┼────────────┼────────────┼────────────┤\n");
267
268        for point in &self.data_points {
269            let optimal_marker = if self.optimal.as_ref().map(|o| o.parameter_value)
270                == Some(point.parameter_value)
271            {
272                " ★"
273            } else {
274                ""
275            };
276
277            output.push_str(&format!(
278                "│ {:>10.2} │ {:>10.4} │ {:>9.1}% │ {:>10}{} │\n",
279                point.parameter_value,
280                point.mean_loss,
281                point.mean_accuracy * 100.0,
282                point.runs,
283                optimal_marker
284            ));
285        }
286
287        output.push_str("└─────────────┴────────────┴────────────┴────────────┘\n");
288
289        if let Some(optimal) = &self.optimal {
290            output.push_str(&format!(
291                "\nOptimal: {} = {:.2} (loss={:.4}, accuracy={:.1}%)\n",
292                self.parameter_name,
293                optimal.parameter_value,
294                optimal.mean_loss,
295                optimal.mean_accuracy * 100.0
296            ));
297        }
298
299        output
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn test_sweep_config_temperature() {
309        let config = SweepConfig::temperature(1.0..5.0, 1.0);
310        assert_eq!(config.parameter.name(), "temperature");
311
312        let values = config.parameter.values();
313        assert_eq!(values.len(), 5); // 1, 2, 3, 4, 5
314    }
315
316    #[test]
317    fn test_sweep_config_alpha() {
318        let config = SweepConfig::alpha(0.1..0.9, 0.1);
319        assert_eq!(config.parameter.name(), "alpha");
320    }
321
322    #[test]
323    fn test_sweeper_runs() {
324        let config = SweepConfig::temperature(1.0..3.0, 1.0).with_runs(2);
325        let sweeper = Sweeper::new(config);
326        let result = sweeper.run().expect("operation should succeed");
327
328        assert!(!result.data_points.is_empty());
329        assert!(result.optimal.is_some());
330    }
331
332    #[test]
333    fn test_sweeper_finds_optimal_temperature() {
334        let config = SweepConfig::temperature(2.0..6.0, 1.0).with_runs(1);
335        let sweeper = Sweeper::new(config);
336        let result = sweeper.run().expect("operation should succeed");
337
338        // Optimal should be around 4.0
339        let optimal = result.optimal.expect("operation should succeed");
340        assert!((optimal.parameter_value - 4.0).abs() < 1.5);
341    }
342
343    #[test]
344    fn test_sweep_result_table() {
345        let config = SweepConfig::temperature(1.0..3.0, 1.0);
346        let sweeper = Sweeper::new(config);
347        let result = sweeper.run().expect("operation should succeed");
348
349        let table = result.to_table();
350        assert!(table.contains("temperature"));
351        assert!(table.contains("Loss"));
352        assert!(table.contains("Accuracy"));
353    }
354
355    #[test]
356    fn test_std_calculation() {
357        let sweeper = Sweeper::new(SweepConfig::temperature(1.0..2.0, 1.0));
358
359        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
360        let std = sweeper.calculate_std(&values);
361        assert!((std - 1.58).abs() < 0.1); // sqrt(2.5) ≈ 1.58
362    }
363
364    #[test]
365    fn test_std_calculation_single_value() {
366        let sweeper = Sweeper::new(SweepConfig::temperature(1.0..2.0, 1.0));
367
368        let values = vec![5.0];
369        let std = sweeper.calculate_std(&values);
370        assert_eq!(std, 0.0);
371    }
372
373    #[test]
374    fn test_std_calculation_empty() {
375        let sweeper = Sweeper::new(SweepConfig::temperature(1.0..2.0, 1.0));
376
377        let values: Vec<f64> = vec![];
378        let std = sweeper.calculate_std(&values);
379        assert_eq!(std, 0.0);
380    }
381
382    #[test]
383    fn test_sweep_config_with_seed() {
384        let config = SweepConfig::temperature(1.0..5.0, 1.0).with_seed(123);
385        assert_eq!(config.seed, Some(123));
386    }
387
388    #[test]
389    fn test_sweep_config_with_early_stop() {
390        let config = SweepConfig::temperature(1.0..5.0, 1.0).with_early_stop();
391        assert!(config.early_stop);
392    }
393
394    #[test]
395    fn test_sweep_config_with_runs() {
396        let config = SweepConfig::temperature(1.0..5.0, 1.0).with_runs(10);
397        assert_eq!(config.runs_per_point, 10);
398    }
399
400    #[test]
401    fn test_sweep_parameter_rank() {
402        let param = SweepParameter::Rank {
403            values: vec![8, 16, 32, 64],
404        };
405        let values = param.values();
406        assert_eq!(values, vec![8.0, 16.0, 32.0, 64.0]);
407        assert_eq!(param.name(), "rank");
408    }
409
410    #[test]
411    fn test_sweep_parameter_learning_rate() {
412        let param = SweepParameter::LearningRate {
413            values: vec![1e-5, 1e-4, 1e-3],
414        };
415        let values = param.values();
416        assert_eq!(values, vec![1e-5, 1e-4, 1e-3]);
417        assert_eq!(param.name(), "learning_rate");
418    }
419
420    #[test]
421    fn test_sweep_result_fields() {
422        let config = SweepConfig::temperature(1.0..3.0, 1.0);
423        let sweeper = Sweeper::new(config);
424        let result = sweeper.run().expect("operation should succeed");
425
426        assert_eq!(result.parameter_name, "temperature");
427        assert!(!result.data_points.is_empty());
428    }
429
430    #[test]
431    fn test_data_point_fields() {
432        let point = DataPoint {
433            parameter_value: 4.0,
434            mean_loss: 0.65,
435            std_loss: 0.02,
436            mean_accuracy: 0.83,
437            std_accuracy: 0.01,
438            runs: 5,
439        };
440
441        assert_eq!(point.parameter_value, 4.0);
442        assert_eq!(point.runs, 5);
443    }
444
445    #[test]
446    fn test_training_metrics_fields() {
447        let metrics = TrainingMetrics {
448            loss: 0.75,
449            accuracy: 0.82,
450            throughput: 1200.0,
451            duration_secs: 3600.0,
452        };
453
454        assert_eq!(metrics.loss, 0.75);
455        assert_eq!(metrics.throughput, 1200.0);
456    }
457
458    #[test]
459    fn test_sweep_result_table_optimal() {
460        let config = SweepConfig::temperature(3.0..5.0, 1.0);
461        let sweeper = Sweeper::new(config);
462        let result = sweeper.run().expect("operation should succeed");
463
464        let table = result.to_table();
465
466        // Should contain "Optimal" section
467        assert!(table.contains("Optimal"));
468        assert!(table.contains('★'));
469    }
470
471    #[test]
472    fn test_sweep_deterministic() {
473        let config = SweepConfig::temperature(1.0..3.0, 1.0).with_seed(42);
474        let sweeper = Sweeper::new(config.clone());
475        let result1 = sweeper.run().expect("operation should succeed");
476
477        let sweeper2 = Sweeper::new(config);
478        let result2 = sweeper2.run().expect("operation should succeed");
479
480        // Same seed should produce same results
481        assert_eq!(
482            result1.data_points[0].mean_loss,
483            result2.data_points[0].mean_loss
484        );
485    }
486
487    #[test]
488    fn test_alpha_sweep_finds_optimal() {
489        let config = SweepConfig::alpha(0.3..0.9, 0.2).with_runs(1);
490        let sweeper = Sweeper::new(config);
491        let result = sweeper.run().expect("operation should succeed");
492
493        // Optimal should be around 0.7
494        let optimal = result.optimal.expect("operation should succeed");
495        assert!((optimal.parameter_value - 0.7).abs() < 0.3);
496    }
497
498    #[test]
499    fn test_sweep_multiple_runs() {
500        let config = SweepConfig::temperature(3.0..5.0, 1.0).with_runs(3);
501        let sweeper = Sweeper::new(config);
502        let result = sweeper.run().expect("operation should succeed");
503
504        // Each data point should have 3 runs
505        for point in &result.data_points {
506            assert_eq!(point.runs, 3);
507        }
508    }
509}