entrenar/hf_pipeline/config/
distillation.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7#[serde(default)]
8pub struct DistillationConfig {
9 pub temperature: f32,
11 pub alpha: f32,
13 pub progressive: Option<ProgressiveConfig>,
15 pub attention_transfer: Option<AttentionTransferConfig>,
17}
18
19impl Default for DistillationConfig {
20 fn default() -> Self {
21 Self { temperature: 4.0, alpha: 0.7, progressive: None, attention_transfer: None }
22 }
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ProgressiveConfig {
28 pub layer_mapping: Vec<[usize; 2]>,
30 #[serde(default = "default_hidden_weight")]
32 pub hidden_weight: f32,
33}
34
35fn default_hidden_weight() -> f32 {
36 1.0
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct AttentionTransferConfig {
42 #[serde(default = "default_attention_weight")]
44 pub weight: f32,
45}
46
47fn default_attention_weight() -> f32 {
48 0.1
49}
50
51#[cfg(test)]
52mod tests {
53 use super::*;
54
55 #[test]
56 fn test_distillation_config_default() {
57 let config = DistillationConfig::default();
58 assert!((config.temperature - 4.0).abs() < 1e-6);
59 assert!((config.alpha - 0.7).abs() < 1e-6);
60 assert!(config.progressive.is_none());
61 assert!(config.attention_transfer.is_none());
62 }
63
64 #[test]
65 fn test_distillation_config_custom() {
66 let config = DistillationConfig {
67 temperature: 2.0,
68 alpha: 0.5,
69 progressive: Some(ProgressiveConfig {
70 layer_mapping: vec![[0, 0], [1, 2], [2, 4]],
71 hidden_weight: 0.5,
72 }),
73 attention_transfer: Some(AttentionTransferConfig { weight: 0.2 }),
74 };
75
76 assert!((config.temperature - 2.0).abs() < 1e-6);
77 assert!((config.alpha - 0.5).abs() < 1e-6);
78 assert!(config.progressive.is_some());
79 assert!(config.attention_transfer.is_some());
80 }
81
82 #[test]
83 fn test_progressive_config_layer_mapping() {
84 let config =
85 ProgressiveConfig { layer_mapping: vec![[0, 0], [1, 2], [2, 4]], hidden_weight: 1.0 };
86
87 assert_eq!(config.layer_mapping.len(), 3);
88 assert_eq!(config.layer_mapping[0], [0, 0]);
89 assert_eq!(config.layer_mapping[1], [1, 2]);
90 }
91
92 #[test]
93 fn test_default_hidden_weight() {
94 assert!((default_hidden_weight() - 1.0).abs() < 1e-6);
95 }
96
97 #[test]
98 fn test_default_attention_weight() {
99 assert!((default_attention_weight() - 0.1).abs() < 1e-6);
100 }
101
102 #[test]
103 fn test_distillation_config_serde() {
104 let config = DistillationConfig {
105 temperature: 3.0,
106 alpha: 0.6,
107 progressive: None,
108 attention_transfer: None,
109 };
110
111 let json = serde_json::to_string(&config).expect("JSON serialization should succeed");
112 let deserialized: DistillationConfig =
113 serde_json::from_str(&json).expect("JSON deserialization should succeed");
114 assert!((config.temperature - deserialized.temperature).abs() < 1e-6);
115 assert!((config.alpha - deserialized.alpha).abs() < 1e-6);
116 }
117
118 #[test]
119 fn test_distillation_config_serde_with_optional() {
120 let config = DistillationConfig {
121 temperature: 3.0,
122 alpha: 0.6,
123 progressive: Some(ProgressiveConfig {
124 layer_mapping: vec![[0, 1]],
125 hidden_weight: 0.8,
126 }),
127 attention_transfer: Some(AttentionTransferConfig { weight: 0.15 }),
128 };
129
130 let json = serde_json::to_string(&config).expect("JSON serialization should succeed");
131 let deserialized: DistillationConfig =
132 serde_json::from_str(&json).expect("JSON deserialization should succeed");
133
134 assert!(deserialized.progressive.is_some());
135 let prog = deserialized.progressive.expect("deserialization should succeed");
136 assert_eq!(prog.layer_mapping.len(), 1);
137 assert!((prog.hidden_weight - 0.8).abs() < 1e-6);
138
139 assert!(deserialized.attention_transfer.is_some());
140 let attn = deserialized.attention_transfer.expect("deserialization should succeed");
141 assert!((attn.weight - 0.15).abs() < 1e-6);
142 }
143
144 #[test]
145 fn test_distillation_config_from_partial_json() {
146 let json = r#"{"temperature": 5.0}"#;
148 let config: DistillationConfig =
149 serde_json::from_str(json).expect("JSON deserialization should succeed");
150 assert!((config.temperature - 5.0).abs() < 1e-6);
151 assert!((config.alpha - 0.7).abs() < 1e-6); }
153
154 #[test]
155 fn test_progressive_config_serde_default_weight() {
156 let json = r#"{"layer_mapping": [[0, 0]]}"#;
158 let config: ProgressiveConfig =
159 serde_json::from_str(json).expect("JSON deserialization should succeed");
160 assert!((config.hidden_weight - 1.0).abs() < 1e-6);
161 }
162
163 #[test]
164 fn test_attention_transfer_config_serde_default_weight() {
165 let json = r"{}";
167 let config: AttentionTransferConfig =
168 serde_json::from_str(json).expect("JSON deserialization should succeed");
169 assert!((config.weight - 0.1).abs() < 1e-6);
170 }
171
172 #[test]
173 fn test_distillation_config_debug() {
174 let config = DistillationConfig::default();
175 let debug_str = format!("{config:?}");
176 assert!(debug_str.contains("DistillationConfig"));
177 assert!(debug_str.contains("temperature"));
178 }
179
180 #[test]
181 fn test_progressive_config_clone() {
182 let config = ProgressiveConfig { layer_mapping: vec![[0, 0], [1, 2]], hidden_weight: 0.5 };
183 let cloned = config.clone();
184 assert_eq!(config.layer_mapping, cloned.layer_mapping);
185 assert!((config.hidden_weight - cloned.hidden_weight).abs() < 1e-6);
186 }
187
188 #[test]
189 fn test_attention_transfer_config_clone() {
190 let config = AttentionTransferConfig { weight: 0.2 };
191 let cloned = config.clone();
192 assert!((config.weight - cloned.weight).abs() < 1e-6);
193 }
194}