opensrdk-optimization 0.1.4

Standard mathematical optimization library for OpenSRDK toolchain.
Documentation
use crate::{vec::Vector as _, Status};
use opensrdk_linear_algebra::*;
use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng};

/// # Stochastic Gradient Descent Adam
///
/// If you are indecisive for values, it is sufficient to use `default()`.
///
/// - `epsilon`: Threshold of the norm of gradients for finishing
/// - `max_iter`: Max limitation of iteration count
/// - `alpha`: Learning rate
/// - `beta1`: Weight for moving average of Momentum
/// - `beta2`: Weight for moving average of RMSProp
/// - `e`: A small value for preventing zero divisions
pub struct SgdAdam {
    epsilon: f64,
    max_iter: usize,
    alpha: f64,
    beta1: f64,
    beta2: f64,
    e: f64,
}

impl Default for SgdAdam {
    fn default() -> Self {
        Self {
            epsilon: 1e-6,
            max_iter: 0,
            alpha: 0.001,
            beta1: 0.9,
            beta2: 0.999,
            e: 0.00000001,
        }
    }
}

impl SgdAdam {
    pub fn new(epsilon: f64, max_iter: usize, alpha: f64, beta1: f64, beta2: f64, e: f64) -> Self {
        Self {
            epsilon,
            max_iter,
            alpha,
            beta1,
            beta2,
            e,
        }
    }

    pub fn with_epsilon(mut self, epsilon: f64) -> Self {
        assert!(epsilon.is_sign_positive(), "epsilon must be positive");

        self.epsilon = epsilon;

        self
    }

    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
        self.max_iter = max_iter;

        self
    }

    pub fn with_alpha(mut self, alpha: f64) -> Self {
        assert!(alpha.is_sign_positive(), "delta must be positive");

        self.alpha = alpha;

        self
    }

    pub fn with_beta1(mut self, beta1: f64) -> Self {
        assert!(beta1.is_sign_positive(), "beta1 must be positive");

        self.beta2 = beta1;

        self
    }

    pub fn with_beta2(mut self, beta2: f64) -> Self {
        assert!(beta2.is_sign_positive(), "beta2 must be positive");

        self.beta2 = beta2;

        self
    }

    pub fn with_e(mut self, e: f64) -> Self {
        assert!(e.is_sign_positive(), "e must be positive");

        self.e = e;

        self
    }

    /// - `x`: Objective variables
    /// - `grad` Returns the gradients of each inputs
    ///   - `&[usize]`: Set of indices included in the batch
    ///   - `&[f64]`: Objective variables as an input
    /// - `batch`: Batch size
    /// - `total`: Total data size
    pub fn minimize(
        &self,
        x: &mut [f64],
        grad: &dyn Fn(&[usize], &[f64]) -> Vec<f64>,
        batch: usize,
        total: usize,
    ) -> Status {
        let mut batch_index = (0..total).into_iter().collect::<Vec<_>>();
        let mut w = x.to_vec().col_mat();
        let mut m = Matrix::new(w.rows(), 1);
        let mut v = Matrix::new(w.rows(), 1);
        let mut rng: StdRng = SeedableRng::seed_from_u64(1);

        for k in 0.. {
            if self.max_iter != 0 && self.max_iter <= k {
                return Status::MaxIter;
            }

            let gfx = grad(&(0..total).into_iter().collect::<Vec<_>>(), w.slice());
            if gfx.l2_norm() < self.epsilon + x.l2_norm() {
                return Status::Epsilon;
            }

            batch_index.shuffle(&mut rng);

            for minibatch in batch_index.chunks(batch) {
                let minibatch_grad = grad(&minibatch, w.slice()).col_mat();

                m = self.beta1 * m + (1.0 - self.beta1) * minibatch_grad.clone();
                v = self.beta2 * v
                    + (1.0 - self.beta2) * minibatch_grad.clone().hadamard_prod(&minibatch_grad);

                let m_hat = m.clone() * (1.0 / (1.0 - self.beta1.powi(k as i32 + 1)));
                let v_hat = v.clone() * (1.0 / (1.0 - self.beta2.powi(k as i32 + 1)));
                let v_hat_sqrt_e_inv = v_hat
                    .vec()
                    .iter()
                    .map(|vi| 1.0 / (vi.sqrt() + self.e))
                    .collect::<Vec<_>>()
                    .col_mat();

                w = w - self.alpha * m_hat.hadamard_prod(&v_hat_sqrt_e_inv);

                if !w.slice().l2_norm().is_finite() {
                    return Status::NaN;
                }
            }

            x.clone_from_slice(w.slice());
        }

        Status::Success
    }
}