Skip to main content

labram_rs/
config.rs

1/// Model configuration for LaBraM.
2
3#[derive(Debug, Clone, serde::Deserialize)]
4pub struct ModelConfig {
5    #[serde(default = "d200")]  pub patch_size: usize,
6    #[serde(default = "d200")]  pub embed_dim: usize,
7    #[serde(default = "d12")]   pub num_layers: usize,
8    #[serde(default = "d10")]   pub num_heads: usize,
9    #[serde(default = "d4f")]   pub mlp_ratio: f64,
10    #[serde(default = "d8")]    pub conv_out_channels: usize,
11    #[serde(default = "d128")]  pub n_pos_embeddings: usize,
12    #[serde(default)]           pub n_outputs: usize,
13    #[serde(default)]           pub n_chans: usize,
14    #[serde(default)]           pub n_times: usize,
15}
16
17fn d200() -> usize { 200 }
18fn d12() -> usize { 12 }
19fn d10() -> usize { 10 }
20fn d4f() -> f64 { 4.0 }
21fn d8() -> usize { 8 }
22fn d128() -> usize { 128 }
23
24impl Default for ModelConfig {
25    fn default() -> Self {
26        Self {
27            patch_size: 200, embed_dim: 200, num_layers: 12, num_heads: 10,
28            mlp_ratio: 4.0, conv_out_channels: 8, n_pos_embeddings: 128,
29            n_outputs: 4, n_chans: 64, n_times: 1600,
30        }
31    }
32}