use crate::error::{NeuralError, Result};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
use scirs2_core::numeric::{Float, FromPrimitive, ToPrimitive};
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct DistillationConfig {
pub temperature: f64,
pub alpha: f64,
}
impl Default for DistillationConfig {
fn default() -> Self {
Self {
temperature: 4.0,
alpha: 0.5,
}
}
}
impl DistillationConfig {
pub fn validate(&self) -> Result<()> {
if self.temperature <= 0.0 {
return Err(NeuralError::InvalidArgument(format!(
"temperature must be > 0, got {}",
self.temperature
)));
}
if self.alpha < 0.0 || self.alpha > 1.0 {
return Err(NeuralError::InvalidArgument(format!(
"alpha must be in [0, 1], got {}",
self.alpha
)));
}
Ok(())
}
}
fn softmax_with_temperature(logits: ArrayView2<f64>, temperature: f64) -> Array2<f64> {
let (nrows, ncols) = (logits.nrows(), logits.ncols());
let mut out = Array2::<f64>::zeros((nrows, ncols));
for r in 0..nrows {
let row_max = logits
.row(r)
.iter()
.cloned()
.fold(f64::NEG_INFINITY, f64::max);
let mut sum_exp = 0.0f64;
for c in 0..ncols {
let e = ((logits[[r, c]] - row_max) / temperature).exp();
out[[r, c]] = e;
sum_exp += e;
}
let inv = if sum_exp > 0.0 { 1.0 / sum_exp } else { 1.0 };
for c in 0..ncols {
out[[r, c]] *= inv;
}
}
out
}
fn softmax(logits: ArrayView2<f64>) -> Array2<f64> {
softmax_with_temperature(logits, 1.0)
}
pub fn soft_target_loss(
student_logits: ArrayView2<f64>,
teacher_logits: ArrayView2<f64>,
temperature: f64,
) -> Result<f64> {
validate_shapes_2d(student_logits, teacher_logits, "soft_target_loss")?;
if temperature <= 0.0 {
return Err(NeuralError::InvalidArgument(format!(
"temperature must be > 0, got {temperature}"
)));
}
let p_teacher = softmax_with_temperature(teacher_logits, temperature);
let p_student = softmax_with_temperature(student_logits, temperature);
let nrows = student_logits.nrows();
let ncols = student_logits.ncols();
let mut kl_sum = 0.0f64;
for r in 0..nrows {
for c in 0..ncols {
let pt = p_teacher[[r, c]];
let ps = p_student[[r, c]].max(1e-40); if pt > 0.0 {
kl_sum += pt * (pt / ps).ln();
}
}
}
Ok(kl_sum * temperature * temperature / nrows as f64)
}
pub fn hard_target_loss(
student_logits: ArrayView2<f64>,
true_labels: &[usize],
) -> Result<f64> {
let nrows = student_logits.nrows();
let ncols = student_logits.ncols();
if true_labels.len() != nrows {
return Err(NeuralError::ShapeMismatch(format!(
"true_labels length {} != batch size {}",
true_labels.len(),
nrows
)));
}
let probs = softmax(student_logits);
let mut nll_sum = 0.0f64;
for (r, &label) in true_labels.iter().enumerate() {
if label >= ncols {
return Err(NeuralError::InvalidArgument(format!(
"label {label} out of range for n_classes={ncols}"
)));
}
let p = probs[[r, label]].max(1e-40);
nll_sum += -p.ln();
}
Ok(nll_sum / nrows as f64)
}
pub fn distillation_loss(
student_logits: ArrayView2<f64>,
teacher_logits: ArrayView2<f64>,
true_labels: &[usize],
config: &DistillationConfig,
) -> Result<f64> {
config.validate()?;
let soft = soft_target_loss(student_logits, teacher_logits, config.temperature)?;
let hard = hard_target_loss(student_logits, true_labels)?;
Ok(config.alpha * soft + (1.0 - config.alpha) * hard)
}
#[derive(Debug, Clone)]
pub struct DistillationLossComponents {
pub soft_loss: f64,
pub hard_loss: f64,
pub total_loss: f64,
pub alpha: f64,
pub temperature: f64,
}
pub fn distillation_loss_detailed(
student_logits: ArrayView2<f64>,
teacher_logits: ArrayView2<f64>,
true_labels: &[usize],
config: &DistillationConfig,
) -> Result<DistillationLossComponents> {
config.validate()?;
let soft_loss =
soft_target_loss(student_logits, teacher_logits, config.temperature)?;
let hard_loss = hard_target_loss(student_logits, true_labels)?;
let total_loss = config.alpha * soft_loss + (1.0 - config.alpha) * hard_loss;
Ok(DistillationLossComponents {
soft_loss,
hard_loss,
total_loss,
alpha: config.alpha,
temperature: config.temperature,
})
}
#[derive(Debug, Clone)]
pub struct FeatureDistillationConfig {
pub loss_weight: f64,
pub normalize_features: bool,
}
impl Default for FeatureDistillationConfig {
fn default() -> Self {
Self {
loss_weight: 1.0,
normalize_features: true,
}
}
}
#[derive(Debug, Clone)]
pub struct FeatureDistillation {
pub config: FeatureDistillationConfig,
pub projection: Option<Array2<f64>>,
}
impl FeatureDistillation {
pub fn new(config: FeatureDistillationConfig) -> Self {
Self {
config,
projection: None,
}
}
pub fn with_projection(mut self, proj: Array2<f64>) -> Self {
self.projection = Some(proj);
self
}
pub fn loss(
&self,
student_features: ArrayView2<f64>,
teacher_features: ArrayView2<f64>,
) -> Result<f64> {
let projected: Array2<f64> = if let Some(ref proj) = self.projection {
matmul_2d(student_features, proj.view())?
} else {
if student_features.ncols() != teacher_features.ncols() {
return Err(NeuralError::DimensionMismatch(format!(
"student_features cols {} != teacher_features cols {}; supply a projection matrix",
student_features.ncols(),
teacher_features.ncols()
)));
}
student_features.to_owned()
};
let s = if self.config.normalize_features {
l2_normalize_rows(projected.view())?
} else {
projected
};
let t = if self.config.normalize_features {
l2_normalize_rows(teacher_features)?
} else {
teacher_features.to_owned()
};
validate_shapes_2d(s.view(), t.view(), "feature_distillation_loss")?;
let mse: f64 = s
.iter()
.zip(t.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
/ s.len() as f64;
Ok(mse * self.config.loss_weight)
}
}
pub fn attention_transfer_loss(
student_attn: ArrayView2<f64>,
teacher_attn: ArrayView2<f64>,
) -> Result<f64> {
validate_shapes_2d(student_attn, teacher_attn, "attention_transfer_loss")?;
let s_map = attention_map(student_attn)?;
let t_map = attention_map(teacher_attn)?;
let loss: f64 = s_map
.iter()
.zip(t_map.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
/ s_map.len() as f64;
Ok(loss)
}
fn attention_map(x: ArrayView2<f64>) -> Result<Array1<f64>> {
let nrows = x.nrows();
let mut norms = Array1::<f64>::zeros(nrows);
for r in 0..nrows {
let n: f64 = x.row(r).iter().map(|v| v * v).sum::<f64>().sqrt();
norms[r] = n;
}
let total: f64 = norms.iter().map(|v| v * v).sum::<f64>().sqrt();
if total > 1e-12 {
for v in norms.iter_mut() {
*v /= total;
}
}
Ok(norms)
}
#[derive(Debug, Clone, Default)]
pub struct DistillationStats {
pub total_soft_loss: f64,
pub total_hard_loss: f64,
pub total_loss: f64,
pub n_batches: usize,
}
impl DistillationStats {
pub fn record(&mut self, components: &DistillationLossComponents) {
self.total_soft_loss += components.soft_loss;
self.total_hard_loss += components.hard_loss;
self.total_loss += components.total_loss;
self.n_batches += 1;
}
pub fn avg_soft_loss(&self) -> f64 {
if self.n_batches == 0 {
0.0
} else {
self.total_soft_loss / self.n_batches as f64
}
}
pub fn avg_hard_loss(&self) -> f64 {
if self.n_batches == 0 {
0.0
} else {
self.total_hard_loss / self.n_batches as f64
}
}
pub fn avg_total_loss(&self) -> f64 {
if self.n_batches == 0 {
0.0
} else {
self.total_loss / self.n_batches as f64
}
}
pub fn reset(&mut self) {
*self = Self::default();
}
}
pub fn calibrate_temperature(
student_logits: ArrayView2<f64>,
teacher_logits: ArrayView2<f64>,
temp_range: (f64, f64),
n_grid: usize,
) -> Result<f64> {
if temp_range.0 <= 0.0 || temp_range.1 <= temp_range.0 {
return Err(NeuralError::InvalidArgument(format!(
"temp_range must be (positive, larger): got {:?}",
temp_range
)));
}
if n_grid < 2 {
return Err(NeuralError::InvalidArgument(
"n_grid must be >= 2".to_string(),
));
}
let (t_min, t_max) = temp_range;
let step = (t_max - t_min) / (n_grid - 1) as f64;
let mut best_temp = t_min;
let mut best_kl = f64::INFINITY;
for i in 0..n_grid {
let t = t_min + step * i as f64;
let kl = soft_target_loss(student_logits, teacher_logits, t)?;
if kl < best_kl {
best_kl = kl;
best_temp = t;
}
}
Ok(best_temp)
}
fn l2_normalize_rows(x: ArrayView2<f64>) -> Result<Array2<f64>> {
let (nrows, ncols) = (x.nrows(), x.ncols());
let mut out = Array2::<f64>::zeros((nrows, ncols));
for r in 0..nrows {
let norm: f64 = x.row(r).iter().map(|v| v * v).sum::<f64>().sqrt();
let inv = if norm > 1e-12 { 1.0 / norm } else { 1.0 };
for c in 0..ncols {
out[[r, c]] = x[[r, c]] * inv;
}
}
Ok(out)
}
fn matmul_2d(a: ArrayView2<f64>, b: ArrayView2<f64>) -> Result<Array2<f64>> {
let (m, k) = (a.nrows(), a.ncols());
let (k2, n) = (b.nrows(), b.ncols());
if k != k2 {
return Err(NeuralError::DimensionMismatch(format!(
"matmul: a.ncols={k} != b.nrows={k2}"
)));
}
let mut out = Array2::<f64>::zeros((m, n));
for i in 0..m {
for j in 0..n {
let mut s = 0.0f64;
for p in 0..k {
s += a[[i, p]] * b[[p, j]];
}
out[[i, j]] = s;
}
}
Ok(out)
}
fn validate_shapes_2d(a: ArrayView2<f64>, b: ArrayView2<f64>, ctx: &str) -> Result<()> {
if a.shape() != b.shape() {
return Err(NeuralError::ShapeMismatch(format!(
"{ctx}: shape {:?} != {:?}",
a.shape(),
b.shape()
)));
}
Ok(())
}
#[derive(Debug, Clone)]
pub struct BornAgainConfig {
pub temperature: f64,
pub alpha: f64,
pub n_generations: usize,
}
impl Default for BornAgainConfig {
fn default() -> Self {
Self {
temperature: 4.0,
alpha: 0.5,
n_generations: 3,
}
}
}
impl BornAgainConfig {
pub fn validate(&self) -> Result<()> {
if self.temperature <= 0.0 {
return Err(NeuralError::InvalidArgument(format!(
"BAN temperature must be > 0, got {}",
self.temperature
)));
}
if self.alpha < 0.0 || self.alpha > 1.0 {
return Err(NeuralError::InvalidArgument(format!(
"BAN alpha must be in [0, 1], got {}",
self.alpha
)));
}
if self.n_generations == 0 {
return Err(NeuralError::InvalidArgument(
"BAN n_generations must be >= 1".to_string(),
));
}
Ok(())
}
}
pub fn ban_ensemble_logits(
generation_logits: &[Array2<f64>],
) -> Result<Array2<f64>> {
if generation_logits.is_empty() {
return Err(NeuralError::InvalidArgument(
"generation_logits must contain at least one matrix".to_string(),
));
}
let (nrows, ncols) = {
let first = &generation_logits[0];
(first.nrows(), first.ncols())
};
for (i, m) in generation_logits.iter().enumerate() {
if m.nrows() != nrows || m.ncols() != ncols {
return Err(NeuralError::ShapeMismatch(format!(
"generation_logits[{}] has shape ({}, {}) but expected ({}, {})",
i, m.nrows(), m.ncols(), nrows, ncols
)));
}
}
let n = generation_logits.len() as f64;
let mut ensemble = Array2::<f64>::zeros((nrows, ncols));
for logits in generation_logits {
for r in 0..nrows {
for c in 0..ncols {
ensemble[[r, c]] += logits[[r, c]] / n;
}
}
}
Ok(ensemble)
}
pub fn ban_distillation_loss(
student_logits: ArrayView2<f64>,
teacher_logits: ArrayView2<f64>,
true_labels: &[usize],
config: &BornAgainConfig,
) -> Result<f64> {
config.validate()?;
let kd_config = DistillationConfig {
temperature: config.temperature,
alpha: config.alpha,
};
distillation_loss(student_logits, teacher_logits, true_labels, &kd_config)
}
pub fn ban_distillation_loss_detailed(
student_logits: ArrayView2<f64>,
teacher_logits: ArrayView2<f64>,
true_labels: &[usize],
config: &BornAgainConfig,
) -> Result<DistillationLossComponents> {
config.validate()?;
let kd_config = DistillationConfig {
temperature: config.temperature,
alpha: config.alpha,
};
distillation_loss_detailed(student_logits, teacher_logits, true_labels, &kd_config)
}
#[derive(Debug, Clone, Default)]
pub struct BornAgainStats {
pub generations: Vec<DistillationStats>,
}
impl BornAgainStats {
pub fn new(n_generations: usize) -> Self {
Self {
generations: vec![DistillationStats::default(); n_generations],
}
}
pub fn record(
&mut self,
generation: usize,
components: &DistillationLossComponents,
) -> Result<()> {
if generation >= self.generations.len() {
return Err(NeuralError::InvalidArgument(format!(
"generation {} out of range (have {} generations)",
generation,
self.generations.len()
)));
}
self.generations[generation].record(components);
Ok(())
}
pub fn reset_all(&mut self) {
for g in &mut self.generations {
g.reset();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
fn batch_logits() -> (Array2<f64>, Array2<f64>) {
let student = array![
[2.0_f64, 1.0, 0.5],
[0.1, 3.0, 1.2],
[1.5, 1.5, 1.5]
];
let teacher = array![
[1.8_f64, 1.1, 0.6],
[0.0, 2.8, 1.0],
[1.0, 2.0, 0.5]
];
(student, teacher)
}
#[test]
fn test_soft_target_loss_positive() {
let (student, teacher) = batch_logits();
let loss = soft_target_loss(student.view(), teacher.view(), 4.0)
.expect("soft loss failed");
assert!(loss >= 0.0, "KL divergence must be non-negative, got {loss}");
}
#[test]
fn test_soft_target_loss_identical_is_zero() {
let (student, _) = batch_logits();
let loss = soft_target_loss(student.view(), student.view(), 4.0)
.expect("soft loss failed");
assert!(loss < 1e-9, "KL(p||p) should be ~0, got {loss}");
}
#[test]
fn test_hard_target_loss() {
let (student, _) = batch_logits();
let labels = vec![0usize, 1, 2];
let loss = hard_target_loss(student.view(), &labels).expect("hard loss failed");
assert!(loss > 0.0, "cross-entropy must be positive, got {loss}");
}
#[test]
fn test_distillation_loss_combined() {
let (student, teacher) = batch_logits();
let labels = vec![0usize, 1, 2];
let config = DistillationConfig {
temperature: 4.0,
alpha: 0.5,
};
let loss =
distillation_loss(student.view(), teacher.view(), &labels, &config)
.expect("distillation loss failed");
assert!(loss > 0.0);
}
#[test]
fn test_distillation_loss_alpha_zero_equals_hard() {
let (student, teacher) = batch_logits();
let labels = vec![0usize, 1, 2];
let config_alpha0 = DistillationConfig {
temperature: 4.0,
alpha: 0.0,
};
let loss_alpha0 =
distillation_loss(student.view(), teacher.view(), &labels, &config_alpha0)
.expect("loss alpha=0");
let hard = hard_target_loss(student.view(), &labels).expect("hard loss");
assert!((loss_alpha0 - hard).abs() < 1e-12, "alpha=0 should equal hard loss");
}
#[test]
fn test_distillation_loss_detailed() {
let (student, teacher) = batch_logits();
let labels = vec![0usize, 1, 2];
let config = DistillationConfig::default();
let detail =
distillation_loss_detailed(student.view(), teacher.view(), &labels, &config)
.expect("detailed loss failed");
assert!(detail.soft_loss >= 0.0);
assert!(detail.hard_loss > 0.0);
let expected =
config.alpha * detail.soft_loss + (1.0 - config.alpha) * detail.hard_loss;
assert!((detail.total_loss - expected).abs() < 1e-12);
}
#[test]
fn test_feature_distillation_same_dim() {
let student_feat = array![[1.0_f64, 0.0, -1.0], [0.5, 0.5, 0.0]];
let teacher_feat = array![[1.1_f64, 0.1, -0.9], [0.6, 0.4, 0.1]];
let fd = FeatureDistillation::new(FeatureDistillationConfig::default());
let loss =
fd.loss(student_feat.view(), teacher_feat.view()).expect("feature loss failed");
assert!(loss >= 0.0);
}
#[test]
fn test_feature_distillation_with_projection() {
let student_feat = array![[1.0_f64, 0.0], [0.5, 0.5]];
let teacher_feat = array![[1.0_f64, 0.5, 0.2], [0.3, 0.7, 0.1]];
let proj = array![[1.0_f64, 0.0, 0.5], [0.0, 1.0, 0.5]]; let fd = FeatureDistillation::new(FeatureDistillationConfig::default())
.with_projection(proj);
let loss =
fd.loss(student_feat.view(), teacher_feat.view()).expect("projected feature loss");
assert!(loss >= 0.0);
}
#[test]
fn test_attention_transfer_loss_identical_is_zero() {
let attn = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
let loss =
attention_transfer_loss(attn.view(), attn.view()).expect("attn loss identical");
assert!(loss < 1e-9, "AT loss for identical maps should be ~0, got {loss}");
}
#[test]
fn test_attention_transfer_loss_different() {
let s = array![[1.0_f64, 0.0, 0.0], [0.0, 1.0, 0.0]];
let t = array![[0.0_f64, 1.0, 0.0], [0.0, 0.0, 1.0]];
let loss = attention_transfer_loss(s.view(), t.view()).expect("attn loss different");
assert!(loss > 0.0);
}
#[test]
fn test_distillation_stats_accumulation() {
let (student, teacher) = batch_logits();
let labels = vec![0usize, 1, 2];
let config = DistillationConfig::default();
let mut stats = DistillationStats::default();
for _ in 0..5 {
let comp =
distillation_loss_detailed(student.view(), teacher.view(), &labels, &config)
.expect("detail failed");
stats.record(&comp);
}
assert_eq!(stats.n_batches, 5);
assert!(stats.avg_total_loss() > 0.0);
}
#[test]
fn test_calibrate_temperature() {
let (student, teacher) = batch_logits();
let best_t = calibrate_temperature(student.view(), teacher.view(), (1.0, 10.0), 20)
.expect("calibrate_temperature failed");
assert!(best_t >= 1.0 && best_t <= 10.0);
}
#[test]
fn test_invalid_alpha() {
let config = DistillationConfig {
temperature: 4.0,
alpha: 1.5,
};
assert!(config.validate().is_err());
}
#[test]
fn test_invalid_temperature() {
let config = DistillationConfig {
temperature: -1.0,
alpha: 0.5,
};
assert!(config.validate().is_err());
}
#[test]
fn test_shape_mismatch_error() {
let a = array![[1.0_f64, 2.0], [3.0, 4.0]];
let b = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
let labels = vec![0usize, 1];
let config = DistillationConfig::default();
assert!(distillation_loss(a.view(), b.view(), &labels, &config).is_err());
}
#[test]
fn test_label_out_of_range_error() {
let (student, teacher) = batch_logits();
let labels = vec![0usize, 1, 99]; let config = DistillationConfig::default();
assert!(distillation_loss(student.view(), teacher.view(), &labels, &config).is_err());
}
}