Skip to main content

oxicuda_quant/distill/
feature.rs

1//! # Feature-Based Knowledge Distillation
2//!
3//! Matches the intermediate activations (feature maps) of the student to those
4//! of the teacher, enabling the student to mimic the teacher's internal
5//! representations in addition to (or instead of) its final predictions.
6//!
7//! ## Total loss
8//!
9//! ```text
10//! L = Σ_l  w_l × loss(teacher_features_l, student_features_l)
11//! ```
12//!
13//! where `w_l` is the user-specified weight for layer `l`.
14
15use crate::distill::loss::DistilLoss;
16use crate::error::{QuantError, QuantResult};
17
18// ─── FeatureDistiller ─────────────────────────────────────────────────────────
19
20/// Feature-based knowledge distillation.
21///
22/// Maintains a list of `(weight, DistilLoss)` pairs, one per distillation layer.
23#[derive(Debug, Clone)]
24pub struct FeatureDistiller {
25    /// Per-layer distillation configuration: `(weight, loss)`.
26    pub layers: Vec<(f32, DistilLoss)>,
27}
28
29impl FeatureDistiller {
30    /// Create a feature distiller with equal-weight MSE loss for each layer.
31    ///
32    /// # Parameters
33    ///
34    /// * `n_layers` — number of intermediate layers to distil.
35    #[must_use]
36    pub fn uniform_mse(n_layers: usize) -> Self {
37        let weight = if n_layers == 0 {
38            1.0
39        } else {
40            1.0 / n_layers as f32
41        };
42        let layers = (0..n_layers).map(|_| (weight, DistilLoss::mse())).collect();
43        Self { layers }
44    }
45
46    /// Create a feature distiller with custom per-layer weights and a shared loss.
47    #[must_use]
48    pub fn with_weights(weights: Vec<f32>, loss: DistilLoss) -> Self {
49        let layers = weights.into_iter().map(|w| (w, loss)).collect();
50        Self { layers }
51    }
52
53    /// Compute the distillation loss for a single layer pair.
54    ///
55    /// # Errors
56    ///
57    /// * [`QuantError::DimensionMismatch`] — `layer_index` out of range.
58    /// * Propagates errors from the underlying `DistilLoss`.
59    pub fn compute_layer_loss(
60        &self,
61        layer_index: usize,
62        teacher_feat: &[f32],
63        student_feat: &[f32],
64    ) -> QuantResult<f32> {
65        if layer_index >= self.layers.len() {
66            return Err(QuantError::DimensionMismatch {
67                expected: self.layers.len(),
68                got: layer_index + 1,
69            });
70        }
71        let (w, ref loss) = self.layers[layer_index];
72        let l = loss.compute(teacher_feat, student_feat)?;
73        Ok(w * l)
74    }
75
76    /// Compute the total feature distillation loss across all layers.
77    ///
78    /// `teacher_feats[l]` and `student_feats[l]` must have matching lengths.
79    ///
80    /// # Errors
81    ///
82    /// * [`QuantError::DimensionMismatch`] — wrong number of feature arrays.
83    /// * Propagates per-layer errors.
84    pub fn compute_total_loss(
85        &self,
86        teacher_feats: &[&[f32]],
87        student_feats: &[&[f32]],
88    ) -> QuantResult<f32> {
89        if teacher_feats.len() != self.layers.len() {
90            return Err(QuantError::DimensionMismatch {
91                expected: self.layers.len(),
92                got: teacher_feats.len(),
93            });
94        }
95        if student_feats.len() != self.layers.len() {
96            return Err(QuantError::DimensionMismatch {
97                expected: self.layers.len(),
98                got: student_feats.len(),
99            });
100        }
101        let total: f32 = (0..self.layers.len())
102            .map(|l| self.compute_layer_loss(l, teacher_feats[l], student_feats[l]))
103            .collect::<QuantResult<Vec<f32>>>()?
104            .iter()
105            .sum();
106        Ok(total)
107    }
108
109    /// Number of distillation layers.
110    #[must_use]
111    pub fn n_layers(&self) -> usize {
112        self.layers.len()
113    }
114
115    /// Normalise layer weights so they sum to 1.
116    pub fn normalise_weights(&mut self) {
117        let sum: f32 = self
118            .layers
119            .iter()
120            .map(|(w, _)| w.abs())
121            .sum::<f32>()
122            .max(1e-12);
123        for (w, _) in &mut self.layers {
124            *w /= sum;
125        }
126    }
127}
128
129// ─── Tests ───────────────────────────────────────────────────────────────────
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use approx::assert_abs_diff_eq;
135
136    #[test]
137    fn uniform_mse_layer_count() {
138        let d = FeatureDistiller::uniform_mse(4);
139        assert_eq!(d.n_layers(), 4);
140        for (w, _) in &d.layers {
141            assert_abs_diff_eq!(*w, 0.25, epsilon = 1e-6);
142        }
143    }
144
145    #[test]
146    fn zero_loss_for_identical_features() {
147        let d = FeatureDistiller::uniform_mse(2);
148        let feat = vec![1.0_f32, 2.0, 3.0];
149        let t0 = feat.as_slice();
150        let t1 = feat.as_slice();
151        let loss = d.compute_total_loss(&[t0, t1], &[t0, t1]).unwrap();
152        assert_abs_diff_eq!(loss, 0.0, epsilon = 1e-5);
153    }
154
155    #[test]
156    fn positive_loss_for_different_features() {
157        let d = FeatureDistiller::uniform_mse(1);
158        let teacher = vec![1.0_f32, 0.0, 0.0];
159        let student = vec![0.0_f32, 1.0, 0.0];
160        let loss = d.compute_total_loss(&[&teacher], &[&student]).unwrap();
161        assert!(loss > 0.0, "loss should be positive for different features");
162    }
163
164    #[test]
165    fn layer_count_mismatch_error() {
166        let d = FeatureDistiller::uniform_mse(2);
167        let feat = vec![1.0_f32; 4];
168        // Provide 3 teacher arrays but distiller expects 2.
169        let err = d.compute_total_loss(&[&feat, &feat, &feat], &[&feat, &feat]);
170        assert!(matches!(err, Err(QuantError::DimensionMismatch { .. })));
171    }
172
173    #[test]
174    fn layer_index_out_of_range_error() {
175        let d = FeatureDistiller::uniform_mse(2);
176        let feat = vec![1.0_f32; 4];
177        assert!(matches!(
178            d.compute_layer_loss(5, &feat, &feat),
179            Err(QuantError::DimensionMismatch { .. })
180        ));
181    }
182
183    #[test]
184    fn normalise_weights() {
185        let mut d = FeatureDistiller::with_weights(vec![2.0, 3.0, 5.0], DistilLoss::mse());
186        d.normalise_weights();
187        let sum: f32 = d.layers.iter().map(|(w, _)| *w).sum();
188        assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-5);
189    }
190
191    #[test]
192    fn with_weights_constructs_correctly() {
193        let d = FeatureDistiller::with_weights(vec![0.3, 0.7], DistilLoss::cosine());
194        assert_eq!(d.n_layers(), 2);
195        assert_abs_diff_eq!(d.layers[0].0, 0.3, epsilon = 1e-6);
196        assert_abs_diff_eq!(d.layers[1].0, 0.7, epsilon = 1e-6);
197    }
198
199    #[test]
200    fn kl_feature_distillation() {
201        let loss = DistilLoss::kl_divergence(2.0);
202        let d = FeatureDistiller::with_weights(vec![1.0], loss);
203        let teacher = vec![0.1_f32, 0.7, 0.2];
204        let student = vec![0.4_f32, 0.4, 0.2];
205        let l = d.compute_total_loss(&[&teacher], &[&student]).unwrap();
206        assert!(l >= 0.0, "KL loss must be non-negative");
207    }
208}