Skip to main content

axonml_serialize/
checkpoint.rs

1//! Checkpoint - Training State Persistence
2//!
3//! Provides functionality for saving and resuming training sessions,
4//! including model parameters, optimizer state, and training metrics.
5
6use crate::StateDict;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10// =============================================================================
11// TrainingState
12// =============================================================================
13
14/// Training state for checkpointing.
15#[derive(Debug, Clone, Serialize, Deserialize, Default)]
16pub struct TrainingState {
17    /// Current epoch.
18    pub epoch: usize,
19    /// Current step within epoch.
20    pub step: usize,
21    /// Global step count.
22    pub global_step: usize,
23    /// Best metric value seen so far.
24    pub best_metric: Option<f32>,
25    /// Name of the best metric.
26    pub best_metric_name: Option<String>,
27    /// Training loss history (last N values).
28    pub loss_history: Vec<f32>,
29    /// Validation loss history.
30    pub val_loss_history: Vec<f32>,
31    /// Learning rate history.
32    pub lr_history: Vec<f32>,
33    /// Custom metrics.
34    pub custom_metrics: HashMap<String, Vec<f32>>,
35}
36
37impl TrainingState {
38    /// Create a new training state.
39    #[must_use]
40    pub fn new() -> Self {
41        Self::default()
42    }
43
44    /// Record a training loss value.
45    pub fn record_loss(&mut self, loss: f32) {
46        self.loss_history.push(loss);
47        // Keep last 1000 values
48        if self.loss_history.len() > 1000 {
49            self.loss_history.remove(0);
50        }
51    }
52
53    /// Record a validation loss value.
54    pub fn record_val_loss(&mut self, loss: f32) {
55        self.val_loss_history.push(loss);
56    }
57
58    /// Record learning rate.
59    pub fn record_lr(&mut self, lr: f32) {
60        self.lr_history.push(lr);
61    }
62
63    /// Record a custom metric.
64    pub fn record_metric(&mut self, name: &str, value: f32) {
65        self.custom_metrics
66            .entry(name.to_string())
67            .or_default()
68            .push(value);
69    }
70
71    /// Update best metric if improved.
72    pub fn update_best(&mut self, name: &str, value: f32, higher_is_better: bool) -> bool {
73        let improved = match self.best_metric {
74            None => true,
75            Some(best) => {
76                if higher_is_better {
77                    value > best
78                } else {
79                    value < best
80                }
81            }
82        };
83
84        if improved {
85            self.best_metric = Some(value);
86            self.best_metric_name = Some(name.to_string());
87        }
88
89        improved
90    }
91
92    /// Get the average loss over recent values.
93    #[must_use]
94    pub fn avg_loss(&self, n: usize) -> Option<f32> {
95        if self.loss_history.is_empty() {
96            return None;
97        }
98        let start = self.loss_history.len().saturating_sub(n);
99        let slice = &self.loss_history[start..];
100        Some(slice.iter().sum::<f32>() / slice.len() as f32)
101    }
102
103    /// Increment epoch.
104    pub fn next_epoch(&mut self) {
105        self.epoch += 1;
106        self.step = 0;
107    }
108
109    /// Increment step.
110    pub fn next_step(&mut self) {
111        self.step += 1;
112        self.global_step += 1;
113    }
114}
115
116// =============================================================================
117// Checkpoint
118// =============================================================================
119
120/// A complete training checkpoint.
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct Checkpoint {
123    /// Model state dictionary.
124    pub model_state: StateDict,
125    /// Optimizer state dictionary.
126    pub optimizer_state: StateDict,
127    /// Training state.
128    pub training_state: TrainingState,
129    /// Random number generator state (for reproducibility).
130    pub rng_state: Option<Vec<u8>>,
131    /// Configuration used for training.
132    pub config: HashMap<String, String>,
133    /// Axonml version.
134    pub axonml_version: String,
135    /// Timestamp when checkpoint was created.
136    pub timestamp: String,
137}
138
139impl Checkpoint {
140    /// Create a new checkpoint builder.
141    #[must_use]
142    pub fn builder() -> CheckpointBuilder {
143        CheckpointBuilder::new()
144    }
145
146    /// Get the epoch from this checkpoint.
147    #[must_use]
148    pub fn epoch(&self) -> usize {
149        self.training_state.epoch
150    }
151
152    /// Get the global step from this checkpoint.
153    #[must_use]
154    pub fn global_step(&self) -> usize {
155        self.training_state.global_step
156    }
157
158    /// Get the best metric value.
159    #[must_use]
160    pub fn best_metric(&self) -> Option<f32> {
161        self.training_state.best_metric
162    }
163}
164
165// =============================================================================
166// CheckpointBuilder
167// =============================================================================
168
169/// Builder for creating checkpoints.
170pub struct CheckpointBuilder {
171    model_state: Option<StateDict>,
172    optimizer_state: Option<StateDict>,
173    training_state: TrainingState,
174    rng_state: Option<Vec<u8>>,
175    config: HashMap<String, String>,
176}
177
178impl CheckpointBuilder {
179    /// Create a new checkpoint builder.
180    #[must_use]
181    pub fn new() -> Self {
182        Self {
183            model_state: None,
184            optimizer_state: None,
185            training_state: TrainingState::new(),
186            rng_state: None,
187            config: HashMap::new(),
188        }
189    }
190
191    /// Set the model state.
192    #[must_use]
193    pub fn model_state(mut self, state: StateDict) -> Self {
194        self.model_state = Some(state);
195        self
196    }
197
198    /// Set the optimizer state.
199    #[must_use]
200    pub fn optimizer_state(mut self, state: StateDict) -> Self {
201        self.optimizer_state = Some(state);
202        self
203    }
204
205    /// Set the training state.
206    #[must_use]
207    pub fn training_state(mut self, state: TrainingState) -> Self {
208        self.training_state = state;
209        self
210    }
211
212    /// Set the RNG state.
213    #[must_use]
214    pub fn rng_state(mut self, state: Vec<u8>) -> Self {
215        self.rng_state = Some(state);
216        self
217    }
218
219    /// Add a configuration value.
220    #[must_use]
221    pub fn config(mut self, key: &str, value: &str) -> Self {
222        self.config.insert(key.to_string(), value.to_string());
223        self
224    }
225
226    /// Set the epoch.
227    #[must_use]
228    pub fn epoch(mut self, epoch: usize) -> Self {
229        self.training_state.epoch = epoch;
230        self
231    }
232
233    /// Set the global step.
234    #[must_use]
235    pub fn global_step(mut self, step: usize) -> Self {
236        self.training_state.global_step = step;
237        self
238    }
239
240    /// Build the checkpoint.
241    #[must_use]
242    pub fn build(self) -> Checkpoint {
243        Checkpoint {
244            model_state: self.model_state.unwrap_or_default(),
245            optimizer_state: self.optimizer_state.unwrap_or_default(),
246            training_state: self.training_state,
247            rng_state: self.rng_state,
248            config: self.config,
249            axonml_version: env!("CARGO_PKG_VERSION").to_string(),
250            timestamp: chrono_timestamp(),
251        }
252    }
253}
254
255impl Default for CheckpointBuilder {
256    fn default() -> Self {
257        Self::new()
258    }
259}
260
261// =============================================================================
262// Utilities
263// =============================================================================
264
265fn chrono_timestamp() -> String {
266    // Simple timestamp without chrono dependency
267    use std::time::{SystemTime, UNIX_EPOCH};
268    let duration = SystemTime::now()
269        .duration_since(UNIX_EPOCH)
270        .unwrap_or_default();
271    format!("{}", duration.as_secs())
272}
273
274// =============================================================================
275// Tests
276// =============================================================================
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use crate::TensorData;
282
283    #[test]
284    fn test_training_state_basic() {
285        let mut state = TrainingState::new();
286        assert_eq!(state.epoch, 0);
287        assert_eq!(state.step, 0);
288
289        state.next_step();
290        assert_eq!(state.step, 1);
291        assert_eq!(state.global_step, 1);
292
293        state.next_epoch();
294        assert_eq!(state.epoch, 1);
295        assert_eq!(state.step, 0);
296    }
297
298    #[test]
299    fn test_training_state_loss_recording() {
300        let mut state = TrainingState::new();
301
302        state.record_loss(1.0);
303        state.record_loss(0.8);
304        state.record_loss(0.6);
305
306        assert_eq!(state.loss_history.len(), 3);
307        let avg = state.avg_loss(2).unwrap();
308        assert!((avg - 0.7).abs() < 1e-5, "Expected ~0.7, got {avg}");
309    }
310
311    #[test]
312    fn test_training_state_best_metric() {
313        let mut state = TrainingState::new();
314
315        // Lower is better (like loss)
316        assert!(state.update_best("loss", 1.0, false));
317        assert!(!state.update_best("loss", 1.5, false));
318        assert!(state.update_best("loss", 0.5, false));
319        assert_eq!(state.best_metric, Some(0.5));
320
321        // Higher is better (like accuracy)
322        let mut state2 = TrainingState::new();
323        assert!(state2.update_best("accuracy", 0.8, true));
324        assert!(!state2.update_best("accuracy", 0.7, true));
325        assert!(state2.update_best("accuracy", 0.9, true));
326        assert_eq!(state2.best_metric, Some(0.9));
327    }
328
329    #[test]
330    fn test_checkpoint_builder() {
331        let mut model_state = StateDict::new();
332        model_state.insert(
333            "weight".to_string(),
334            TensorData {
335                shape: vec![10, 5],
336                values: vec![0.0; 50],
337            },
338        );
339
340        let checkpoint = Checkpoint::builder()
341            .model_state(model_state)
342            .epoch(5)
343            .global_step(1000)
344            .config("learning_rate", "0.001")
345            .build();
346
347        assert_eq!(checkpoint.epoch(), 5);
348        assert_eq!(checkpoint.global_step(), 1000);
349        assert!(checkpoint.config.contains_key("learning_rate"));
350    }
351
352    #[test]
353    fn test_checkpoint_serialization() {
354        let checkpoint = Checkpoint::builder().epoch(10).global_step(5000).build();
355
356        // Serialize
357        let bytes = bincode::serialize(&checkpoint).unwrap();
358        assert!(!bytes.is_empty());
359
360        // Deserialize
361        let restored: Checkpoint = bincode::deserialize(&bytes).unwrap();
362        assert_eq!(restored.epoch(), 10);
363        assert_eq!(restored.global_step(), 5000);
364    }
365}