use crate::error::{AprenderError, Result};
pub const DEFAULT_TEMPERATURE: f64 = 3.0;
pub const DEFAULT_ALPHA: f64 = 0.7;
#[derive(Debug, Clone)]
pub struct DistillationConfig {
pub temperature: f64,
pub alpha: f64,
pub learning_rate: f64,
pub l2_reg: f64,
}
impl Default for DistillationConfig {
fn default() -> Self {
Self {
temperature: DEFAULT_TEMPERATURE,
alpha: DEFAULT_ALPHA,
learning_rate: 0.01,
l2_reg: 0.0,
}
}
}
impl DistillationConfig {
#[must_use]
pub fn with_temperature(mut self, temperature: f64) -> Self {
self.temperature = temperature;
self
}
#[must_use]
pub fn with_alpha(mut self, alpha: f64) -> Self {
self.alpha = alpha;
self
}
}
pub fn softmax_temperature(logits: &[f64], temperature: f64) -> Vec<f64> {
if logits.is_empty() {
return vec![];
}
let t = temperature.max(1e-10);
let scaled: Vec<f64> = logits.iter().map(|&z| z / t).collect();
crate::nn::functional::softmax_1d_f64(&scaled)
}
#[must_use]
pub fn softmax(logits: &[f64]) -> Vec<f64> {
crate::nn::functional::softmax_1d_f64(logits)
}
#[must_use]
pub fn kl_divergence(p: &[f64], q: &[f64]) -> f64 {
if p.len() != q.len() {
return f64::INFINITY;
}
let eps = 1e-15;
p.iter()
.zip(q.iter())
.map(|(&pi, &qi)| {
let pi = pi.clamp(eps, 1.0 - eps);
let qi = qi.clamp(eps, 1.0 - eps);
pi * (pi / qi).ln()
})
.sum()
}
#[must_use]
pub fn cross_entropy(probs: &[f64], targets: &[f64]) -> f64 {
if probs.len() != targets.len() {
return f64::INFINITY;
}
let eps = 1e-15;
probs
.iter()
.zip(targets.iter())
.map(|(&p, &y)| -y * p.clamp(eps, 1.0 - eps).ln())
.sum()
}
#[must_use]
pub fn binary_cross_entropy(prob: f64, target: f64) -> f64 {
let eps = 1e-15;
let p = prob.clamp(eps, 1.0 - eps);
-target * p.ln() - (1.0 - target) * (1.0 - p).ln()
}
#[derive(Debug, Clone)]
pub struct SoftTargetGenerator {
temperature: f64,
}
impl Default for SoftTargetGenerator {
fn default() -> Self {
Self::new()
}
}
impl SoftTargetGenerator {
#[must_use]
pub fn new() -> Self {
Self {
temperature: DEFAULT_TEMPERATURE,
}
}
#[must_use]
pub fn with_temperature(temperature: f64) -> Self {
Self { temperature }
}
#[must_use]
pub fn generate(&self, logits: &[f64]) -> Vec<f64> {
softmax_temperature(logits, self.temperature)
}
#[must_use]
pub fn generate_batch(&self, logits: &[f64], n_classes: usize) -> Vec<f64> {
if logits.is_empty() || n_classes == 0 || !logits.len().is_multiple_of(n_classes) {
return vec![];
}
let n_samples = logits.len() / n_classes;
let mut result = Vec::with_capacity(logits.len());
for i in 0..n_samples {
let sample_logits = &logits[i * n_classes..(i + 1) * n_classes];
result.extend(self.generate(sample_logits));
}
result
}
}
#[derive(Debug, Clone)]
pub struct DistillationLoss {
config: DistillationConfig,
}
impl Default for DistillationLoss {
fn default() -> Self {
Self::new()
}
}
impl DistillationLoss {
#[must_use]
pub fn new() -> Self {
Self {
config: DistillationConfig::default(),
}
}
#[must_use]
pub fn with_config(config: DistillationConfig) -> Self {
Self { config }
}
pub fn compute(
&self,
student_logits: &[f64],
teacher_logits: &[f64],
hard_labels: &[f64],
) -> Result<f64> {
if student_logits.len() != teacher_logits.len() || student_logits.len() != hard_labels.len()
{
return Err(AprenderError::dimension_mismatch(
"logits/labels",
student_logits.len(),
teacher_logits.len(),
));
}
let t = self.config.temperature;
let teacher_soft = softmax_temperature(teacher_logits, t);
let student_soft = softmax_temperature(student_logits, t);
let student_hard = softmax(student_logits);
let kl_loss = kl_divergence(&student_soft, &teacher_soft);
let distill_loss = t * t * kl_loss;
let hard_loss = cross_entropy(&student_hard, hard_labels);
let total = self.config.alpha * distill_loss + (1.0 - self.config.alpha) * hard_loss;
Ok(total)
}
pub fn gradient(
&self,
student_logits: &[f64],
teacher_logits: &[f64],
hard_labels: &[f64],
) -> Result<Vec<f64>> {
if student_logits.len() != teacher_logits.len() || student_logits.len() != hard_labels.len()
{
return Err(AprenderError::dimension_mismatch(
"logits/labels",
student_logits.len(),
teacher_logits.len(),
));
}
let t = self.config.temperature;
let teacher_soft = softmax_temperature(teacher_logits, t);
let student_soft = softmax_temperature(student_logits, t);
let student_hard = softmax(student_logits);
let grad: Vec<f64> = student_soft
.iter()
.zip(teacher_soft.iter())
.zip(student_hard.iter())
.zip(hard_labels.iter())
.map(|(((&ss, &ts), &sh), &hl)| {
let distill_grad = t * (ss - ts);
let hard_grad = sh - hl;
self.config.alpha * distill_grad + (1.0 - self.config.alpha) * hard_grad
})
.collect();
Ok(grad)
}
#[must_use]
pub fn config(&self) -> &DistillationConfig {
&self.config
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct LinearDistiller {
weights: Vec<f64>,
biases: Vec<f64>,
n_features: usize,
n_classes: usize,
loss: DistillationLoss,
}
impl LinearDistiller {
#[must_use]
pub fn new(n_features: usize, n_classes: usize) -> Self {
Self {
weights: vec![0.0; n_classes * n_features],
biases: vec![0.0; n_classes],
n_features,
n_classes,
loss: DistillationLoss::new(),
}
}
#[must_use]
pub fn with_config(n_features: usize, n_classes: usize, config: DistillationConfig) -> Self {
Self {
weights: vec![0.0; n_classes * n_features],
biases: vec![0.0; n_classes],
n_features,
n_classes,
loss: DistillationLoss::with_config(config),
}
}
pub fn forward(&self, features: &[f64]) -> Result<Vec<f64>> {
if features.len() != self.n_features {
return Err(AprenderError::dimension_mismatch(
"features",
self.n_features,
features.len(),
));
}
let mut logits = self.biases.clone();
for (c, logit) in logits.iter_mut().enumerate() {
for (f, &feat) in features.iter().enumerate() {
*logit += self.weights[c * self.n_features + f] * feat;
}
}
Ok(logits)
}
pub fn train_step(
&mut self,
features: &[f64],
teacher_logits: &[f64],
hard_labels: &[f64],
) -> Result<f64> {
let student_logits = self.forward(features)?;
let loss_val = self
.loss
.compute(&student_logits, teacher_logits, hard_labels)?;
let grad = self
.loss
.gradient(&student_logits, teacher_logits, hard_labels)?;
let lr = self.loss.config().learning_rate;
let l2 = self.loss.config().l2_reg;
for (c, (&g, bias)) in grad.iter().zip(self.biases.iter_mut()).enumerate() {
for (f, &feat) in features.iter().enumerate() {
let idx = c * self.n_features + f;
let weight_grad = g * feat + l2 * self.weights[idx];
self.weights[idx] -= lr * weight_grad;
}
*bias -= lr * g;
}
Ok(loss_val)
}
#[must_use]
pub fn weights(&self) -> &[f64] {
&self.weights
}
#[must_use]
pub fn biases(&self) -> &[f64] {
&self.biases
}
pub fn predict_proba(&self, features: &[f64]) -> Result<Vec<f64>> {
let logits = self.forward(features)?;
Ok(softmax(&logits))
}
pub fn predict(&self, features: &[f64]) -> Result<usize> {
let probs = self.predict_proba(features)?;
Ok(probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(i, _)| i))
}
}
#[derive(Debug, Clone)]
pub struct DistillationResult {
pub final_loss: f64,
pub n_samples: usize,
pub loss_history: Vec<f64>,
pub train_accuracy: Option<f64>,
}
#[cfg(test)]
#[path = "distillation_tests.rs"]
mod tests;