1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
5#[serde(rename_all = "lowercase")]
6pub enum ModelSize {
7 Small,
8 Medium,
9 Large,
10}
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ModelConfig {
17 pub model_size: ModelSize,
18 pub feature_size: usize,
20 pub num_heads: usize,
22 pub num_layers: usize,
24 pub dim_feedforward: usize,
26 pub num_global_tokens: usize,
28 pub global_token_layer: usize,
30 pub num_channels: usize,
32 pub patch_size: usize,
34 pub conv_channels: [usize; 3],
36 pub norm_groups: [usize; 3],
38 pub layer_norm_eps: f64,
40}
41
42impl ModelConfig {
43 pub fn from_size(size: ModelSize) -> Self {
44 match size {
45 ModelSize::Small => Self {
46 model_size: size,
47 feature_size: 200,
48 num_heads: 8,
49 num_layers: 12,
50 dim_feedforward: 512,
51 num_global_tokens: 1,
52 global_token_layer: 1,
53 num_channels: 19,
54 patch_size: 200,
55 conv_channels: [25, 25, 25],
56 norm_groups: [5, 5, 5],
57 layer_norm_eps: 1e-5,
58 },
59 ModelSize::Medium => Self {
60 model_size: size,
61 feature_size: 512,
62 num_heads: 16,
63 num_layers: 16,
64 dim_feedforward: 1024,
65 num_global_tokens: 1,
66 global_token_layer: 1,
67 num_channels: 19,
68 patch_size: 200,
69 conv_channels: [64, 128, 64],
70 norm_groups: [8, 8, 8],
71 layer_norm_eps: 1e-5,
72 },
73 ModelSize::Large => Self {
74 model_size: size,
75 feature_size: 1024,
76 num_heads: 16,
80 num_layers: 24,
81 dim_feedforward: 2048,
82 num_global_tokens: 1,
83 global_token_layer: 1,
84 num_channels: 19,
85 patch_size: 200,
86 conv_channels: [128, 256, 128],
87 norm_groups: [16, 16, 16],
88 layer_norm_eps: 1e-5,
89 },
90 }
91 }
92
93 pub fn from_file(path: &std::path::Path) -> crate::error::Result<Self> {
95 let data = std::fs::read_to_string(path)?;
96 serde_json::from_str(&data).map_err(|e| {
97 crate::error::EegDinoError::WeightLoad(format!("config parse error: {e}"))
98 })
99 }
100
101 pub fn head_dim(&self) -> usize {
103 self.feature_size / self.num_heads
104 }
105
106 pub fn spectral_bins(&self) -> usize {
108 self.patch_size / 2 + 1
109 }
110
111 pub fn temporal_conv_out(&self) -> usize {
113 (self.patch_size - 49 + 2 * 24) / 25 + 1
114 }
115}