Skip to main content

reve_rs/
config.rs

1/// Model and runtime configuration for REVE inference.
2///
3/// Field names match the REVE hyperparameters from the Python implementation.
4
5#[derive(Debug, Clone, serde::Deserialize)]
6pub struct ModelConfig {
7    /// Embedding dimension (512 for REVE-Base, 1250 for REVE-Large).
8    #[serde(default = "default_embed_dim")]
9    pub embed_dim: usize,
10
11    /// Number of Transformer layers.
12    #[serde(default = "default_depth")]
13    pub depth: usize,
14
15    /// Number of attention heads.
16    #[serde(default = "default_heads")]
17    pub heads: usize,
18
19    /// Dimension per attention head.
20    #[serde(default = "default_head_dim")]
21    pub head_dim: usize,
22
23    /// FFN hidden dimension ratio: mlp_dim = embed_dim * mlp_dim_ratio.
24    #[serde(default = "default_mlp_dim_ratio")]
25    pub mlp_dim_ratio: f64,
26
27    /// Use GEGLU activation.
28    #[serde(default = "default_use_geglu")]
29    pub use_geglu: bool,
30
31    /// Number of frequencies for Fourier positional embedding.
32    #[serde(default = "default_freqs")]
33    pub freqs: usize,
34
35    /// Temporal patch size in samples.
36    #[serde(default = "default_patch_size")]
37    pub patch_size: usize,
38
39    /// Overlap between patches in samples.
40    #[serde(default = "default_patch_overlap")]
41    pub patch_overlap: usize,
42
43    /// Use attention pooling for classification.
44    #[serde(default)]
45    pub attention_pooling: bool,
46
47    /// Number of output classes.
48    #[serde(default)]
49    pub n_outputs: usize,
50
51    /// Number of EEG channels.
52    #[serde(default)]
53    pub n_chans: usize,
54
55    /// Number of time samples per input.
56    #[serde(default)]
57    pub n_times: usize,
58}
59
60fn default_embed_dim()     -> usize { 512 }
61fn default_depth()         -> usize { 22 }
62fn default_heads()         -> usize { 8 }
63fn default_head_dim()      -> usize { 64 }
64fn default_mlp_dim_ratio() -> f64   { 2.66 }
65fn default_use_geglu()     -> bool  { true }
66fn default_freqs()         -> usize { 4 }
67fn default_patch_size()    -> usize { 200 }
68fn default_patch_overlap() -> usize { 20 }
69
70impl Default for ModelConfig {
71    fn default() -> Self {
72        Self {
73            embed_dim:         default_embed_dim(),
74            depth:             default_depth(),
75            heads:             default_heads(),
76            head_dim:          default_head_dim(),
77            mlp_dim_ratio:     default_mlp_dim_ratio(),
78            use_geglu:         default_use_geglu(),
79            freqs:             default_freqs(),
80            patch_size:        default_patch_size(),
81            patch_overlap:     default_patch_overlap(),
82            attention_pooling: false,
83            n_outputs:         4,
84            n_chans:           22,
85            n_times:           1000,
86        }
87    }
88}
89
90impl ModelConfig {
91    /// Inner attention dimension: head_dim * heads.
92    pub fn inner_dim(&self) -> usize {
93        self.head_dim * self.heads
94    }
95
96    /// FFN hidden dimension.
97    pub fn mlp_dim(&self) -> usize {
98        (self.embed_dim as f64 * self.mlp_dim_ratio) as usize
99    }
100
101    /// GEGLU doubles the FFN input features.
102    pub fn ffn_in_features(&self) -> usize {
103        let mlp = self.mlp_dim();
104        if self.use_geglu { mlp * 2 } else { mlp }
105    }
106}