ghostflow_optim/
optimizer.rs

1//! Base Optimizer trait
2
3use ghostflow_core::Tensor;
4
5/// Base trait for all optimizers
6pub trait Optimizer {
7    /// Perform a single optimization step
8    fn step(&mut self);
9    
10    /// Zero all gradients
11    fn zero_grad(&mut self);
12    
13    /// Get current learning rate
14    fn get_lr(&self) -> f32;
15    
16    /// Set learning rate
17    fn set_lr(&mut self, lr: f32);
18    
19    /// Get all parameters
20    fn parameters(&self) -> &[Tensor];
21}
22
23/// Parameter group for different learning rates
24pub struct ParamGroup {
25    pub params: Vec<Tensor>,
26    pub lr: f32,
27    pub weight_decay: f32,
28}
29
30impl ParamGroup {
31    pub fn new(params: Vec<Tensor>, lr: f32) -> Self {
32        ParamGroup {
33            params,
34            lr,
35            weight_decay: 0.0,
36        }
37    }
38
39    pub fn with_weight_decay(mut self, weight_decay: f32) -> Self {
40        self.weight_decay = weight_decay;
41        self
42    }
43}