Skip to main content

axonml_serialize/
checkpoint.rs

1//! Checkpoint - Training State Persistence
2//!
3//! Provides functionality for saving and resuming training sessions,
4//! including model parameters, optimizer state, and training metrics.
5
6use crate::StateDict;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10// =============================================================================
11// TrainingState
12// =============================================================================
13
14/// Training state for checkpointing.
15#[derive(Debug, Clone, Serialize, Deserialize, Default)]
16pub struct TrainingState {
17    /// Current epoch.
18    pub epoch: usize,
19    /// Current step within epoch.
20    pub step: usize,
21    /// Global step count.
22    pub global_step: usize,
23    /// Best metric value seen so far.
24    pub best_metric: Option<f32>,
25    /// Name of the best metric.
26    pub best_metric_name: Option<String>,
27    /// Training loss history (last N values).
28    pub loss_history: Vec<f32>,
29    /// Validation loss history.
30    pub val_loss_history: Vec<f32>,
31    /// Learning rate history.
32    pub lr_history: Vec<f32>,
33    /// Custom metrics.
34    pub custom_metrics: HashMap<String, Vec<f32>>,
35}
36
37impl TrainingState {
38    /// Create a new training state.
39    #[must_use] pub fn new() -> Self {
40        Self::default()
41    }
42
43    /// Record a training loss value.
44    pub fn record_loss(&mut self, loss: f32) {
45        self.loss_history.push(loss);
46        // Keep last 1000 values
47        if self.loss_history.len() > 1000 {
48            self.loss_history.remove(0);
49        }
50    }
51
52    /// Record a validation loss value.
53    pub fn record_val_loss(&mut self, loss: f32) {
54        self.val_loss_history.push(loss);
55    }
56
57    /// Record learning rate.
58    pub fn record_lr(&mut self, lr: f32) {
59        self.lr_history.push(lr);
60    }
61
62    /// Record a custom metric.
63    pub fn record_metric(&mut self, name: &str, value: f32) {
64        self.custom_metrics
65            .entry(name.to_string())
66            .or_default()
67            .push(value);
68    }
69
70    /// Update best metric if improved.
71    pub fn update_best(&mut self, name: &str, value: f32, higher_is_better: bool) -> bool {
72        let improved = match self.best_metric {
73            None => true,
74            Some(best) => {
75                if higher_is_better {
76                    value > best
77                } else {
78                    value < best
79                }
80            }
81        };
82
83        if improved {
84            self.best_metric = Some(value);
85            self.best_metric_name = Some(name.to_string());
86        }
87
88        improved
89    }
90
91    /// Get the average loss over recent values.
92    #[must_use] pub fn avg_loss(&self, n: usize) -> Option<f32> {
93        if self.loss_history.is_empty() {
94            return None;
95        }
96        let start = self.loss_history.len().saturating_sub(n);
97        let slice = &self.loss_history[start..];
98        Some(slice.iter().sum::<f32>() / slice.len() as f32)
99    }
100
101    /// Increment epoch.
102    pub fn next_epoch(&mut self) {
103        self.epoch += 1;
104        self.step = 0;
105    }
106
107    /// Increment step.
108    pub fn next_step(&mut self) {
109        self.step += 1;
110        self.global_step += 1;
111    }
112}
113
114// =============================================================================
115// Checkpoint
116// =============================================================================
117
118/// A complete training checkpoint.
119#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct Checkpoint {
121    /// Model state dictionary.
122    pub model_state: StateDict,
123    /// Optimizer state dictionary.
124    pub optimizer_state: StateDict,
125    /// Training state.
126    pub training_state: TrainingState,
127    /// Random number generator state (for reproducibility).
128    pub rng_state: Option<Vec<u8>>,
129    /// Configuration used for training.
130    pub config: HashMap<String, String>,
131    /// Axonml version.
132    pub axonml_version: String,
133    /// Timestamp when checkpoint was created.
134    pub timestamp: String,
135}
136
137impl Checkpoint {
138    /// Create a new checkpoint builder.
139    #[must_use] pub fn builder() -> CheckpointBuilder {
140        CheckpointBuilder::new()
141    }
142
143    /// Get the epoch from this checkpoint.
144    #[must_use] pub fn epoch(&self) -> usize {
145        self.training_state.epoch
146    }
147
148    /// Get the global step from this checkpoint.
149    #[must_use] pub fn global_step(&self) -> usize {
150        self.training_state.global_step
151    }
152
153    /// Get the best metric value.
154    #[must_use] pub fn best_metric(&self) -> Option<f32> {
155        self.training_state.best_metric
156    }
157}
158
159// =============================================================================
160// CheckpointBuilder
161// =============================================================================
162
163/// Builder for creating checkpoints.
164pub struct CheckpointBuilder {
165    model_state: Option<StateDict>,
166    optimizer_state: Option<StateDict>,
167    training_state: TrainingState,
168    rng_state: Option<Vec<u8>>,
169    config: HashMap<String, String>,
170}
171
172impl CheckpointBuilder {
173    /// Create a new checkpoint builder.
174    #[must_use] pub fn new() -> Self {
175        Self {
176            model_state: None,
177            optimizer_state: None,
178            training_state: TrainingState::new(),
179            rng_state: None,
180            config: HashMap::new(),
181        }
182    }
183
184    /// Set the model state.
185    #[must_use] pub fn model_state(mut self, state: StateDict) -> Self {
186        self.model_state = Some(state);
187        self
188    }
189
190    /// Set the optimizer state.
191    #[must_use] pub fn optimizer_state(mut self, state: StateDict) -> Self {
192        self.optimizer_state = Some(state);
193        self
194    }
195
196    /// Set the training state.
197    #[must_use] pub fn training_state(mut self, state: TrainingState) -> Self {
198        self.training_state = state;
199        self
200    }
201
202    /// Set the RNG state.
203    #[must_use] pub fn rng_state(mut self, state: Vec<u8>) -> Self {
204        self.rng_state = Some(state);
205        self
206    }
207
208    /// Add a configuration value.
209    #[must_use] pub fn config(mut self, key: &str, value: &str) -> Self {
210        self.config.insert(key.to_string(), value.to_string());
211        self
212    }
213
214    /// Set the epoch.
215    #[must_use] pub fn epoch(mut self, epoch: usize) -> Self {
216        self.training_state.epoch = epoch;
217        self
218    }
219
220    /// Set the global step.
221    #[must_use] pub fn global_step(mut self, step: usize) -> Self {
222        self.training_state.global_step = step;
223        self
224    }
225
226    /// Build the checkpoint.
227    #[must_use] pub fn build(self) -> Checkpoint {
228        Checkpoint {
229            model_state: self.model_state.unwrap_or_default(),
230            optimizer_state: self.optimizer_state.unwrap_or_default(),
231            training_state: self.training_state,
232            rng_state: self.rng_state,
233            config: self.config,
234            axonml_version: env!("CARGO_PKG_VERSION").to_string(),
235            timestamp: chrono_timestamp(),
236        }
237    }
238}
239
240impl Default for CheckpointBuilder {
241    fn default() -> Self {
242        Self::new()
243    }
244}
245
246// =============================================================================
247// Utilities
248// =============================================================================
249
250fn chrono_timestamp() -> String {
251    // Simple timestamp without chrono dependency
252    use std::time::{SystemTime, UNIX_EPOCH};
253    let duration = SystemTime::now()
254        .duration_since(UNIX_EPOCH)
255        .unwrap_or_default();
256    format!("{}", duration.as_secs())
257}
258
259// =============================================================================
260// Tests
261// =============================================================================
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use crate::TensorData;
267
268    #[test]
269    fn test_training_state_basic() {
270        let mut state = TrainingState::new();
271        assert_eq!(state.epoch, 0);
272        assert_eq!(state.step, 0);
273
274        state.next_step();
275        assert_eq!(state.step, 1);
276        assert_eq!(state.global_step, 1);
277
278        state.next_epoch();
279        assert_eq!(state.epoch, 1);
280        assert_eq!(state.step, 0);
281    }
282
283    #[test]
284    fn test_training_state_loss_recording() {
285        let mut state = TrainingState::new();
286
287        state.record_loss(1.0);
288        state.record_loss(0.8);
289        state.record_loss(0.6);
290
291        assert_eq!(state.loss_history.len(), 3);
292        let avg = state.avg_loss(2).unwrap();
293        assert!((avg - 0.7).abs() < 1e-5, "Expected ~0.7, got {avg}");
294    }
295
296    #[test]
297    fn test_training_state_best_metric() {
298        let mut state = TrainingState::new();
299
300        // Lower is better (like loss)
301        assert!(state.update_best("loss", 1.0, false));
302        assert!(!state.update_best("loss", 1.5, false));
303        assert!(state.update_best("loss", 0.5, false));
304        assert_eq!(state.best_metric, Some(0.5));
305
306        // Higher is better (like accuracy)
307        let mut state2 = TrainingState::new();
308        assert!(state2.update_best("accuracy", 0.8, true));
309        assert!(!state2.update_best("accuracy", 0.7, true));
310        assert!(state2.update_best("accuracy", 0.9, true));
311        assert_eq!(state2.best_metric, Some(0.9));
312    }
313
314    #[test]
315    fn test_checkpoint_builder() {
316        let mut model_state = StateDict::new();
317        model_state.insert(
318            "weight".to_string(),
319            TensorData {
320                shape: vec![10, 5],
321                values: vec![0.0; 50],
322            },
323        );
324
325        let checkpoint = Checkpoint::builder()
326            .model_state(model_state)
327            .epoch(5)
328            .global_step(1000)
329            .config("learning_rate", "0.001")
330            .build();
331
332        assert_eq!(checkpoint.epoch(), 5);
333        assert_eq!(checkpoint.global_step(), 1000);
334        assert!(checkpoint.config.contains_key("learning_rate"));
335    }
336
337    #[test]
338    fn test_checkpoint_serialization() {
339        let checkpoint = Checkpoint::builder().epoch(10).global_step(5000).build();
340
341        // Serialize
342        let bytes = bincode::serialize(&checkpoint).unwrap();
343        assert!(!bytes.is_empty());
344
345        // Deserialize
346        let restored: Checkpoint = bincode::deserialize(&bytes).unwrap();
347        assert_eq!(restored.epoch(), 10);
348        assert_eq!(restored.global_step(), 5000);
349    }
350}