1use crate::stats::StatisticalAnalyzer;
4use entrenar_common::Result;
5
6#[derive(Debug, Clone)]
8pub enum DistillStrategy {
9 KDOnly { temperature: f32, alpha: f32 },
11 Progressive {
13 temperature: f32,
14 alpha: f32,
15 layer_weight: f32,
16 },
17 Attention {
19 temperature: f32,
20 alpha: f32,
21 attention_weight: f32,
22 },
23 Combined {
25 temperature: f32,
26 alpha: f32,
27 layer_weight: f32,
28 attention_weight: f32,
29 },
30}
31
32impl DistillStrategy {
33 pub fn name(&self) -> &'static str {
35 match self {
36 Self::KDOnly { .. } => "KD-only",
37 Self::Progressive { .. } => "Progressive",
38 Self::Attention { .. } => "Attention",
39 Self::Combined { .. } => "Combined",
40 }
41 }
42
43 pub fn kd_only() -> Self {
45 Self::KDOnly {
46 temperature: 4.0,
47 alpha: 0.7,
48 }
49 }
50
51 pub fn progressive() -> Self {
53 Self::Progressive {
54 temperature: 4.0,
55 alpha: 0.7,
56 layer_weight: 0.3,
57 }
58 }
59
60 pub fn attention() -> Self {
62 Self::Attention {
63 temperature: 4.0,
64 alpha: 0.7,
65 attention_weight: 0.1,
66 }
67 }
68
69 pub fn combined() -> Self {
71 Self::Combined {
72 temperature: 4.0,
73 alpha: 0.7,
74 layer_weight: 0.3,
75 attention_weight: 0.1,
76 }
77 }
78
79 fn simulate(&self, seed: u64) -> StrategyMetrics {
81 let noise = (seed as f64 * 0.1).sin() * 0.02;
82
83 let (base_loss, base_accuracy, time_factor) = match self {
84 Self::KDOnly { .. } => (0.82, 0.782, 1.0),
85 Self::Progressive { .. } => (0.75, 0.818, 1.15),
86 Self::Attention { .. } => (0.78, 0.796, 1.08),
87 Self::Combined { .. } => (0.71, 0.831, 1.25),
88 };
89
90 StrategyMetrics {
91 final_loss: base_loss + noise,
92 final_accuracy: base_accuracy + noise * 0.5,
93 training_time_hours: 2.0 * time_factor + noise * 0.5,
94 peak_memory_gb: 16.0 + noise * 2.0,
95 }
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct StrategyMetrics {
102 pub final_loss: f64,
104 pub final_accuracy: f64,
106 pub training_time_hours: f64,
108 pub peak_memory_gb: f64,
110}
111
112#[derive(Debug, Clone)]
114pub struct StrategyComparison {
115 pub results: Vec<StrategyResult>,
117 pub best_by_loss: Option<String>,
119 pub best_by_accuracy: Option<String>,
121 pub significance: Vec<PairwiseComparison>,
123}
124
125#[derive(Debug, Clone)]
127pub struct StrategyResult {
128 pub name: String,
130 pub mean_loss: f64,
132 pub std_loss: f64,
134 pub mean_accuracy: f64,
136 pub std_accuracy: f64,
138 pub mean_time_hours: f64,
140 pub runs: usize,
142}
143
144#[derive(Debug, Clone)]
146pub struct PairwiseComparison {
147 pub strategy1: String,
149 pub strategy2: String,
151 pub p_value: f64,
153 pub significant: bool,
155 pub effect_size: f64,
157}
158
159pub fn compare(strategies: &[DistillStrategy]) -> Result<StrategyComparison> {
161 let runs_per_strategy = 5;
162 let mut results = Vec::new();
163 let mut all_losses: Vec<(String, Vec<f64>)> = Vec::new();
164
165 for strategy in strategies {
166 let mut losses = Vec::new();
167 let mut accuracies = Vec::new();
168 let mut times = Vec::new();
169
170 for run in 0..runs_per_strategy {
171 let metrics = strategy.simulate(run as u64);
172 losses.push(metrics.final_loss);
173 accuracies.push(metrics.final_accuracy);
174 times.push(metrics.training_time_hours);
175 }
176
177 let n = losses.len() as f64;
178 let mean_loss = losses.iter().sum::<f64>() / n;
179 let mean_accuracy = accuracies.iter().sum::<f64>() / n;
180 let mean_time = times.iter().sum::<f64>() / n;
181
182 let std_loss =
183 (losses.iter().map(|x| (x - mean_loss).powi(2)).sum::<f64>() / (n - 1.0)).sqrt();
184 let std_accuracy = (accuracies
185 .iter()
186 .map(|x| (x - mean_accuracy).powi(2))
187 .sum::<f64>()
188 / (n - 1.0))
189 .sqrt();
190
191 results.push(StrategyResult {
192 name: strategy.name().to_string(),
193 mean_loss,
194 std_loss,
195 mean_accuracy,
196 std_accuracy,
197 mean_time_hours: mean_time,
198 runs: runs_per_strategy,
199 });
200
201 all_losses.push((strategy.name().to_string(), losses));
202 }
203
204 let best_by_loss = results
206 .iter()
207 .min_by(|a, b| {
208 a.mean_loss
209 .partial_cmp(&b.mean_loss)
210 .unwrap_or(std::cmp::Ordering::Equal)
211 })
212 .map(|r| r.name.clone());
213
214 let best_by_accuracy = results
215 .iter()
216 .max_by(|a, b| {
217 a.mean_accuracy
218 .partial_cmp(&b.mean_accuracy)
219 .unwrap_or(std::cmp::Ordering::Equal)
220 })
221 .map(|r| r.name.clone());
222
223 let mut significance = Vec::new();
225 for i in 0..all_losses.len() {
226 for j in (i + 1)..all_losses.len() {
227 let (name1, losses1) = &all_losses[i];
228 let (name2, losses2) = &all_losses[j];
229
230 let test = StatisticalAnalyzer::welch_t_test(losses1, losses2);
231
232 significance.push(PairwiseComparison {
233 strategy1: name1.clone(),
234 strategy2: name2.clone(),
235 p_value: test.p_value,
236 significant: test.significant,
237 effect_size: test.effect_size,
238 });
239 }
240 }
241
242 Ok(StrategyComparison {
243 results,
244 best_by_loss,
245 best_by_accuracy,
246 significance,
247 })
248}
249
250impl StrategyComparison {
251 pub fn to_table(&self) -> String {
253 let mut output = String::from("Strategy Comparison\n");
254 output.push_str("┌──────────────┬─────────────────┬─────────────────┬────────────┐\n");
255 output.push_str("│ Strategy │ Loss │ Accuracy │ Time (h) │\n");
256 output.push_str("├──────────────┼─────────────────┼─────────────────┼────────────┤\n");
257
258 for result in &self.results {
259 let loss_marker = if self.best_by_loss.as_ref() == Some(&result.name) {
260 " ★"
261 } else {
262 ""
263 };
264 let acc_marker = if self.best_by_accuracy.as_ref() == Some(&result.name) {
265 " ★"
266 } else {
267 ""
268 };
269
270 output.push_str(&format!(
271 "│ {:12} │ {:.3} ± {:.3}{:2} │ {:.1}% ± {:.1}%{:2} │ {:>10.1} │\n",
272 result.name,
273 result.mean_loss,
274 result.std_loss,
275 loss_marker,
276 result.mean_accuracy * 100.0,
277 result.std_accuracy * 100.0,
278 acc_marker,
279 result.mean_time_hours
280 ));
281 }
282
283 output.push_str("└──────────────┴─────────────────┴─────────────────┴────────────┘\n");
284
285 output.push_str("\nStatistical Significance:\n");
287 for comp in &self.significance {
288 let sig = if comp.significant { "✓" } else { "✗" };
289 output.push_str(&format!(
290 " {} vs {}: p={:.4} {} (effect={:.2})\n",
291 comp.strategy1, comp.strategy2, comp.p_value, sig, comp.effect_size
292 ));
293 }
294
295 output
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn test_strategy_names() {
305 assert_eq!(DistillStrategy::kd_only().name(), "KD-only");
306 assert_eq!(DistillStrategy::progressive().name(), "Progressive");
307 assert_eq!(DistillStrategy::attention().name(), "Attention");
308 assert_eq!(DistillStrategy::combined().name(), "Combined");
309 }
310
311 #[test]
312 fn test_compare_strategies() {
313 let strategies = vec![
314 DistillStrategy::kd_only(),
315 DistillStrategy::progressive(),
316 DistillStrategy::combined(),
317 ];
318
319 let comparison = compare(&strategies).expect("operation should succeed");
320
321 assert_eq!(comparison.results.len(), 3);
322 assert!(comparison.best_by_loss.is_some());
323 assert!(comparison.best_by_accuracy.is_some());
324 }
325
326 #[test]
327 fn test_combined_is_best() {
328 let strategies = vec![DistillStrategy::kd_only(), DistillStrategy::combined()];
329
330 let comparison = compare(&strategies).expect("operation should succeed");
331
332 assert_eq!(comparison.best_by_accuracy.as_deref(), Some("Combined"));
334 }
335
336 #[test]
337 fn test_comparison_table() {
338 let strategies = vec![DistillStrategy::kd_only(), DistillStrategy::progressive()];
339
340 let comparison = compare(&strategies).expect("operation should succeed");
341 let table = comparison.to_table();
342
343 assert!(table.contains("KD-only"));
344 assert!(table.contains("Progressive"));
345 assert!(table.contains("Significance"));
346 }
347
348 #[test]
349 fn test_strategy_constructors() {
350 let kd = DistillStrategy::kd_only();
351 if let DistillStrategy::KDOnly { temperature, alpha } = kd {
352 assert_eq!(temperature, 4.0);
353 assert_eq!(alpha, 0.7);
354 } else {
355 panic!("Expected KDOnly");
356 }
357
358 let prog = DistillStrategy::progressive();
359 if let DistillStrategy::Progressive {
360 temperature,
361 alpha,
362 layer_weight,
363 } = prog
364 {
365 assert_eq!(temperature, 4.0);
366 assert_eq!(alpha, 0.7);
367 assert_eq!(layer_weight, 0.3);
368 } else {
369 panic!("Expected Progressive");
370 }
371
372 let attn = DistillStrategy::attention();
373 if let DistillStrategy::Attention {
374 temperature,
375 alpha,
376 attention_weight,
377 } = attn
378 {
379 assert_eq!(temperature, 4.0);
380 assert_eq!(alpha, 0.7);
381 assert_eq!(attention_weight, 0.1);
382 } else {
383 panic!("Expected Attention");
384 }
385
386 let combined = DistillStrategy::combined();
387 if let DistillStrategy::Combined {
388 temperature,
389 alpha,
390 layer_weight,
391 attention_weight,
392 } = combined
393 {
394 assert_eq!(temperature, 4.0);
395 assert_eq!(alpha, 0.7);
396 assert_eq!(layer_weight, 0.3);
397 assert_eq!(attention_weight, 0.1);
398 } else {
399 panic!("Expected Combined");
400 }
401 }
402
403 #[test]
404 fn test_strategy_simulate_deterministic() {
405 let strategy = DistillStrategy::kd_only();
406 let metrics1 = strategy.simulate(42);
407 let metrics2 = strategy.simulate(42);
408
409 assert_eq!(metrics1.final_loss, metrics2.final_loss);
411 assert_eq!(metrics1.final_accuracy, metrics2.final_accuracy);
412 }
413
414 #[test]
415 fn test_strategy_simulate_different_seeds() {
416 let strategy = DistillStrategy::kd_only();
417 let metrics1 = strategy.simulate(1);
418 let metrics2 = strategy.simulate(2);
419
420 assert_ne!(metrics1.final_loss, metrics2.final_loss);
422 }
423
424 #[test]
425 fn test_strategy_metrics_fields() {
426 let metrics = StrategyMetrics {
427 final_loss: 0.75,
428 final_accuracy: 0.82,
429 training_time_hours: 2.5,
430 peak_memory_gb: 16.0,
431 };
432
433 assert_eq!(metrics.final_loss, 0.75);
434 assert_eq!(metrics.final_accuracy, 0.82);
435 assert_eq!(metrics.training_time_hours, 2.5);
436 assert_eq!(metrics.peak_memory_gb, 16.0);
437 }
438
439 #[test]
440 fn test_strategy_result_fields() {
441 let result = StrategyResult {
442 name: "test".to_string(),
443 mean_loss: 0.7,
444 std_loss: 0.02,
445 mean_accuracy: 0.85,
446 std_accuracy: 0.01,
447 mean_time_hours: 3.0,
448 runs: 5,
449 };
450
451 assert_eq!(result.name, "test");
452 assert_eq!(result.runs, 5);
453 }
454
455 #[test]
456 fn test_pairwise_comparison_fields() {
457 let comp = PairwiseComparison {
458 strategy1: "A".to_string(),
459 strategy2: "B".to_string(),
460 p_value: 0.03,
461 significant: true,
462 effect_size: 0.8,
463 };
464
465 assert!(comp.significant);
466 assert_eq!(comp.effect_size, 0.8);
467 }
468
469 #[test]
470 fn test_comparison_significance_markers() {
471 let strategies = vec![DistillStrategy::kd_only(), DistillStrategy::combined()];
472
473 let comparison = compare(&strategies).expect("operation should succeed");
474
475 assert_eq!(comparison.significance.len(), 1);
477 }
478
479 #[test]
480 fn test_compare_all_strategies() {
481 let strategies = vec![
482 DistillStrategy::kd_only(),
483 DistillStrategy::progressive(),
484 DistillStrategy::attention(),
485 DistillStrategy::combined(),
486 ];
487
488 let comparison = compare(&strategies).expect("operation should succeed");
489
490 assert_eq!(comparison.significance.len(), 6);
492 assert_eq!(comparison.results.len(), 4);
493 }
494
495 #[test]
496 fn test_comparison_table_star_markers() {
497 let strategies = vec![DistillStrategy::kd_only(), DistillStrategy::combined()];
498
499 let comparison = compare(&strategies).expect("operation should succeed");
500 let table = comparison.to_table();
501
502 assert!(table.contains('★'));
504 }
505}