entrenar/hf_pipeline/distillation/utils.rs
1//! Utility functions for distillation computations.
2//!
3//! Provides numerically stable softmax, log-softmax, KL divergence,
4//! cross-entropy loss, and L2 normalization.
5
6use ndarray::{Array1, Array2};
7
8/// Softmax with numerical stability
9pub(crate) fn softmax(logits: &Array1<f32>) -> Array1<f32> {
10 let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
11 let exp: Array1<f32> = logits.mapv(|x| (x - max).exp());
12 let sum = exp.sum();
13 exp / sum
14}
15
16/// Log softmax with numerical stability
17pub(crate) fn log_softmax(logits: &Array1<f32>) -> Array1<f32> {
18 let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
19 let shifted = logits.mapv(|x| x - max);
20 let log_sum_exp = shifted.mapv(f32::exp).sum().max(f32::MIN_POSITIVE).ln();
21 shifted.mapv(|x| x - log_sum_exp)
22}
23
24/// KL divergence: KL(P || Q) = sum(P * log(P/Q))
25pub(crate) fn kl_divergence(log_q: &Array1<f32>, p: &Array1<f32>) -> f32 {
26 // KL(P || Q) = sum(P * (log(P) - log(Q)))
27 // Since we have log(Q), we compute: sum(P * log(P)) - sum(P * log(Q))
28 let p_log_p: f32 = p
29 .iter()
30 .map(|&pi| if pi > 1e-10 { pi * pi.max(f32::MIN_POSITIVE).ln() } else { 0.0 })
31 .sum();
32 let p_log_q: f32 = p.iter().zip(log_q.iter()).map(|(&pi, &lqi)| pi * lqi).sum();
33 p_log_p - p_log_q
34}
35
36/// Cross-entropy loss
37pub(crate) fn cross_entropy_loss(logits: &Array1<f32>, target: usize) -> f32 {
38 let log_probs = log_softmax(logits);
39 -log_probs[target]
40}
41
42/// L2 normalize a 2D array
43pub(crate) fn l2_normalize(arr: &Array2<f32>) -> Array2<f32> {
44 let norm = arr.mapv(|x| x * x).sum().sqrt();
45 if norm > 1e-10 {
46 arr / norm
47 } else {
48 arr.clone()
49 }
50}