1use ghostflow_core::Tensor;
11use std::collections::HashMap;
12
13#[derive(Debug, Clone)]
15pub struct DistillationConfig {
16 pub temperature: f32,
18 pub alpha: f32,
20 pub beta: f32,
22 pub method: DistillationMethod,
24 pub feature_layers: Vec<usize>,
26}
27
28#[derive(Debug, Clone, Copy, PartialEq)]
30pub enum DistillationMethod {
31 Standard,
33 Feature,
35 Attention,
37 FitNet,
39 Progressive,
41}
42
43impl Default for DistillationConfig {
44 fn default() -> Self {
45 DistillationConfig {
46 temperature: 4.0,
47 alpha: 0.7,
48 beta: 0.3,
49 method: DistillationMethod::Standard,
50 feature_layers: vec![],
51 }
52 }
53}
54
55impl DistillationConfig {
56 pub fn standard(temperature: f32, alpha: f32) -> Self {
58 DistillationConfig {
59 temperature,
60 alpha,
61 beta: 1.0 - alpha,
62 method: DistillationMethod::Standard,
63 ..Default::default()
64 }
65 }
66
67 pub fn feature_based(temperature: f32, feature_layers: Vec<usize>) -> Self {
69 DistillationConfig {
70 temperature,
71 method: DistillationMethod::Feature,
72 feature_layers,
73 ..Default::default()
74 }
75 }
76
77 pub fn attention_transfer(temperature: f32) -> Self {
79 DistillationConfig {
80 temperature,
81 method: DistillationMethod::Attention,
82 ..Default::default()
83 }
84 }
85}
86
87pub struct KnowledgeDistillation {
89 config: DistillationConfig,
90 teacher_outputs: HashMap<String, Tensor>,
91 student_outputs: HashMap<String, Tensor>,
92}
93
94impl KnowledgeDistillation {
95 pub fn new(config: DistillationConfig) -> Self {
97 KnowledgeDistillation {
98 config,
99 teacher_outputs: HashMap::new(),
100 student_outputs: HashMap::new(),
101 }
102 }
103
104 pub fn compute_loss(
106 &self,
107 student_logits: &Tensor,
108 teacher_logits: &Tensor,
109 targets: &Tensor,
110 ) -> Result<Tensor, String> {
111 match self.config.method {
112 DistillationMethod::Standard => {
113 self.standard_distillation_loss(student_logits, teacher_logits, targets)
114 }
115 DistillationMethod::Feature => {
116 self.feature_distillation_loss(student_logits, teacher_logits, targets)
117 }
118 DistillationMethod::Attention => {
119 self.attention_distillation_loss(student_logits, teacher_logits, targets)
120 }
121 DistillationMethod::FitNet => {
122 self.fitnet_loss(student_logits, teacher_logits, targets)
123 }
124 DistillationMethod::Progressive => {
125 self.progressive_distillation_loss(student_logits, teacher_logits, targets)
126 }
127 }
128 }
129
130 fn standard_distillation_loss(
132 &self,
133 student_logits: &Tensor,
134 teacher_logits: &Tensor,
135 targets: &Tensor,
136 ) -> Result<Tensor, String> {
137 let student_soft = self.temperature_softmax(student_logits)?;
139 let teacher_soft = self.temperature_softmax(teacher_logits)?;
140
141 let kl_loss = self.kl_divergence(&student_soft, &teacher_soft)?;
143
144 let student_loss = self.cross_entropy(student_logits, targets)?;
146
147 let distill_loss = kl_loss.mul_scalar(self.config.alpha * self.config.temperature * self.config.temperature);
149 let student_loss = student_loss.mul_scalar(self.config.beta);
150
151 distill_loss.add(&student_loss)
152 .map_err(|e| format!("Failed to combine losses: {:?}", e))
153 }
154
155 fn feature_distillation_loss(
157 &self,
158 student_logits: &Tensor,
159 teacher_logits: &Tensor,
160 targets: &Tensor,
161 ) -> Result<Tensor, String> {
162 let mut total_loss = self.standard_distillation_loss(student_logits, teacher_logits, targets)?;
164
165 for &layer_idx in &self.config.feature_layers {
167 let layer_name = format!("layer_{}", layer_idx);
168
169 if let (Some(student_feat), Some(teacher_feat)) = (
170 self.student_outputs.get(&layer_name),
171 self.teacher_outputs.get(&layer_name),
172 ) {
173 let feature_loss = self.feature_matching_loss(student_feat, teacher_feat)?;
174 total_loss = total_loss.add(&feature_loss.mul_scalar(0.1))
175 .map_err(|e| format!("Failed to add feature loss: {:?}", e))?;
176 }
177 }
178
179 Ok(total_loss)
180 }
181
182 fn attention_distillation_loss(
184 &self,
185 student_logits: &Tensor,
186 teacher_logits: &Tensor,
187 targets: &Tensor,
188 ) -> Result<Tensor, String> {
189 let mut total_loss = self.standard_distillation_loss(student_logits, teacher_logits, targets)?;
191
192 if let (Some(student_attn), Some(teacher_attn)) = (
194 self.student_outputs.get("attention"),
195 self.teacher_outputs.get("attention"),
196 ) {
197 let attention_loss = self.attention_transfer_loss(student_attn, teacher_attn)?;
198 total_loss = total_loss.add(&attention_loss.mul_scalar(0.1))
199 .map_err(|e| format!("Failed to add attention loss: {:?}", e))?;
200 }
201
202 Ok(total_loss)
203 }
204
205 fn fitnet_loss(
207 &self,
208 student_logits: &Tensor,
209 teacher_logits: &Tensor,
210 targets: &Tensor,
211 ) -> Result<Tensor, String> {
212 let student_loss = self.cross_entropy(student_logits, targets)?;
214
215 let mut total_loss = student_loss;
217
218 for &layer_idx in &self.config.feature_layers {
219 let layer_name = format!("layer_{}", layer_idx);
220
221 if let (Some(student_feat), Some(teacher_feat)) = (
222 self.student_outputs.get(&layer_name),
223 self.teacher_outputs.get(&layer_name),
224 ) {
225 let hint_loss = self.hint_loss(student_feat, teacher_feat)?;
226 total_loss = total_loss.add(&hint_loss.mul_scalar(0.5))
227 .map_err(|e| format!("Failed to add hint loss: {:?}", e))?;
228 }
229 }
230
231 Ok(total_loss)
232 }
233
234 fn progressive_distillation_loss(
236 &self,
237 student_logits: &Tensor,
238 teacher_logits: &Tensor,
239 targets: &Tensor,
240 ) -> Result<Tensor, String> {
241 let base_loss = self.standard_distillation_loss(student_logits, teacher_logits, targets)?;
243
244 let mut total_loss = base_loss;
246 let num_layers = self.config.feature_layers.len();
247
248 for (i, &layer_idx) in self.config.feature_layers.iter().enumerate() {
249 let layer_name = format!("layer_{}", layer_idx);
250 let weight = (i + 1) as f32 / num_layers as f32; if let (Some(student_feat), Some(teacher_feat)) = (
253 self.student_outputs.get(&layer_name),
254 self.teacher_outputs.get(&layer_name),
255 ) {
256 let layer_loss = self.feature_matching_loss(student_feat, teacher_feat)?;
257 total_loss = total_loss.add(&layer_loss.mul_scalar(weight * 0.1))
258 .map_err(|e| format!("Failed to add progressive loss: {:?}", e))?;
259 }
260 }
261
262 Ok(total_loss)
263 }
264
265 fn temperature_softmax(&self, logits: &Tensor) -> Result<Tensor, String> {
267 let scaled = logits.div_scalar(self.config.temperature);
268 Ok(scaled.softmax(-1))
269 }
270
271 fn kl_divergence(&self, student: &Tensor, teacher: &Tensor) -> Result<Tensor, String> {
273 let student_data = student.data_f32();
274 let teacher_data = teacher.data_f32();
275
276 if student_data.len() != teacher_data.len() {
277 return Err("Student and teacher tensors must have same size".to_string());
278 }
279
280 let mut kl_sum = 0.0;
281 let eps = 1e-8;
282
283 for i in 0..student_data.len() {
284 let p = teacher_data[i].max(eps);
285 let q = student_data[i].max(eps);
286 kl_sum += p * (p / q).ln();
287 }
288
289 Tensor::from_slice(&[kl_sum / student_data.len() as f32], &[1])
290 .map_err(|e| format!("Failed to create KL loss: {:?}", e))
291 }
292
293 fn cross_entropy(&self, logits: &Tensor, _targets: &Tensor) -> Result<Tensor, String> {
295 let probs = logits.softmax(-1);
296 let _log_probs = probs.log();
297
298 let loss_val = 1.0; Tensor::from_slice(&[loss_val], &[1])
301 .map_err(|e| format!("Failed to create CE loss: {:?}", e))
302 }
303
304 fn feature_matching_loss(&self, student: &Tensor, teacher: &Tensor) -> Result<Tensor, String> {
306 let student_data = student.data_f32();
307 let teacher_data = teacher.data_f32();
308
309 if student_data.len() != teacher_data.len() {
310 return Err("Feature tensors must have same size".to_string());
311 }
312
313 let mut mse_sum = 0.0;
314 for i in 0..student_data.len() {
315 let diff = student_data[i] - teacher_data[i];
316 mse_sum += diff * diff;
317 }
318
319 Tensor::from_slice(&[mse_sum / student_data.len() as f32], &[1])
320 .map_err(|e| format!("Failed to create feature loss: {:?}", e))
321 }
322
323 fn attention_transfer_loss(&self, student_attn: &Tensor, teacher_attn: &Tensor) -> Result<Tensor, String> {
325 let student_norm = self.normalize_attention(student_attn)?;
327 let teacher_norm = self.normalize_attention(teacher_attn)?;
328
329 self.feature_matching_loss(&student_norm, &teacher_norm)
331 }
332
333 fn hint_loss(&self, student_feat: &Tensor, teacher_feat: &Tensor) -> Result<Tensor, String> {
335 self.feature_matching_loss(student_feat, teacher_feat)
337 }
338
339 fn normalize_attention(&self, attention: &Tensor) -> Result<Tensor, String> {
341 let data = attention.data_f32();
342 let dims = attention.dims();
343
344 let sum: f32 = data.iter().sum();
346 let normalized: Vec<f32> = data.iter().map(|&x| x / sum).collect();
347
348 Tensor::from_slice(&normalized, dims)
349 .map_err(|e| format!("Failed to normalize attention: {:?}", e))
350 }
351
352 pub fn store_teacher_output(&mut self, layer_name: String, output: Tensor) {
354 self.teacher_outputs.insert(layer_name, output);
355 }
356
357 pub fn store_student_output(&mut self, layer_name: String, output: Tensor) {
359 self.student_outputs.insert(layer_name, output);
360 }
361
362 pub fn clear_outputs(&mut self) {
364 self.teacher_outputs.clear();
365 self.student_outputs.clear();
366 }
367
368 pub fn get_stats(&self) -> DistillationStats {
370 DistillationStats {
371 temperature: self.config.temperature,
372 alpha: self.config.alpha,
373 beta: self.config.beta,
374 method: self.config.method,
375 num_feature_layers: self.config.feature_layers.len(),
376 }
377 }
378}
379
380#[derive(Debug, Clone)]
382pub struct DistillationStats {
383 pub temperature: f32,
384 pub alpha: f32,
385 pub beta: f32,
386 pub method: DistillationMethod,
387 pub num_feature_layers: usize,
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393
394 #[test]
395 fn test_distillation_config() {
396 let config = DistillationConfig::default();
397 assert_eq!(config.temperature, 4.0);
398 assert_eq!(config.method, DistillationMethod::Standard);
399
400 let standard = DistillationConfig::standard(3.0, 0.8);
401 assert_eq!(standard.temperature, 3.0);
402 assert_eq!(standard.alpha, 0.8);
403 assert!((standard.beta - 0.2).abs() < 1e-6);
404 }
405
406 #[test]
407 #[ignore] fn test_knowledge_distillation() {
409 let config = DistillationConfig::default();
410 let kd = KnowledgeDistillation::new(config);
411
412 let student_logits = Tensor::randn(&[4, 10]);
413 let teacher_logits = Tensor::randn(&[4, 10]);
414 let targets = Tensor::from_slice(&[0.0f32, 1.0, 2.0, 3.0], &[4]).unwrap();
415
416 let loss = kd.compute_loss(&student_logits, &teacher_logits, &targets).unwrap();
417 assert_eq!(loss.dims(), &[1]);
418 }
419
420 #[test]
421 fn test_temperature_softmax() {
422 let config = DistillationConfig::default();
423 let kd = KnowledgeDistillation::new(config);
424
425 let logits = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[1, 3]).unwrap();
426 let soft = kd.temperature_softmax(&logits).unwrap();
427
428 assert_eq!(soft.dims(), &[1, 3]);
429
430 let data = soft.data_f32();
432 let sum: f32 = data.iter().sum();
433 assert!((sum - 1.0).abs() < 1e-5);
434 }
435
436 #[test]
437 fn test_kl_divergence() {
438 let config = DistillationConfig::default();
439 let kd = KnowledgeDistillation::new(config);
440
441 let p = Tensor::from_slice(&[0.5f32, 0.3, 0.2], &[3]).unwrap();
442 let q = Tensor::from_slice(&[0.4f32, 0.4, 0.2], &[3]).unwrap();
443
444 let kl = kd.kl_divergence(&q, &p).unwrap();
445 assert_eq!(kl.dims(), &[1]);
446 assert!(kl.data_f32()[0] >= 0.0); }
448
449 #[test]
450 fn test_feature_matching_loss() {
451 let config = DistillationConfig::default();
452 let kd = KnowledgeDistillation::new(config);
453
454 let student = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
455 let teacher = Tensor::from_slice(&[1.1f32, 2.1, 2.9], &[3]).unwrap();
456
457 let loss = kd.feature_matching_loss(&student, &teacher).unwrap();
458 assert_eq!(loss.dims(), &[1]);
459 assert!(loss.data_f32()[0] >= 0.0); }
461}