Skip to main content

entrenar/distill/
checkpoint.rs

1//! Student model checkpoint saving for knowledge distillation
2//!
3//! Saves student model weights along with distillation metadata including
4//! teacher model name, temperature, alpha, and loss.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9
10/// Distillation checkpoint metadata
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct DistillationCheckpoint {
13    /// Teacher model name or path
14    pub teacher_model: String,
15    /// Distillation temperature
16    pub temperature: f32,
17    /// KD loss weight (alpha)
18    pub alpha: f32,
19    /// Final distillation loss
20    pub final_loss: Option<f32>,
21    /// Training epoch at checkpoint
22    pub epoch: usize,
23    /// Training step at checkpoint
24    pub step: usize,
25}
26
27/// Save a student model checkpoint with distillation metadata
28///
29/// Creates:
30/// - Weight file (SafeTensors format)
31/// - `distillation_metadata.json` sidecar with teacher info, temperature, alpha, loss
32///
33/// Returns the path to the weight file.
34#[allow(clippy::implicit_hasher)]
35pub fn save_student_checkpoint(
36    weights: &HashMap<String, Vec<f32>>,
37    shapes: &HashMap<String, Vec<usize>>,
38    checkpoint: &DistillationCheckpoint,
39    output_dir: impl AsRef<Path>,
40    filename: &str,
41) -> Result<PathBuf, std::io::Error> {
42    let output_dir = output_dir.as_ref();
43    std::fs::create_dir_all(output_dir)?;
44
45    // Save weights as SafeTensors
46    use safetensors::tensor::{Dtype, TensorView};
47
48    let mut sorted_names: Vec<&String> = weights.keys().collect();
49    sorted_names.sort();
50
51    let tensor_data: Vec<(String, Vec<u8>, Vec<usize>)> = sorted_names
52        .iter()
53        .map(|name| {
54            let data = &weights[*name];
55            let bytes: Vec<u8> = bytemuck::cast_slice(data).to_vec();
56            let shape = shapes.get(*name).cloned().unwrap_or_else(|| vec![data.len()]);
57            ((*name).clone(), bytes, shape)
58        })
59        .collect();
60
61    let views: Vec<(&str, TensorView<'_>)> = tensor_data
62        .iter()
63        .map(|(name, bytes, shape)| {
64            let view = TensorView::new(Dtype::F32, shape.clone(), bytes)
65                .expect("TensorView construction must not fail for valid F32 data");
66            (name.as_str(), view)
67        })
68        .collect();
69
70    let mut metadata = HashMap::new();
71    metadata.insert("teacher_model".to_string(), checkpoint.teacher_model.clone());
72    metadata.insert("temperature".to_string(), format!("{}", checkpoint.temperature));
73    metadata.insert("alpha".to_string(), format!("{}", checkpoint.alpha));
74    metadata.insert("epoch".to_string(), format!("{}", checkpoint.epoch));
75    metadata.insert("step".to_string(), format!("{}", checkpoint.step));
76    if let Some(loss) = checkpoint.final_loss {
77        metadata.insert("final_loss".to_string(), format!("{loss}"));
78    }
79
80    let safetensor_bytes = safetensors::serialize(views, Some(metadata))
81        .map_err(|e| std::io::Error::other(e.to_string()))?;
82
83    let weights_path = output_dir.join(filename);
84    std::fs::write(&weights_path, safetensor_bytes)?;
85
86    // Save distillation metadata sidecar
87    let metadata_json = serde_json::to_string_pretty(checkpoint)
88        .map_err(|e| std::io::Error::other(e.to_string()))?;
89    std::fs::write(output_dir.join("distillation_metadata.json"), metadata_json)?;
90
91    Ok(weights_path)
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97    use tempfile::TempDir;
98
99    fn make_test_data(
100    ) -> (HashMap<String, Vec<f32>>, HashMap<String, Vec<usize>>, DistillationCheckpoint) {
101        let mut weights = HashMap::new();
102        let mut shapes = HashMap::new();
103
104        weights.insert("student.layer.0.weight".to_string(), vec![1.0; 64]);
105        shapes.insert("student.layer.0.weight".to_string(), vec![8, 8]);
106        weights.insert("student.layer.0.bias".to_string(), vec![0.1; 8]);
107        shapes.insert("student.layer.0.bias".to_string(), vec![8]);
108
109        let checkpoint = DistillationCheckpoint {
110            teacher_model: "bert-base-uncased".to_string(),
111            temperature: 3.0,
112            alpha: 0.5,
113            final_loss: Some(1.23),
114            epoch: 5,
115            step: 10000,
116        };
117
118        (weights, shapes, checkpoint)
119    }
120
121    #[test]
122    fn test_save_checkpoint_creates_files() {
123        let (weights, shapes, checkpoint) = make_test_data();
124        let tmp = TempDir::new().expect("temp file creation should succeed");
125
126        let path = save_student_checkpoint(
127            &weights,
128            &shapes,
129            &checkpoint,
130            tmp.path(),
131            "student.safetensors",
132        )
133        .expect("operation should succeed");
134
135        assert!(path.exists());
136        assert!(tmp.path().join("distillation_metadata.json").exists());
137    }
138
139    #[test]
140    fn test_save_checkpoint_safetensors_valid() {
141        let (weights, shapes, checkpoint) = make_test_data();
142        let tmp = TempDir::new().expect("temp file creation should succeed");
143
144        let path = save_student_checkpoint(
145            &weights,
146            &shapes,
147            &checkpoint,
148            tmp.path(),
149            "student.safetensors",
150        )
151        .expect("operation should succeed");
152
153        let data = std::fs::read(&path).expect("file read should succeed");
154        let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
155        assert_eq!(loaded.len(), 2);
156
157        let names = loaded.names();
158        assert!(names.contains(&"student.layer.0.weight"));
159        assert!(names.contains(&"student.layer.0.bias"));
160    }
161
162    #[test]
163    fn test_save_checkpoint_metadata_in_safetensors() {
164        let (weights, shapes, checkpoint) = make_test_data();
165        let tmp = TempDir::new().expect("temp file creation should succeed");
166
167        let path = save_student_checkpoint(
168            &weights,
169            &shapes,
170            &checkpoint,
171            tmp.path(),
172            "student.safetensors",
173        )
174        .expect("operation should succeed");
175
176        let data = std::fs::read(&path).expect("file read should succeed");
177        let (_, st_meta) =
178            safetensors::SafeTensors::read_metadata(&data).expect("deserialization should succeed");
179        let meta = st_meta.metadata().as_ref().expect("operation should succeed");
180
181        assert_eq!(meta.get("teacher_model").expect("key should exist"), "bert-base-uncased");
182        assert_eq!(meta.get("temperature").expect("key should exist"), "3");
183        assert_eq!(meta.get("alpha").expect("key should exist"), "0.5");
184        assert_eq!(meta.get("epoch").expect("key should exist"), "5");
185    }
186
187    #[test]
188    fn test_save_checkpoint_distillation_metadata() {
189        let (weights, shapes, checkpoint) = make_test_data();
190        let tmp = TempDir::new().expect("temp file creation should succeed");
191
192        save_student_checkpoint(&weights, &shapes, &checkpoint, tmp.path(), "student.safetensors")
193            .expect("operation should succeed");
194
195        let json = std::fs::read_to_string(tmp.path().join("distillation_metadata.json"))
196            .expect("file read should succeed");
197        let loaded: DistillationCheckpoint =
198            serde_json::from_str(&json).expect("JSON deserialization should succeed");
199
200        assert_eq!(loaded.teacher_model, "bert-base-uncased");
201        assert_eq!(loaded.temperature, 3.0);
202        assert_eq!(loaded.alpha, 0.5);
203        assert_eq!(loaded.final_loss, Some(1.23));
204        assert_eq!(loaded.epoch, 5);
205        assert_eq!(loaded.step, 10000);
206    }
207
208    #[test]
209    fn test_save_checkpoint_no_loss() {
210        let mut weights = HashMap::new();
211        let mut shapes = HashMap::new();
212        weights.insert("w".to_string(), vec![1.0; 4]);
213        shapes.insert("w".to_string(), vec![2, 2]);
214
215        let checkpoint = DistillationCheckpoint {
216            teacher_model: "gpt2".to_string(),
217            temperature: 2.0,
218            alpha: 0.7,
219            final_loss: None,
220            epoch: 0,
221            step: 0,
222        };
223
224        let tmp = TempDir::new().expect("temp file creation should succeed");
225        let path =
226            save_student_checkpoint(&weights, &shapes, &checkpoint, tmp.path(), "ckpt.safetensors")
227                .expect("operation should succeed");
228        assert!(path.exists());
229    }
230}