oxicuda_quant/distill/
feature.rs1use crate::distill::loss::DistilLoss;
16use crate::error::{QuantError, QuantResult};
17
18#[derive(Debug, Clone)]
24pub struct FeatureDistiller {
25 pub layers: Vec<(f32, DistilLoss)>,
27}
28
29impl FeatureDistiller {
30 #[must_use]
36 pub fn uniform_mse(n_layers: usize) -> Self {
37 let weight = if n_layers == 0 {
38 1.0
39 } else {
40 1.0 / n_layers as f32
41 };
42 let layers = (0..n_layers).map(|_| (weight, DistilLoss::mse())).collect();
43 Self { layers }
44 }
45
46 #[must_use]
48 pub fn with_weights(weights: Vec<f32>, loss: DistilLoss) -> Self {
49 let layers = weights.into_iter().map(|w| (w, loss)).collect();
50 Self { layers }
51 }
52
53 pub fn compute_layer_loss(
60 &self,
61 layer_index: usize,
62 teacher_feat: &[f32],
63 student_feat: &[f32],
64 ) -> QuantResult<f32> {
65 if layer_index >= self.layers.len() {
66 return Err(QuantError::DimensionMismatch {
67 expected: self.layers.len(),
68 got: layer_index + 1,
69 });
70 }
71 let (w, ref loss) = self.layers[layer_index];
72 let l = loss.compute(teacher_feat, student_feat)?;
73 Ok(w * l)
74 }
75
76 pub fn compute_total_loss(
85 &self,
86 teacher_feats: &[&[f32]],
87 student_feats: &[&[f32]],
88 ) -> QuantResult<f32> {
89 if teacher_feats.len() != self.layers.len() {
90 return Err(QuantError::DimensionMismatch {
91 expected: self.layers.len(),
92 got: teacher_feats.len(),
93 });
94 }
95 if student_feats.len() != self.layers.len() {
96 return Err(QuantError::DimensionMismatch {
97 expected: self.layers.len(),
98 got: student_feats.len(),
99 });
100 }
101 let total: f32 = (0..self.layers.len())
102 .map(|l| self.compute_layer_loss(l, teacher_feats[l], student_feats[l]))
103 .collect::<QuantResult<Vec<f32>>>()?
104 .iter()
105 .sum();
106 Ok(total)
107 }
108
109 #[must_use]
111 pub fn n_layers(&self) -> usize {
112 self.layers.len()
113 }
114
115 pub fn normalise_weights(&mut self) {
117 let sum: f32 = self
118 .layers
119 .iter()
120 .map(|(w, _)| w.abs())
121 .sum::<f32>()
122 .max(1e-12);
123 for (w, _) in &mut self.layers {
124 *w /= sum;
125 }
126 }
127}
128
129#[cfg(test)]
132mod tests {
133 use super::*;
134 use approx::assert_abs_diff_eq;
135
136 #[test]
137 fn uniform_mse_layer_count() {
138 let d = FeatureDistiller::uniform_mse(4);
139 assert_eq!(d.n_layers(), 4);
140 for (w, _) in &d.layers {
141 assert_abs_diff_eq!(*w, 0.25, epsilon = 1e-6);
142 }
143 }
144
145 #[test]
146 fn zero_loss_for_identical_features() {
147 let d = FeatureDistiller::uniform_mse(2);
148 let feat = vec![1.0_f32, 2.0, 3.0];
149 let t0 = feat.as_slice();
150 let t1 = feat.as_slice();
151 let loss = d.compute_total_loss(&[t0, t1], &[t0, t1]).unwrap();
152 assert_abs_diff_eq!(loss, 0.0, epsilon = 1e-5);
153 }
154
155 #[test]
156 fn positive_loss_for_different_features() {
157 let d = FeatureDistiller::uniform_mse(1);
158 let teacher = vec![1.0_f32, 0.0, 0.0];
159 let student = vec![0.0_f32, 1.0, 0.0];
160 let loss = d.compute_total_loss(&[&teacher], &[&student]).unwrap();
161 assert!(loss > 0.0, "loss should be positive for different features");
162 }
163
164 #[test]
165 fn layer_count_mismatch_error() {
166 let d = FeatureDistiller::uniform_mse(2);
167 let feat = vec![1.0_f32; 4];
168 let err = d.compute_total_loss(&[&feat, &feat, &feat], &[&feat, &feat]);
170 assert!(matches!(err, Err(QuantError::DimensionMismatch { .. })));
171 }
172
173 #[test]
174 fn layer_index_out_of_range_error() {
175 let d = FeatureDistiller::uniform_mse(2);
176 let feat = vec![1.0_f32; 4];
177 assert!(matches!(
178 d.compute_layer_loss(5, &feat, &feat),
179 Err(QuantError::DimensionMismatch { .. })
180 ));
181 }
182
183 #[test]
184 fn normalise_weights() {
185 let mut d = FeatureDistiller::with_weights(vec![2.0, 3.0, 5.0], DistilLoss::mse());
186 d.normalise_weights();
187 let sum: f32 = d.layers.iter().map(|(w, _)| *w).sum();
188 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-5);
189 }
190
191 #[test]
192 fn with_weights_constructs_correctly() {
193 let d = FeatureDistiller::with_weights(vec![0.3, 0.7], DistilLoss::cosine());
194 assert_eq!(d.n_layers(), 2);
195 assert_abs_diff_eq!(d.layers[0].0, 0.3, epsilon = 1e-6);
196 assert_abs_diff_eq!(d.layers[1].0, 0.7, epsilon = 1e-6);
197 }
198
199 #[test]
200 fn kl_feature_distillation() {
201 let loss = DistilLoss::kl_divergence(2.0);
202 let d = FeatureDistiller::with_weights(vec![1.0], loss);
203 let teacher = vec![0.1_f32, 0.7, 0.2];
204 let student = vec![0.4_f32, 0.4, 0.2];
205 let l = d.compute_total_loss(&[&teacher], &[&student]).unwrap();
206 assert!(l >= 0.0, "KL loss must be non-negative");
207 }
208}