Optimizer

Trait Optimizer 

Source
pub trait Optimizer<A, D>{
    // Required methods
    fn step(
        &mut self,
        params: &Array<A, D>,
        gradients: &Array<A, D>,
    ) -> Result<Array<A, D>>;
    fn get_learning_rate(&self) -> A;
    fn set_learning_rate(&mut self, learning_rate: A);

    // Provided method
    fn step_list(
        &mut self,
        params_list: &[&Array<A, D>],
        gradients_list: &[&Array<A, D>],
    ) -> Result<Vec<Array<A, D>>> { ... }
}
Expand description

Trait that defines the interface for optimization algorithms

Required Methods§

Source

fn step( &mut self, params: &Array<A, D>, gradients: &Array<A, D>, ) -> Result<Array<A, D>>

Updates parameters using the given gradients

§Arguments
  • params - The current parameter values
  • gradients - The gradients of the parameters
§Returns

The updated parameters

Source

fn get_learning_rate(&self) -> A

Gets the current learning rate

Source

fn set_learning_rate(&mut self, learning_rate: A)

Sets a new learning rate

Provided Methods§

Source

fn step_list( &mut self, params_list: &[&Array<A, D>], gradients_list: &[&Array<A, D>], ) -> Result<Vec<Array<A, D>>>

Updates multiple parameter arrays at once

§Arguments
  • params_list - List of parameter arrays
  • gradients_list - List of gradient arrays corresponding to the parameters
§Returns

Updated parameter arrays

Implementors§

Source§

impl Optimizer<f32, Dim<[usize; 1]>> for SimdSGD<f32>

Source§

impl Optimizer<f64, Dim<[usize; 1]>> for SimdSGD<f64>

Source§

impl<A> Optimizer<A, Dim<[usize; 1]>> for SparseAdam<A>
where A: Float + ScalarOperand + Debug + Send + Sync,

Source§

impl<A, D> Optimizer<A, D> for ChainedOptimizer<A, D>

Source§

impl<A, D> Optimizer<A, D> for ParallelOptimizer<A, D>

Source§

impl<A, D> Optimizer<A, D> for SequentialOptimizer<A, D>

Source§

impl<A, D> Optimizer<A, D> for Adagrad<A>
where A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension,

Source§

impl<A, D> Optimizer<A, D> for Adam<A>
where A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension,

Source§

impl<A, D> Optimizer<A, D> for AdamW<A>
where A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension,

Source§

impl<A, D> Optimizer<A, D> for LAMB<A>
where A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension,

Source§

impl<A, D> Optimizer<A, D> for LBFGS<A>
where A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension,

Source§

impl<A, D> Optimizer<A, D> for Lion<A>
where A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension,

Source§

impl<A, D> Optimizer<A, D> for RAdam<A>
where A: Float + ScalarOperand + Debug + Send + Sync + From<f64>, D: Dimension,

Source§

impl<A, D> Optimizer<A, D> for RMSprop<A>
where A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension,

Source§

impl<A, D> Optimizer<A, D> for SGD<A>
where A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension,

Source§

impl<A, O, D> Optimizer<A, D> for Lookahead<A, O, D>
where A: Float + ScalarOperand + Debug + Send + Sync, O: Optimizer<A, D> + Clone + Send + Sync, D: Dimension,

Source§

impl<A, O, D> Optimizer<A, D> for SAM<A, O, D>
where A: Float + ScalarOperand + Debug + Send + Sync, O: Optimizer<A, D> + Clone + Send + Sync, D: Dimension,

Source§

impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync> Optimizer<A, D> for GroupedAdam<A, D>

Source§

impl<A: Float + ScalarOperand + Debug + Send + Sync, D: Dimension + Send + Sync> Optimizer<A, D> for LARS<A>