Skip to main content

yscv_model/
checkpoint_state.rs

1//! Save and restore full training checkpoints (model weights + optimizer state).
2//!
3//! A training checkpoint bundles model parameters and optimizer state into a
4//! single binary file so training can be resumed exactly where it left off.
5
6use std::collections::HashMap;
7use std::path::Path;
8
9use yscv_tensor::Tensor;
10
11use super::weights::{load_weights, save_weights};
12use crate::ModelError;
13
14/// Prefix used to distinguish optimizer state tensors from model weights.
15const OPT_PREFIX: &str = "__opt__.";
16
17/// Save a full training checkpoint: model weights + optimizer state.
18///
19/// Both maps use string keys. Optimizer state keys are automatically prefixed
20/// to avoid collisions with model weight names.
21pub fn save_training_checkpoint(
22    path: &Path,
23    model_weights: &HashMap<String, Tensor>,
24    optimizer_state: &HashMap<String, Tensor>,
25) -> Result<(), ModelError> {
26    let mut combined = model_weights.clone();
27    for (key, tensor) in optimizer_state {
28        combined.insert(format!("{OPT_PREFIX}{key}"), tensor.clone());
29    }
30    save_weights(path, &combined)
31}
32
33/// Load a full training checkpoint, splitting model weights from optimizer state.
34///
35/// Returns `(model_weights, optimizer_state)`.
36pub fn load_training_checkpoint(
37    path: &Path,
38) -> Result<(HashMap<String, Tensor>, HashMap<String, Tensor>), ModelError> {
39    let all = load_weights(path)?;
40    let mut model_weights = HashMap::new();
41    let mut optimizer_state = HashMap::new();
42
43    for (key, tensor) in all {
44        if let Some(stripped) = key.strip_prefix(OPT_PREFIX) {
45            optimizer_state.insert(stripped.to_owned(), tensor);
46        } else {
47            model_weights.insert(key, tensor);
48        }
49    }
50
51    Ok((model_weights, optimizer_state))
52}
53
54/// Flatten SGD velocity buffers into a string-keyed map for serialization.
55///
56/// Keys: `"sgd.{param_id}.velocity"`
57pub fn sgd_state_to_map(velocity: &HashMap<u64, Tensor>) -> HashMap<String, Tensor> {
58    velocity
59        .iter()
60        .map(|(id, t)| (format!("sgd.{id}.velocity"), t.clone()))
61        .collect()
62}
63
64/// Restore SGD velocity buffers from a string-keyed map.
65pub fn sgd_state_from_map(map: &HashMap<String, Tensor>) -> HashMap<u64, Tensor> {
66    let mut velocity = HashMap::new();
67    for (key, tensor) in map {
68        if let Some(rest) = key.strip_prefix("sgd.")
69            && let Some(id_str) = rest.strip_suffix(".velocity")
70            && let Ok(id) = id_str.parse::<u64>()
71        {
72            velocity.insert(id, tensor.clone());
73        }
74    }
75    velocity
76}
77
78/// Flatten Adam/AdamW state into a string-keyed map for serialization.
79///
80/// Keys: `"adam.{param_id}.m"`, `"adam.{param_id}.v"`, `"adam.{param_id}.step"`
81pub fn adam_state_to_map(state: &[(u64, Tensor, Tensor, u64)]) -> HashMap<String, Tensor> {
82    let mut map = HashMap::new();
83    for (id, m, v, step) in state {
84        map.insert(format!("adam.{id}.m"), m.clone());
85        map.insert(format!("adam.{id}.v"), v.clone());
86        // Store step as a scalar tensor.
87        map.insert(
88            format!("adam.{id}.step"),
89            Tensor::from_vec(vec![1], vec![*step as f32]).expect("scalar shape matches data"),
90        );
91    }
92    map
93}
94
95/// Restore Adam/AdamW state from a string-keyed map.
96///
97/// Returns `Vec<(param_id, first_moment, second_moment, step)>`.
98pub fn adam_state_from_map(map: &HashMap<String, Tensor>) -> Vec<(u64, Tensor, Tensor, u64)> {
99    // Collect unique param IDs from "adam.{id}.m" keys.
100    let mut ids: Vec<u64> = map
101        .keys()
102        .filter_map(|k| {
103            k.strip_prefix("adam.")
104                .and_then(|rest| rest.strip_suffix(".m"))
105                .and_then(|id_str| id_str.parse::<u64>().ok())
106        })
107        .collect();
108    ids.sort();
109    ids.dedup();
110
111    ids.into_iter()
112        .filter_map(|id| {
113            let m = map.get(&format!("adam.{id}.m"))?.clone();
114            let v = map.get(&format!("adam.{id}.v"))?.clone();
115            let step = map
116                .get(&format!("adam.{id}.step"))
117                .map(|t| t.data()[0] as u64)
118                .unwrap_or(0);
119            Some((id, m, v, step))
120        })
121        .collect()
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn test_save_load_training_checkpoint_roundtrip() {
130        let dir = std::env::temp_dir().join("yscv_checkpoint_test");
131        let _ = std::fs::create_dir_all(&dir);
132        let path = dir.join("checkpoint.bin");
133
134        let mut model_weights = HashMap::new();
135        model_weights.insert(
136            "layer.0.weight".to_string(),
137            Tensor::from_vec(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]).unwrap(),
138        );
139
140        let mut opt_state = HashMap::new();
141        opt_state.insert(
142            "sgd.0.velocity".to_string(),
143            Tensor::from_vec(vec![2, 2], vec![0.1, 0.2, 0.3, 0.4]).unwrap(),
144        );
145
146        save_training_checkpoint(&path, &model_weights, &opt_state).unwrap();
147
148        let (loaded_weights, loaded_opt) = load_training_checkpoint(&path).unwrap();
149
150        assert!(loaded_weights.contains_key("layer.0.weight"));
151        assert!(loaded_opt.contains_key("sgd.0.velocity"));
152        assert_eq!(
153            loaded_weights["layer.0.weight"].data(),
154            &[1.0, 2.0, 3.0, 4.0]
155        );
156        assert_eq!(loaded_opt["sgd.0.velocity"].data(), &[0.1, 0.2, 0.3, 0.4]);
157
158        let _ = std::fs::remove_file(&path);
159    }
160
161    #[test]
162    fn test_sgd_state_roundtrip() {
163        let mut velocity = HashMap::new();
164        velocity.insert(
165            42u64,
166            Tensor::from_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap(),
167        );
168
169        let map = sgd_state_to_map(&velocity);
170        let restored = sgd_state_from_map(&map);
171
172        assert!(restored.contains_key(&42));
173        assert_eq!(restored[&42].data(), &[1.0, 2.0, 3.0]);
174    }
175
176    #[test]
177    fn test_adam_state_roundtrip() {
178        let state = vec![(
179            7u64,
180            Tensor::from_vec(vec![2], vec![0.1, 0.2]).unwrap(),
181            Tensor::from_vec(vec![2], vec![0.01, 0.02]).unwrap(),
182            100u64,
183        )];
184
185        let map = adam_state_to_map(&state);
186        let restored = adam_state_from_map(&map);
187
188        assert_eq!(restored.len(), 1);
189        assert_eq!(restored[0].0, 7);
190        assert_eq!(restored[0].3, 100);
191    }
192}