osf-rs 0.0.1

OSF Sleep Foundation Model — inference in Rust with Burn ML
Documentation
/// LayerNorm wrapper for OSF.
///
/// Python: `nn.LayerNorm(dim)` used throughout the ViT.

use burn::prelude::*;
use burn::nn::{LayerNorm, LayerNormConfig};

#[derive(Module, Debug)]
pub struct OsfLayerNorm<B: Backend> {
    pub inner: LayerNorm<B>,
}

impl<B: Backend> OsfLayerNorm<B> {
    pub fn new(dim: usize, eps: f64, device: &B::Device) -> Self {
        Self {
            inner: LayerNormConfig::new(dim).with_epsilon(eps).init(device),
        }
    }

    pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
        self.inner.forward(x)
    }
}