use crate::{Loss, TrainError, TrainResult};
use scirs2_core::ndarray::{Array, ArrayView, Ix2};
pub struct DistillationLoss {
pub temperature: f64,
pub alpha: f64,
pub hard_loss: Box<dyn Loss>,
}
impl DistillationLoss {
pub fn new(temperature: f64, alpha: f64, hard_loss: Box<dyn Loss>) -> TrainResult<Self> {
if temperature <= 0.0 {
return Err(TrainError::ConfigError(
"Temperature must be positive".to_string(),
));
}
if !(0.0..=1.0).contains(&alpha) {
return Err(TrainError::ConfigError(
"Alpha must be between 0 and 1".to_string(),
));
}
Ok(Self {
temperature,
alpha,
hard_loss,
})
}
pub fn compute_distillation(
&self,
student_logits: &ArrayView<f64, Ix2>,
teacher_logits: &ArrayView<f64, Ix2>,
hard_targets: &ArrayView<f64, Ix2>,
) -> TrainResult<f64> {
if student_logits.shape() != teacher_logits.shape() {
return Err(TrainError::LossError(format!(
"Student and teacher logits must have same shape: {:?} vs {:?}",
student_logits.shape(),
teacher_logits.shape()
)));
}
let soft_loss =
self.compute_kl_divergence_with_temperature(student_logits, teacher_logits)?;
let hard_loss = self.hard_loss.compute(student_logits, hard_targets)?;
let t_squared = self.temperature * self.temperature;
let combined_loss = self.alpha * soft_loss * t_squared + (1.0 - self.alpha) * hard_loss;
Ok(combined_loss)
}
fn compute_kl_divergence_with_temperature(
&self,
student_logits: &ArrayView<f64, Ix2>,
teacher_logits: &ArrayView<f64, Ix2>,
) -> TrainResult<f64> {
let t = self.temperature;
let mut total_loss = 0.0;
let n_samples = student_logits.nrows();
for i in 0..n_samples {
let student_probs = self.softmax_with_temperature(&student_logits.row(i), t);
let teacher_probs = self.softmax_with_temperature(&teacher_logits.row(i), t);
for j in 0..student_probs.len() {
if teacher_probs[j] > 1e-8 {
let ratio = teacher_probs[j] / (student_probs[j] + 1e-8);
total_loss += teacher_probs[j] * ratio.ln();
}
}
}
Ok(total_loss / n_samples as f64)
}
fn softmax_with_temperature(
&self,
logits: &ArrayView<f64, scirs2_core::ndarray::Ix1>,
temperature: f64,
) -> Vec<f64> {
let scaled: Vec<f64> = logits.iter().map(|&x| x / temperature).collect();
let max_val = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let exp_vals: Vec<f64> = scaled.iter().map(|&x| (x - max_val).exp()).collect();
let sum: f64 = exp_vals.iter().sum();
exp_vals.iter().map(|&x| x / sum).collect()
}
}
pub struct FeatureDistillationLoss {
pub layer_weights: Vec<f64>,
pub p_norm: f64,
}
impl FeatureDistillationLoss {
pub fn new(layer_weights: Vec<f64>, p_norm: f64) -> TrainResult<Self> {
if layer_weights.is_empty() {
return Err(TrainError::ConfigError(
"Must specify at least one layer weight".to_string(),
));
}
if p_norm != 1.0 && p_norm != 2.0 {
return Err(TrainError::ConfigError(
"p_norm must be 1.0 or 2.0".to_string(),
));
}
Ok(Self {
layer_weights,
p_norm,
})
}
pub fn compute_feature_loss(
&self,
student_features: &[ArrayView<f64, Ix2>],
teacher_features: &[ArrayView<f64, Ix2>],
) -> TrainResult<f64> {
if student_features.len() != teacher_features.len() {
return Err(TrainError::LossError(
"Number of student and teacher feature layers must match".to_string(),
));
}
if student_features.len() != self.layer_weights.len() {
return Err(TrainError::LossError(format!(
"Number of layers ({}) must match number of weights ({})",
student_features.len(),
self.layer_weights.len()
)));
}
let mut total_loss = 0.0;
for (i, (student_feat, teacher_feat)) in student_features
.iter()
.zip(teacher_features.iter())
.enumerate()
{
if student_feat.shape() != teacher_feat.shape() {
return Err(TrainError::LossError(format!(
"Layer {} shape mismatch: {:?} vs {:?}",
i,
student_feat.shape(),
teacher_feat.shape()
)));
}
let mut layer_loss = 0.0;
for (&s, &t) in student_feat.iter().zip(teacher_feat.iter()) {
let diff = (s - t).abs();
layer_loss += if self.p_norm == 2.0 {
diff * diff
} else {
diff
};
}
let n_elements = student_feat.len() as f64;
layer_loss /= n_elements;
total_loss += self.layer_weights[i] * layer_loss;
}
Ok(total_loss)
}
}
pub struct AttentionTransferLoss {
pub beta: f64,
}
impl AttentionTransferLoss {
pub fn new(beta: f64) -> Self {
Self { beta }
}
pub fn compute_attention_loss(
&self,
student_attention: &ArrayView<f64, Ix2>,
teacher_attention: &ArrayView<f64, Ix2>,
) -> TrainResult<f64> {
if student_attention.shape() != teacher_attention.shape() {
return Err(TrainError::LossError(format!(
"Attention maps must have same shape: {:?} vs {:?}",
student_attention.shape(),
teacher_attention.shape()
)));
}
let student_norm = self.normalize_attention(student_attention);
let teacher_norm = self.normalize_attention(teacher_attention);
let mut loss = 0.0;
for (s, t) in student_norm.iter().zip(teacher_norm.iter()) {
let diff = s - t;
loss += diff * diff;
}
let n_elements = student_norm.len() as f64;
Ok(loss / n_elements)
}
fn normalize_attention(&self, attention: &ArrayView<f64, Ix2>) -> Array<f64, Ix2> {
let mut normalized = attention.mapv(|x| x.abs().powf(self.beta));
for mut row in normalized.rows_mut() {
let sum: f64 = row.iter().sum();
if sum > 1e-8 {
row.mapv_inplace(|x| x / sum);
}
}
normalized
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::CrossEntropyLoss;
use scirs2_core::array;
#[test]
fn test_distillation_loss_creation() {
let loss = DistillationLoss::new(3.0, 0.7, Box::new(CrossEntropyLoss::default()));
assert!(loss.is_ok());
let loss = loss.expect("unwrap");
assert_eq!(loss.temperature, 3.0);
assert_eq!(loss.alpha, 0.7);
}
#[test]
fn test_distillation_invalid_temperature() {
let result = DistillationLoss::new(0.0, 0.5, Box::new(CrossEntropyLoss::default()));
assert!(result.is_err());
let result = DistillationLoss::new(-1.0, 0.5, Box::new(CrossEntropyLoss::default()));
assert!(result.is_err());
}
#[test]
fn test_distillation_invalid_alpha() {
let result = DistillationLoss::new(3.0, -0.1, Box::new(CrossEntropyLoss::default()));
assert!(result.is_err());
let result = DistillationLoss::new(3.0, 1.1, Box::new(CrossEntropyLoss::default()));
assert!(result.is_err());
}
#[test]
fn test_distillation_compute() {
let loss =
DistillationLoss::new(2.0, 0.5, Box::new(CrossEntropyLoss::default())).expect("unwrap");
let student_logits = array![[1.0, 2.0, 0.5], [0.5, 1.0, 2.0]];
let teacher_logits = array![[1.2, 1.8, 0.6], [0.6, 1.1, 1.9]];
let hard_targets = array![[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
let result = loss.compute_distillation(
&student_logits.view(),
&teacher_logits.view(),
&hard_targets.view(),
);
assert!(result.is_ok());
let loss_value = result.expect("unwrap");
assert!(loss_value > 0.0);
assert!(loss_value.is_finite());
}
#[test]
fn test_feature_distillation_loss() {
let loss = FeatureDistillationLoss::new(vec![0.5, 0.3, 0.2], 2.0).expect("unwrap");
let s1 = array![[1.0, 2.0], [3.0, 4.0]];
let s2 = array![[0.5, 1.5], [2.5, 3.5]];
let s3 = array![[0.1, 0.2], [0.3, 0.4]];
let student_features = vec![s1.view(), s2.view(), s3.view()];
let t1 = array![[1.1, 2.1], [3.1, 4.1]];
let t2 = array![[0.6, 1.6], [2.6, 3.6]];
let t3 = array![[0.2, 0.3], [0.4, 0.5]];
let teacher_features = vec![t1.view(), t2.view(), t3.view()];
let result = loss.compute_feature_loss(&student_features, &teacher_features);
assert!(result.is_ok());
let loss_value = result.expect("unwrap");
assert!(loss_value > 0.0);
assert!(loss_value < 1.0); }
#[test]
fn test_attention_transfer_loss() {
let loss = AttentionTransferLoss::new(2.0);
let student_attention = array![[0.3, 0.5, 0.2], [0.4, 0.4, 0.2]];
let teacher_attention = array![[0.35, 0.45, 0.2], [0.35, 0.45, 0.2]];
let result =
loss.compute_attention_loss(&student_attention.view(), &teacher_attention.view());
assert!(result.is_ok());
let loss_value = result.expect("unwrap");
assert!(loss_value >= 0.0);
assert!(loss_value.is_finite());
}
#[test]
fn test_feature_distillation_shape_mismatch() {
let loss = FeatureDistillationLoss::new(vec![1.0], 2.0).expect("unwrap");
let s1 = array![[1.0, 2.0]];
let student_features = vec![s1.view()];
let t1 = array![[1.0, 2.0, 3.0]];
let teacher_features = vec![t1.view()];
let result = loss.compute_feature_loss(&student_features, &teacher_features);
assert!(result.is_err());
}
}