Skip to main content

entrenar/hf_pipeline/distillation/
utils.rs

1//! Utility functions for distillation computations.
2//!
3//! Provides numerically stable softmax, log-softmax, KL divergence,
4//! cross-entropy loss, and L2 normalization.
5
6use ndarray::{Array1, Array2};
7
8/// Softmax with numerical stability
9pub(crate) fn softmax(logits: &Array1<f32>) -> Array1<f32> {
10    let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
11    let exp: Array1<f32> = logits.mapv(|x| (x - max).exp());
12    let sum = exp.sum();
13    exp / sum
14}
15
16/// Log softmax with numerical stability
17pub(crate) fn log_softmax(logits: &Array1<f32>) -> Array1<f32> {
18    let max = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
19    let shifted = logits.mapv(|x| x - max);
20    let log_sum_exp = shifted.mapv(f32::exp).sum().max(f32::MIN_POSITIVE).ln();
21    shifted.mapv(|x| x - log_sum_exp)
22}
23
24/// KL divergence: KL(P || Q) = sum(P * log(P/Q))
25pub(crate) fn kl_divergence(log_q: &Array1<f32>, p: &Array1<f32>) -> f32 {
26    // KL(P || Q) = sum(P * (log(P) - log(Q)))
27    // Since we have log(Q), we compute: sum(P * log(P)) - sum(P * log(Q))
28    let p_log_p: f32 = p
29        .iter()
30        .map(|&pi| if pi > 1e-10 { pi * pi.max(f32::MIN_POSITIVE).ln() } else { 0.0 })
31        .sum();
32    let p_log_q: f32 = p.iter().zip(log_q.iter()).map(|(&pi, &lqi)| pi * lqi).sum();
33    p_log_p - p_log_q
34}
35
36/// Cross-entropy loss
37pub(crate) fn cross_entropy_loss(logits: &Array1<f32>, target: usize) -> f32 {
38    let log_probs = log_softmax(logits);
39    -log_probs[target]
40}
41
42/// L2 normalize a 2D array
43pub(crate) fn l2_normalize(arr: &Array2<f32>) -> Array2<f32> {
44    let norm = arr.mapv(|x| x * x).sum().sqrt();
45    if norm > 1e-10 {
46        arr / norm
47    } else {
48        arr.clone()
49    }
50}