osf-rs 0.0.1

OSF Sleep Foundation Model — inference in Rust with Burn ML
Documentation
/// 1D Vision Transformer with CLS token for OSF.
///
/// Python: `ViT` class in `vit1d_cls.py`.
///
/// Architecture:
///   1. PatchEmbed — Conv1d or Conv2d patchification
///   2. CLS token prepended
///   3. Positional embedding added
///   4. N × TransformerBlock
///   5. LayerNorm
///   6. Identity head (overridable)
///
/// lead_wise=1 (used by OSF-Base):
///   Input [B, 12, 1920] → Conv2d(1, 768, (4, 64), stride=(4, 64))
///   → [B, 768, 3, 30] → rearrange → [B, 90, 768]
///   Prepend CLS → [B, 91, 768]
///   + pos_embedding [1, 91, 768]
///   12 × TransformerBlock
///   LayerNorm → CLS output [B, 768], patch output [B, 90, 768]

use burn::prelude::*;
use burn::module::{Param, ParamId};
use burn::nn::{Embedding, EmbeddingConfig};

use crate::model::patch_embed::PatchEmbed;
use crate::model::transformer_block::TransformerBlock;
use crate::model::norm::OsfLayerNorm;

#[derive(Module, Debug)]
pub struct OsfViT<B: Backend> {
    pub patch_embed: PatchEmbed<B>,
    /// CLS token: [1, 1, width].
    pub cls_token: Param<Tensor<B, 3>>,
    /// Positional embedding: [1, N_max+1, width].
    pub pos_embedding: Param<Tensor<B, 3>>,
    /// Lead embedding (for lead_wise=1 only, kept for checkpoint compat).
    pub lead_emb: Option<Embedding<B>>,
    /// Transformer blocks.
    pub blocks: Vec<TransformerBlock<B>>,
    /// Final layer norm.
    pub norm: OsfLayerNorm<B>,
    // Config
    pub width: usize,
    pub depth: usize,
    pub lead_wise: usize,
}

impl<B: Backend> OsfViT<B> {
    pub fn new(
        num_leads: usize,
        seq_len: usize,
        patch_size_time: usize,
        patch_size_ch: usize,
        lead_wise: usize,
        width: usize,
        depth: usize,
        mlp_dim: usize,
        heads: usize,
        dim_head: usize,
        device: &B::Device,
    ) -> Self {
        let num_patches_time = seq_len / patch_size_time;
        let n_max = if lead_wise == 0 {
            num_patches_time
        } else {
            let lr = num_leads / patch_size_ch;
            lr * num_patches_time
        };

        let patch_embed = PatchEmbed::new(
            num_leads, width, patch_size_time, patch_size_ch, lead_wise, device,
        );

        let cls_token = Param::initialized(
            ParamId::new(),
            Tensor::zeros([1, 1, width], device),
        );
        let pos_embedding = Param::initialized(
            ParamId::new(),
            Tensor::zeros([1, n_max + 1, width], device),
        );

        let lead_emb = if lead_wise != 0 {
            let lr = num_leads / patch_size_ch;
            Some(EmbeddingConfig::new(lr, width).init(device))
        } else {
            None
        };

        let blocks = (0..depth)
            .map(|_| TransformerBlock::new(
                width, width, mlp_dim, heads, dim_head, true, device,
            ))
            .collect();

        let norm = OsfLayerNorm::new(width, 1e-5, device);

        Self {
            patch_embed,
            cls_token,
            pos_embedding,
            lead_emb,
            blocks,
            norm,
            width,
            depth,
            lead_wise,
        }
    }

    /// Forward encoding with CLS token.
    ///
    /// Input: series [B, C, T]
    /// Returns: (cls_emb [B, D], patch_embs [B, N, D])
    pub fn forward_encoding(&self, series: Tensor<B, 3>) -> (Tensor<B, 3>, Tensor<B, 3>) {
        let tokens = self.patch_embed.forward(series); // [B, N, D]
        let b = tokens.dims()[0];

        // Prepend CLS token
        let cls_tok = self.cls_token.val().expand([b, 1, self.width]);
        let x = Tensor::cat(vec![cls_tok, tokens], 1); // [B, N+1, D]

        // Add positional embedding
        let pe = self.pos_embedding.val()
            .narrow(1, 0, x.dims()[1])
            .to_device(&x.device());
        let mut x = x + pe;

        // Transformer blocks
        for block in &self.blocks {
            x = block.forward(x);
        }

        // LayerNorm
        x = self.norm.forward(x);

        // Split CLS and patches
        let n_plus_one = x.dims()[1];
        let cls = x.clone().narrow(1, 0, 1); // [B, 1, D]
        let patches = x.narrow(1, 1, n_plus_one - 1); // [B, N, D]

        (cls, patches)
    }

    /// Forward pass returning CLS embedding [B, 1, D].
    pub fn forward(&self, series: Tensor<B, 3>) -> Tensor<B, 3> {
        let (cls, _) = self.forward_encoding(series);
        cls
    }

    /// Forward with average pooling over patch tokens → [B, 1, D].
    pub fn forward_avg_pool(&self, series: Tensor<B, 3>) -> Tensor<B, 3> {
        let (_, patches) = self.forward_encoding(series);
        patches.mean_dim(1)
    }
}