use crate::error::{OptimError, Result};
use crate::optimizers::Optimizer;
use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
pub struct SequentialOptimizer<A, D>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
optimizers: Vec<Box<dyn Optimizer<A, D>>>,
}
impl<A, D> SequentialOptimizer<A, D>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
pub fn new(optimizers: Vec<Box<dyn Optimizer<A, D>>>) -> Self {
Self { optimizers }
}
pub fn add_optimizer(&mut self, optimizer: Box<dyn Optimizer<A, D>>) {
self.optimizers.push(optimizer);
}
pub fn num_optimizers(&self) -> usize {
self.optimizers.len()
}
pub fn get_optimizer(&self, index: usize) -> Option<&dyn Optimizer<A, D>> {
if index < self.optimizers.len() {
Some(self.optimizers[index].as_ref())
} else {
None
}
}
pub fn get_optimizer_mut(&mut self, index: usize) -> Option<&mut dyn Optimizer<A, D>> {
if index < self.optimizers.len() {
Some(self.optimizers[index].as_mut())
} else {
None
}
}
}
impl<A, D> Optimizer<A, D> for SequentialOptimizer<A, D>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
if self.optimizers.is_empty() {
return Err(OptimError::InvalidConfig(
"SequentialOptimizer has no optimizers".to_string(),
));
}
let mut current_params = params.clone();
for optimizer in &mut self.optimizers {
current_params = optimizer.step(¤t_params, gradients)?;
}
Ok(current_params)
}
fn get_learning_rate(&self) -> A {
if let Some(optimizer) = self.optimizers.first() {
optimizer.get_learning_rate()
} else {
A::from(0.01).expect("unwrap failed") }
}
fn set_learning_rate(&mut self, learningrate: A) {
for optimizer in &mut self.optimizers {
optimizer.set_learning_rate(learningrate);
}
}
}
pub struct ParameterGroup<A, D>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
pub params: Array<A, D>,
pub optimizerindex: usize,
}
impl<A, D> ParameterGroup<A, D>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
pub fn new(params: Array<A, D>, optimizerindex: usize) -> Self {
Self {
params,
optimizerindex,
}
}
}
pub struct ParallelOptimizer<A, D>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
optimizers: Vec<Box<dyn Optimizer<A, D>>>,
parameter_groups: Vec<ParameterGroup<A, D>>,
}
impl<A, D> ParallelOptimizer<A, D>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
pub fn new(
optimizers: Vec<Box<dyn Optimizer<A, D>>>,
parameter_groups: Vec<ParameterGroup<A, D>>,
) -> Self {
Self {
optimizers,
parameter_groups,
}
}
pub fn add_optimizer(&mut self, optimizer: Box<dyn Optimizer<A, D>>) -> usize {
let index = self.optimizers.len();
self.optimizers.push(optimizer);
index
}
pub fn add_parameter_group(
&mut self,
params: Array<A, D>,
optimizerindex: usize,
) -> Result<usize> {
if optimizerindex >= self.optimizers.len() {
return Err(OptimError::InvalidConfig(format!(
"Invalid optimizer _index: {}. Only {} optimizers available.",
optimizerindex,
self.optimizers.len()
)));
}
let _index = self.parameter_groups.len();
self.parameter_groups
.push(ParameterGroup::new(params, optimizerindex));
Ok(_index)
}
pub fn num_optimizers(&self) -> usize {
self.optimizers.len()
}
pub fn num_parameter_groups(&self) -> usize {
self.parameter_groups.len()
}
pub fn get_optimizer(&self, index: usize) -> Option<&dyn Optimizer<A, D>> {
if index < self.optimizers.len() {
Some(self.optimizers[index].as_ref())
} else {
None
}
}
pub fn get_optimizer_mut(&mut self, index: usize) -> Option<&mut dyn Optimizer<A, D>> {
if index < self.optimizers.len() {
Some(self.optimizers[index].as_mut())
} else {
None
}
}
pub fn get_parameter_group(&self, index: usize) -> Option<&ParameterGroup<A, D>> {
self.parameter_groups.get(index)
}
pub fn get_parameter_group_mut(&mut self, index: usize) -> Option<&mut ParameterGroup<A, D>> {
self.parameter_groups.get_mut(index)
}
pub fn get_all_parameters(&self) -> Result<Vec<Array<A, D>>> {
Ok(self
.parameter_groups
.iter()
.map(|group| group.params.clone())
.collect())
}
pub fn update_all_parameters(&mut self, gradients: &[Array<A, D>]) -> Result<Vec<Array<A, D>>> {
if gradients.len() != self.parameter_groups.len() {
return Err(OptimError::InvalidConfig(format!(
"Number of gradients ({}) does not match number of parameter groups ({})",
gradients.len(),
self.parameter_groups.len()
)));
}
let mut updated_params = Vec::with_capacity(self.parameter_groups.len());
for (i, group) in self.parameter_groups.iter_mut().enumerate() {
let optimizerindex = group.optimizerindex;
if optimizerindex >= self.optimizers.len() {
return Err(OptimError::InvalidConfig(format!(
"Invalid optimizer index: {}. Only {} optimizers available.",
optimizerindex,
self.optimizers.len()
)));
}
let optimizer = &mut self.optimizers[optimizerindex];
let params = &group.params;
let gradient = &gradients[i];
let updated = optimizer.step(params, gradient)?;
group.params = updated.clone();
updated_params.push(updated);
}
Ok(updated_params)
}
}
impl<A, D> Optimizer<A, D> for ParallelOptimizer<A, D>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
fn step(&mut self, _params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
Err(OptimError::InvalidConfig(
"ParallelOptimizer doesn't support the standard step method. Use update_all_parameters instead."
.to_string(),
))
}
fn step_list(
&mut self,
params_list: &[&Array<A, D>],
gradients_list: &[&Array<A, D>],
) -> Result<Vec<Array<A, D>>> {
let params_vec: Vec<Array<A, D>> = params_list.iter().map(|&p| p.clone()).collect();
self.parameter_groups = params_vec
.into_iter()
.enumerate()
.map(|(i, params)| {
let optimizerindex = i.min(self.optimizers.len() - 1);
ParameterGroup::new(params, optimizerindex)
})
.collect();
let gradients_vec: Vec<Array<A, D>> = gradients_list.iter().map(|&g| g.clone()).collect();
self.update_all_parameters(&gradients_vec)
}
fn get_learning_rate(&self) -> A {
if let Some(optimizer) = self.optimizers.first() {
optimizer.get_learning_rate()
} else {
A::from(0.01).expect("unwrap failed") }
}
fn set_learning_rate(&mut self, learningrate: A) {
for optimizer in &mut self.optimizers {
optimizer.set_learning_rate(learningrate);
}
}
}
pub struct ChainedOptimizer<A, D>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
inner: Box<dyn Optimizer<A, D>>,
outer: Box<dyn Optimizer<A, D>>,
}
impl<A, D> ChainedOptimizer<A, D>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
pub fn new(inner: Box<dyn Optimizer<A, D>>, outer: Box<dyn Optimizer<A, D>>) -> Self {
Self { inner, outer }
}
pub fn inner(&self) -> &dyn Optimizer<A, D> {
self.inner.as_ref()
}
pub fn inner_mut(&mut self) -> &mut dyn Optimizer<A, D> {
self.inner.as_mut()
}
pub fn outer(&self) -> &dyn Optimizer<A, D> {
self.outer.as_ref()
}
pub fn outer_mut(&mut self) -> &mut dyn Optimizer<A, D> {
self.outer.as_mut()
}
}
impl<A, D> Optimizer<A, D> for ChainedOptimizer<A, D>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
let intermediate_params = self.inner.step(params, gradients)?;
self.outer.step(&intermediate_params, gradients)
}
fn get_learning_rate(&self) -> A {
self.inner.get_learning_rate()
}
fn set_learning_rate(&mut self, learningrate: A) {
self.inner.set_learning_rate(learningrate);
self.outer.set_learning_rate(learningrate);
}
}
pub struct WeightedOptimizer<A, D>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
optimizers: Vec<Box<dyn Optimizer<A, D>>>,
weights: Vec<A>,
}
impl<A, D> Default for WeightedOptimizer<A, D>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
fn default() -> Self {
Self::new()
}
}
impl<A, D> WeightedOptimizer<A, D>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
pub fn new() -> Self {
Self {
optimizers: Vec::new(),
weights: Vec::new(),
}
}
pub fn add_optimizer(mut self, opt: Box<dyn Optimizer<A, D>>, weight: A) -> Self {
self.optimizers.push(opt);
self.weights.push(weight);
self
}
pub fn with_optimizers(mut self, opts: Vec<(Box<dyn Optimizer<A, D>>, A)>) -> Self {
for (opt, weight) in opts {
self.optimizers.push(opt);
self.weights.push(weight);
}
self
}
pub fn normalize_weights(&mut self) {
let sum: A = self.weights.iter().copied().fold(A::zero(), |a, b| a + b);
if sum > A::zero() {
for w in &mut self.weights {
*w = *w / sum;
}
}
}
pub fn num_optimizers(&self) -> usize {
self.optimizers.len()
}
pub fn weights(&self) -> &[A] {
&self.weights
}
}
impl<A, D> Optimizer<A, D> for WeightedOptimizer<A, D>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
if self.optimizers.is_empty() {
return Err(OptimError::InvalidConfig(
"WeightedOptimizer has no optimizers".to_string(),
));
}
let weight_sum: A = self.weights.iter().copied().fold(A::zero(), |a, b| a + b);
if weight_sum <= A::zero() {
return Err(OptimError::InvalidConfig(
"WeightedOptimizer weight sum must be positive".to_string(),
));
}
let mut result: Option<Array<A, D>> = None;
for (optimizer, &weight) in self.optimizers.iter_mut().zip(self.weights.iter()) {
let updated = optimizer.step(params, gradients)?;
let normalized_weight = weight / weight_sum;
match result {
None => {
result = Some(updated * normalized_weight);
}
Some(ref mut acc) => {
acc.zip_mut_with(&updated, |a, &b| {
*a = *a + b * normalized_weight;
});
}
}
}
result.ok_or_else(|| {
OptimError::InvalidConfig("WeightedOptimizer produced no result".to_string())
})
}
fn get_learning_rate(&self) -> A {
if let Some(optimizer) = self.optimizers.first() {
optimizer.get_learning_rate()
} else {
A::from(0.01).expect("failed to convert default learning rate")
}
}
fn set_learning_rate(&mut self, learning_rate: A) {
for optimizer in &mut self.optimizers {
optimizer.set_learning_rate(learning_rate);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::optimizers::{Adam, SGD};
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::Array1;
#[test]
fn test_sequential_optimizer() {
let sgd = SGD::new(0.1);
let adam = Adam::new(0.01);
let mut seq_optimizer: SequentialOptimizer<f64, scirs2_core::ndarray::Ix1> =
SequentialOptimizer::new(vec![Box::new(sgd), Box::new(adam)]);
let params = Array1::zeros(3);
let gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let updated_params = seq_optimizer
.step(¶ms, &gradients)
.expect("unwrap failed");
assert!(updated_params[0] < -0.1);
assert!(updated_params[1] < -0.2);
assert!(updated_params[2] < -0.3);
}
#[test]
fn test_parallel_optimizer() {
let sgd = SGD::new(0.1);
let adam = Adam::new(0.01);
let params1 = Array1::zeros(2);
let params2 = Array1::zeros(3);
let group1 = ParameterGroup::new(params1.clone(), 0); let group2 = ParameterGroup::new(params2.clone(), 1);
let mut parallel_optimizer: ParallelOptimizer<f64, scirs2_core::ndarray::Ix1> =
ParallelOptimizer::new(vec![Box::new(sgd), Box::new(adam)], vec![group1, group2]);
let gradients1 = Array1::from_vec(vec![1.0, 2.0]);
let gradients2 = Array1::from_vec(vec![3.0, 4.0, 5.0]);
let updated_params = parallel_optimizer
.update_all_parameters(&[gradients1, gradients2])
.expect("unwrap failed");
assert_abs_diff_eq!(updated_params[0][0], -0.1);
assert_abs_diff_eq!(updated_params[0][1], -0.2);
assert!(updated_params[1][0] != 0.0);
assert!(updated_params[1][1] != 0.0);
assert!(updated_params[1][2] != 0.0);
}
#[test]
fn test_chained_optimizer() {
let inner = SGD::new(0.1);
let outer = Adam::new(0.01);
let mut chained_optimizer: ChainedOptimizer<f64, scirs2_core::ndarray::Ix1> =
ChainedOptimizer::new(Box::new(inner), Box::new(outer));
let params = Array1::zeros(3);
let gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let updated_params = chained_optimizer
.step(¶ms, &gradients)
.expect("unwrap failed");
assert!(updated_params[0] < -0.1);
assert!(updated_params[1] < -0.2);
assert!(updated_params[2] < -0.3);
}
#[test]
fn test_sequential_learning_rate() {
let sgd = SGD::new(0.1);
let adam = Adam::new(0.01);
let mut seq_optimizer: SequentialOptimizer<f64, scirs2_core::ndarray::Ix1> =
SequentialOptimizer::new(vec![Box::new(sgd), Box::new(adam)]);
assert_abs_diff_eq!(seq_optimizer.get_learning_rate(), 0.1);
seq_optimizer.set_learning_rate(0.05);
assert_abs_diff_eq!(seq_optimizer.get_learning_rate(), 0.05);
assert_abs_diff_eq!(
seq_optimizer
.get_optimizer(0)
.expect("unwrap failed")
.get_learning_rate(),
0.05
);
assert_abs_diff_eq!(
seq_optimizer
.get_optimizer(1)
.expect("unwrap failed")
.get_learning_rate(),
0.05
);
}
#[test]
fn test_parallel_optimizer_step_list() {
let sgd = SGD::new(0.1);
let adam = Adam::new(0.01);
let mut parallel_optimizer: ParallelOptimizer<f64, scirs2_core::ndarray::Ix1> =
ParallelOptimizer::new(vec![Box::new(sgd), Box::new(adam)], vec![]);
let params1 = Array1::zeros(2);
let params2 = Array1::zeros(3);
let params3 = Array1::zeros(4);
let gradients1 = Array1::from_vec(vec![1.0, 2.0]);
let gradients2 = Array1::from_vec(vec![3.0, 4.0, 5.0]);
let gradients3 = Array1::from_vec(vec![6.0, 7.0, 8.0, 9.0]);
let params_refs = vec![¶ms1, ¶ms2, ¶ms3];
let gradients_refs = vec![&gradients1, &gradients2, &gradients3];
let updated_params = parallel_optimizer
.step_list(¶ms_refs, &gradients_refs)
.expect("unwrap failed");
assert_abs_diff_eq!(updated_params[0][0], -0.1);
assert_abs_diff_eq!(updated_params[0][1], -0.2);
assert!(updated_params[1][0] != -0.3);
assert!(updated_params[2][0] < 0.0);
}
#[test]
fn test_chained_optimizer_learning_rate() {
let inner = SGD::new(0.1);
let outer = Adam::new(0.01);
let mut chained_optimizer: ChainedOptimizer<f64, scirs2_core::ndarray::Ix1> =
ChainedOptimizer::new(Box::new(inner), Box::new(outer));
assert_abs_diff_eq!(chained_optimizer.get_learning_rate(), 0.1);
chained_optimizer.set_learning_rate(0.05);
assert_abs_diff_eq!(chained_optimizer.get_learning_rate(), 0.05);
assert_abs_diff_eq!(chained_optimizer.inner().get_learning_rate(), 0.05);
assert_abs_diff_eq!(chained_optimizer.outer().get_learning_rate(), 0.05);
}
#[test]
fn test_weighted_optimizer_basic() {
let sgd1 = SGD::new(0.1);
let sgd2 = SGD::new(0.2);
let mut weighted: WeightedOptimizer<f64, scirs2_core::ndarray::Ix1> =
WeightedOptimizer::new()
.add_optimizer(Box::new(sgd1), 0.5)
.add_optimizer(Box::new(sgd2), 0.5);
let params = Array1::zeros(3);
let gradients = Array1::ones(3);
let updated = weighted.step(¶ms, &gradients).expect("step failed");
assert_abs_diff_eq!(updated[0], -0.15, epsilon = 1e-10);
assert_abs_diff_eq!(updated[1], -0.15, epsilon = 1e-10);
assert_abs_diff_eq!(updated[2], -0.15, epsilon = 1e-10);
}
#[test]
fn test_weighted_optimizer_unequal_weights() {
let sgd1 = SGD::new(0.1);
let sgd2 = SGD::new(0.2);
let mut weighted: WeightedOptimizer<f64, scirs2_core::ndarray::Ix1> =
WeightedOptimizer::new()
.add_optimizer(Box::new(sgd1), 3.0)
.add_optimizer(Box::new(sgd2), 1.0);
let params = Array1::zeros(2);
let gradients = Array1::ones(2);
let updated = weighted.step(¶ms, &gradients).expect("step failed");
assert_abs_diff_eq!(updated[0], -0.125, epsilon = 1e-10);
}
#[test]
fn test_weighted_optimizer_empty() {
let mut weighted: WeightedOptimizer<f64, scirs2_core::ndarray::Ix1> =
WeightedOptimizer::new();
let params = Array1::zeros(3);
let gradients = Array1::ones(3);
let result = weighted.step(¶ms, &gradients);
assert!(result.is_err());
}
#[test]
fn test_weighted_optimizer_normalize_weights() {
let mut weighted: WeightedOptimizer<f64, scirs2_core::ndarray::Ix1> =
WeightedOptimizer::new()
.add_optimizer(Box::new(SGD::new(0.1)), 2.0)
.add_optimizer(Box::new(SGD::new(0.2)), 8.0);
weighted.normalize_weights();
assert_abs_diff_eq!(weighted.weights()[0], 0.2, epsilon = 1e-10);
assert_abs_diff_eq!(weighted.weights()[1], 0.8, epsilon = 1e-10);
}
#[test]
fn test_weighted_optimizer_learning_rate() {
let mut weighted: WeightedOptimizer<f64, scirs2_core::ndarray::Ix1> =
WeightedOptimizer::new()
.add_optimizer(Box::new(SGD::new(0.1)), 1.0)
.add_optimizer(Box::new(Adam::new(0.01)), 1.0);
assert_abs_diff_eq!(weighted.get_learning_rate(), 0.1);
weighted.set_learning_rate(0.05);
assert_abs_diff_eq!(weighted.get_learning_rate(), 0.05);
}
#[test]
fn test_weighted_optimizer_with_optimizers() {
let opts: Vec<(Box<dyn Optimizer<f64, scirs2_core::ndarray::Ix1>>, f64)> = vec![
(Box::new(SGD::new(0.1)), 1.0),
(Box::new(SGD::new(0.2)), 1.0),
];
let weighted: WeightedOptimizer<f64, scirs2_core::ndarray::Ix1> =
WeightedOptimizer::new().with_optimizers(opts);
assert_eq!(weighted.num_optimizers(), 2);
assert_abs_diff_eq!(weighted.weights()[0], 1.0);
assert_abs_diff_eq!(weighted.weights()[1], 1.0);
}
}