pub trait GroupedOptimizer<A: Float + ScalarOperand + Debug, D: Dimension>: Optimizer<A, D> {
// Required methods
fn add_group(
&mut self,
params: Vec<Array<A, D>>,
config: ParameterGroupConfig<A>,
) -> Result<usize>;
fn get_group(&self, groupid: usize) -> Result<&ParameterGroup<A, D>>;
fn get_group_mut(
&mut self,
groupid: usize,
) -> Result<&mut ParameterGroup<A, D>>;
fn groups(&self) -> &[ParameterGroup<A, D>];
fn groups_mut(&mut self) -> &mut [ParameterGroup<A, D>];
fn step_group(
&mut self,
group_id: usize,
gradients: &[Array<A, D>],
) -> Result<Vec<Array<A, D>>>;
fn set_group_learning_rate(&mut self, groupid: usize, lr: A) -> Result<()>;
fn set_group_weight_decay(&mut self, groupid: usize, wd: A) -> Result<()>;
}Expand description
Optimizer with parameter group support
Required Methods§
Sourcefn add_group(
&mut self,
params: Vec<Array<A, D>>,
config: ParameterGroupConfig<A>,
) -> Result<usize>
fn add_group( &mut self, params: Vec<Array<A, D>>, config: ParameterGroupConfig<A>, ) -> Result<usize>
Add a parameter group
Sourcefn get_group(&self, groupid: usize) -> Result<&ParameterGroup<A, D>>
fn get_group(&self, groupid: usize) -> Result<&ParameterGroup<A, D>>
Get parameter group by ID
Sourcefn get_group_mut(&mut self, groupid: usize) -> Result<&mut ParameterGroup<A, D>>
fn get_group_mut(&mut self, groupid: usize) -> Result<&mut ParameterGroup<A, D>>
Get mutable parameter group by ID
Sourcefn groups(&self) -> &[ParameterGroup<A, D>]
fn groups(&self) -> &[ParameterGroup<A, D>]
Get all parameter groups
Sourcefn groups_mut(&mut self) -> &mut [ParameterGroup<A, D>]
fn groups_mut(&mut self) -> &mut [ParameterGroup<A, D>]
Get all parameter groups mutably
Sourcefn step_group(
&mut self,
group_id: usize,
gradients: &[Array<A, D>],
) -> Result<Vec<Array<A, D>>>
fn step_group( &mut self, group_id: usize, gradients: &[Array<A, D>], ) -> Result<Vec<Array<A, D>>>
Step for a specific group
Sourcefn set_group_learning_rate(&mut self, groupid: usize, lr: A) -> Result<()>
fn set_group_learning_rate(&mut self, groupid: usize, lr: A) -> Result<()>
Set learning rate for a specific group
Sourcefn set_group_weight_decay(&mut self, groupid: usize, wd: A) -> Result<()>
fn set_group_weight_decay(&mut self, groupid: usize, wd: A) -> Result<()>
Set weight decay for a specific group