Skip to main content

entrenar/hf_pipeline/config/
distillation.rs

1//! Distillation loss configuration
2
3use serde::{Deserialize, Serialize};
4
5/// Distillation loss configuration
6#[derive(Debug, Clone, Serialize, Deserialize)]
7#[serde(default)]
8pub struct DistillationConfig {
9    /// Temperature for softening distributions
10    pub temperature: f32,
11    /// Alpha weight for soft vs hard loss
12    pub alpha: f32,
13    /// Progressive distillation config
14    pub progressive: Option<ProgressiveConfig>,
15    /// Attention transfer config
16    pub attention_transfer: Option<AttentionTransferConfig>,
17}
18
19impl Default for DistillationConfig {
20    fn default() -> Self {
21        Self { temperature: 4.0, alpha: 0.7, progressive: None, attention_transfer: None }
22    }
23}
24
25/// Progressive distillation configuration
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ProgressiveConfig {
28    /// Layer mapping [[student_layer, teacher_layer], ...]
29    pub layer_mapping: Vec<[usize; 2]>,
30    /// Weight for hidden state loss
31    #[serde(default = "default_hidden_weight")]
32    pub hidden_weight: f32,
33}
34
35fn default_hidden_weight() -> f32 {
36    1.0
37}
38
39/// Attention transfer configuration
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct AttentionTransferConfig {
42    /// Weight for attention transfer loss
43    #[serde(default = "default_attention_weight")]
44    pub weight: f32,
45}
46
47fn default_attention_weight() -> f32 {
48    0.1
49}
50
51#[cfg(test)]
52mod tests {
53    use super::*;
54
55    #[test]
56    fn test_distillation_config_default() {
57        let config = DistillationConfig::default();
58        assert!((config.temperature - 4.0).abs() < 1e-6);
59        assert!((config.alpha - 0.7).abs() < 1e-6);
60        assert!(config.progressive.is_none());
61        assert!(config.attention_transfer.is_none());
62    }
63
64    #[test]
65    fn test_distillation_config_custom() {
66        let config = DistillationConfig {
67            temperature: 2.0,
68            alpha: 0.5,
69            progressive: Some(ProgressiveConfig {
70                layer_mapping: vec![[0, 0], [1, 2], [2, 4]],
71                hidden_weight: 0.5,
72            }),
73            attention_transfer: Some(AttentionTransferConfig { weight: 0.2 }),
74        };
75
76        assert!((config.temperature - 2.0).abs() < 1e-6);
77        assert!((config.alpha - 0.5).abs() < 1e-6);
78        assert!(config.progressive.is_some());
79        assert!(config.attention_transfer.is_some());
80    }
81
82    #[test]
83    fn test_progressive_config_layer_mapping() {
84        let config =
85            ProgressiveConfig { layer_mapping: vec![[0, 0], [1, 2], [2, 4]], hidden_weight: 1.0 };
86
87        assert_eq!(config.layer_mapping.len(), 3);
88        assert_eq!(config.layer_mapping[0], [0, 0]);
89        assert_eq!(config.layer_mapping[1], [1, 2]);
90    }
91
92    #[test]
93    fn test_default_hidden_weight() {
94        assert!((default_hidden_weight() - 1.0).abs() < 1e-6);
95    }
96
97    #[test]
98    fn test_default_attention_weight() {
99        assert!((default_attention_weight() - 0.1).abs() < 1e-6);
100    }
101
102    #[test]
103    fn test_distillation_config_serde() {
104        let config = DistillationConfig {
105            temperature: 3.0,
106            alpha: 0.6,
107            progressive: None,
108            attention_transfer: None,
109        };
110
111        let json = serde_json::to_string(&config).expect("JSON serialization should succeed");
112        let deserialized: DistillationConfig =
113            serde_json::from_str(&json).expect("JSON deserialization should succeed");
114        assert!((config.temperature - deserialized.temperature).abs() < 1e-6);
115        assert!((config.alpha - deserialized.alpha).abs() < 1e-6);
116    }
117
118    #[test]
119    fn test_distillation_config_serde_with_optional() {
120        let config = DistillationConfig {
121            temperature: 3.0,
122            alpha: 0.6,
123            progressive: Some(ProgressiveConfig {
124                layer_mapping: vec![[0, 1]],
125                hidden_weight: 0.8,
126            }),
127            attention_transfer: Some(AttentionTransferConfig { weight: 0.15 }),
128        };
129
130        let json = serde_json::to_string(&config).expect("JSON serialization should succeed");
131        let deserialized: DistillationConfig =
132            serde_json::from_str(&json).expect("JSON deserialization should succeed");
133
134        assert!(deserialized.progressive.is_some());
135        let prog = deserialized.progressive.expect("deserialization should succeed");
136        assert_eq!(prog.layer_mapping.len(), 1);
137        assert!((prog.hidden_weight - 0.8).abs() < 1e-6);
138
139        assert!(deserialized.attention_transfer.is_some());
140        let attn = deserialized.attention_transfer.expect("deserialization should succeed");
141        assert!((attn.weight - 0.15).abs() < 1e-6);
142    }
143
144    #[test]
145    fn test_distillation_config_from_partial_json() {
146        // Test that defaults are used when fields are missing
147        let json = r#"{"temperature": 5.0}"#;
148        let config: DistillationConfig =
149            serde_json::from_str(json).expect("JSON deserialization should succeed");
150        assert!((config.temperature - 5.0).abs() < 1e-6);
151        assert!((config.alpha - 0.7).abs() < 1e-6); // default
152    }
153
154    #[test]
155    fn test_progressive_config_serde_default_weight() {
156        // Test that hidden_weight defaults to 1.0 when not specified
157        let json = r#"{"layer_mapping": [[0, 0]]}"#;
158        let config: ProgressiveConfig =
159            serde_json::from_str(json).expect("JSON deserialization should succeed");
160        assert!((config.hidden_weight - 1.0).abs() < 1e-6);
161    }
162
163    #[test]
164    fn test_attention_transfer_config_serde_default_weight() {
165        // Test that weight defaults to 0.1 when not specified
166        let json = r"{}";
167        let config: AttentionTransferConfig =
168            serde_json::from_str(json).expect("JSON deserialization should succeed");
169        assert!((config.weight - 0.1).abs() < 1e-6);
170    }
171
172    #[test]
173    fn test_distillation_config_debug() {
174        let config = DistillationConfig::default();
175        let debug_str = format!("{config:?}");
176        assert!(debug_str.contains("DistillationConfig"));
177        assert!(debug_str.contains("temperature"));
178    }
179
180    #[test]
181    fn test_progressive_config_clone() {
182        let config = ProgressiveConfig { layer_mapping: vec![[0, 0], [1, 2]], hidden_weight: 0.5 };
183        let cloned = config.clone();
184        assert_eq!(config.layer_mapping, cloned.layer_mapping);
185        assert!((config.hidden_weight - cloned.hidden_weight).abs() < 1e-6);
186    }
187
188    #[test]
189    fn test_attention_transfer_config_clone() {
190        let config = AttentionTransferConfig { weight: 0.2 };
191        let cloned = config.clone();
192        assert!((config.weight - cloned.weight).abs() < 1e-6);
193    }
194}