ghostflow_optim/
optimizer.rs1use ghostflow_core::Tensor;
4
5pub trait Optimizer {
7 fn step(&mut self);
9
10 fn zero_grad(&mut self);
12
13 fn get_lr(&self) -> f32;
15
16 fn set_lr(&mut self, lr: f32);
18
19 fn parameters(&self) -> &[Tensor];
21}
22
23pub struct ParamGroup {
25 pub params: Vec<Tensor>,
26 pub lr: f32,
27 pub weight_decay: f32,
28}
29
30impl ParamGroup {
31 pub fn new(params: Vec<Tensor>, lr: f32) -> Self {
32 ParamGroup {
33 params,
34 lr,
35 weight_decay: 0.0,
36 }
37 }
38
39 pub fn with_weight_decay(mut self, weight_decay: f32) -> Self {
40 self.weight_decay = weight_decay;
41 self
42 }
43}