1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct DistillationCheckpoint {
13 pub teacher_model: String,
15 pub temperature: f32,
17 pub alpha: f32,
19 pub final_loss: Option<f32>,
21 pub epoch: usize,
23 pub step: usize,
25}
26
27#[allow(clippy::implicit_hasher)]
35pub fn save_student_checkpoint(
36 weights: &HashMap<String, Vec<f32>>,
37 shapes: &HashMap<String, Vec<usize>>,
38 checkpoint: &DistillationCheckpoint,
39 output_dir: impl AsRef<Path>,
40 filename: &str,
41) -> Result<PathBuf, std::io::Error> {
42 let output_dir = output_dir.as_ref();
43 std::fs::create_dir_all(output_dir)?;
44
45 use safetensors::tensor::{Dtype, TensorView};
47
48 let mut sorted_names: Vec<&String> = weights.keys().collect();
49 sorted_names.sort();
50
51 let tensor_data: Vec<(String, Vec<u8>, Vec<usize>)> = sorted_names
52 .iter()
53 .map(|name| {
54 let data = &weights[*name];
55 let bytes: Vec<u8> = bytemuck::cast_slice(data).to_vec();
56 let shape = shapes.get(*name).cloned().unwrap_or_else(|| vec![data.len()]);
57 ((*name).clone(), bytes, shape)
58 })
59 .collect();
60
61 let views: Vec<(&str, TensorView<'_>)> = tensor_data
62 .iter()
63 .map(|(name, bytes, shape)| {
64 let view = TensorView::new(Dtype::F32, shape.clone(), bytes)
65 .expect("TensorView construction must not fail for valid F32 data");
66 (name.as_str(), view)
67 })
68 .collect();
69
70 let mut metadata = HashMap::new();
71 metadata.insert("teacher_model".to_string(), checkpoint.teacher_model.clone());
72 metadata.insert("temperature".to_string(), format!("{}", checkpoint.temperature));
73 metadata.insert("alpha".to_string(), format!("{}", checkpoint.alpha));
74 metadata.insert("epoch".to_string(), format!("{}", checkpoint.epoch));
75 metadata.insert("step".to_string(), format!("{}", checkpoint.step));
76 if let Some(loss) = checkpoint.final_loss {
77 metadata.insert("final_loss".to_string(), format!("{loss}"));
78 }
79
80 let safetensor_bytes = safetensors::serialize(views, Some(metadata))
81 .map_err(|e| std::io::Error::other(e.to_string()))?;
82
83 let weights_path = output_dir.join(filename);
84 std::fs::write(&weights_path, safetensor_bytes)?;
85
86 let metadata_json = serde_json::to_string_pretty(checkpoint)
88 .map_err(|e| std::io::Error::other(e.to_string()))?;
89 std::fs::write(output_dir.join("distillation_metadata.json"), metadata_json)?;
90
91 Ok(weights_path)
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97 use tempfile::TempDir;
98
99 fn make_test_data(
100 ) -> (HashMap<String, Vec<f32>>, HashMap<String, Vec<usize>>, DistillationCheckpoint) {
101 let mut weights = HashMap::new();
102 let mut shapes = HashMap::new();
103
104 weights.insert("student.layer.0.weight".to_string(), vec![1.0; 64]);
105 shapes.insert("student.layer.0.weight".to_string(), vec![8, 8]);
106 weights.insert("student.layer.0.bias".to_string(), vec![0.1; 8]);
107 shapes.insert("student.layer.0.bias".to_string(), vec![8]);
108
109 let checkpoint = DistillationCheckpoint {
110 teacher_model: "bert-base-uncased".to_string(),
111 temperature: 3.0,
112 alpha: 0.5,
113 final_loss: Some(1.23),
114 epoch: 5,
115 step: 10000,
116 };
117
118 (weights, shapes, checkpoint)
119 }
120
121 #[test]
122 fn test_save_checkpoint_creates_files() {
123 let (weights, shapes, checkpoint) = make_test_data();
124 let tmp = TempDir::new().expect("temp file creation should succeed");
125
126 let path = save_student_checkpoint(
127 &weights,
128 &shapes,
129 &checkpoint,
130 tmp.path(),
131 "student.safetensors",
132 )
133 .expect("operation should succeed");
134
135 assert!(path.exists());
136 assert!(tmp.path().join("distillation_metadata.json").exists());
137 }
138
139 #[test]
140 fn test_save_checkpoint_safetensors_valid() {
141 let (weights, shapes, checkpoint) = make_test_data();
142 let tmp = TempDir::new().expect("temp file creation should succeed");
143
144 let path = save_student_checkpoint(
145 &weights,
146 &shapes,
147 &checkpoint,
148 tmp.path(),
149 "student.safetensors",
150 )
151 .expect("operation should succeed");
152
153 let data = std::fs::read(&path).expect("file read should succeed");
154 let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
155 assert_eq!(loaded.len(), 2);
156
157 let names = loaded.names();
158 assert!(names.contains(&"student.layer.0.weight"));
159 assert!(names.contains(&"student.layer.0.bias"));
160 }
161
162 #[test]
163 fn test_save_checkpoint_metadata_in_safetensors() {
164 let (weights, shapes, checkpoint) = make_test_data();
165 let tmp = TempDir::new().expect("temp file creation should succeed");
166
167 let path = save_student_checkpoint(
168 &weights,
169 &shapes,
170 &checkpoint,
171 tmp.path(),
172 "student.safetensors",
173 )
174 .expect("operation should succeed");
175
176 let data = std::fs::read(&path).expect("file read should succeed");
177 let (_, st_meta) =
178 safetensors::SafeTensors::read_metadata(&data).expect("deserialization should succeed");
179 let meta = st_meta.metadata().as_ref().expect("operation should succeed");
180
181 assert_eq!(meta.get("teacher_model").expect("key should exist"), "bert-base-uncased");
182 assert_eq!(meta.get("temperature").expect("key should exist"), "3");
183 assert_eq!(meta.get("alpha").expect("key should exist"), "0.5");
184 assert_eq!(meta.get("epoch").expect("key should exist"), "5");
185 }
186
187 #[test]
188 fn test_save_checkpoint_distillation_metadata() {
189 let (weights, shapes, checkpoint) = make_test_data();
190 let tmp = TempDir::new().expect("temp file creation should succeed");
191
192 save_student_checkpoint(&weights, &shapes, &checkpoint, tmp.path(), "student.safetensors")
193 .expect("operation should succeed");
194
195 let json = std::fs::read_to_string(tmp.path().join("distillation_metadata.json"))
196 .expect("file read should succeed");
197 let loaded: DistillationCheckpoint =
198 serde_json::from_str(&json).expect("JSON deserialization should succeed");
199
200 assert_eq!(loaded.teacher_model, "bert-base-uncased");
201 assert_eq!(loaded.temperature, 3.0);
202 assert_eq!(loaded.alpha, 0.5);
203 assert_eq!(loaded.final_loss, Some(1.23));
204 assert_eq!(loaded.epoch, 5);
205 assert_eq!(loaded.step, 10000);
206 }
207
208 #[test]
209 fn test_save_checkpoint_no_loss() {
210 let mut weights = HashMap::new();
211 let mut shapes = HashMap::new();
212 weights.insert("w".to_string(), vec![1.0; 4]);
213 shapes.insert("w".to_string(), vec![2, 2]);
214
215 let checkpoint = DistillationCheckpoint {
216 teacher_model: "gpt2".to_string(),
217 temperature: 2.0,
218 alpha: 0.7,
219 final_loss: None,
220 epoch: 0,
221 step: 0,
222 };
223
224 let tmp = TempDir::new().expect("temp file creation should succeed");
225 let path =
226 save_student_checkpoint(&weights, &shapes, &checkpoint, tmp.path(), "ckpt.safetensors")
227 .expect("operation should succeed");
228 assert!(path.exists());
229 }
230}