atomr_accel_train/
optimizer.rs1#[derive(Debug, Clone, Copy)]
6pub enum OptimizerKind {
7 Sgd {
8 lr: f32,
9 momentum: f32,
10 weight_decay: f32,
11 },
12 AdamW {
13 lr: f32,
14 beta1: f32,
15 beta2: f32,
16 eps: f32,
17 weight_decay: f32,
18 },
19}
20
21impl OptimizerKind {
22 pub fn lr(&self) -> f32 {
23 match self {
24 OptimizerKind::Sgd { lr, .. } => *lr,
25 OptimizerKind::AdamW { lr, .. } => *lr,
26 }
27 }
28}
29
30#[derive(Debug, Clone, Copy, Default)]
31pub struct StepStats {
32 pub loss: f32,
33 pub grad_norm: f32,
34 pub step_micros: u64,
35}