use super::distillation::softmax_with_temp;
#[allow(clippy::wildcard_imports)]
use super::*;
use crate::autograd::Tensor;
use crate::nn::Module;
impl<E: TransferEncoder> Module for DomainAdapter<E> {
fn forward(&self, input: &Tensor) -> Tensor {
self.encode(input)
}
fn parameters(&self) -> Vec<&Tensor> {
let mut params = self.encoder.parameters();
params.extend(self.discriminator.parameters());
params
}
fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
let mut params = self.encoder.parameters_mut();
params.extend(self.discriminator.parameters_mut());
params
}
}
#[derive(Debug, Clone)]
pub struct LoRAConfig {
pub rank: usize,
pub alpha: f32,
pub target_modules: Vec<String>,
pub dropout: f32,
}
impl LoRAConfig {
#[must_use]
pub fn new(rank: usize, alpha: f32) -> Self {
Self {
rank,
alpha,
target_modules: vec!["q_proj".to_string(), "v_proj".to_string()],
dropout: 0.0,
}
}
#[must_use]
pub fn with_target_modules(mut self, modules: Vec<String>) -> Self {
self.target_modules = modules;
self
}
#[must_use]
pub fn with_dropout(mut self, dropout: f32) -> Self {
self.dropout = dropout;
self
}
#[must_use]
pub fn scaling(&self) -> f32 {
self.alpha / self.rank as f32
}
}
impl Default for LoRAConfig {
fn default() -> Self {
Self::new(8, 8.0)
}
}
#[derive(Debug)]
pub struct LoRAAdapter {
pub lora_a: Tensor,
pub lora_b: Tensor,
pub config: LoRAConfig,
}
impl LoRAAdapter {
#[must_use]
pub fn new(input_dim: usize, output_dim: usize, config: LoRAConfig) -> Self {
let scale = 0.01;
let a_data: Vec<f32> = (0..config.rank * input_dim)
.map(|i| {
((i % 7) as f32 - 3.0) * scale
})
.collect();
let lora_a = Tensor::new(&a_data, &[config.rank, input_dim]).requires_grad();
let lora_b = Tensor::zeros(&[output_dim, config.rank]).requires_grad();
Self {
lora_a,
lora_b,
config,
}
}
#[must_use]
pub fn apply(&self, base_weight: &Tensor) -> Tensor {
let ba = self.lora_b.matmul(&self.lora_a);
let scaled = ba.mul_scalar(self.config.scaling());
base_weight.add(&scaled)
}
#[must_use]
pub fn delta_weight(&self) -> Tensor {
self.lora_b
.matmul(&self.lora_a)
.mul_scalar(self.config.scaling())
}
}
#[derive(Debug, Clone)]
pub struct KnowledgeDistillation {
temperature: f32,
alpha: f32,
}
impl KnowledgeDistillation {
#[must_use]
pub fn new(temperature: f32, alpha: f32) -> Self {
assert!(temperature > 0.0, "Temperature must be positive");
assert!((0.0..=1.0).contains(&alpha), "Alpha must be in [0, 1]");
Self { temperature, alpha }
}
#[must_use]
pub fn distillation_loss(&self, student_logits: &[f32], teacher_logits: &[f32]) -> f32 {
let student_soft = softmax_with_temp(student_logits, self.temperature);
let teacher_soft = softmax_with_temp(teacher_logits, self.temperature);
let eps = 1e-10;
let kl: f32 = teacher_soft
.iter()
.zip(student_soft.iter())
.map(|(&t, &s)| t * ((t + eps) / (s + eps)).ln())
.sum();
kl * self.temperature * self.temperature
}
#[must_use]
pub fn combined_loss(
&self,
student_logits: &[f32],
teacher_logits: &[f32],
task_loss: f32,
) -> f32 {
let distill = self.distillation_loss(student_logits, teacher_logits);
self.alpha * distill + (1.0 - self.alpha) * task_loss
}
#[must_use]
pub fn temperature(&self) -> f32 {
self.temperature
}
#[must_use]
pub fn alpha(&self) -> f32 {
self.alpha
}
}
#[derive(Debug, Clone)]
pub struct FeatureDistillation {
loss_type: FeatureLossType,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum FeatureLossType {
MSE,
MAE,
Cosine,
}
impl FeatureDistillation {
#[must_use]
pub fn new(loss_type: FeatureLossType) -> Self {
Self { loss_type }
}
#[must_use]
pub fn compute_loss(&self, student: &[f32], teacher: &[f32]) -> f32 {
assert_eq!(student.len(), teacher.len());
match self.loss_type {
FeatureLossType::MSE => {
student
.iter()
.zip(teacher.iter())
.map(|(&s, &t)| (s - t).powi(2))
.sum::<f32>()
/ student.len() as f32
}
FeatureLossType::MAE => {
student
.iter()
.zip(teacher.iter())
.map(|(&s, &t)| (s - t).abs())
.sum::<f32>()
/ student.len() as f32
}
FeatureLossType::Cosine => {
let dot: f32 = student
.iter()
.zip(teacher.iter())
.map(|(&s, &t)| s * t)
.sum();
let norm_s: f32 = student.iter().map(|&s| s * s).sum::<f32>().sqrt();
let norm_t: f32 = teacher.iter().map(|&t| t * t).sum::<f32>().sqrt();
let cosine = dot / (norm_s * norm_t + 1e-10);
1.0 - cosine }
}
}
}
#[derive(Debug, Clone)]
pub struct AttentionTransfer {
p: usize,
}
impl AttentionTransfer {
#[must_use]
pub fn new(p: usize) -> Self {
Self { p }
}
#[allow(clippy::needless_range_loop)]
#[must_use]
pub fn compute_attention_map(
&self,
activations: &[f32],
channels: usize,
spatial: usize,
) -> Vec<f32> {
let mut attention = vec![0.0_f32; spatial];
for c in 0..channels {
for s in 0..spatial {
let idx = c * spatial + s;
if idx < activations.len() {
attention[s] += activations[idx].abs().powi(self.p as i32);
}
}
}
let norm: f32 = attention.iter().map(|&a| a * a).sum::<f32>().sqrt();
if norm > 1e-10 {
for a in &mut attention {
*a /= norm;
}
}
attention
}
#[must_use]
pub fn compute_loss(
&self,
student_acts: &[f32],
teacher_acts: &[f32],
channels: usize,
spatial: usize,
) -> f32 {
let student_att = self.compute_attention_map(student_acts, channels, spatial);
let teacher_att = self.compute_attention_map(teacher_acts, channels, spatial);
student_att
.iter()
.zip(teacher_att.iter())
.map(|(&s, &t)| (s - t).powi(2))
.sum::<f32>()
/ spatial as f32
}
}
#[derive(Debug, Clone)]
pub struct SelfDistillation {
temperature: f32,
layer_pairs: Vec<(usize, usize)>,
}
impl SelfDistillation {
#[must_use]
pub fn new(temperature: f32) -> Self {
Self {
temperature,
layer_pairs: Vec::new(),
}
}
#[must_use]
pub fn add_layer_pair(mut self, teacher_idx: usize, student_idx: usize) -> Self {
self.layer_pairs.push((teacher_idx, student_idx));
self
}
#[must_use]
pub fn layer_pairs(&self) -> &[(usize, usize)] {
&self.layer_pairs
}
#[must_use]
pub fn layer_loss(&self, student_output: &[f32], teacher_output: &[f32]) -> f32 {
let student_soft = softmax_with_temp(student_output, self.temperature);
let teacher_soft = softmax_with_temp(teacher_output, self.temperature);
let eps = 1e-10;
teacher_soft
.iter()
.zip(student_soft.iter())
.map(|(&t, &s)| t * ((t + eps) / (s + eps)).ln())
.sum::<f32>()
* self.temperature
* self.temperature
}
}
#[derive(Debug, Clone)]
pub struct OnlineDistillation {
pub(crate) num_networks: usize,
pub(crate) temperature: f32,
pub(crate) mutual_weight: f32,
}
#[cfg(test)]
#[path = "tests_lora_contract.rs"]
mod tests_lora_contract;