Skip to main content

osf_rs/
config.rs

1/// Model and data configuration for OSF inference.
2///
3/// `ModelConfig` mirrors the Python OSF ViT hyperparameters stored in
4/// the `metadata` dict of `osf_backbone.pth`.
5
6// ── ModelConfig ───────────────────────────────────────────────────────────────
7
8#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
9pub struct ModelConfig {
10    /// Model variant name (e.g. "vit_base").
11    #[serde(default = "default_encoder_name")]
12    pub encoder_name: String,
13
14    /// Number of input channels (default 12).
15    #[serde(default = "default_num_leads")]
16    pub num_leads: usize,
17
18    /// Temporal patch size in samples (default 64).
19    #[serde(default = "default_patch_size_time")]
20    pub patch_size_time: usize,
21
22    /// Channel patch size (default 4).
23    #[serde(default = "default_patch_size_ch")]
24    pub patch_size_ch: usize,
25
26    /// 0 = 1D patchify (all channels), 1 = 2D patchify (channel groups).
27    #[serde(default = "default_lead_wise")]
28    pub lead_wise: usize,
29
30    /// Sampling rate (Hz, default 64).
31    #[serde(default = "default_sample_rate")]
32    pub sample_rate: usize,
33
34    /// Window size in seconds (default 30).
35    #[serde(default = "default_window_size_sec")]
36    pub window_size_sec: usize,
37
38    /// Sequence length = sample_rate * window_size_sec (default 1920).
39    #[serde(default = "default_seq_len")]
40    pub seq_len: usize,
41
42    /// Transformer hidden dimension (default 768).
43    #[serde(default = "default_width")]
44    pub width: usize,
45
46    /// Number of transformer blocks (default 12).
47    #[serde(default = "default_depth")]
48    pub depth: usize,
49
50    /// Number of attention heads (default 12).
51    #[serde(default = "default_heads")]
52    pub heads: usize,
53
54    /// MLP hidden dimension (default 3072 = 4 * width).
55    #[serde(default = "default_mlp_dim")]
56    pub mlp_dim: usize,
57
58    /// Attention head dimension (default 64 = width / heads).
59    #[serde(default = "default_dim_head")]
60    pub dim_head: usize,
61}
62
63fn default_encoder_name() -> String { "vit_base".to_string() }
64fn default_num_leads()     -> usize { 12 }
65fn default_patch_size_time() -> usize { 64 }
66fn default_patch_size_ch() -> usize { 4 }
67fn default_lead_wise()     -> usize { 1 }
68fn default_sample_rate()   -> usize { 64 }
69fn default_window_size_sec() -> usize { 30 }
70fn default_seq_len()       -> usize { 1920 }
71fn default_width()         -> usize { 768 }
72fn default_depth()         -> usize { 12 }
73fn default_heads()         -> usize { 12 }
74fn default_mlp_dim()       -> usize { 3072 }
75fn default_dim_head()      -> usize { 64 }
76
77impl Default for ModelConfig {
78    fn default() -> Self {
79        Self {
80            encoder_name:    default_encoder_name(),
81            num_leads:       default_num_leads(),
82            patch_size_time: default_patch_size_time(),
83            patch_size_ch:   default_patch_size_ch(),
84            lead_wise:       default_lead_wise(),
85            sample_rate:     default_sample_rate(),
86            window_size_sec: default_window_size_sec(),
87            seq_len:         default_seq_len(),
88            width:           default_width(),
89            depth:           default_depth(),
90            heads:           default_heads(),
91            mlp_dim:         default_mlp_dim(),
92            dim_head:        default_dim_head(),
93        }
94    }
95}
96
97impl ModelConfig {
98    /// Number of time patches per channel row.
99    pub fn num_patches_time(&self) -> usize {
100        self.seq_len / self.patch_size_time
101    }
102
103    /// Number of channel rows (lead groups).
104    pub fn num_lead_rows(&self) -> usize {
105        if self.lead_wise == 0 { 1 } else { self.num_leads / self.patch_size_ch }
106    }
107
108    /// Total number of patches (excluding CLS token).
109    pub fn num_patches(&self) -> usize {
110        self.num_lead_rows() * self.num_patches_time()
111    }
112
113    /// Build config for a specific variant.
114    pub fn for_variant(name: &str) -> Self {
115        match name {
116            "vit_nano" => Self {
117                encoder_name: "vit_nano".into(),
118                width: 128, depth: 6, heads: 4, mlp_dim: 512, dim_head: 32,
119                ..Default::default()
120            },
121            "vit_tiny" => Self {
122                encoder_name: "vit_tiny".into(),
123                width: 192, depth: 12, heads: 3, mlp_dim: 768, dim_head: 64,
124                ..Default::default()
125            },
126            "vit_small" => Self {
127                encoder_name: "vit_small".into(),
128                width: 384, depth: 12, heads: 6, mlp_dim: 1536, dim_head: 64,
129                ..Default::default()
130            },
131            "vit_middle" => Self {
132                encoder_name: "vit_middle".into(),
133                width: 512, depth: 12, heads: 8, mlp_dim: 2048, dim_head: 64,
134                ..Default::default()
135            },
136            "vit_base" | _ => Self::default(),
137        }
138    }
139}
140
141// ── PSG Channel definitions ─────────────────────────────────────────────────
142
143/// The 12 standard PSG channels used by OSF, in canonical order.
144pub const PSG_CHANNELS: &[&str] = &[
145    "ECG",
146    "EMG_Chin",
147    "EMG_LLeg",
148    "EMG_RLeg",
149    "ABD",
150    "THX",
151    "NP",
152    "SN",
153    "EOG_E1_A2",
154    "EOG_E2_A1",
155    "EEG_C3_A2",
156    "EEG_C4_A1",
157];
158
159/// Number of PSG channels.
160pub const NUM_PSG_CHANNELS: usize = 12;
161
162/// Sampling frequency (Hz).
163pub const SAMPLE_RATE: usize = 64;
164
165/// Epoch duration (seconds).
166pub const EPOCH_SEC: usize = 30;
167
168/// Samples per epoch.
169pub const EPOCH_SAMPLES: usize = SAMPLE_RATE * EPOCH_SEC; // 1920