Skip to main content

axonml_serialize/
checkpoint.rs

1//! Checkpoint - Training State Persistence
2//!
3//! # File
4//! `crates/axonml-serialize/src/checkpoint.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use crate::StateDict;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20
21// =============================================================================
22// TrainingState
23// =============================================================================
24
25/// Training state for checkpointing.
26#[derive(Debug, Clone, Serialize, Deserialize, Default)]
27pub struct TrainingState {
28    /// Current epoch.
29    pub epoch: usize,
30    /// Current step within epoch.
31    pub step: usize,
32    /// Global step count.
33    pub global_step: usize,
34    /// Best metric value seen so far.
35    pub best_metric: Option<f32>,
36    /// Name of the best metric.
37    pub best_metric_name: Option<String>,
38    /// Training loss history (last N values).
39    pub loss_history: Vec<f32>,
40    /// Validation loss history.
41    pub val_loss_history: Vec<f32>,
42    /// Learning rate history.
43    pub lr_history: Vec<f32>,
44    /// Custom metrics.
45    pub custom_metrics: HashMap<String, Vec<f32>>,
46}
47
48impl TrainingState {
49    /// Create a new training state.
50    #[must_use]
51    pub fn new() -> Self {
52        Self::default()
53    }
54
55    /// Record a training loss value.
56    pub fn record_loss(&mut self, loss: f32) {
57        self.loss_history.push(loss);
58        // Keep last 1000 values
59        if self.loss_history.len() > 1000 {
60            self.loss_history.remove(0);
61        }
62    }
63
64    /// Record a validation loss value.
65    pub fn record_val_loss(&mut self, loss: f32) {
66        self.val_loss_history.push(loss);
67    }
68
69    /// Record learning rate.
70    pub fn record_lr(&mut self, lr: f32) {
71        self.lr_history.push(lr);
72    }
73
74    /// Record a custom metric.
75    pub fn record_metric(&mut self, name: &str, value: f32) {
76        self.custom_metrics
77            .entry(name.to_string())
78            .or_default()
79            .push(value);
80    }
81
82    /// Update best metric if improved.
83    pub fn update_best(&mut self, name: &str, value: f32, higher_is_better: bool) -> bool {
84        let improved = match self.best_metric {
85            None => true,
86            Some(best) => {
87                if higher_is_better {
88                    value > best
89                } else {
90                    value < best
91                }
92            }
93        };
94
95        if improved {
96            self.best_metric = Some(value);
97            self.best_metric_name = Some(name.to_string());
98        }
99
100        improved
101    }
102
103    /// Get the average loss over recent values.
104    #[must_use]
105    pub fn avg_loss(&self, n: usize) -> Option<f32> {
106        if self.loss_history.is_empty() {
107            return None;
108        }
109        let start = self.loss_history.len().saturating_sub(n);
110        let slice = &self.loss_history[start..];
111        Some(slice.iter().sum::<f32>() / slice.len() as f32)
112    }
113
114    /// Increment epoch.
115    pub fn next_epoch(&mut self) {
116        self.epoch += 1;
117        self.step = 0;
118    }
119
120    /// Increment step.
121    pub fn next_step(&mut self) {
122        self.step += 1;
123        self.global_step += 1;
124    }
125}
126
127// =============================================================================
128// Checkpoint
129// =============================================================================
130
131/// A complete training checkpoint.
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct Checkpoint {
134    /// Model state dictionary.
135    pub model_state: StateDict,
136    /// Optimizer state dictionary.
137    pub optimizer_state: StateDict,
138    /// Training state.
139    pub training_state: TrainingState,
140    /// Random number generator state (for reproducibility).
141    pub rng_state: Option<Vec<u8>>,
142    /// Configuration used for training.
143    pub config: HashMap<String, String>,
144    /// Axonml version.
145    pub axonml_version: String,
146    /// Timestamp when checkpoint was created.
147    pub timestamp: String,
148}
149
150impl Checkpoint {
151    /// Create a new checkpoint builder.
152    #[must_use]
153    pub fn builder() -> CheckpointBuilder {
154        CheckpointBuilder::new()
155    }
156
157    /// Get the epoch from this checkpoint.
158    #[must_use]
159    pub fn epoch(&self) -> usize {
160        self.training_state.epoch
161    }
162
163    /// Get the global step from this checkpoint.
164    #[must_use]
165    pub fn global_step(&self) -> usize {
166        self.training_state.global_step
167    }
168
169    /// Get the best metric value.
170    #[must_use]
171    pub fn best_metric(&self) -> Option<f32> {
172        self.training_state.best_metric
173    }
174}
175
176// =============================================================================
177// CheckpointBuilder
178// =============================================================================
179
180/// Builder for creating checkpoints.
181pub struct CheckpointBuilder {
182    model_state: Option<StateDict>,
183    optimizer_state: Option<StateDict>,
184    training_state: TrainingState,
185    rng_state: Option<Vec<u8>>,
186    config: HashMap<String, String>,
187}
188
189impl CheckpointBuilder {
190    /// Create a new checkpoint builder.
191    #[must_use]
192    pub fn new() -> Self {
193        Self {
194            model_state: None,
195            optimizer_state: None,
196            training_state: TrainingState::new(),
197            rng_state: None,
198            config: HashMap::new(),
199        }
200    }
201
202    /// Set the model state.
203    #[must_use]
204    pub fn model_state(mut self, state: StateDict) -> Self {
205        self.model_state = Some(state);
206        self
207    }
208
209    /// Set the optimizer state.
210    #[must_use]
211    pub fn optimizer_state(mut self, state: StateDict) -> Self {
212        self.optimizer_state = Some(state);
213        self
214    }
215
216    /// Set the training state.
217    #[must_use]
218    pub fn training_state(mut self, state: TrainingState) -> Self {
219        self.training_state = state;
220        self
221    }
222
223    /// Set the RNG state.
224    #[must_use]
225    pub fn rng_state(mut self, state: Vec<u8>) -> Self {
226        self.rng_state = Some(state);
227        self
228    }
229
230    /// Add a configuration value.
231    #[must_use]
232    pub fn config(mut self, key: &str, value: &str) -> Self {
233        self.config.insert(key.to_string(), value.to_string());
234        self
235    }
236
237    /// Set the epoch.
238    #[must_use]
239    pub fn epoch(mut self, epoch: usize) -> Self {
240        self.training_state.epoch = epoch;
241        self
242    }
243
244    /// Set the global step.
245    #[must_use]
246    pub fn global_step(mut self, step: usize) -> Self {
247        self.training_state.global_step = step;
248        self
249    }
250
251    /// Build the checkpoint.
252    #[must_use]
253    pub fn build(self) -> Checkpoint {
254        Checkpoint {
255            model_state: self.model_state.unwrap_or_default(),
256            optimizer_state: self.optimizer_state.unwrap_or_default(),
257            training_state: self.training_state,
258            rng_state: self.rng_state,
259            config: self.config,
260            axonml_version: env!("CARGO_PKG_VERSION").to_string(),
261            timestamp: chrono_timestamp(),
262        }
263    }
264}
265
266impl Default for CheckpointBuilder {
267    fn default() -> Self {
268        Self::new()
269    }
270}
271
272// =============================================================================
273// Utilities
274// =============================================================================
275
276fn chrono_timestamp() -> String {
277    // Simple timestamp without chrono dependency
278    use std::time::{SystemTime, UNIX_EPOCH};
279    let duration = SystemTime::now()
280        .duration_since(UNIX_EPOCH)
281        .unwrap_or_default();
282    format!("{}", duration.as_secs())
283}
284
285// =============================================================================
286// Tests
287// =============================================================================
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use crate::TensorData;
293
294    #[test]
295    fn test_training_state_basic() {
296        let mut state = TrainingState::new();
297        assert_eq!(state.epoch, 0);
298        assert_eq!(state.step, 0);
299
300        state.next_step();
301        assert_eq!(state.step, 1);
302        assert_eq!(state.global_step, 1);
303
304        state.next_epoch();
305        assert_eq!(state.epoch, 1);
306        assert_eq!(state.step, 0);
307    }
308
309    #[test]
310    fn test_training_state_loss_recording() {
311        let mut state = TrainingState::new();
312
313        state.record_loss(1.0);
314        state.record_loss(0.8);
315        state.record_loss(0.6);
316
317        assert_eq!(state.loss_history.len(), 3);
318        let avg = state.avg_loss(2).unwrap();
319        assert!((avg - 0.7).abs() < 1e-5, "Expected ~0.7, got {avg}");
320    }
321
322    #[test]
323    fn test_training_state_best_metric() {
324        let mut state = TrainingState::new();
325
326        // Lower is better (like loss)
327        assert!(state.update_best("loss", 1.0, false));
328        assert!(!state.update_best("loss", 1.5, false));
329        assert!(state.update_best("loss", 0.5, false));
330        assert_eq!(state.best_metric, Some(0.5));
331
332        // Higher is better (like accuracy)
333        let mut state2 = TrainingState::new();
334        assert!(state2.update_best("accuracy", 0.8, true));
335        assert!(!state2.update_best("accuracy", 0.7, true));
336        assert!(state2.update_best("accuracy", 0.9, true));
337        assert_eq!(state2.best_metric, Some(0.9));
338    }
339
340    #[test]
341    fn test_checkpoint_builder() {
342        let mut model_state = StateDict::new();
343        model_state.insert(
344            "weight".to_string(),
345            TensorData {
346                shape: vec![10, 5],
347                values: vec![0.0; 50],
348            },
349        );
350
351        let checkpoint = Checkpoint::builder()
352            .model_state(model_state)
353            .epoch(5)
354            .global_step(1000)
355            .config("learning_rate", "0.001")
356            .build();
357
358        assert_eq!(checkpoint.epoch(), 5);
359        assert_eq!(checkpoint.global_step(), 1000);
360        assert!(checkpoint.config.contains_key("learning_rate"));
361    }
362
363    #[test]
364    fn test_checkpoint_serialization() {
365        let checkpoint = Checkpoint::builder().epoch(10).global_step(5000).build();
366
367        // Serialize
368        let bytes = bincode::serialize(&checkpoint).unwrap();
369        assert!(!bytes.is_empty());
370
371        // Deserialize
372        let restored: Checkpoint = bincode::deserialize(&bytes).unwrap();
373        assert_eq!(restored.epoch(), 10);
374        assert_eq!(restored.global_step(), 5000);
375    }
376}