1use entrenar_common::Result;
4
5#[derive(Debug, Clone)]
7pub struct SweepConfig {
8 pub parameter: SweepParameter,
10 pub runs_per_point: usize,
12 pub early_stop: bool,
14 pub seed: Option<u64>,
16}
17
18impl SweepConfig {
19 pub fn temperature(range: std::ops::Range<f32>, step: f32) -> Self {
21 Self {
22 parameter: SweepParameter::Temperature {
23 start: range.start,
24 end: range.end,
25 step,
26 },
27 runs_per_point: 1,
28 early_stop: false,
29 seed: Some(42),
30 }
31 }
32
33 pub fn alpha(range: std::ops::Range<f32>, step: f32) -> Self {
35 Self {
36 parameter: SweepParameter::Alpha {
37 start: range.start,
38 end: range.end,
39 step,
40 },
41 runs_per_point: 1,
42 early_stop: false,
43 seed: Some(42),
44 }
45 }
46
47 pub fn with_runs(mut self, runs: usize) -> Self {
49 self.runs_per_point = runs;
50 self
51 }
52
53 pub fn with_early_stop(mut self) -> Self {
55 self.early_stop = true;
56 self
57 }
58
59 pub fn with_seed(mut self, seed: u64) -> Self {
61 self.seed = Some(seed);
62 self
63 }
64}
65
66#[derive(Debug, Clone)]
68pub enum SweepParameter {
69 Temperature { start: f32, end: f32, step: f32 },
71 Alpha { start: f32, end: f32, step: f32 },
73 Rank { values: Vec<u32> },
75 LearningRate { values: Vec<f64> },
77}
78
79impl SweepParameter {
80 pub fn values(&self) -> Vec<f64> {
82 match self {
83 Self::Temperature { start, end, step } | Self::Alpha { start, end, step } => {
84 let mut values = Vec::new();
85 let mut v = *start;
86 while v <= *end {
87 values.push(f64::from(v));
88 v += step;
89 }
90 values
91 }
92 Self::Rank { values } => values.iter().map(|&v| f64::from(v)).collect(),
93 Self::LearningRate { values } => values.clone(),
94 }
95 }
96
97 pub fn name(&self) -> &'static str {
99 match self {
100 Self::Temperature { .. } => "temperature",
101 Self::Alpha { .. } => "alpha",
102 Self::Rank { .. } => "rank",
103 Self::LearningRate { .. } => "learning_rate",
104 }
105 }
106}
107
108pub struct Sweeper {
110 config: SweepConfig,
111}
112
113impl Sweeper {
114 pub fn new(config: SweepConfig) -> Self {
116 Self { config }
117 }
118
119 pub fn run(&self) -> Result<SweepResult> {
121 let values = self.config.parameter.values();
122 let mut data_points = Vec::new();
123
124 for value in &values {
125 let mut metrics = Vec::new();
126
127 for run in 0..self.config.runs_per_point {
128 let result = self.simulate_training(*value, run);
130 metrics.push(result);
131 }
132
133 let mean_loss = metrics.iter().map(|m| m.loss).sum::<f64>() / metrics.len() as f64;
135 let mean_accuracy =
136 metrics.iter().map(|m| m.accuracy).sum::<f64>() / metrics.len() as f64;
137 let std_loss = self.calculate_std(&metrics.iter().map(|m| m.loss).collect::<Vec<_>>());
138 let std_accuracy =
139 self.calculate_std(&metrics.iter().map(|m| m.accuracy).collect::<Vec<_>>());
140
141 data_points.push(DataPoint {
142 parameter_value: *value,
143 mean_loss,
144 std_loss,
145 mean_accuracy,
146 std_accuracy,
147 runs: metrics.len(),
148 });
149 }
150
151 let optimal = data_points
153 .iter()
154 .min_by(|a, b| {
155 a.mean_loss
156 .partial_cmp(&b.mean_loss)
157 .unwrap_or(std::cmp::Ordering::Equal)
158 })
159 .cloned();
160
161 Ok(SweepResult {
162 parameter_name: self.config.parameter.name().to_string(),
163 data_points,
164 optimal,
165 config: self.config.clone(),
166 })
167 }
168
169 fn simulate_training(&self, param_value: f64, run: usize) -> TrainingMetrics {
170 let seed_offset = self.config.seed.unwrap_or(0) + run as u64;
176 let noise = (seed_offset as f64 * 0.1).sin() * 0.05; let param_name = self.config.parameter.name();
179
180 let (loss, accuracy) = match param_name {
181 "temperature" => {
182 let deviation = (param_value - 4.0).abs();
184 let loss = 0.65 + deviation * 0.1 + noise;
185 let accuracy = 0.83 - deviation * 0.02 + noise * 0.5;
186 (loss, accuracy.clamp(0.0, 1.0))
187 }
188 "alpha" => {
189 let deviation = (param_value - 0.7).abs();
191 let loss = 0.65 + deviation * 0.2 + noise;
192 let accuracy = 0.83 - deviation * 0.05 + noise * 0.5;
193 (loss, accuracy.clamp(0.0, 1.0))
194 }
195 _ => (0.8 + noise, 0.75 + noise * 0.5),
196 };
197
198 TrainingMetrics {
199 loss,
200 accuracy,
201 throughput: 1200.0 + noise * 100.0,
202 duration_secs: 3600.0 + noise * 600.0,
203 }
204 }
205
206 fn calculate_std(&self, values: &[f64]) -> f64 {
207 if values.len() < 2 {
208 return 0.0;
209 }
210 let mean = values.iter().sum::<f64>() / values.len() as f64;
211 let variance =
212 values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
213 variance.sqrt()
214 }
215}
216
217#[derive(Debug, Clone)]
219pub struct TrainingMetrics {
220 pub loss: f64,
222 pub accuracy: f64,
224 pub throughput: f64,
226 pub duration_secs: f64,
228}
229
230#[derive(Debug, Clone)]
232pub struct DataPoint {
233 pub parameter_value: f64,
235 pub mean_loss: f64,
237 pub std_loss: f64,
239 pub mean_accuracy: f64,
241 pub std_accuracy: f64,
243 pub runs: usize,
245}
246
247#[derive(Debug, Clone)]
249pub struct SweepResult {
250 pub parameter_name: String,
252 pub data_points: Vec<DataPoint>,
254 pub optimal: Option<DataPoint>,
256 pub config: SweepConfig,
258}
259
260impl SweepResult {
261 pub fn to_table(&self) -> String {
263 let mut output = format!("{} Sweep Results\n", self.parameter_name);
264 output.push_str("┌─────────────┬────────────┬────────────┬────────────┐\n");
265 output.push_str("│ Value │ Loss │ Accuracy │ Runs │\n");
266 output.push_str("├─────────────┼────────────┼────────────┼────────────┤\n");
267
268 for point in &self.data_points {
269 let optimal_marker = if self.optimal.as_ref().map(|o| o.parameter_value)
270 == Some(point.parameter_value)
271 {
272 " ★"
273 } else {
274 ""
275 };
276
277 output.push_str(&format!(
278 "│ {:>10.2} │ {:>10.4} │ {:>9.1}% │ {:>10}{} │\n",
279 point.parameter_value,
280 point.mean_loss,
281 point.mean_accuracy * 100.0,
282 point.runs,
283 optimal_marker
284 ));
285 }
286
287 output.push_str("└─────────────┴────────────┴────────────┴────────────┘\n");
288
289 if let Some(optimal) = &self.optimal {
290 output.push_str(&format!(
291 "\nOptimal: {} = {:.2} (loss={:.4}, accuracy={:.1}%)\n",
292 self.parameter_name,
293 optimal.parameter_value,
294 optimal.mean_loss,
295 optimal.mean_accuracy * 100.0
296 ));
297 }
298
299 output
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn test_sweep_config_temperature() {
309 let config = SweepConfig::temperature(1.0..5.0, 1.0);
310 assert_eq!(config.parameter.name(), "temperature");
311
312 let values = config.parameter.values();
313 assert_eq!(values.len(), 5); }
315
316 #[test]
317 fn test_sweep_config_alpha() {
318 let config = SweepConfig::alpha(0.1..0.9, 0.1);
319 assert_eq!(config.parameter.name(), "alpha");
320 }
321
322 #[test]
323 fn test_sweeper_runs() {
324 let config = SweepConfig::temperature(1.0..3.0, 1.0).with_runs(2);
325 let sweeper = Sweeper::new(config);
326 let result = sweeper.run().expect("operation should succeed");
327
328 assert!(!result.data_points.is_empty());
329 assert!(result.optimal.is_some());
330 }
331
332 #[test]
333 fn test_sweeper_finds_optimal_temperature() {
334 let config = SweepConfig::temperature(2.0..6.0, 1.0).with_runs(1);
335 let sweeper = Sweeper::new(config);
336 let result = sweeper.run().expect("operation should succeed");
337
338 let optimal = result.optimal.expect("operation should succeed");
340 assert!((optimal.parameter_value - 4.0).abs() < 1.5);
341 }
342
343 #[test]
344 fn test_sweep_result_table() {
345 let config = SweepConfig::temperature(1.0..3.0, 1.0);
346 let sweeper = Sweeper::new(config);
347 let result = sweeper.run().expect("operation should succeed");
348
349 let table = result.to_table();
350 assert!(table.contains("temperature"));
351 assert!(table.contains("Loss"));
352 assert!(table.contains("Accuracy"));
353 }
354
355 #[test]
356 fn test_std_calculation() {
357 let sweeper = Sweeper::new(SweepConfig::temperature(1.0..2.0, 1.0));
358
359 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
360 let std = sweeper.calculate_std(&values);
361 assert!((std - 1.58).abs() < 0.1); }
363
364 #[test]
365 fn test_std_calculation_single_value() {
366 let sweeper = Sweeper::new(SweepConfig::temperature(1.0..2.0, 1.0));
367
368 let values = vec![5.0];
369 let std = sweeper.calculate_std(&values);
370 assert_eq!(std, 0.0);
371 }
372
373 #[test]
374 fn test_std_calculation_empty() {
375 let sweeper = Sweeper::new(SweepConfig::temperature(1.0..2.0, 1.0));
376
377 let values: Vec<f64> = vec![];
378 let std = sweeper.calculate_std(&values);
379 assert_eq!(std, 0.0);
380 }
381
382 #[test]
383 fn test_sweep_config_with_seed() {
384 let config = SweepConfig::temperature(1.0..5.0, 1.0).with_seed(123);
385 assert_eq!(config.seed, Some(123));
386 }
387
388 #[test]
389 fn test_sweep_config_with_early_stop() {
390 let config = SweepConfig::temperature(1.0..5.0, 1.0).with_early_stop();
391 assert!(config.early_stop);
392 }
393
394 #[test]
395 fn test_sweep_config_with_runs() {
396 let config = SweepConfig::temperature(1.0..5.0, 1.0).with_runs(10);
397 assert_eq!(config.runs_per_point, 10);
398 }
399
400 #[test]
401 fn test_sweep_parameter_rank() {
402 let param = SweepParameter::Rank {
403 values: vec![8, 16, 32, 64],
404 };
405 let values = param.values();
406 assert_eq!(values, vec![8.0, 16.0, 32.0, 64.0]);
407 assert_eq!(param.name(), "rank");
408 }
409
410 #[test]
411 fn test_sweep_parameter_learning_rate() {
412 let param = SweepParameter::LearningRate {
413 values: vec![1e-5, 1e-4, 1e-3],
414 };
415 let values = param.values();
416 assert_eq!(values, vec![1e-5, 1e-4, 1e-3]);
417 assert_eq!(param.name(), "learning_rate");
418 }
419
420 #[test]
421 fn test_sweep_result_fields() {
422 let config = SweepConfig::temperature(1.0..3.0, 1.0);
423 let sweeper = Sweeper::new(config);
424 let result = sweeper.run().expect("operation should succeed");
425
426 assert_eq!(result.parameter_name, "temperature");
427 assert!(!result.data_points.is_empty());
428 }
429
430 #[test]
431 fn test_data_point_fields() {
432 let point = DataPoint {
433 parameter_value: 4.0,
434 mean_loss: 0.65,
435 std_loss: 0.02,
436 mean_accuracy: 0.83,
437 std_accuracy: 0.01,
438 runs: 5,
439 };
440
441 assert_eq!(point.parameter_value, 4.0);
442 assert_eq!(point.runs, 5);
443 }
444
445 #[test]
446 fn test_training_metrics_fields() {
447 let metrics = TrainingMetrics {
448 loss: 0.75,
449 accuracy: 0.82,
450 throughput: 1200.0,
451 duration_secs: 3600.0,
452 };
453
454 assert_eq!(metrics.loss, 0.75);
455 assert_eq!(metrics.throughput, 1200.0);
456 }
457
458 #[test]
459 fn test_sweep_result_table_optimal() {
460 let config = SweepConfig::temperature(3.0..5.0, 1.0);
461 let sweeper = Sweeper::new(config);
462 let result = sweeper.run().expect("operation should succeed");
463
464 let table = result.to_table();
465
466 assert!(table.contains("Optimal"));
468 assert!(table.contains('★'));
469 }
470
471 #[test]
472 fn test_sweep_deterministic() {
473 let config = SweepConfig::temperature(1.0..3.0, 1.0).with_seed(42);
474 let sweeper = Sweeper::new(config.clone());
475 let result1 = sweeper.run().expect("operation should succeed");
476
477 let sweeper2 = Sweeper::new(config);
478 let result2 = sweeper2.run().expect("operation should succeed");
479
480 assert_eq!(
482 result1.data_points[0].mean_loss,
483 result2.data_points[0].mean_loss
484 );
485 }
486
487 #[test]
488 fn test_alpha_sweep_finds_optimal() {
489 let config = SweepConfig::alpha(0.3..0.9, 0.2).with_runs(1);
490 let sweeper = Sweeper::new(config);
491 let result = sweeper.run().expect("operation should succeed");
492
493 let optimal = result.optimal.expect("operation should succeed");
495 assert!((optimal.parameter_value - 0.7).abs() < 0.3);
496 }
497
498 #[test]
499 fn test_sweep_multiple_runs() {
500 let config = SweepConfig::temperature(3.0..5.0, 1.0).with_runs(3);
501 let sweeper = Sweeper::new(config);
502 let result = sweeper.run().expect("operation should succeed");
503
504 for point in &result.data_points {
506 assert_eq!(point.runs, 3);
507 }
508 }
509}