native_neural_network 0.1.6

Lib no_std Rust for native neural network (.rnn)
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
use super::decay::{cosine_decay, step_decay};
use super::policy::{LrSchedule, ScheduleError};
use super::warmup::linear_warmup;

pub fn compute_learning_rate(base_lr: f32, step: u32, schedule: LrSchedule) -> Result<f32, ScheduleError> {
    let lr = match schedule {
        LrSchedule::Constant => Some(base_lr),
        LrSchedule::StepDecay { step_size, gamma } => step_decay(base_lr, step, step_size, gamma),
        LrSchedule::Cosine { total_steps, min_lr_ratio } => cosine_decay(base_lr, step, total_steps, min_lr_ratio),
        LrSchedule::LinearWarmup { warmup_steps } => linear_warmup(base_lr, step, warmup_steps),
    };

    let lr = lr.ok_or(ScheduleError::InvalidConfig)?;
    if !lr.is_finite() || lr <= 0.0 {
        return Err(ScheduleError::InvalidConfig);
    }
    Ok(lr)
}