Skip to main content

eegpt_rs/
weights.rs

1/// Load EEGPT weights from safetensors.
2
3use std::collections::HashMap;
4use burn::prelude::*;
5use half::bf16;
6use safetensors::SafeTensors;
7use crate::model::eegpt::EEGPT;
8use crate::config::ModelConfig;
9
10pub struct WeightMap {
11    pub tensors: HashMap<String, (Vec<f32>, Vec<usize>)>,
12}
13
14impl WeightMap {
15    pub fn from_file(path: &str) -> anyhow::Result<Self> {
16        let bytes = std::fs::read(path)?;
17        let st = SafeTensors::deserialize(&bytes)?;
18        let mut tensors = HashMap::with_capacity(st.len());
19        for (key, view) in st.tensors() {
20            let key = key.strip_prefix("model.").unwrap_or(&key).to_string();
21            let shape: Vec<usize> = view.shape().to_vec();
22            let data = view.data();
23            let f32s: Vec<f32> = match view.dtype() {
24                safetensors::Dtype::BF16 => data.chunks_exact(2).map(|b| bf16::from_le_bytes([b[0], b[1]]).to_f32()).collect(),
25                safetensors::Dtype::F32 => data.chunks_exact(4).map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])).collect(),
26                safetensors::Dtype::F16 => data.chunks_exact(2).map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32()).collect(),
27                other => anyhow::bail!("unsupported dtype {:?} for key {key}", other),
28            };
29            tensors.insert(key, (f32s, shape));
30        }
31        Ok(Self { tensors })
32    }
33
34    pub fn take<B: Backend, const N: usize>(&mut self, key: &str, device: &B::Device) -> anyhow::Result<Tensor<B, N>> {
35        let (data, shape) = self.tensors.remove(key).ok_or_else(|| anyhow::anyhow!("key not found: {key}"))?;
36        if shape.len() != N { anyhow::bail!("rank mismatch for {key}: expected {N}, got {}", shape.len()); }
37        Ok(Tensor::<B, N>::from_data(TensorData::new(data, shape), device))
38    }
39
40    pub fn has(&self, key: &str) -> bool { self.tensors.contains_key(key) }
41}
42
43fn set_linear_wb<B: Backend>(l: &mut burn::nn::Linear<B>, w: Tensor<B, 2>, b: Tensor<B, 1>) {
44    l.weight = l.weight.clone().map(|_| w.transpose());
45    if let Some(ref bias) = l.bias { l.bias = Some(bias.clone().map(|_| b)); }
46}
47
48fn set_layernorm<B: Backend>(n: &mut burn::nn::LayerNorm<B>, w: Tensor<B, 1>, b: Tensor<B, 1>) {
49    n.gamma = n.gamma.clone().map(|_| w);
50    if let Some(ref beta) = n.beta { n.beta = Some(beta.clone().map(|_| b)); }
51}
52
53fn set_conv2d_wb<B: Backend>(c: &mut burn::nn::conv::Conv2d<B>, w: Tensor<B, 4>, b: Tensor<B, 1>) {
54    c.weight = c.weight.clone().map(|_| w);
55    if let Some(ref bias) = c.bias { c.bias = Some(bias.clone().map(|_| b)); }
56}
57
58pub fn load_model<B: Backend>(cfg: &ModelConfig, path: &str, device: &B::Device) -> anyhow::Result<EEGPT<B>> {
59    let mut wm = WeightMap::from_file(path)?;
60    eprintln!("Loading {} weight tensors...", wm.tensors.len());
61    load_model_from_wm(cfg, &mut wm, device)
62}
63
64pub fn load_model_from_wm<B: Backend>(cfg: &ModelConfig, wm: &mut WeightMap, device: &B::Device) -> anyhow::Result<EEGPT<B>> {
65    let mut model = EEGPT::new(
66        cfg.n_outputs, cfg.n_chans, cfg.n_times,
67        cfg.patch_size, cfg.patch_stride, cfg.embed_num, cfg.embed_dim,
68        cfg.depth, cfg.num_heads, cfg.mlp_ratio, cfg.qkv_bias,
69        cfg.n_chan_embeddings, cfg.probe_hidden_dim, 1e-6, device,
70    );
71    load_weights(wm, &mut model, cfg, device)?;
72    Ok(model)
73}
74
75fn load_weights<B: Backend>(wm: &mut WeightMap, model: &mut EEGPT<B>, cfg: &ModelConfig, device: &B::Device) -> anyhow::Result<()> {
76    let te = &mut model.target_encoder;
77
78    // Summary token
79    if let Ok(t) = wm.take::<B, 3>("target_encoder.summary_token", device) {
80        te.summary_token = te.summary_token.clone().map(|_| t);
81    }
82
83    // Patch embedding conv
84    if let (Ok(w), Ok(b)) = (wm.take::<B,4>("target_encoder.patch_embed.proj.weight", device),
85                             wm.take::<B,1>("target_encoder.patch_embed.proj.bias", device)) {
86        set_conv2d_wb(&mut te.patch_embed.proj, w, b);
87    }
88
89    // Channel embedding
90    if let Ok(w) = wm.take::<B, 2>("target_encoder.chan_embed.weight", device) {
91        te.chan_embed.weight = te.chan_embed.weight.clone().map(|_| w);
92    }
93
94    // Transformer blocks
95    for i in 0..cfg.depth {
96        let block = &mut te.blocks[i];
97        let p = format!("target_encoder.blocks.{i}");
98
99        if let (Ok(w), Ok(b)) = (wm.take::<B,1>(&format!("{p}.norm1.weight"), device),
100                                 wm.take::<B,1>(&format!("{p}.norm1.bias"), device)) {
101            set_layernorm(&mut block.norm1, w, b);
102        }
103        if let (Ok(w), Ok(b)) = (wm.take::<B,2>(&format!("{p}.attn.qkv.weight"), device),
104                                 wm.take::<B,1>(&format!("{p}.attn.qkv.bias"), device)) {
105            set_linear_wb(&mut block.attn.qkv, w, b);
106        }
107        if let (Ok(w), Ok(b)) = (wm.take::<B,2>(&format!("{p}.attn.proj.weight"), device),
108                                 wm.take::<B,1>(&format!("{p}.attn.proj.bias"), device)) {
109            set_linear_wb(&mut block.attn.proj, w, b);
110        }
111        if let (Ok(w), Ok(b)) = (wm.take::<B,1>(&format!("{p}.norm2.weight"), device),
112                                 wm.take::<B,1>(&format!("{p}.norm2.bias"), device)) {
113            set_layernorm(&mut block.norm2, w, b);
114        }
115        if let (Ok(w), Ok(b)) = (wm.take::<B,2>(&format!("{p}.mlp.fc1.weight"), device),
116                                 wm.take::<B,1>(&format!("{p}.mlp.fc1.bias"), device)) {
117            set_linear_wb(&mut block.mlp_fc1, w, b);
118        }
119        if let (Ok(w), Ok(b)) = (wm.take::<B,2>(&format!("{p}.mlp.fc2.weight"), device),
120                                 wm.take::<B,1>(&format!("{p}.mlp.fc2.bias"), device)) {
121            set_linear_wb(&mut block.mlp_fc2, w, b);
122        }
123    }
124
125    // Final norm
126    if let (Ok(w), Ok(b)) = (wm.take::<B,1>("target_encoder.norm.weight", device),
127                             wm.take::<B,1>("target_encoder.norm.bias", device)) {
128        set_layernorm(&mut te.norm, w, b);
129    }
130
131    // Probe layers
132    if let (Ok(w), Ok(b)) = (wm.take::<B,2>("final_layer.probe1.weight", device),
133                             wm.take::<B,1>("final_layer.probe1.bias", device)) {
134        set_linear_wb(&mut model.probe1, w, b);
135    }
136    if let (Ok(w), Ok(b)) = (wm.take::<B,2>("final_layer.probe2.weight", device),
137                             wm.take::<B,1>("final_layer.probe2.bias", device)) {
138        set_linear_wb(&mut model.probe2, w, b);
139    }
140
141    Ok(())
142}