Skip to main content

atomr_accel_train/
optimizer.rs

1//! Optimizer kinds. F4 ships SGD and AdamW configs; the actual
2//! parameter-update kernels live in F4.x once the gradient
3//! buffers are flowing through NCCL.
4
5#[derive(Debug, Clone, Copy)]
6pub enum OptimizerKind {
7    Sgd {
8        lr: f32,
9        momentum: f32,
10        weight_decay: f32,
11    },
12    AdamW {
13        lr: f32,
14        beta1: f32,
15        beta2: f32,
16        eps: f32,
17        weight_decay: f32,
18    },
19}
20
21impl OptimizerKind {
22    pub fn lr(&self) -> f32 {
23        match self {
24            OptimizerKind::Sgd { lr, .. } => *lr,
25            OptimizerKind::AdamW { lr, .. } => *lr,
26        }
27    }
28}
29
30#[derive(Debug, Clone, Copy, Default)]
31pub struct StepStats {
32    pub loss: f32,
33    pub grad_norm: f32,
34    pub step_micros: u64,
35}