native_neural_network 0.3.1

Lib no_std Rust for native neural network (.rnn)
Documentation
use super::decay::{cosine_decay, step_decay};
use super::decay::{cosine_decay_f64, step_decay_f64};
use super::policy::{LrSchedule, ScheduleError};
use super::warmup::linear_warmup;
use super::warmup::linear_warmup_f64;

pub fn compute_learning_rate(
    base_lr: f32,
    step: u32,
    schedule: LrSchedule,
) -> Result<f32, ScheduleError> {
    let lr = match schedule {
        LrSchedule::Constant => Some(base_lr),
        LrSchedule::StepDecay { step_size, gamma } => step_decay(base_lr, step, step_size, gamma),
        LrSchedule::Cosine {
            total_steps,
            min_lr_ratio,
        } => cosine_decay(base_lr, step, total_steps, min_lr_ratio),
        LrSchedule::LinearWarmup { warmup_steps } => linear_warmup(base_lr, step, warmup_steps),
    };

    let lr = lr.ok_or(ScheduleError::InvalidConfig)?;
    if !lr.is_finite() || lr <= 0.0 {
        return Err(ScheduleError::InvalidConfig);
    }
    Ok(lr)
}

pub fn compute_learning_rate_f64(
    base_lr: f64,
    step: u32,
    schedule: super::policy::LrScheduleF64,
) -> Result<f64, ScheduleError> {
    let lr = match schedule {
        super::policy::LrScheduleF64::Constant => Some(base_lr),
        super::policy::LrScheduleF64::StepDecay { step_size, gamma } => {
            step_decay_f64(base_lr, step, step_size, gamma)
        }
        super::policy::LrScheduleF64::Cosine {
            total_steps,
            min_lr_ratio,
        } => cosine_decay_f64(base_lr, step, total_steps, min_lr_ratio),
        super::policy::LrScheduleF64::LinearWarmup { warmup_steps } => {
            linear_warmup_f64(base_lr, step, warmup_steps)
        }
    };

    let lr = lr.ok_or(ScheduleError::InvalidConfig)?;
    if !lr.is_finite() || lr <= 0.0 {
        return Err(ScheduleError::InvalidConfig);
    }
    Ok(lr)
}