entrenar/hf_pipeline/distillation/attention.rs
1//! Attention Transfer Loss
2//!
3//! Transfers attention maps from teacher to student.
4//! Based on Zagoruyko & Komodakis (2017).
5
6use ndarray::Array2;
7
8use super::utils::l2_normalize;
9
10/// Attention Transfer Loss
11///
12/// Transfers attention maps from teacher to student.
13/// Based on Zagoruyko & Komodakis (2017).
14#[derive(Debug, Clone)]
15pub struct AttentionTransfer {
16 /// Loss weight
17 pub weight: f32,
18}
19
20impl Default for AttentionTransfer {
21 fn default() -> Self {
22 Self { weight: 0.1 }
23 }
24}
25
26impl AttentionTransfer {
27 /// Create new attention transfer config
28 #[must_use]
29 pub fn new(weight: f32) -> Self {
30 Self { weight }
31 }
32
33 /// Compute attention transfer loss
34 ///
35 /// Uses L2 norm of normalized attention map differences.
36 pub fn loss(
37 &self,
38 student_attention: &[Array2<f32>],
39 teacher_attention: &[Array2<f32>],
40 ) -> f32 {
41 let mut total_loss = 0.0;
42 let count = student_attention.len().min(teacher_attention.len());
43
44 for (s_attn, t_attn) in student_attention.iter().zip(teacher_attention.iter()) {
45 // L2 normalize attention maps
46 let s_norm = l2_normalize(s_attn);
47 let t_norm = l2_normalize(t_attn);
48
49 // Frobenius norm of difference
50 let diff = &s_norm - &t_norm;
51 let frob = diff.mapv(|x| x * x).sum().sqrt();
52 total_loss += frob * frob;
53 }
54
55 if count > 0 {
56 self.weight * total_loss / count as f32
57 } else {
58 0.0
59 }
60 }
61}