osf-rs 0.0.1

OSF Sleep Foundation Model — inference in Rust with Burn ML
Documentation
/// Patch Embedding for OSF ViT.
///
/// Two modes controlled by `lead_wise`:
///
/// lead_wise=0 (1D patchify):
///   Conv1d(num_leads, width, kernel=patch_size, stride=patch_size, bias=False)
///   Input [B, C, T] → output [B, T/patch_size, width]
///
/// lead_wise=1 (2D patchify — used by OSF):
///   Conv2d(1, width, kernel=(patch_size_ch, patch_size_time),
///          stride=(patch_size_ch, patch_size_time), bias=False)
///   Input [B, C, T] → unsqueeze → [B, 1, C, T]
///   → Conv2d → [B, width, Lr, Nt]
///   → rearrange → [B, Lr*Nt, width]
///   where Lr = C/patch_size_ch, Nt = T/patch_size_time

use burn::prelude::*;
use burn::nn::conv::{Conv1d, Conv1dConfig, Conv2d, Conv2dConfig};

/// Unified patch embedding supporting both 1D and 2D modes.
#[derive(Module, Debug)]
pub struct PatchEmbed<B: Backend> {
    /// 1D conv for lead_wise=0.
    pub conv1d: Option<Conv1d<B>>,
    /// 2D conv for lead_wise=1.
    pub conv2d: Option<Conv2d<B>>,
    pub lead_wise: usize,
}

impl<B: Backend> PatchEmbed<B> {
    /// Create a new patch embedding.
    pub fn new(
        num_leads: usize,
        width: usize,
        patch_size_time: usize,
        patch_size_ch: usize,
        lead_wise: usize,
        device: &B::Device,
    ) -> Self {
        if lead_wise == 0 {
            let conv = Conv1dConfig::new(num_leads, width, patch_size_time)
                .with_stride(patch_size_time)
                .with_bias(false)
                .init(device);
            Self { conv1d: Some(conv), conv2d: None, lead_wise }
        } else {
            let conv = Conv2dConfig::new([1, width], [patch_size_ch, patch_size_time])
                .with_stride([patch_size_ch, patch_size_time])
                .with_bias(false)
                .init(device);
            Self { conv1d: None, conv2d: Some(conv), lead_wise }
        }
    }

    /// Patchify input signal.
    ///
    /// Input: [B, C, T]
    /// Output: [B, N, width] where N = num_patches
    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
        if self.lead_wise == 0 {
            // 1D: Conv1d([B, C, T]) → [B, width, Nt] → [B, Nt, width]
            let conv = self.conv1d.as_ref().unwrap();
            let out = conv.forward(x); // [B, width, Nt]
            let [_b, _w, _nt] = out.dims();
            out.swap_dims(1, 2) // [B, Nt, width]
        } else {
            // 2D: [B, C, T] → [B, 1, C, T] → Conv2d → [B, width, Lr, Nt]
            let conv = self.conv2d.as_ref().unwrap();
            let x = x.unsqueeze_dim::<4>(1); // [B, 1, C, T]
            let out = conv.forward(x); // [B, width, Lr, Nt]
            let [b, w, lr, nt] = out.dims();
            // rearrange 'b w lr nt -> b (lr nt) w'
            out.swap_dims(1, 2)              // [B, Lr, width, Nt]
               .swap_dims(2, 3)              // [B, Lr, Nt, width]
               .reshape([b, lr * nt, w])     // [B, Lr*Nt, width]
        }
    }
}