osf-rs 0.0.1

OSF Sleep Foundation Model — inference in Rust with Burn ML
Documentation
/// Transformer Block for OSF ViT.
///
/// Python: `TransformerBlock` in vit1d_cls.py:
///   x = droppath(attn(norm(x))) + x      # PreNorm → Attention → DropPath → residual
///   x = droppath(ff(norm(x))) + x         # PreNorm → FeedForward → DropPath �� residual
///
/// At inference, DropPath is identity.

use burn::prelude::*;
use crate::model::norm::OsfLayerNorm;
use crate::model::attention::Attention;
use crate::model::feedforward::FeedForward;

#[derive(Module, Debug)]
pub struct TransformerBlock<B: Backend> {
    pub attn_norm: OsfLayerNorm<B>,
    pub attn:      Attention<B>,
    pub ff_norm:   OsfLayerNorm<B>,
    pub ff:        FeedForward<B>,
}

impl<B: Backend> TransformerBlock<B> {
    pub fn new(
        input_dim: usize,
        output_dim: usize,
        hidden_dim: usize,
        heads: usize,
        dim_head: usize,
        qkv_bias: bool,
        device: &B::Device,
    ) -> Self {
        Self {
            attn_norm: OsfLayerNorm::new(input_dim, 1e-5, device),
            attn:      Attention::new(input_dim, output_dim, heads, dim_head, qkv_bias, device),
            ff_norm:   OsfLayerNorm::new(output_dim, 1e-5, device),
            ff:        FeedForward::new(output_dim, output_dim, hidden_dim, device),
        }
    }

    /// x: [B, N, D] → [B, N, D]
    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
        let h = x.clone() + self.attn.forward(self.attn_norm.forward(x.clone()));
        h.clone() + self.ff.forward(self.ff_norm.forward(h))
    }
}