sensorlm-rs 0.1.0

SensorLM – wearable sensor foundation model in Rust (Burn + WGPU)
Documentation
//! Learning-rate schedulers.
//!
//! # Rsqrt with warm-up and cool-down
//!
//! The reference SensorLM implementation uses Big Vision's `decay_type='rsqrt'`
//! schedule with `warmup_steps = 0.2 × total` and `cooldown_steps = 0.2 × total`.
//!
//! Three phases:
//!
//! ```text
//! Phase 1 – Warm-up:    lr(t) = peak_lr × t / warmup_steps
//! Phase 2 – Rsqrt:      lr(t) = peak_lr × sqrt(warmup_steps / t)
//! Phase 3 – Cool-down:  lr(t) decays linearly from rsqrt value to 0
//! ```

use burn::lr_scheduler::LrScheduler;
use burn::record::Record;
use burn::tensor::backend::Backend;

/// Rsqrt learning-rate schedule with linear warm-up and linear cool-down.
///
/// Implements [`burn::lr_scheduler::LrScheduler`] so it can be passed
/// directly to `LearnerBuilder`.
#[derive(Clone, Debug)]
pub struct RsqrtScheduler {
    peak_lr: f64,
    warmup_steps: usize,
    cooldown_start: usize,
    total_steps: usize,
    current_step: usize,
}

impl RsqrtScheduler {
    /// Construct a new scheduler.
    pub fn new(
        peak_lr: f64,
        total_steps: usize,
        warmup_fraction: f64,
        cooldown_fraction: f64,
    ) -> Self {
        let warmup_steps = ((total_steps as f64) * warmup_fraction).round() as usize;
        let cooldown_steps = ((total_steps as f64) * cooldown_fraction).round() as usize;
        let cooldown_start = total_steps.saturating_sub(cooldown_steps);
        Self {
            peak_lr,
            warmup_steps,
            cooldown_start,
            total_steps,
            current_step: 0,
        }
    }

    /// Compute the learning rate for an absolute step index.
    pub fn lr_at(&self, step: usize) -> f64 {
        if step == 0 {
            return 0.0;
        }
        if step <= self.warmup_steps {
            return self.peak_lr * step as f64 / self.warmup_steps as f64;
        }
        let rsqrt_lr = self.peak_lr * (self.warmup_steps as f64 / step as f64).sqrt();
        if step >= self.cooldown_start && self.cooldown_start < self.total_steps {
            let lr_at_cooldown_start =
                self.peak_lr * (self.warmup_steps as f64 / self.cooldown_start as f64).sqrt();
            let frac = (step - self.cooldown_start) as f64
                / (self.total_steps - self.cooldown_start) as f64;
            return lr_at_cooldown_start * (1.0 - frac).max(0.0);
        }
        rsqrt_lr
    }
}

// ---------------------------------------------------------------------------
// Burn 0.14 LrScheduler<B: Backend> trait implementation
// ---------------------------------------------------------------------------

/// Simple record that stores the current step count.
#[derive(Clone, Debug)]
pub struct RsqrtSchedulerRecord {
    current_step: usize,
}

impl<B: Backend> Record<B> for RsqrtSchedulerRecord {
    type Item<S: burn::record::PrecisionSettings> = usize;

    fn into_item<S: burn::record::PrecisionSettings>(self) -> Self::Item<S> {
        self.current_step
    }

    fn from_item<S: burn::record::PrecisionSettings>(item: Self::Item<S>, _device: &B::Device) -> Self {
        Self { current_step: item }
    }
}

impl<B: Backend> LrScheduler<B> for RsqrtScheduler {
    type Record = RsqrtSchedulerRecord;

    fn step(&mut self) -> f64 {
        let lr = self.lr_at(self.current_step);
        self.current_step += 1;
        lr
    }

    fn to_record(&self) -> Self::Record {
        RsqrtSchedulerRecord { current_step: self.current_step }
    }

    fn load_record(mut self, record: Self::Record) -> Self {
        self.current_step = record.current_step;
        self
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_scheduler_warmup() {
        let s = RsqrtScheduler::new(1e-3, 1000, 0.2, 0.2);
        assert_eq!(s.lr_at(0), 0.0);
        let lr_end = s.lr_at(200);
        assert!((lr_end - 1e-3).abs() < 1e-10, "LR at end of warmup should be peak");
    }

    #[test]
    fn test_scheduler_rsqrt_decay() {
        // Use a tiny warmup so that both test steps are well into the rsqrt phase.
        // warmup_steps = 0.01 * 10_000 = 100, so step 400 and 1600 are both
        // past warmup and in the pure rsqrt region.
        // rsqrt: lr(t) = peak * sqrt(warmup / t)
        // lr(400) / lr(1600) = sqrt(1600/400) = sqrt(4) = 2.0  ✓
        let s = RsqrtScheduler::new(1e-3, 10_000, 0.01, 0.0);
        let lr_400  = s.lr_at(400);
        let lr_1600 = s.lr_at(1600);
        let ratio = lr_400 / lr_1600;
        assert!(
            (ratio - 2.0).abs() < 0.01,
            "Rsqrt: quadrupling step should halve LR; got ratio={ratio:.4}"
        );
    }

    #[test]
    fn test_scheduler_cooldown_reaches_zero() {
        let s = RsqrtScheduler::new(1e-3, 1000, 0.2, 0.2);
        assert_eq!(s.lr_at(1000), 0.0);
    }
}