Skip to main content

entrenar_bench/
strategies.rs

1//! Distillation strategy comparison.
2
3use crate::stats::StatisticalAnalyzer;
4use entrenar_common::Result;
5
6/// A distillation strategy to benchmark.
7#[derive(Debug, Clone)]
8pub enum DistillStrategy {
9    /// Knowledge distillation only (soft targets)
10    KDOnly { temperature: f32, alpha: f32 },
11    /// Progressive distillation (hidden state matching)
12    Progressive {
13        temperature: f32,
14        alpha: f32,
15        layer_weight: f32,
16    },
17    /// Attention transfer
18    Attention {
19        temperature: f32,
20        alpha: f32,
21        attention_weight: f32,
22    },
23    /// Combined approach
24    Combined {
25        temperature: f32,
26        alpha: f32,
27        layer_weight: f32,
28        attention_weight: f32,
29    },
30}
31
32impl DistillStrategy {
33    /// Get strategy name.
34    pub fn name(&self) -> &'static str {
35        match self {
36            Self::KDOnly { .. } => "KD-only",
37            Self::Progressive { .. } => "Progressive",
38            Self::Attention { .. } => "Attention",
39            Self::Combined { .. } => "Combined",
40        }
41    }
42
43    /// Default KD-only strategy.
44    pub fn kd_only() -> Self {
45        Self::KDOnly {
46            temperature: 4.0,
47            alpha: 0.7,
48        }
49    }
50
51    /// Default progressive strategy.
52    pub fn progressive() -> Self {
53        Self::Progressive {
54            temperature: 4.0,
55            alpha: 0.7,
56            layer_weight: 0.3,
57        }
58    }
59
60    /// Default attention strategy.
61    pub fn attention() -> Self {
62        Self::Attention {
63            temperature: 4.0,
64            alpha: 0.7,
65            attention_weight: 0.1,
66        }
67    }
68
69    /// Default combined strategy.
70    pub fn combined() -> Self {
71        Self::Combined {
72            temperature: 4.0,
73            alpha: 0.7,
74            layer_weight: 0.3,
75            attention_weight: 0.1,
76        }
77    }
78
79    /// Simulate training with this strategy.
80    fn simulate(&self, seed: u64) -> StrategyMetrics {
81        let noise = (seed as f64 * 0.1).sin() * 0.02;
82
83        let (base_loss, base_accuracy, time_factor) = match self {
84            Self::KDOnly { .. } => (0.82, 0.782, 1.0),
85            Self::Progressive { .. } => (0.75, 0.818, 1.15),
86            Self::Attention { .. } => (0.78, 0.796, 1.08),
87            Self::Combined { .. } => (0.71, 0.831, 1.25),
88        };
89
90        StrategyMetrics {
91            final_loss: base_loss + noise,
92            final_accuracy: base_accuracy + noise * 0.5,
93            training_time_hours: 2.0 * time_factor + noise * 0.5,
94            peak_memory_gb: 16.0 + noise * 2.0,
95        }
96    }
97}
98
99/// Metrics from running a strategy.
100#[derive(Debug, Clone)]
101pub struct StrategyMetrics {
102    /// Final training loss
103    pub final_loss: f64,
104    /// Final accuracy/score
105    pub final_accuracy: f64,
106    /// Training time in hours
107    pub training_time_hours: f64,
108    /// Peak memory usage in GB
109    pub peak_memory_gb: f64,
110}
111
112/// Result of comparing strategies.
113#[derive(Debug, Clone)]
114pub struct StrategyComparison {
115    /// Results per strategy
116    pub results: Vec<StrategyResult>,
117    /// Best strategy by loss
118    pub best_by_loss: Option<String>,
119    /// Best strategy by accuracy
120    pub best_by_accuracy: Option<String>,
121    /// Statistical significance of differences
122    pub significance: Vec<PairwiseComparison>,
123}
124
125/// Result for a single strategy.
126#[derive(Debug, Clone)]
127pub struct StrategyResult {
128    /// Strategy name
129    pub name: String,
130    /// Mean metrics across runs
131    pub mean_loss: f64,
132    /// Standard deviation
133    pub std_loss: f64,
134    /// Mean accuracy
135    pub mean_accuracy: f64,
136    /// Standard deviation
137    pub std_accuracy: f64,
138    /// Mean training time
139    pub mean_time_hours: f64,
140    /// Number of runs
141    pub runs: usize,
142}
143
144/// Pairwise statistical comparison.
145#[derive(Debug, Clone)]
146pub struct PairwiseComparison {
147    /// First strategy
148    pub strategy1: String,
149    /// Second strategy
150    pub strategy2: String,
151    /// P-value for difference
152    pub p_value: f64,
153    /// Whether difference is significant
154    pub significant: bool,
155    /// Effect size
156    pub effect_size: f64,
157}
158
159/// Compare multiple strategies.
160pub fn compare(strategies: &[DistillStrategy]) -> Result<StrategyComparison> {
161    let runs_per_strategy = 5;
162    let mut results = Vec::new();
163    let mut all_losses: Vec<(String, Vec<f64>)> = Vec::new();
164
165    for strategy in strategies {
166        let mut losses = Vec::new();
167        let mut accuracies = Vec::new();
168        let mut times = Vec::new();
169
170        for run in 0..runs_per_strategy {
171            let metrics = strategy.simulate(run as u64);
172            losses.push(metrics.final_loss);
173            accuracies.push(metrics.final_accuracy);
174            times.push(metrics.training_time_hours);
175        }
176
177        let n = losses.len() as f64;
178        let mean_loss = losses.iter().sum::<f64>() / n;
179        let mean_accuracy = accuracies.iter().sum::<f64>() / n;
180        let mean_time = times.iter().sum::<f64>() / n;
181
182        let std_loss =
183            (losses.iter().map(|x| (x - mean_loss).powi(2)).sum::<f64>() / (n - 1.0)).sqrt();
184        let std_accuracy = (accuracies
185            .iter()
186            .map(|x| (x - mean_accuracy).powi(2))
187            .sum::<f64>()
188            / (n - 1.0))
189            .sqrt();
190
191        results.push(StrategyResult {
192            name: strategy.name().to_string(),
193            mean_loss,
194            std_loss,
195            mean_accuracy,
196            std_accuracy,
197            mean_time_hours: mean_time,
198            runs: runs_per_strategy,
199        });
200
201        all_losses.push((strategy.name().to_string(), losses));
202    }
203
204    // Find best
205    let best_by_loss = results
206        .iter()
207        .min_by(|a, b| {
208            a.mean_loss
209                .partial_cmp(&b.mean_loss)
210                .unwrap_or(std::cmp::Ordering::Equal)
211        })
212        .map(|r| r.name.clone());
213
214    let best_by_accuracy = results
215        .iter()
216        .max_by(|a, b| {
217            a.mean_accuracy
218                .partial_cmp(&b.mean_accuracy)
219                .unwrap_or(std::cmp::Ordering::Equal)
220        })
221        .map(|r| r.name.clone());
222
223    // Pairwise comparisons
224    let mut significance = Vec::new();
225    for i in 0..all_losses.len() {
226        for j in (i + 1)..all_losses.len() {
227            let (name1, losses1) = &all_losses[i];
228            let (name2, losses2) = &all_losses[j];
229
230            let test = StatisticalAnalyzer::welch_t_test(losses1, losses2);
231
232            significance.push(PairwiseComparison {
233                strategy1: name1.clone(),
234                strategy2: name2.clone(),
235                p_value: test.p_value,
236                significant: test.significant,
237                effect_size: test.effect_size,
238            });
239        }
240    }
241
242    Ok(StrategyComparison {
243        results,
244        best_by_loss,
245        best_by_accuracy,
246        significance,
247    })
248}
249
250impl StrategyComparison {
251    /// Format as ASCII table.
252    pub fn to_table(&self) -> String {
253        let mut output = String::from("Strategy Comparison\n");
254        output.push_str("┌──────────────┬─────────────────┬─────────────────┬────────────┐\n");
255        output.push_str("│ Strategy     │ Loss            │ Accuracy        │ Time (h)   │\n");
256        output.push_str("├──────────────┼─────────────────┼─────────────────┼────────────┤\n");
257
258        for result in &self.results {
259            let loss_marker = if self.best_by_loss.as_ref() == Some(&result.name) {
260                " ★"
261            } else {
262                ""
263            };
264            let acc_marker = if self.best_by_accuracy.as_ref() == Some(&result.name) {
265                " ★"
266            } else {
267                ""
268            };
269
270            output.push_str(&format!(
271                "│ {:12} │ {:.3} ± {:.3}{:2} │ {:.1}% ± {:.1}%{:2} │ {:>10.1} │\n",
272                result.name,
273                result.mean_loss,
274                result.std_loss,
275                loss_marker,
276                result.mean_accuracy * 100.0,
277                result.std_accuracy * 100.0,
278                acc_marker,
279                result.mean_time_hours
280            ));
281        }
282
283        output.push_str("└──────────────┴─────────────────┴─────────────────┴────────────┘\n");
284
285        // Significance
286        output.push_str("\nStatistical Significance:\n");
287        for comp in &self.significance {
288            let sig = if comp.significant { "✓" } else { "✗" };
289            output.push_str(&format!(
290                "  {} vs {}: p={:.4} {} (effect={:.2})\n",
291                comp.strategy1, comp.strategy2, comp.p_value, sig, comp.effect_size
292            ));
293        }
294
295        output
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    #[test]
304    fn test_strategy_names() {
305        assert_eq!(DistillStrategy::kd_only().name(), "KD-only");
306        assert_eq!(DistillStrategy::progressive().name(), "Progressive");
307        assert_eq!(DistillStrategy::attention().name(), "Attention");
308        assert_eq!(DistillStrategy::combined().name(), "Combined");
309    }
310
311    #[test]
312    fn test_compare_strategies() {
313        let strategies = vec![
314            DistillStrategy::kd_only(),
315            DistillStrategy::progressive(),
316            DistillStrategy::combined(),
317        ];
318
319        let comparison = compare(&strategies).expect("operation should succeed");
320
321        assert_eq!(comparison.results.len(), 3);
322        assert!(comparison.best_by_loss.is_some());
323        assert!(comparison.best_by_accuracy.is_some());
324    }
325
326    #[test]
327    fn test_combined_is_best() {
328        let strategies = vec![DistillStrategy::kd_only(), DistillStrategy::combined()];
329
330        let comparison = compare(&strategies).expect("operation should succeed");
331
332        // Combined should generally be best
333        assert_eq!(comparison.best_by_accuracy.as_deref(), Some("Combined"));
334    }
335
336    #[test]
337    fn test_comparison_table() {
338        let strategies = vec![DistillStrategy::kd_only(), DistillStrategy::progressive()];
339
340        let comparison = compare(&strategies).expect("operation should succeed");
341        let table = comparison.to_table();
342
343        assert!(table.contains("KD-only"));
344        assert!(table.contains("Progressive"));
345        assert!(table.contains("Significance"));
346    }
347
348    #[test]
349    fn test_strategy_constructors() {
350        let kd = DistillStrategy::kd_only();
351        if let DistillStrategy::KDOnly { temperature, alpha } = kd {
352            assert_eq!(temperature, 4.0);
353            assert_eq!(alpha, 0.7);
354        } else {
355            panic!("Expected KDOnly");
356        }
357
358        let prog = DistillStrategy::progressive();
359        if let DistillStrategy::Progressive {
360            temperature,
361            alpha,
362            layer_weight,
363        } = prog
364        {
365            assert_eq!(temperature, 4.0);
366            assert_eq!(alpha, 0.7);
367            assert_eq!(layer_weight, 0.3);
368        } else {
369            panic!("Expected Progressive");
370        }
371
372        let attn = DistillStrategy::attention();
373        if let DistillStrategy::Attention {
374            temperature,
375            alpha,
376            attention_weight,
377        } = attn
378        {
379            assert_eq!(temperature, 4.0);
380            assert_eq!(alpha, 0.7);
381            assert_eq!(attention_weight, 0.1);
382        } else {
383            panic!("Expected Attention");
384        }
385
386        let combined = DistillStrategy::combined();
387        if let DistillStrategy::Combined {
388            temperature,
389            alpha,
390            layer_weight,
391            attention_weight,
392        } = combined
393        {
394            assert_eq!(temperature, 4.0);
395            assert_eq!(alpha, 0.7);
396            assert_eq!(layer_weight, 0.3);
397            assert_eq!(attention_weight, 0.1);
398        } else {
399            panic!("Expected Combined");
400        }
401    }
402
403    #[test]
404    fn test_strategy_simulate_deterministic() {
405        let strategy = DistillStrategy::kd_only();
406        let metrics1 = strategy.simulate(42);
407        let metrics2 = strategy.simulate(42);
408
409        // Same seed should produce same results
410        assert_eq!(metrics1.final_loss, metrics2.final_loss);
411        assert_eq!(metrics1.final_accuracy, metrics2.final_accuracy);
412    }
413
414    #[test]
415    fn test_strategy_simulate_different_seeds() {
416        let strategy = DistillStrategy::kd_only();
417        let metrics1 = strategy.simulate(1);
418        let metrics2 = strategy.simulate(2);
419
420        // Different seeds should produce different results (due to noise)
421        assert_ne!(metrics1.final_loss, metrics2.final_loss);
422    }
423
424    #[test]
425    fn test_strategy_metrics_fields() {
426        let metrics = StrategyMetrics {
427            final_loss: 0.75,
428            final_accuracy: 0.82,
429            training_time_hours: 2.5,
430            peak_memory_gb: 16.0,
431        };
432
433        assert_eq!(metrics.final_loss, 0.75);
434        assert_eq!(metrics.final_accuracy, 0.82);
435        assert_eq!(metrics.training_time_hours, 2.5);
436        assert_eq!(metrics.peak_memory_gb, 16.0);
437    }
438
439    #[test]
440    fn test_strategy_result_fields() {
441        let result = StrategyResult {
442            name: "test".to_string(),
443            mean_loss: 0.7,
444            std_loss: 0.02,
445            mean_accuracy: 0.85,
446            std_accuracy: 0.01,
447            mean_time_hours: 3.0,
448            runs: 5,
449        };
450
451        assert_eq!(result.name, "test");
452        assert_eq!(result.runs, 5);
453    }
454
455    #[test]
456    fn test_pairwise_comparison_fields() {
457        let comp = PairwiseComparison {
458            strategy1: "A".to_string(),
459            strategy2: "B".to_string(),
460            p_value: 0.03,
461            significant: true,
462            effect_size: 0.8,
463        };
464
465        assert!(comp.significant);
466        assert_eq!(comp.effect_size, 0.8);
467    }
468
469    #[test]
470    fn test_comparison_significance_markers() {
471        let strategies = vec![DistillStrategy::kd_only(), DistillStrategy::combined()];
472
473        let comparison = compare(&strategies).expect("operation should succeed");
474
475        // Should have one pairwise comparison
476        assert_eq!(comparison.significance.len(), 1);
477    }
478
479    #[test]
480    fn test_compare_all_strategies() {
481        let strategies = vec![
482            DistillStrategy::kd_only(),
483            DistillStrategy::progressive(),
484            DistillStrategy::attention(),
485            DistillStrategy::combined(),
486        ];
487
488        let comparison = compare(&strategies).expect("operation should succeed");
489
490        // 4 choose 2 = 6 pairwise comparisons
491        assert_eq!(comparison.significance.len(), 6);
492        assert_eq!(comparison.results.len(), 4);
493    }
494
495    #[test]
496    fn test_comparison_table_star_markers() {
497        let strategies = vec![DistillStrategy::kd_only(), DistillStrategy::combined()];
498
499        let comparison = compare(&strategies).expect("operation should succeed");
500        let table = comparison.to_table();
501
502        // Should have star marker for best
503        assert!(table.contains('★'));
504    }
505}