1use crate::StateDict;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize, Default)]
16pub struct TrainingState {
17 pub epoch: usize,
19 pub step: usize,
21 pub global_step: usize,
23 pub best_metric: Option<f32>,
25 pub best_metric_name: Option<String>,
27 pub loss_history: Vec<f32>,
29 pub val_loss_history: Vec<f32>,
31 pub lr_history: Vec<f32>,
33 pub custom_metrics: HashMap<String, Vec<f32>>,
35}
36
37impl TrainingState {
38 #[must_use]
40 pub fn new() -> Self {
41 Self::default()
42 }
43
44 pub fn record_loss(&mut self, loss: f32) {
46 self.loss_history.push(loss);
47 if self.loss_history.len() > 1000 {
49 self.loss_history.remove(0);
50 }
51 }
52
53 pub fn record_val_loss(&mut self, loss: f32) {
55 self.val_loss_history.push(loss);
56 }
57
58 pub fn record_lr(&mut self, lr: f32) {
60 self.lr_history.push(lr);
61 }
62
63 pub fn record_metric(&mut self, name: &str, value: f32) {
65 self.custom_metrics
66 .entry(name.to_string())
67 .or_default()
68 .push(value);
69 }
70
71 pub fn update_best(&mut self, name: &str, value: f32, higher_is_better: bool) -> bool {
73 let improved = match self.best_metric {
74 None => true,
75 Some(best) => {
76 if higher_is_better {
77 value > best
78 } else {
79 value < best
80 }
81 }
82 };
83
84 if improved {
85 self.best_metric = Some(value);
86 self.best_metric_name = Some(name.to_string());
87 }
88
89 improved
90 }
91
92 #[must_use]
94 pub fn avg_loss(&self, n: usize) -> Option<f32> {
95 if self.loss_history.is_empty() {
96 return None;
97 }
98 let start = self.loss_history.len().saturating_sub(n);
99 let slice = &self.loss_history[start..];
100 Some(slice.iter().sum::<f32>() / slice.len() as f32)
101 }
102
103 pub fn next_epoch(&mut self) {
105 self.epoch += 1;
106 self.step = 0;
107 }
108
109 pub fn next_step(&mut self) {
111 self.step += 1;
112 self.global_step += 1;
113 }
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct Checkpoint {
123 pub model_state: StateDict,
125 pub optimizer_state: StateDict,
127 pub training_state: TrainingState,
129 pub rng_state: Option<Vec<u8>>,
131 pub config: HashMap<String, String>,
133 pub axonml_version: String,
135 pub timestamp: String,
137}
138
139impl Checkpoint {
140 #[must_use]
142 pub fn builder() -> CheckpointBuilder {
143 CheckpointBuilder::new()
144 }
145
146 #[must_use]
148 pub fn epoch(&self) -> usize {
149 self.training_state.epoch
150 }
151
152 #[must_use]
154 pub fn global_step(&self) -> usize {
155 self.training_state.global_step
156 }
157
158 #[must_use]
160 pub fn best_metric(&self) -> Option<f32> {
161 self.training_state.best_metric
162 }
163}
164
165pub struct CheckpointBuilder {
171 model_state: Option<StateDict>,
172 optimizer_state: Option<StateDict>,
173 training_state: TrainingState,
174 rng_state: Option<Vec<u8>>,
175 config: HashMap<String, String>,
176}
177
178impl CheckpointBuilder {
179 #[must_use]
181 pub fn new() -> Self {
182 Self {
183 model_state: None,
184 optimizer_state: None,
185 training_state: TrainingState::new(),
186 rng_state: None,
187 config: HashMap::new(),
188 }
189 }
190
191 #[must_use]
193 pub fn model_state(mut self, state: StateDict) -> Self {
194 self.model_state = Some(state);
195 self
196 }
197
198 #[must_use]
200 pub fn optimizer_state(mut self, state: StateDict) -> Self {
201 self.optimizer_state = Some(state);
202 self
203 }
204
205 #[must_use]
207 pub fn training_state(mut self, state: TrainingState) -> Self {
208 self.training_state = state;
209 self
210 }
211
212 #[must_use]
214 pub fn rng_state(mut self, state: Vec<u8>) -> Self {
215 self.rng_state = Some(state);
216 self
217 }
218
219 #[must_use]
221 pub fn config(mut self, key: &str, value: &str) -> Self {
222 self.config.insert(key.to_string(), value.to_string());
223 self
224 }
225
226 #[must_use]
228 pub fn epoch(mut self, epoch: usize) -> Self {
229 self.training_state.epoch = epoch;
230 self
231 }
232
233 #[must_use]
235 pub fn global_step(mut self, step: usize) -> Self {
236 self.training_state.global_step = step;
237 self
238 }
239
240 #[must_use]
242 pub fn build(self) -> Checkpoint {
243 Checkpoint {
244 model_state: self.model_state.unwrap_or_default(),
245 optimizer_state: self.optimizer_state.unwrap_or_default(),
246 training_state: self.training_state,
247 rng_state: self.rng_state,
248 config: self.config,
249 axonml_version: env!("CARGO_PKG_VERSION").to_string(),
250 timestamp: chrono_timestamp(),
251 }
252 }
253}
254
255impl Default for CheckpointBuilder {
256 fn default() -> Self {
257 Self::new()
258 }
259}
260
261fn chrono_timestamp() -> String {
266 use std::time::{SystemTime, UNIX_EPOCH};
268 let duration = SystemTime::now()
269 .duration_since(UNIX_EPOCH)
270 .unwrap_or_default();
271 format!("{}", duration.as_secs())
272}
273
274#[cfg(test)]
279mod tests {
280 use super::*;
281 use crate::TensorData;
282
283 #[test]
284 fn test_training_state_basic() {
285 let mut state = TrainingState::new();
286 assert_eq!(state.epoch, 0);
287 assert_eq!(state.step, 0);
288
289 state.next_step();
290 assert_eq!(state.step, 1);
291 assert_eq!(state.global_step, 1);
292
293 state.next_epoch();
294 assert_eq!(state.epoch, 1);
295 assert_eq!(state.step, 0);
296 }
297
298 #[test]
299 fn test_training_state_loss_recording() {
300 let mut state = TrainingState::new();
301
302 state.record_loss(1.0);
303 state.record_loss(0.8);
304 state.record_loss(0.6);
305
306 assert_eq!(state.loss_history.len(), 3);
307 let avg = state.avg_loss(2).unwrap();
308 assert!((avg - 0.7).abs() < 1e-5, "Expected ~0.7, got {avg}");
309 }
310
311 #[test]
312 fn test_training_state_best_metric() {
313 let mut state = TrainingState::new();
314
315 assert!(state.update_best("loss", 1.0, false));
317 assert!(!state.update_best("loss", 1.5, false));
318 assert!(state.update_best("loss", 0.5, false));
319 assert_eq!(state.best_metric, Some(0.5));
320
321 let mut state2 = TrainingState::new();
323 assert!(state2.update_best("accuracy", 0.8, true));
324 assert!(!state2.update_best("accuracy", 0.7, true));
325 assert!(state2.update_best("accuracy", 0.9, true));
326 assert_eq!(state2.best_metric, Some(0.9));
327 }
328
329 #[test]
330 fn test_checkpoint_builder() {
331 let mut model_state = StateDict::new();
332 model_state.insert(
333 "weight".to_string(),
334 TensorData {
335 shape: vec![10, 5],
336 values: vec![0.0; 50],
337 },
338 );
339
340 let checkpoint = Checkpoint::builder()
341 .model_state(model_state)
342 .epoch(5)
343 .global_step(1000)
344 .config("learning_rate", "0.001")
345 .build();
346
347 assert_eq!(checkpoint.epoch(), 5);
348 assert_eq!(checkpoint.global_step(), 1000);
349 assert!(checkpoint.config.contains_key("learning_rate"));
350 }
351
352 #[test]
353 fn test_checkpoint_serialization() {
354 let checkpoint = Checkpoint::builder().epoch(10).global_step(5000).build();
355
356 let bytes = bincode::serialize(&checkpoint).unwrap();
358 assert!(!bytes.is_empty());
359
360 let restored: Checkpoint = bincode::deserialize(&bytes).unwrap();
362 assert_eq!(restored.epoch(), 10);
363 assert_eq!(restored.global_step(), 5000);
364 }
365}