use crate::error::{AprenderError, Result};
#[derive(Debug, Clone)]
pub struct HiddenStateConfig {
pub teacher_dim: usize,
pub student_dim: usize,
pub layer_map: Vec<(usize, usize)>,
pub hidden_loss_weight: f64,
pub projection_lr: f64,
}
impl Default for HiddenStateConfig {
fn default() -> Self {
Self {
teacher_dim: 768,
student_dim: 256,
layer_map: vec![(3, 1), (7, 2), (11, 3)], hidden_loss_weight: 0.5,
projection_lr: 0.001,
}
}
}
#[derive(Debug, Clone)]
pub struct HiddenProjection {
weights: Vec<f64>,
dim_in: usize,
dim_out: usize,
}
impl HiddenProjection {
#[must_use]
pub fn new(dim_in: usize, dim_out: usize, seed: u64) -> Self {
let scale = (2.0 / (dim_in + dim_out) as f64).sqrt();
let mut weights = Vec::with_capacity(dim_out * dim_in);
let mut state = seed;
for _ in 0..dim_out * dim_in {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let u = (state >> 33) as f64 / (1u64 << 31) as f64 - 1.0;
weights.push(u * scale);
}
Self {
weights,
dim_in,
dim_out,
}
}
#[must_use]
pub fn forward(&self, teacher_hidden: &[f64]) -> Vec<f64> {
let n = teacher_hidden.len().min(self.dim_in);
let mut output = vec![0.0; self.dim_out];
for i in 0..self.dim_out {
for j in 0..n {
output[i] += self.weights[i * self.dim_in + j] * teacher_hidden[j];
}
}
output
}
pub fn update(&mut self, teacher_hidden: &[f64], student_hidden: &[f64], lr: f64) {
let projected = self.forward(teacher_hidden);
let n = self.dim_out as f64;
for i in 0..self.dim_out {
let error = projected[i] - student_hidden.get(i).copied().unwrap_or(0.0);
for j in 0..self.dim_in {
let grad = 2.0 * error * teacher_hidden.get(j).copied().unwrap_or(0.0) / n;
self.weights[i * self.dim_in + j] -= lr * grad;
}
}
}
#[must_use]
pub fn mse_loss(&self, teacher_hidden: &[f64], student_hidden: &[f64]) -> f64 {
let projected = self.forward(teacher_hidden);
let n = self.dim_out;
if n == 0 {
return 0.0;
}
projected
.iter()
.zip(student_hidden.iter())
.map(|(&p, &s)| (p - s).powi(2))
.sum::<f64>()
/ n as f64
}
}
#[derive(Debug, Clone)]
pub struct HiddenStateDistiller {
projections: Vec<HiddenProjection>,
config: HiddenStateConfig,
}
impl HiddenStateDistiller {
#[must_use]
pub fn new(config: HiddenStateConfig) -> Self {
let projections = config
.layer_map
.iter()
.enumerate()
.map(|(i, _)| {
HiddenProjection::new(config.teacher_dim, config.student_dim, 42 + i as u64)
})
.collect();
Self {
projections,
config,
}
}
#[must_use]
pub fn hidden_loss(&self, teacher_hiddens: &[Vec<f64>], student_hiddens: &[Vec<f64>]) -> f64 {
let mut total_loss = 0.0;
for (idx, &(t_layer, s_layer)) in self.config.layer_map.iter().enumerate() {
if let (Some(th), Some(sh)) =
(teacher_hiddens.get(t_layer), student_hiddens.get(s_layer))
{
total_loss += self.projections[idx].mse_loss(th, sh);
}
}
total_loss / self.config.layer_map.len().max(1) as f64
}
pub fn update_projections(
&mut self,
teacher_hiddens: &[Vec<f64>],
student_hiddens: &[Vec<f64>],
) {
for (idx, &(t_layer, s_layer)) in self.config.layer_map.iter().enumerate() {
if let (Some(th), Some(sh)) =
(teacher_hiddens.get(t_layer), student_hiddens.get(s_layer))
{
self.projections[idx].update(th, sh, self.config.projection_lr);
}
}
}
#[must_use]
pub fn layer_map(&self) -> &[(usize, usize)] {
&self.config.layer_map
}
#[must_use]
pub fn num_projections(&self) -> usize {
self.projections.len()
}
}
#[derive(Debug, Clone)]
pub struct QuantAwareConfig {
pub bits: u32,
pub symmetric: bool,
pub error_diffusion: f64,
pub poly_degree: usize,
}
impl Default for QuantAwareConfig {
fn default() -> Self {
Self {
bits: 4,
symmetric: false,
error_diffusion: 0.5,
poly_degree: 3,
}
}
}
#[derive(Debug, Clone)]
pub struct QuantAwareDistiller {
config: QuantAwareConfig,
}
impl QuantAwareDistiller {
#[must_use]
pub fn new(config: QuantAwareConfig) -> Self {
Self { config }
}
#[must_use]
pub fn fake_quantize(&self, weights: &[f64]) -> Vec<f64> {
if weights.is_empty() {
return vec![];
}
let (qmin, qmax, scale, zero_point) = self.compute_quant_params(weights);
weights
.iter()
.map(|&w| {
let q = (w / scale + zero_point).round().clamp(qmin, qmax);
(q - zero_point) * scale
})
.collect()
}
#[must_use]
pub fn fake_quantize_diffused(&self, weights: &[f64]) -> Vec<f64> {
if weights.is_empty() {
return vec![];
}
let (qmin, qmax, scale, zero_point) = self.compute_quant_params(weights);
let diffusion = self.config.error_diffusion;
let mut result = Vec::with_capacity(weights.len());
let mut error_accum = 0.0;
for &w in weights {
let adjusted = w + diffusion * error_accum;
let q = (adjusted / scale + zero_point).round().clamp(qmin, qmax);
let dequantized = (q - zero_point) * scale;
error_accum = adjusted - dequantized;
result.push(dequantized);
}
result
}
fn compute_quant_params(&self, weights: &[f64]) -> (f64, f64, f64, f64) {
let levels = (1u64 << self.config.bits) as f64;
if self.config.symmetric {
let max_abs = weights
.iter()
.map(|w| w.abs())
.fold(0.0_f64, f64::max)
.max(1e-10);
let qmax = levels / 2.0 - 1.0;
let qmin = -qmax;
let scale = max_abs / qmax;
(qmin, qmax, scale, 0.0)
} else {
let min_val = weights.iter().cloned().fold(f64::INFINITY, f64::min);
let max_val = weights.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let range = (max_val - min_val).max(1e-10);
let qmin = 0.0;
let qmax = levels - 1.0;
let scale = range / qmax;
let zero_point = (-min_val / scale).round();
(qmin, qmax, scale, zero_point)
}
}
#[must_use]
pub fn quantization_error(&self, weights: &[f64]) -> f64 {
let quantized = self.fake_quantize(weights);
if weights.is_empty() {
return 0.0;
}
weights
.iter()
.zip(quantized.iter())
.map(|(&w, &q)| (w - q).powi(2))
.sum::<f64>()
/ weights.len() as f64
}
pub fn polynomial_activation_approx(
&self,
x_values: &[f64],
y_values: &[f64],
) -> Result<Vec<f64>> {
if x_values.len() != y_values.len() {
return Err(AprenderError::dimension_mismatch(
"x/y values",
x_values.len(),
y_values.len(),
));
}
let n = x_values.len();
let degree = self.config.poly_degree;
if n <= degree {
return Err(AprenderError::FormatError {
message: format!(
"Need at least {} data points for degree-{} polynomial, got {}",
degree + 1,
degree,
n
),
});
}
let cols = degree + 1;
let mut xtx = vec![0.0; cols * cols];
let mut xty = vec![0.0; cols];
for i in 0..n {
let mut xi_powers = vec![1.0; cols];
for j in 1..cols {
xi_powers[j] = xi_powers[j - 1] * x_values[i];
}
for r in 0..cols {
for c in 0..cols {
xtx[r * cols + c] += xi_powers[r] * xi_powers[c];
}
xty[r] += xi_powers[r] * y_values[i];
}
}
solve_linear_system(&xtx, &xty, cols)
}
#[must_use]
pub fn bits(&self) -> u32 {
self.config.bits
}
}
fn solve_linear_system(a: &[f64], b: &[f64], n: usize) -> Result<Vec<f64>> {
let mut aug = vec![0.0; n * (n + 1)];
for i in 0..n {
for j in 0..n {
aug[i * (n + 1) + j] = a[i * n + j];
}
aug[i * (n + 1) + n] = b[i];
}
for col in 0..n {
let mut max_row = col;
let mut max_val = aug[col * (n + 1) + col].abs();
for row in (col + 1)..n {
let val = aug[row * (n + 1) + col].abs();
if val > max_val {
max_val = val;
max_row = row;
}
}
if max_val < 1e-12 {
return Err(AprenderError::FormatError {
message: "Singular matrix in polynomial fit".to_string(),
});
}
if max_row != col {
for j in 0..=n {
aug.swap(col * (n + 1) + j, max_row * (n + 1) + j);
}
}
let pivot = aug[col * (n + 1) + col];
for row in (col + 1)..n {
let factor = aug[row * (n + 1) + col] / pivot;
for j in col..=n {
let above = aug[col * (n + 1) + j];
aug[row * (n + 1) + j] -= factor * above;
}
}
}
let mut x = vec![0.0; n];
for i in (0..n).rev() {
x[i] = aug[i * (n + 1) + n];
for j in (i + 1)..n {
x[i] -= aug[i * (n + 1) + j] * x[j];
}
x[i] /= aug[i * (n + 1) + i];
}
Ok(x)
}
#[derive(Debug, Clone)]
pub struct OnlineDistillConfig {
pub temperature: f64,
pub alpha: f64,
pub ema_decay: Option<f64>,
pub buffer_size: usize,
}
impl Default for OnlineDistillConfig {
fn default() -> Self {
Self {
temperature: 3.0,
alpha: 0.7,
ema_decay: Some(0.999),
buffer_size: 1024,
}
}
}
#[derive(Debug, Clone)]
pub struct OnlineDistiller {
config: OnlineDistillConfig,
ema_logits: Option<Vec<f64>>,
update_count: usize,
}
impl OnlineDistiller {
#[must_use]
pub fn new(config: OnlineDistillConfig) -> Self {
Self {
config,
ema_logits: None,
update_count: 0,
}
}
pub fn step(
&mut self,
student_logits: &[f64],
teacher_logits: &[f64],
hard_labels: &[f64],
) -> Result<f64> {
if student_logits.len() != teacher_logits.len() || student_logits.len() != hard_labels.len()
{
return Err(AprenderError::dimension_mismatch(
"logits/labels",
student_logits.len(),
teacher_logits.len(),
));
}
let effective_teacher = if let Some(decay) = self.config.ema_decay {
self.update_ema(teacher_logits, decay);
self.ema_logits.as_deref().unwrap_or(teacher_logits)
} else {
teacher_logits
};
let t = self.config.temperature;
let teacher_soft = super::distillation::softmax_temperature(effective_teacher, t);
let student_soft = super::distillation::softmax_temperature(student_logits, t);
let student_hard = super::distillation::softmax(student_logits);
let kl_loss = super::distillation::kl_divergence(&student_soft, &teacher_soft);
let distill_loss = t * t * kl_loss;
let hard_loss = super::distillation::cross_entropy(&student_hard, hard_labels);
let total = self.config.alpha * distill_loss + (1.0 - self.config.alpha) * hard_loss;
self.update_count += 1;
Ok(total)
}
fn update_ema(&mut self, teacher_logits: &[f64], decay: f64) {
match &mut self.ema_logits {
Some(ema) if ema.len() == teacher_logits.len() => {
for (e, &t) in ema.iter_mut().zip(teacher_logits.iter()) {
*e = decay * *e + (1.0 - decay) * t;
}
}
_ => {
self.ema_logits = Some(teacher_logits.to_vec());
}
}
}
#[must_use]
pub fn update_count(&self) -> usize {
self.update_count
}
#[must_use]
pub fn ema_logits(&self) -> Option<&[f64]> {
self.ema_logits.as_deref()
}
pub fn reset(&mut self) {
self.ema_logits = None;
self.update_count = 0;
}
}
#[cfg(test)]
#[path = "distillation_advanced_tests.rs"]
mod tests;