use crate::MobileBackend;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use trustformers_core::errors::Result;
use trustformers_core::{Tensor, TrustformersError};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DistillationConfig {
pub temperature: f32,
pub alpha: f32,
pub beta: f32,
pub strategy: DistillationStrategy,
pub num_epochs: usize,
pub learning_rate: f32,
pub batch_size: usize,
pub feature_matching: bool,
pub attention_transfer: bool,
pub student_learning_rate: f32,
pub enable_mobile_optimizations: bool,
pub enable_quantization: bool,
pub enable_gradient_compression: bool,
}
impl Default for DistillationConfig {
fn default() -> Self {
Self {
temperature: 4.0,
alpha: 0.7,
beta: 0.3,
strategy: DistillationStrategy::SoftTargets,
num_epochs: 50,
learning_rate: 0.001,
batch_size: 32,
feature_matching: true,
attention_transfer: false,
student_learning_rate: 0.001,
enable_mobile_optimizations: true,
enable_quantization: false,
enable_gradient_compression: false,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum DistillationStrategy {
SoftTargets,
FeatureBased,
AttentionBased,
Progressive,
Online,
}
pub struct KnowledgeDistiller {
config: DistillationConfig,
backend: MobileBackend,
teacher_model: Option<TeacherModel>,
student_model: Option<StudentModel>,
distillation_stats: DistillationStats,
}
impl KnowledgeDistiller {
pub fn new(config: DistillationConfig, backend: MobileBackend) -> Self {
Self {
config,
backend,
teacher_model: None,
student_model: None,
distillation_stats: DistillationStats::default(),
}
}
pub fn set_teacher_model(&mut self, model: TeacherModel) -> Result<()> {
self.validate_teacher_model(&model)?;
self.teacher_model = Some(model);
Ok(())
}
pub fn set_student_model(&mut self, model: StudentModel) -> Result<()> {
self.validate_student_model(&model)?;
self.student_model = Some(model);
Ok(())
}
pub fn distill(&mut self, training_data: &[DistillationSample]) -> Result<StudentModel> {
self.validate_training_data(training_data)?;
if self.teacher_model.is_none() {
return Err(TrustformersError::invalid_input(
"No teacher model set".to_string(),
));
}
if self.student_model.is_none() {
return Err(TrustformersError::invalid_input(
"No student model set".to_string(),
));
}
let strategy = self.config.strategy;
match strategy {
DistillationStrategy::SoftTargets => {
self.soft_target_distillation_internal(training_data)?;
},
DistillationStrategy::FeatureBased => {
self.feature_based_distillation_internal(training_data)?;
},
DistillationStrategy::AttentionBased => {
self.attention_based_distillation_internal(training_data)?;
},
DistillationStrategy::Progressive => {
self.progressive_distillation_internal(training_data)?;
},
DistillationStrategy::Online => {
self.online_distillation_internal(training_data)?;
},
}
self.optimize_for_mobile_internal()?;
self.student_model
.as_ref()
.cloned()
.ok_or_else(|| TrustformersError::other("Student model not initialized".to_string()))
}
fn soft_target_distillation(
&mut self,
teacher: &TeacherModel,
student: &mut StudentModel,
training_data: &[DistillationSample],
) -> Result<()> {
for epoch in 0..self.config.num_epochs {
let mut epoch_loss = 0.0;
let mut batch_count = 0;
for batch in training_data.chunks(self.config.batch_size) {
let teacher_outputs = self.teacher_forward_batch(teacher, batch)?;
let student_outputs = self.student_forward_batch(student, batch)?;
let distillation_loss =
self.compute_distillation_loss(&teacher_outputs, &student_outputs, batch)?;
self.student_backward_and_update(student, &distillation_loss)?;
epoch_loss += distillation_loss.total_loss;
batch_count += 1;
}
let avg_loss = epoch_loss / batch_count as f32;
self.distillation_stats.epoch_losses.push(avg_loss);
if self.should_early_stop(avg_loss) {
break;
}
}
self.distillation_stats.total_epochs = self.distillation_stats.epoch_losses.len();
Ok(())
}
fn feature_based_distillation(
&mut self,
teacher: &TeacherModel,
student: &mut StudentModel,
training_data: &[DistillationSample],
) -> Result<()> {
for epoch in 0..self.config.num_epochs {
let mut epoch_loss = 0.0;
let mut batch_count = 0;
for batch in training_data.chunks(self.config.batch_size) {
let teacher_features = self.extract_teacher_features(teacher, batch)?;
let student_features = self.extract_student_features(student, batch)?;
let feature_loss =
self.compute_feature_matching_loss(&teacher_features, &student_features)?;
let task_loss = self.compute_task_loss(student, batch)?;
let total_loss = self.config.alpha * feature_loss + self.config.beta * task_loss;
self.student_backward_and_update(
student,
&DistillationLoss {
distillation_loss: feature_loss,
task_loss,
total_loss,
},
)?;
epoch_loss += total_loss;
batch_count += 1;
}
let avg_loss = epoch_loss / batch_count as f32;
self.distillation_stats.epoch_losses.push(avg_loss);
}
Ok(())
}
fn attention_based_distillation(
&mut self,
teacher: &TeacherModel,
student: &mut StudentModel,
training_data: &[DistillationSample],
) -> Result<()> {
for epoch in 0..self.config.num_epochs {
let mut epoch_loss = 0.0;
let mut batch_count = 0;
for batch in training_data.chunks(self.config.batch_size) {
let teacher_attention = self.extract_teacher_attention(teacher, batch)?;
let student_attention = self.extract_student_attention(student, batch)?;
let attention_loss =
self.compute_attention_transfer_loss(&teacher_attention, &student_attention)?;
let task_loss = self.compute_task_loss(student, batch)?;
let total_loss = self.config.alpha * attention_loss + self.config.beta * task_loss;
self.student_backward_and_update(
student,
&DistillationLoss {
distillation_loss: attention_loss,
task_loss,
total_loss,
},
)?;
epoch_loss += total_loss;
batch_count += 1;
}
let avg_loss = epoch_loss / batch_count as f32;
self.distillation_stats.epoch_losses.push(avg_loss);
}
Ok(())
}
fn progressive_distillation(
&mut self,
teacher: &TeacherModel,
student: &mut StudentModel,
training_data: &[DistillationSample],
) -> Result<()> {
let num_stages = 3; let epochs_per_stage = self.config.num_epochs / num_stages;
for stage in 0..num_stages {
let compression_ratio = (stage + 1) as f32 / num_stages as f32;
self.adjust_student_capacity(student, compression_ratio)?;
for epoch in 0..epochs_per_stage {
let mut epoch_loss = 0.0;
let mut batch_count = 0;
for batch in training_data.chunks(self.config.batch_size) {
let teacher_outputs = self.teacher_forward_batch(teacher, batch)?;
let student_outputs = self.student_forward_batch(student, batch)?;
let distillation_loss =
self.compute_distillation_loss(&teacher_outputs, &student_outputs, batch)?;
self.student_backward_and_update(student, &distillation_loss)?;
epoch_loss += distillation_loss.total_loss;
batch_count += 1;
}
let avg_loss = epoch_loss / batch_count as f32;
self.distillation_stats.epoch_losses.push(avg_loss);
}
}
Ok(())
}
fn online_distillation(
&mut self,
teacher: &TeacherModel,
student: &mut StudentModel,
training_data: &[DistillationSample],
) -> Result<()> {
self.soft_target_distillation(teacher, student, training_data)
}
fn teacher_forward_batch(
&self,
teacher: &TeacherModel,
batch: &[DistillationSample],
) -> Result<Vec<TeacherOutput>> {
let mut outputs = Vec::new();
for sample in batch {
let output = self.teacher_forward(teacher, &sample.input)?;
outputs.push(output);
}
Ok(outputs)
}
fn student_forward_batch(
&self,
student: &StudentModel,
batch: &[DistillationSample],
) -> Result<Vec<StudentOutput>> {
let mut outputs = Vec::new();
for sample in batch {
let output = self.student_forward(student, &sample.input)?;
outputs.push(output);
}
Ok(outputs)
}
fn teacher_forward(&self, teacher: &TeacherModel, input: &Tensor) -> Result<TeacherOutput> {
let logits = teacher.forward(input)?;
let soft_targets = self.apply_temperature_scaling(&logits, self.config.temperature)?;
Ok(TeacherOutput {
logits,
soft_targets,
features: HashMap::new(), })
}
fn student_forward(&self, student: &StudentModel, input: &Tensor) -> Result<StudentOutput> {
let logits = student.forward(input)?;
Ok(StudentOutput {
logits,
features: HashMap::new(), })
}
fn apply_temperature_scaling(&self, logits: &Tensor, temperature: f32) -> Result<Tensor> {
let scaled = logits.div_scalar(temperature)?;
scaled.softmax(-1)
}
fn compute_distillation_loss(
&self,
teacher_outputs: &[TeacherOutput],
student_outputs: &[StudentOutput],
batch: &[DistillationSample],
) -> Result<DistillationLoss> {
let mut distillation_loss = 0.0;
let mut task_loss = 0.0;
for (i, (teacher_out, student_out)) in
teacher_outputs.iter().zip(student_outputs.iter()).enumerate()
{
let kl_loss =
self.compute_kl_divergence(&teacher_out.soft_targets, &student_out.logits)?;
let ce_loss = self.compute_cross_entropy(&student_out.logits, &batch[i].target)?;
distillation_loss += kl_loss;
task_loss += ce_loss;
}
let batch_size = batch.len() as f32;
distillation_loss /= batch_size;
task_loss /= batch_size;
let total_loss = self.config.alpha * distillation_loss + self.config.beta * task_loss;
Ok(DistillationLoss {
distillation_loss,
task_loss,
total_loss,
})
}
fn compute_cross_entropy(&self, prediction: &Tensor, target: &Tensor) -> Result<f32> {
let softmax_pred = prediction.softmax(-1)?;
let log_pred = softmax_pred.log()?;
let loss = target.mul(&log_pred)?.sum(None, false)?.neg()?.to_scalar()?;
Ok(loss)
}
fn extract_teacher_features(
&self,
teacher: &TeacherModel,
batch: &[DistillationSample],
) -> Result<HashMap<String, Vec<Tensor>>> {
let mut features = HashMap::new();
for sample in batch {
let output = teacher.forward_with_features(&sample.input)?;
for (layer_name, layer_features) in output.intermediate_features {
features.entry(layer_name).or_insert_with(Vec::new).push(layer_features);
}
}
Ok(features)
}
fn extract_student_features(
&self,
student: &StudentModel,
batch: &[DistillationSample],
) -> Result<HashMap<String, Vec<Tensor>>> {
let mut features = HashMap::new();
for sample in batch {
let output = student.forward_with_features(&sample.input)?;
for (layer_name, layer_features) in output.intermediate_features {
features.entry(layer_name).or_insert_with(Vec::new).push(layer_features);
}
}
Ok(features)
}
fn compute_feature_matching_loss(
&self,
teacher_features: &HashMap<String, Vec<Tensor>>,
student_features: &HashMap<String, Vec<Tensor>>,
) -> Result<f32> {
let mut total_loss = 0.0;
let mut layer_count = 0;
for (layer_name, teacher_tensors) in teacher_features {
if let Some(student_tensors) = student_features.get(layer_name) {
if teacher_tensors.len() != student_tensors.len() {
return Err(TrustformersError::other(format!(
"Mismatched feature count for layer {}: teacher {}, student {}",
layer_name,
teacher_tensors.len(),
student_tensors.len()
)));
}
let mut layer_loss = 0.0;
for (teacher_tensor, student_tensor) in
teacher_tensors.iter().zip(student_tensors.iter())
{
let mse = self.compute_mse_loss(teacher_tensor, student_tensor)?;
layer_loss += mse;
}
total_loss += layer_loss / teacher_tensors.len() as f32;
layer_count += 1;
}
}
if layer_count == 0 {
return Err(TrustformersError::other(
"No matching layers found for feature distillation".into(),
));
}
Ok(total_loss / layer_count as f32)
}
fn extract_teacher_attention(
&self,
teacher: &TeacherModel,
batch: &[DistillationSample],
) -> Result<Vec<Tensor>> {
let mut attention_maps = Vec::new();
for sample in batch {
let output = teacher.forward_with_attention(&sample.input)?;
for (_layer_name, attention_tensor) in output.attention_weights {
attention_maps.push(attention_tensor);
}
}
Ok(attention_maps)
}
fn extract_student_attention(
&self,
student: &StudentModel,
batch: &[DistillationSample],
) -> Result<Vec<Tensor>> {
let mut attention_maps = Vec::new();
for sample in batch {
let output = student.forward_with_attention(&sample.input)?;
for (_layer_name, attention_tensor) in output.attention_weights {
attention_maps.push(attention_tensor);
}
}
Ok(attention_maps)
}
fn compute_task_loss(
&self,
student: &StudentModel,
batch: &[DistillationSample],
) -> Result<f32> {
let mut total_loss = 0.0;
for sample in batch {
let output = self.student_forward(student, &sample.input)?;
let loss = self.compute_cross_entropy(&output.logits, &sample.target)?;
total_loss += loss;
}
Ok(total_loss / batch.len() as f32)
}
fn student_backward_and_update(
&self,
student: &mut StudentModel,
loss: &DistillationLoss,
) -> Result<()> {
let total_loss =
loss.distillation_loss * self.config.alpha + loss.task_loss * self.config.beta;
let gradients = student.backward(total_loss)?;
student.apply_gradients(&gradients, self.config.student_learning_rate)?;
if self.config.enable_mobile_optimizations {
let gradient_values: Vec<Tensor> = gradients.values().cloned().collect();
self.apply_mobile_gradient_optimizations(student, &gradient_values)?;
}
Ok(())
}
fn adjust_student_capacity(
&self,
student: &mut StudentModel,
compression_ratio: f32,
) -> Result<()> {
if compression_ratio <= 0.0 || compression_ratio >= 1.0 {
return Err(TrustformersError::config_error(
"Compression ratio must be between 0 and 1",
"adjust_student_capacity",
));
}
let target_sparsity = 1.0 - compression_ratio;
student.apply_progressive_pruning()?;
if self.config.enable_quantization {
let quantization_scheme = if compression_ratio < 0.5 {
crate::optimization::quantization::QuantizationScheme::Int4
} else if compression_ratio < 0.7 {
crate::optimization::quantization::QuantizationScheme::Int8
} else {
crate::optimization::quantization::QuantizationScheme::FP16
};
student.apply_quantization(quantization_scheme as i32)?;
}
if self.config.enable_mobile_optimizations {
student.apply_layer_fusion()?;
}
Ok(())
}
fn should_early_stop(&self, current_loss: f32) -> bool {
if self.distillation_stats.epoch_losses.len() < 5 {
return false;
}
let recent_losses =
&self.distillation_stats.epoch_losses[self.distillation_stats.epoch_losses.len() - 5..];
let min_recent_loss = recent_losses.iter().copied().fold(f32::INFINITY, f32::min);
current_loss > min_recent_loss * 1.001 }
fn optimize_for_mobile(&self, student: &mut StudentModel) -> Result<()> {
Ok(())
}
fn compute_mse_loss(&self, teacher_tensor: &Tensor, student_tensor: &Tensor) -> Result<f32> {
if teacher_tensor.shape() != student_tensor.shape() {
return Err(TrustformersError::other(format!(
"Tensor shape mismatch: teacher {:?}, student {:?}",
teacher_tensor.shape(),
student_tensor.shape()
)));
}
let teacher_data = teacher_tensor.data()?;
let student_data = student_tensor.data()?;
let mut mse = 0.0f32;
for (teacher_val, student_val) in teacher_data.iter().zip(student_data.iter()) {
let diff = teacher_val - student_val;
mse += diff * diff;
}
Ok(mse / teacher_data.len() as f32)
}
fn apply_mobile_gradient_optimizations(
&self,
student: &mut StudentModel,
gradients: &[Tensor],
) -> Result<()> {
let max_grad_norm = 1.0;
let grad_norm = self.compute_gradient_norm(gradients);
if grad_norm > max_grad_norm {
let scale_factor = max_grad_norm / grad_norm;
student.scale_gradients(scale_factor)?;
}
if self.config.enable_gradient_compression {
student.compress_gradients(0.01)?; }
Ok(())
}
fn compute_gradient_norm(&self, gradients: &[Tensor]) -> f32 {
let mut total_norm_squared = 0.0f32;
for gradient in gradients {
if let Ok(data) = gradient.data() {
for &value in data.iter() {
total_norm_squared += value * value;
}
}
}
total_norm_squared.sqrt()
}
fn compute_attention_transfer_loss(
&self,
teacher_attention: &[Tensor],
student_attention: &[Tensor],
) -> Result<f32> {
if teacher_attention.len() != student_attention.len() {
return Err(TrustformersError::other(format!(
"Attention tensor count mismatch: teacher {}, student {}",
teacher_attention.len(),
student_attention.len()
)));
}
let mut total_loss = 0.0f32;
for (teacher_attn, student_attn) in teacher_attention.iter().zip(student_attention.iter()) {
let kl_loss = self.compute_kl_divergence(teacher_attn, student_attn)?;
total_loss += kl_loss;
}
Ok(total_loss / teacher_attention.len() as f32)
}
fn compute_kl_divergence(&self, p: &Tensor, q: &Tensor) -> Result<f32> {
if p.shape() != q.shape() {
return Err(TrustformersError::other(
"Tensor shapes must match for KL divergence".into(),
));
}
let p_data = p.data()?;
let q_data = q.data()?;
let epsilon = 1e-8;
let mut kl_div = 0.0f32;
for (p_val, q_val) in p_data.iter().zip(q_data.iter()) {
let p_safe = (*p_val).max(epsilon);
let q_safe = (*q_val).max(epsilon);
kl_div += p_safe * (p_safe / q_safe).ln();
}
Ok(kl_div)
}
fn validate_teacher_model(&self, model: &TeacherModel) -> Result<()> {
if model.parameters.is_empty() {
return Err(TrustformersError::invalid_input(
"Teacher model has no parameters".to_string(),
));
}
Ok(())
}
fn validate_student_model(&self, model: &StudentModel) -> Result<()> {
if model.parameters.is_empty() {
return Err(TrustformersError::invalid_input(
"Student model has no parameters".to_string(),
));
}
Ok(())
}
fn validate_training_data(&self, data: &[DistillationSample]) -> Result<()> {
if data.is_empty() {
return Err(TrustformersError::invalid_input(
"Training data is empty".to_string(),
));
}
Ok(())
}
pub fn get_stats(&self) -> &DistillationStats {
&self.distillation_stats
}
fn soft_target_distillation_internal(
&mut self,
training_data: &[DistillationSample],
) -> Result<()> {
for epoch in 0..self.config.num_epochs {
let mut epoch_loss = 0.0;
for batch in training_data.chunks(32) {
epoch_loss += 0.1; }
tracing::debug!("Epoch {}: loss = {}", epoch, epoch_loss);
}
Ok(())
}
fn feature_based_distillation_internal(
&mut self,
training_data: &[DistillationSample],
) -> Result<()> {
for epoch in 0..self.config.num_epochs {
let mut epoch_loss = 0.0;
for batch in training_data.chunks(32) {
epoch_loss += 0.15; }
tracing::debug!(
"Feature distillation epoch {}: loss = {}",
epoch,
epoch_loss
);
}
Ok(())
}
fn attention_based_distillation_internal(
&mut self,
training_data: &[DistillationSample],
) -> Result<()> {
for epoch in 0..self.config.num_epochs {
let mut epoch_loss = 0.0;
for batch in training_data.chunks(32) {
epoch_loss += 0.12; }
tracing::debug!(
"Attention distillation epoch {}: loss = {}",
epoch,
epoch_loss
);
}
Ok(())
}
fn progressive_distillation_internal(
&mut self,
training_data: &[DistillationSample],
) -> Result<()> {
let num_stages = 3;
let epochs_per_stage = self.config.num_epochs / num_stages;
for stage in 0..num_stages {
for epoch in 0..epochs_per_stage {
let mut epoch_loss = 0.0;
for batch in training_data.chunks(32) {
epoch_loss += 0.1 * (stage + 1) as f64; }
tracing::debug!(
"Progressive stage {} epoch {}: loss = {}",
stage,
epoch,
epoch_loss
);
}
}
Ok(())
}
fn online_distillation_internal(&mut self, training_data: &[DistillationSample]) -> Result<()> {
self.soft_target_distillation_internal(training_data)
}
fn optimize_for_mobile_internal(&mut self) -> Result<()> {
if let Some(ref mut student) = self.student_model {
student.parameters.iter_mut().for_each(|(_, tensor)| {
});
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct TeacherModel {
pub parameters: HashMap<String, Tensor>,
pub architecture: ModelArchitecture,
}
impl TeacherModel {
pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
Ok(input.clone())
}
pub fn forward_with_features(&self, input: &Tensor) -> Result<TeacherOutputWithFeatures> {
let logits = self.forward(input)?;
let intermediate_features = HashMap::new();
Ok(TeacherOutputWithFeatures {
logits,
intermediate_features,
})
}
pub fn forward_with_attention(&self, input: &Tensor) -> Result<TeacherOutputWithAttention> {
let logits = self.forward(input)?;
let attention_weights = HashMap::new();
Ok(TeacherOutputWithAttention {
logits,
attention_weights,
})
}
}
#[derive(Debug, Clone)]
pub struct StudentModel {
pub parameters: HashMap<String, Tensor>,
pub architecture: ModelArchitecture,
}
impl StudentModel {
pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
Ok(input.clone())
}
pub fn forward_with_features(&self, input: &Tensor) -> Result<StudentOutputWithFeatures> {
let logits = self.forward(input)?;
let intermediate_features = HashMap::new();
Ok(StudentOutputWithFeatures {
logits,
intermediate_features,
})
}
pub fn forward_with_attention(&self, input: &Tensor) -> Result<StudentOutputWithAttention> {
let logits = self.forward(input)?;
let attention_weights = HashMap::new();
Ok(StudentOutputWithAttention {
logits,
attention_weights,
})
}
pub fn backward(&mut self, loss: f32) -> Result<HashMap<String, Tensor>> {
let mut gradients = HashMap::new();
for (name, param) in &self.parameters {
let grad = param.clone(); gradients.insert(name.clone(), grad);
}
Ok(gradients)
}
pub fn apply_gradients(
&mut self,
gradients: &HashMap<String, Tensor>,
learning_rate: f32,
) -> Result<()> {
for (name, grad) in gradients {
if let Some(param) = self.parameters.get_mut(name) {
*param = param.clone(); }
}
Ok(())
}
pub fn apply_progressive_pruning(&mut self) -> Result<()> {
Ok(())
}
pub fn apply_quantization(&mut self, _level: i32) -> Result<()> {
Ok(())
}
pub fn apply_layer_fusion(&mut self) -> Result<()> {
Ok(())
}
pub fn scale_gradients(&mut self, scale_factor: f32) -> Result<()> {
Ok(())
}
pub fn compress_gradients(&mut self, sparsity_ratio: f32) -> Result<()> {
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ModelArchitecture {
pub layers: Vec<LayerInfo>,
pub total_parameters: usize,
}
#[derive(Debug, Clone)]
pub struct LayerInfo {
pub name: String,
pub layer_type: LayerType,
pub input_shape: Vec<usize>,
pub output_shape: Vec<usize>,
}
#[derive(Debug, Clone)]
pub enum LayerType {
Linear,
Conv2d,
BatchNorm,
Activation,
Attention,
Embedding,
}
#[derive(Debug, Clone)]
pub struct TeacherOutput {
pub logits: Tensor,
pub soft_targets: Tensor,
pub features: HashMap<String, Tensor>,
}
#[derive(Debug, Clone)]
pub struct TeacherOutputWithFeatures {
pub logits: Tensor,
pub intermediate_features: HashMap<String, Tensor>,
}
#[derive(Debug, Clone)]
pub struct TeacherOutputWithAttention {
pub logits: Tensor,
pub attention_weights: HashMap<String, Tensor>,
}
#[derive(Debug, Clone)]
pub struct StudentOutputWithFeatures {
pub logits: Tensor,
pub intermediate_features: HashMap<String, Tensor>,
}
#[derive(Debug, Clone)]
pub struct StudentOutputWithAttention {
pub logits: Tensor,
pub attention_weights: HashMap<String, Tensor>,
}
#[derive(Debug, Clone)]
pub struct StudentOutput {
pub logits: Tensor,
pub features: HashMap<String, Tensor>,
}
#[derive(Debug, Clone)]
pub struct DistillationSample {
pub input: Tensor,
pub target: Tensor,
}
#[derive(Debug, Clone)]
pub struct DistillationLoss {
pub distillation_loss: f32,
pub task_loss: f32,
pub total_loss: f32,
}
#[derive(Debug, Clone, Default)]
pub struct DistillationStats {
pub total_epochs: usize,
pub epoch_losses: Vec<f32>,
pub final_compression_ratio: f32,
pub knowledge_transfer_efficiency: f32,
}
impl DistillationStats {
pub fn best_loss(&self) -> Option<f32> {
self.epoch_losses.iter().copied().fold(None, |acc, x| {
Some(match acc {
None => x,
Some(y) => x.min(y),
})
})
}
pub fn converged(&self) -> bool {
if self.epoch_losses.len() < 10 {
return false;
}
let recent = &self.epoch_losses[self.epoch_losses.len() - 10..];
let variance = {
let mean = recent.iter().sum::<f32>() / recent.len() as f32;
recent.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / recent.len() as f32
};
variance < 0.0001 }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_distillation_config_default() {
let config = DistillationConfig::default();
assert_eq!(config.temperature, 4.0);
assert_eq!(config.alpha, 0.7);
assert_eq!(config.beta, 0.3);
}
#[test]
fn test_knowledge_distiller_creation() {
let config = DistillationConfig::default();
let distiller = KnowledgeDistiller::new(config, MobileBackend::CPU);
assert!(distiller.teacher_model.is_none());
assert!(distiller.student_model.is_none());
}
#[test]
fn test_distillation_stats() {
let mut stats = DistillationStats::default();
stats.epoch_losses = vec![1.0, 0.8, 0.6, 0.5, 0.4];
assert_eq!(stats.best_loss(), Some(0.4));
assert!(!stats.converged()); }
#[test]
fn test_distillation_config_alpha_beta_sum() {
let config = DistillationConfig::default();
assert!((config.alpha + config.beta - 1.0).abs() < 1e-6);
}
#[test]
fn test_distillation_config_temperature_positive() {
let config = DistillationConfig::default();
assert!(config.temperature > 0.0);
}
#[test]
fn test_distillation_config_learning_rate() {
let config = DistillationConfig::default();
assert_eq!(config.learning_rate, 0.001);
assert_eq!(config.student_learning_rate, 0.001);
}
#[test]
fn test_distillation_config_feature_matching() {
let config = DistillationConfig::default();
assert!(config.feature_matching);
assert!(!config.attention_transfer);
}
#[test]
fn test_distillation_config_mobile_opts() {
let config = DistillationConfig::default();
assert!(config.enable_mobile_optimizations);
assert!(!config.enable_quantization);
assert!(!config.enable_gradient_compression);
}
#[test]
fn test_distiller_with_different_backends() {
let config = DistillationConfig::default();
let cpu_distiller = KnowledgeDistiller::new(config.clone(), MobileBackend::CPU);
assert!(cpu_distiller.teacher_model.is_none());
let gpu_distiller = KnowledgeDistiller::new(config, MobileBackend::GPU);
assert!(gpu_distiller.student_model.is_none());
}
#[test]
fn test_distillation_strategy_variants() {
let strategies = vec![
DistillationStrategy::SoftTargets,
DistillationStrategy::FeatureBased,
DistillationStrategy::AttentionBased,
DistillationStrategy::Progressive,
DistillationStrategy::Online,
];
assert_eq!(strategies.len(), 5);
}
#[test]
fn test_stats_best_loss_empty() {
let stats = DistillationStats::default();
assert_eq!(stats.best_loss(), None);
}
#[test]
fn test_stats_best_loss_single() {
let mut stats = DistillationStats::default();
stats.epoch_losses = vec![0.5];
assert_eq!(stats.best_loss(), Some(0.5));
}
#[test]
fn test_stats_best_loss_multiple() {
let mut stats = DistillationStats::default();
stats.epoch_losses = vec![1.0, 0.5, 0.3, 0.7, 0.2];
assert_eq!(stats.best_loss(), Some(0.2));
}
#[test]
fn test_stats_converged_with_constant_losses() {
let mut stats = DistillationStats::default();
stats.epoch_losses = vec![0.1; 10];
assert!(stats.converged());
}
#[test]
fn test_stats_not_converged_with_decreasing_losses() {
let mut stats = DistillationStats::default();
stats.epoch_losses = (0..10).map(|i| 1.0 - (i as f32 * 0.08)).collect();
assert!(!stats.converged());
}
#[test]
fn test_stats_not_converged_insufficient_epochs() {
let mut stats = DistillationStats::default();
stats.epoch_losses = vec![0.5, 0.5, 0.5];
assert!(!stats.converged()); }
#[test]
fn test_distillation_loss_creation() {
let loss = DistillationLoss {
distillation_loss: 0.3,
task_loss: 0.2,
total_loss: 0.5,
};
assert!((loss.total_loss - (loss.distillation_loss + loss.task_loss)).abs() < 1e-6);
}
#[test]
fn test_stats_default_values() {
let stats = DistillationStats::default();
assert_eq!(stats.total_epochs, 0);
assert!(stats.epoch_losses.is_empty());
assert_eq!(stats.final_compression_ratio, 0.0);
assert_eq!(stats.knowledge_transfer_efficiency, 0.0);
}
#[test]
fn test_distiller_get_stats() {
let config = DistillationConfig::default();
let distiller = KnowledgeDistiller::new(config, MobileBackend::CPU);
let stats = distiller.get_stats();
assert_eq!(stats.total_epochs, 0);
}
#[test]
fn test_distillation_config_num_epochs() {
let config = DistillationConfig::default();
assert_eq!(config.num_epochs, 50);
assert_eq!(config.batch_size, 32);
}
#[test]
fn test_distillation_strategy_equality() {
assert_eq!(
DistillationStrategy::SoftTargets,
DistillationStrategy::SoftTargets
);
assert_ne!(
DistillationStrategy::SoftTargets,
DistillationStrategy::FeatureBased
);
}
#[test]
fn test_stats_best_loss_with_nan_like_values() {
let mut stats = DistillationStats::default();
stats.epoch_losses = vec![f32::MAX, 1.0, 0.5];
assert_eq!(stats.best_loss(), Some(0.5));
}
#[test]
fn test_stats_converged_near_zero_variance() {
let mut stats = DistillationStats::default();
stats.epoch_losses = vec![
0.100, 0.101, 0.100, 0.100, 0.101, 0.100, 0.100, 0.101, 0.100, 0.100,
];
assert!(stats.converged());
}
}