1use crate::StateDict;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20
21#[derive(Debug, Clone, Serialize, Deserialize, Default)]
27pub struct TrainingState {
28 pub epoch: usize,
30 pub step: usize,
32 pub global_step: usize,
34 pub best_metric: Option<f32>,
36 pub best_metric_name: Option<String>,
38 pub loss_history: Vec<f32>,
40 pub val_loss_history: Vec<f32>,
42 pub lr_history: Vec<f32>,
44 pub custom_metrics: HashMap<String, Vec<f32>>,
46}
47
48impl TrainingState {
49 #[must_use]
51 pub fn new() -> Self {
52 Self::default()
53 }
54
55 pub fn record_loss(&mut self, loss: f32) {
57 self.loss_history.push(loss);
58 if self.loss_history.len() > 1000 {
60 self.loss_history.remove(0);
61 }
62 }
63
64 pub fn record_val_loss(&mut self, loss: f32) {
66 self.val_loss_history.push(loss);
67 }
68
69 pub fn record_lr(&mut self, lr: f32) {
71 self.lr_history.push(lr);
72 }
73
74 pub fn record_metric(&mut self, name: &str, value: f32) {
76 self.custom_metrics
77 .entry(name.to_string())
78 .or_default()
79 .push(value);
80 }
81
82 pub fn update_best(&mut self, name: &str, value: f32, higher_is_better: bool) -> bool {
84 let improved = match self.best_metric {
85 None => true,
86 Some(best) => {
87 if higher_is_better {
88 value > best
89 } else {
90 value < best
91 }
92 }
93 };
94
95 if improved {
96 self.best_metric = Some(value);
97 self.best_metric_name = Some(name.to_string());
98 }
99
100 improved
101 }
102
103 #[must_use]
105 pub fn avg_loss(&self, n: usize) -> Option<f32> {
106 if self.loss_history.is_empty() {
107 return None;
108 }
109 let start = self.loss_history.len().saturating_sub(n);
110 let slice = &self.loss_history[start..];
111 Some(slice.iter().sum::<f32>() / slice.len() as f32)
112 }
113
114 pub fn next_epoch(&mut self) {
116 self.epoch += 1;
117 self.step = 0;
118 }
119
120 pub fn next_step(&mut self) {
122 self.step += 1;
123 self.global_step += 1;
124 }
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct Checkpoint {
134 pub model_state: StateDict,
136 pub optimizer_state: StateDict,
138 pub training_state: TrainingState,
140 pub rng_state: Option<Vec<u8>>,
142 pub config: HashMap<String, String>,
144 pub axonml_version: String,
146 pub timestamp: String,
148}
149
150impl Checkpoint {
151 #[must_use]
153 pub fn builder() -> CheckpointBuilder {
154 CheckpointBuilder::new()
155 }
156
157 #[must_use]
159 pub fn epoch(&self) -> usize {
160 self.training_state.epoch
161 }
162
163 #[must_use]
165 pub fn global_step(&self) -> usize {
166 self.training_state.global_step
167 }
168
169 #[must_use]
171 pub fn best_metric(&self) -> Option<f32> {
172 self.training_state.best_metric
173 }
174}
175
176pub struct CheckpointBuilder {
182 model_state: Option<StateDict>,
183 optimizer_state: Option<StateDict>,
184 training_state: TrainingState,
185 rng_state: Option<Vec<u8>>,
186 config: HashMap<String, String>,
187}
188
189impl CheckpointBuilder {
190 #[must_use]
192 pub fn new() -> Self {
193 Self {
194 model_state: None,
195 optimizer_state: None,
196 training_state: TrainingState::new(),
197 rng_state: None,
198 config: HashMap::new(),
199 }
200 }
201
202 #[must_use]
204 pub fn model_state(mut self, state: StateDict) -> Self {
205 self.model_state = Some(state);
206 self
207 }
208
209 #[must_use]
211 pub fn optimizer_state(mut self, state: StateDict) -> Self {
212 self.optimizer_state = Some(state);
213 self
214 }
215
216 #[must_use]
218 pub fn training_state(mut self, state: TrainingState) -> Self {
219 self.training_state = state;
220 self
221 }
222
223 #[must_use]
225 pub fn rng_state(mut self, state: Vec<u8>) -> Self {
226 self.rng_state = Some(state);
227 self
228 }
229
230 #[must_use]
232 pub fn config(mut self, key: &str, value: &str) -> Self {
233 self.config.insert(key.to_string(), value.to_string());
234 self
235 }
236
237 #[must_use]
239 pub fn epoch(mut self, epoch: usize) -> Self {
240 self.training_state.epoch = epoch;
241 self
242 }
243
244 #[must_use]
246 pub fn global_step(mut self, step: usize) -> Self {
247 self.training_state.global_step = step;
248 self
249 }
250
251 #[must_use]
253 pub fn build(self) -> Checkpoint {
254 Checkpoint {
255 model_state: self.model_state.unwrap_or_default(),
256 optimizer_state: self.optimizer_state.unwrap_or_default(),
257 training_state: self.training_state,
258 rng_state: self.rng_state,
259 config: self.config,
260 axonml_version: env!("CARGO_PKG_VERSION").to_string(),
261 timestamp: chrono_timestamp(),
262 }
263 }
264}
265
266impl Default for CheckpointBuilder {
267 fn default() -> Self {
268 Self::new()
269 }
270}
271
272fn chrono_timestamp() -> String {
277 use std::time::{SystemTime, UNIX_EPOCH};
279 let duration = SystemTime::now()
280 .duration_since(UNIX_EPOCH)
281 .unwrap_or_default();
282 format!("{}", duration.as_secs())
283}
284
285#[cfg(test)]
290mod tests {
291 use super::*;
292 use crate::TensorData;
293
294 #[test]
295 fn test_training_state_basic() {
296 let mut state = TrainingState::new();
297 assert_eq!(state.epoch, 0);
298 assert_eq!(state.step, 0);
299
300 state.next_step();
301 assert_eq!(state.step, 1);
302 assert_eq!(state.global_step, 1);
303
304 state.next_epoch();
305 assert_eq!(state.epoch, 1);
306 assert_eq!(state.step, 0);
307 }
308
309 #[test]
310 fn test_training_state_loss_recording() {
311 let mut state = TrainingState::new();
312
313 state.record_loss(1.0);
314 state.record_loss(0.8);
315 state.record_loss(0.6);
316
317 assert_eq!(state.loss_history.len(), 3);
318 let avg = state.avg_loss(2).unwrap();
319 assert!((avg - 0.7).abs() < 1e-5, "Expected ~0.7, got {avg}");
320 }
321
322 #[test]
323 fn test_training_state_best_metric() {
324 let mut state = TrainingState::new();
325
326 assert!(state.update_best("loss", 1.0, false));
328 assert!(!state.update_best("loss", 1.5, false));
329 assert!(state.update_best("loss", 0.5, false));
330 assert_eq!(state.best_metric, Some(0.5));
331
332 let mut state2 = TrainingState::new();
334 assert!(state2.update_best("accuracy", 0.8, true));
335 assert!(!state2.update_best("accuracy", 0.7, true));
336 assert!(state2.update_best("accuracy", 0.9, true));
337 assert_eq!(state2.best_metric, Some(0.9));
338 }
339
340 #[test]
341 fn test_checkpoint_builder() {
342 let mut model_state = StateDict::new();
343 model_state.insert(
344 "weight".to_string(),
345 TensorData {
346 shape: vec![10, 5],
347 values: vec![0.0; 50],
348 },
349 );
350
351 let checkpoint = Checkpoint::builder()
352 .model_state(model_state)
353 .epoch(5)
354 .global_step(1000)
355 .config("learning_rate", "0.001")
356 .build();
357
358 assert_eq!(checkpoint.epoch(), 5);
359 assert_eq!(checkpoint.global_step(), 1000);
360 assert!(checkpoint.config.contains_key("learning_rate"));
361 }
362
363 #[test]
364 fn test_checkpoint_serialization() {
365 let checkpoint = Checkpoint::builder().epoch(10).global_step(5000).build();
366
367 let bytes = bincode::serialize(&checkpoint).unwrap();
369 assert!(!bytes.is_empty());
370
371 let restored: Checkpoint = bincode::deserialize(&bytes).unwrap();
373 assert_eq!(restored.epoch(), 10);
374 assert_eq!(restored.global_step(), 5000);
375 }
376}