1use crate::StateDict;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize, Default)]
16pub struct TrainingState {
17 pub epoch: usize,
19 pub step: usize,
21 pub global_step: usize,
23 pub best_metric: Option<f32>,
25 pub best_metric_name: Option<String>,
27 pub loss_history: Vec<f32>,
29 pub val_loss_history: Vec<f32>,
31 pub lr_history: Vec<f32>,
33 pub custom_metrics: HashMap<String, Vec<f32>>,
35}
36
37impl TrainingState {
38 #[must_use] pub fn new() -> Self {
40 Self::default()
41 }
42
43 pub fn record_loss(&mut self, loss: f32) {
45 self.loss_history.push(loss);
46 if self.loss_history.len() > 1000 {
48 self.loss_history.remove(0);
49 }
50 }
51
52 pub fn record_val_loss(&mut self, loss: f32) {
54 self.val_loss_history.push(loss);
55 }
56
57 pub fn record_lr(&mut self, lr: f32) {
59 self.lr_history.push(lr);
60 }
61
62 pub fn record_metric(&mut self, name: &str, value: f32) {
64 self.custom_metrics
65 .entry(name.to_string())
66 .or_default()
67 .push(value);
68 }
69
70 pub fn update_best(&mut self, name: &str, value: f32, higher_is_better: bool) -> bool {
72 let improved = match self.best_metric {
73 None => true,
74 Some(best) => {
75 if higher_is_better {
76 value > best
77 } else {
78 value < best
79 }
80 }
81 };
82
83 if improved {
84 self.best_metric = Some(value);
85 self.best_metric_name = Some(name.to_string());
86 }
87
88 improved
89 }
90
91 #[must_use] pub fn avg_loss(&self, n: usize) -> Option<f32> {
93 if self.loss_history.is_empty() {
94 return None;
95 }
96 let start = self.loss_history.len().saturating_sub(n);
97 let slice = &self.loss_history[start..];
98 Some(slice.iter().sum::<f32>() / slice.len() as f32)
99 }
100
101 pub fn next_epoch(&mut self) {
103 self.epoch += 1;
104 self.step = 0;
105 }
106
107 pub fn next_step(&mut self) {
109 self.step += 1;
110 self.global_step += 1;
111 }
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct Checkpoint {
121 pub model_state: StateDict,
123 pub optimizer_state: StateDict,
125 pub training_state: TrainingState,
127 pub rng_state: Option<Vec<u8>>,
129 pub config: HashMap<String, String>,
131 pub axonml_version: String,
133 pub timestamp: String,
135}
136
137impl Checkpoint {
138 #[must_use] pub fn builder() -> CheckpointBuilder {
140 CheckpointBuilder::new()
141 }
142
143 #[must_use] pub fn epoch(&self) -> usize {
145 self.training_state.epoch
146 }
147
148 #[must_use] pub fn global_step(&self) -> usize {
150 self.training_state.global_step
151 }
152
153 #[must_use] pub fn best_metric(&self) -> Option<f32> {
155 self.training_state.best_metric
156 }
157}
158
159pub struct CheckpointBuilder {
165 model_state: Option<StateDict>,
166 optimizer_state: Option<StateDict>,
167 training_state: TrainingState,
168 rng_state: Option<Vec<u8>>,
169 config: HashMap<String, String>,
170}
171
172impl CheckpointBuilder {
173 #[must_use] pub fn new() -> Self {
175 Self {
176 model_state: None,
177 optimizer_state: None,
178 training_state: TrainingState::new(),
179 rng_state: None,
180 config: HashMap::new(),
181 }
182 }
183
184 #[must_use] pub fn model_state(mut self, state: StateDict) -> Self {
186 self.model_state = Some(state);
187 self
188 }
189
190 #[must_use] pub fn optimizer_state(mut self, state: StateDict) -> Self {
192 self.optimizer_state = Some(state);
193 self
194 }
195
196 #[must_use] pub fn training_state(mut self, state: TrainingState) -> Self {
198 self.training_state = state;
199 self
200 }
201
202 #[must_use] pub fn rng_state(mut self, state: Vec<u8>) -> Self {
204 self.rng_state = Some(state);
205 self
206 }
207
208 #[must_use] pub fn config(mut self, key: &str, value: &str) -> Self {
210 self.config.insert(key.to_string(), value.to_string());
211 self
212 }
213
214 #[must_use] pub fn epoch(mut self, epoch: usize) -> Self {
216 self.training_state.epoch = epoch;
217 self
218 }
219
220 #[must_use] pub fn global_step(mut self, step: usize) -> Self {
222 self.training_state.global_step = step;
223 self
224 }
225
226 #[must_use] pub fn build(self) -> Checkpoint {
228 Checkpoint {
229 model_state: self.model_state.unwrap_or_default(),
230 optimizer_state: self.optimizer_state.unwrap_or_default(),
231 training_state: self.training_state,
232 rng_state: self.rng_state,
233 config: self.config,
234 axonml_version: env!("CARGO_PKG_VERSION").to_string(),
235 timestamp: chrono_timestamp(),
236 }
237 }
238}
239
240impl Default for CheckpointBuilder {
241 fn default() -> Self {
242 Self::new()
243 }
244}
245
246fn chrono_timestamp() -> String {
251 use std::time::{SystemTime, UNIX_EPOCH};
253 let duration = SystemTime::now()
254 .duration_since(UNIX_EPOCH)
255 .unwrap_or_default();
256 format!("{}", duration.as_secs())
257}
258
259#[cfg(test)]
264mod tests {
265 use super::*;
266 use crate::TensorData;
267
268 #[test]
269 fn test_training_state_basic() {
270 let mut state = TrainingState::new();
271 assert_eq!(state.epoch, 0);
272 assert_eq!(state.step, 0);
273
274 state.next_step();
275 assert_eq!(state.step, 1);
276 assert_eq!(state.global_step, 1);
277
278 state.next_epoch();
279 assert_eq!(state.epoch, 1);
280 assert_eq!(state.step, 0);
281 }
282
283 #[test]
284 fn test_training_state_loss_recording() {
285 let mut state = TrainingState::new();
286
287 state.record_loss(1.0);
288 state.record_loss(0.8);
289 state.record_loss(0.6);
290
291 assert_eq!(state.loss_history.len(), 3);
292 let avg = state.avg_loss(2).unwrap();
293 assert!((avg - 0.7).abs() < 1e-5, "Expected ~0.7, got {avg}");
294 }
295
296 #[test]
297 fn test_training_state_best_metric() {
298 let mut state = TrainingState::new();
299
300 assert!(state.update_best("loss", 1.0, false));
302 assert!(!state.update_best("loss", 1.5, false));
303 assert!(state.update_best("loss", 0.5, false));
304 assert_eq!(state.best_metric, Some(0.5));
305
306 let mut state2 = TrainingState::new();
308 assert!(state2.update_best("accuracy", 0.8, true));
309 assert!(!state2.update_best("accuracy", 0.7, true));
310 assert!(state2.update_best("accuracy", 0.9, true));
311 assert_eq!(state2.best_metric, Some(0.9));
312 }
313
314 #[test]
315 fn test_checkpoint_builder() {
316 let mut model_state = StateDict::new();
317 model_state.insert(
318 "weight".to_string(),
319 TensorData {
320 shape: vec![10, 5],
321 values: vec![0.0; 50],
322 },
323 );
324
325 let checkpoint = Checkpoint::builder()
326 .model_state(model_state)
327 .epoch(5)
328 .global_step(1000)
329 .config("learning_rate", "0.001")
330 .build();
331
332 assert_eq!(checkpoint.epoch(), 5);
333 assert_eq!(checkpoint.global_step(), 1000);
334 assert!(checkpoint.config.contains_key("learning_rate"));
335 }
336
337 #[test]
338 fn test_checkpoint_serialization() {
339 let checkpoint = Checkpoint::builder().epoch(10).global_step(5000).build();
340
341 let bytes = bincode::serialize(&checkpoint).unwrap();
343 assert!(!bytes.is_empty());
344
345 let restored: Checkpoint = bincode::deserialize(&bytes).unwrap();
347 assert_eq!(restored.epoch(), 10);
348 assert_eq!(restored.global_step(), 5000);
349 }
350}