osf-rs 0.0.1

OSF Sleep Foundation Model — inference in Rust with Burn ML
Documentation
/// Model and data configuration for OSF inference.
///
/// `ModelConfig` mirrors the Python OSF ViT hyperparameters stored in
/// the `metadata` dict of `osf_backbone.pth`.

// ── ModelConfig ───────────────────────────────────────────────────────────────

#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct ModelConfig {
    /// Model variant name (e.g. "vit_base").
    #[serde(default = "default_encoder_name")]
    pub encoder_name: String,

    /// Number of input channels (default 12).
    #[serde(default = "default_num_leads")]
    pub num_leads: usize,

    /// Temporal patch size in samples (default 64).
    #[serde(default = "default_patch_size_time")]
    pub patch_size_time: usize,

    /// Channel patch size (default 4).
    #[serde(default = "default_patch_size_ch")]
    pub patch_size_ch: usize,

    /// 0 = 1D patchify (all channels), 1 = 2D patchify (channel groups).
    #[serde(default = "default_lead_wise")]
    pub lead_wise: usize,

    /// Sampling rate (Hz, default 64).
    #[serde(default = "default_sample_rate")]
    pub sample_rate: usize,

    /// Window size in seconds (default 30).
    #[serde(default = "default_window_size_sec")]
    pub window_size_sec: usize,

    /// Sequence length = sample_rate * window_size_sec (default 1920).
    #[serde(default = "default_seq_len")]
    pub seq_len: usize,

    /// Transformer hidden dimension (default 768).
    #[serde(default = "default_width")]
    pub width: usize,

    /// Number of transformer blocks (default 12).
    #[serde(default = "default_depth")]
    pub depth: usize,

    /// Number of attention heads (default 12).
    #[serde(default = "default_heads")]
    pub heads: usize,

    /// MLP hidden dimension (default 3072 = 4 * width).
    #[serde(default = "default_mlp_dim")]
    pub mlp_dim: usize,

    /// Attention head dimension (default 64 = width / heads).
    #[serde(default = "default_dim_head")]
    pub dim_head: usize,
}

fn default_encoder_name() -> String { "vit_base".to_string() }
fn default_num_leads()     -> usize { 12 }
fn default_patch_size_time() -> usize { 64 }
fn default_patch_size_ch() -> usize { 4 }
fn default_lead_wise()     -> usize { 1 }
fn default_sample_rate()   -> usize { 64 }
fn default_window_size_sec() -> usize { 30 }
fn default_seq_len()       -> usize { 1920 }
fn default_width()         -> usize { 768 }
fn default_depth()         -> usize { 12 }
fn default_heads()         -> usize { 12 }
fn default_mlp_dim()       -> usize { 3072 }
fn default_dim_head()      -> usize { 64 }

impl Default for ModelConfig {
    fn default() -> Self {
        Self {
            encoder_name:    default_encoder_name(),
            num_leads:       default_num_leads(),
            patch_size_time: default_patch_size_time(),
            patch_size_ch:   default_patch_size_ch(),
            lead_wise:       default_lead_wise(),
            sample_rate:     default_sample_rate(),
            window_size_sec: default_window_size_sec(),
            seq_len:         default_seq_len(),
            width:           default_width(),
            depth:           default_depth(),
            heads:           default_heads(),
            mlp_dim:         default_mlp_dim(),
            dim_head:        default_dim_head(),
        }
    }
}

impl ModelConfig {
    /// Number of time patches per channel row.
    pub fn num_patches_time(&self) -> usize {
        self.seq_len / self.patch_size_time
    }

    /// Number of channel rows (lead groups).
    pub fn num_lead_rows(&self) -> usize {
        if self.lead_wise == 0 { 1 } else { self.num_leads / self.patch_size_ch }
    }

    /// Total number of patches (excluding CLS token).
    pub fn num_patches(&self) -> usize {
        self.num_lead_rows() * self.num_patches_time()
    }

    /// Build config for a specific variant.
    pub fn for_variant(name: &str) -> Self {
        match name {
            "vit_nano" => Self {
                encoder_name: "vit_nano".into(),
                width: 128, depth: 6, heads: 4, mlp_dim: 512, dim_head: 32,
                ..Default::default()
            },
            "vit_tiny" => Self {
                encoder_name: "vit_tiny".into(),
                width: 192, depth: 12, heads: 3, mlp_dim: 768, dim_head: 64,
                ..Default::default()
            },
            "vit_small" => Self {
                encoder_name: "vit_small".into(),
                width: 384, depth: 12, heads: 6, mlp_dim: 1536, dim_head: 64,
                ..Default::default()
            },
            "vit_middle" => Self {
                encoder_name: "vit_middle".into(),
                width: 512, depth: 12, heads: 8, mlp_dim: 2048, dim_head: 64,
                ..Default::default()
            },
            "vit_base" | _ => Self::default(),
        }
    }
}

// ── PSG Channel definitions ─────────────────────────────────────────────────

/// The 12 standard PSG channels used by OSF, in canonical order.
pub const PSG_CHANNELS: &[&str] = &[
    "ECG",
    "EMG_Chin",
    "EMG_LLeg",
    "EMG_RLeg",
    "ABD",
    "THX",
    "NP",
    "SN",
    "EOG_E1_A2",
    "EOG_E2_A1",
    "EEG_C3_A2",
    "EEG_C4_A1",
];

/// Number of PSG channels.
pub const NUM_PSG_CHANNELS: usize = 12;

/// Sampling frequency (Hz).
pub const SAMPLE_RATE: usize = 64;

/// Epoch duration (seconds).
pub const EPOCH_SEC: usize = 30;

/// Samples per epoch.
pub const EPOCH_SAMPLES: usize = SAMPLE_RATE * EPOCH_SEC; // 1920