Skip to main content

cbramod_rs/
weights.rs

1/// Load CBraMod weights from safetensors.
2
3use std::collections::HashMap;
4use burn::prelude::*;
5use half::bf16;
6use safetensors::SafeTensors;
7use crate::model::cbramod::CBraMod;
8use crate::config::ModelConfig;
9
10pub struct WeightMap {
11    pub tensors: HashMap<String, (Vec<f32>, Vec<usize>)>,
12}
13
14impl WeightMap {
15    pub fn from_file(path: &str) -> anyhow::Result<Self> {
16        let bytes = std::fs::read(path)?;
17        let st = SafeTensors::deserialize(&bytes)?;
18        let mut tensors = HashMap::with_capacity(st.len());
19        for (key, view) in st.tensors() {
20            let key = key.strip_prefix("model.").unwrap_or(&key).to_string();
21            let shape: Vec<usize> = view.shape().to_vec();
22            let data = view.data();
23            let f32s: Vec<f32> = match view.dtype() {
24                safetensors::Dtype::BF16 => data.chunks_exact(2)
25                    .map(|b| bf16::from_le_bytes([b[0], b[1]]).to_f32()).collect(),
26                safetensors::Dtype::F32 => data.chunks_exact(4)
27                    .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])).collect(),
28                safetensors::Dtype::F16 => data.chunks_exact(2)
29                    .map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32()).collect(),
30                other => anyhow::bail!("unsupported dtype {:?} for key {key}", other),
31            };
32            tensors.insert(key, (f32s, shape));
33        }
34        Ok(Self { tensors })
35    }
36
37    pub fn take<B: Backend, const N: usize>(&mut self, key: &str, device: &B::Device) -> anyhow::Result<Tensor<B, N>> {
38        let (data, shape) = self.tensors.remove(key)
39            .ok_or_else(|| anyhow::anyhow!("key not found: {key}"))?;
40        if shape.len() != N { anyhow::bail!("rank mismatch for {key}: expected {N}, got {}", shape.len()); }
41        Ok(Tensor::<B, N>::from_data(TensorData::new(data, shape), device))
42    }
43
44    pub fn has(&self, key: &str) -> bool { self.tensors.contains_key(key) }
45
46    pub fn print_keys(&self) {
47        let mut keys: Vec<&str> = self.tensors.keys().map(String::as_str).collect();
48        keys.sort();
49        for k in keys { let (_, s) = &self.tensors[k]; println!("  {k:70}  {s:?}"); }
50    }
51}
52
53fn set_linear_wb<B: Backend>(l: &mut burn::nn::Linear<B>, w: Tensor<B, 2>, b: Tensor<B, 1>) {
54    l.weight = l.weight.clone().map(|_| w.transpose());
55    if let Some(ref bias) = l.bias { l.bias = Some(bias.clone().map(|_| b)); }
56}
57
58#[allow(dead_code)]
59fn set_linear_w<B: Backend>(l: &mut burn::nn::Linear<B>, w: Tensor<B, 2>) {
60    l.weight = l.weight.clone().map(|_| w.transpose());
61}
62
63fn set_layernorm<B: Backend>(n: &mut burn::nn::LayerNorm<B>, w: Tensor<B, 1>, b: Tensor<B, 1>) {
64    n.gamma = n.gamma.clone().map(|_| w);
65    if let Some(ref beta) = n.beta { n.beta = Some(beta.clone().map(|_| b)); }
66}
67
68fn set_conv2d_wb<B: Backend>(c: &mut burn::nn::conv::Conv2d<B>, w: Tensor<B, 4>, b: Tensor<B, 1>) {
69    c.weight = c.weight.clone().map(|_| w);
70    if let Some(ref bias) = c.bias { c.bias = Some(bias.clone().map(|_| b)); }
71}
72
73fn set_groupnorm<B: Backend>(g: &mut burn::nn::GroupNorm<B>, w: Tensor<B, 1>, b: Tensor<B, 1>) {
74    if let Some(ref gamma) = g.gamma { g.gamma = Some(gamma.clone().map(|_| w)); }
75    if let Some(ref beta) = g.beta { g.beta = Some(beta.clone().map(|_| b)); }
76}
77
78pub fn load_model<B: Backend>(cfg: &ModelConfig, path: &str, device: &B::Device) -> anyhow::Result<CBraMod<B>> {
79    let mut wm = WeightMap::from_file(path)?;
80    eprintln!("Loading {} weight tensors...", wm.tensors.len());
81    load_model_from_wm(cfg, &mut wm, device)
82}
83
84pub fn load_model_from_wm<B: Backend>(cfg: &ModelConfig, wm: &mut WeightMap, device: &B::Device) -> anyhow::Result<CBraMod<B>> {
85    let mut model = CBraMod::new(
86        cfg.n_outputs, cfg.n_chans, cfg.n_times, cfg.patch_size,
87        cfg.dim_feedforward, cfg.n_layer, cfg.nhead, cfg.emb_dim, false, device,
88    );
89    load_weights(wm, &mut model, cfg, device)?;
90    Ok(model)
91}
92
93fn load_weights<B: Backend>(wm: &mut WeightMap, model: &mut CBraMod<B>, cfg: &ModelConfig, device: &B::Device) -> anyhow::Result<()> {
94    // Patch embedding conv layers
95    macro_rules! load_conv_gn {
96        ($block:expr, $conv_key:expr, $gn_key:expr) => {
97            if let (Ok(w), Ok(b)) = (wm.take::<B,4>(&format!("{}.weight", $conv_key), device),
98                                     wm.take::<B,1>(&format!("{}.bias", $conv_key), device)) {
99                set_conv2d_wb(&mut $block.conv, w, b);
100            }
101            if let (Ok(w), Ok(b)) = (wm.take::<B,1>(&format!("{}.weight", $gn_key), device),
102                                     wm.take::<B,1>(&format!("{}.bias", $gn_key), device)) {
103                set_groupnorm(&mut $block.gn, w, b);
104            }
105        };
106    }
107    load_conv_gn!(model.patch_embedding.conv1, "patch_embedding.proj_in.0", "patch_embedding.proj_in.1");
108    load_conv_gn!(model.patch_embedding.conv2, "patch_embedding.proj_in.3", "patch_embedding.proj_in.4");
109    load_conv_gn!(model.patch_embedding.conv3, "patch_embedding.proj_in.6", "patch_embedding.proj_in.7");
110
111    // Spectral projection
112    if let (Ok(w), Ok(b)) = (wm.take::<B,2>("patch_embedding.spectral_proj.0.weight", device),
113                             wm.take::<B,1>("patch_embedding.spectral_proj.0.bias", device)) {
114        set_linear_wb(&mut model.patch_embedding.spectral_linear, w, b);
115    }
116
117    // Positional encoding conv
118    if let (Ok(w), Ok(b)) = (wm.take::<B,4>("patch_embedding.positional_encoding.0.weight", device),
119                             wm.take::<B,1>("patch_embedding.positional_encoding.0.bias", device)) {
120        set_conv2d_wb(&mut model.patch_embedding.pos_conv, w, b);
121    }
122
123    // Encoder layers
124    for i in 0..cfg.n_layer {
125        let layer = &mut model.encoder.layers[i];
126        let p = format!("encoder.layers.{i}");
127
128        // S-Attention
129        if let (Ok(w), Ok(b)) = (wm.take::<B,2>(&format!("{p}.self_attn_s.in_proj_weight"), device),
130                                 wm.take::<B,1>(&format!("{p}.self_attn_s.in_proj_bias"), device)) {
131            set_linear_wb(&mut layer.self_attn_s.in_proj, w, b);
132        }
133        if let (Ok(w), Ok(b)) = (wm.take::<B,2>(&format!("{p}.self_attn_s.out_proj.weight"), device),
134                                 wm.take::<B,1>(&format!("{p}.self_attn_s.out_proj.bias"), device)) {
135            set_linear_wb(&mut layer.self_attn_s.out_proj, w, b);
136        }
137
138        // T-Attention
139        if let (Ok(w), Ok(b)) = (wm.take::<B,2>(&format!("{p}.self_attn_t.in_proj_weight"), device),
140                                 wm.take::<B,1>(&format!("{p}.self_attn_t.in_proj_bias"), device)) {
141            set_linear_wb(&mut layer.self_attn_t.in_proj, w, b);
142        }
143        if let (Ok(w), Ok(b)) = (wm.take::<B,2>(&format!("{p}.self_attn_t.out_proj.weight"), device),
144                                 wm.take::<B,1>(&format!("{p}.self_attn_t.out_proj.bias"), device)) {
145            set_linear_wb(&mut layer.self_attn_t.out_proj, w, b);
146        }
147
148        // FFN
149        if let (Ok(w), Ok(b)) = (wm.take::<B,2>(&format!("{p}.linear1.weight"), device),
150                                 wm.take::<B,1>(&format!("{p}.linear1.bias"), device)) {
151            set_linear_wb(&mut layer.linear1, w, b);
152        }
153        if let (Ok(w), Ok(b)) = (wm.take::<B,2>(&format!("{p}.linear2.weight"), device),
154                                 wm.take::<B,1>(&format!("{p}.linear2.bias"), device)) {
155            set_linear_wb(&mut layer.linear2, w, b);
156        }
157
158        // Norms
159        if let (Ok(w), Ok(b)) = (wm.take::<B,1>(&format!("{p}.norm1.weight"), device),
160                                 wm.take::<B,1>(&format!("{p}.norm1.bias"), device)) {
161            set_layernorm(&mut layer.norm1, w, b);
162        }
163        if let (Ok(w), Ok(b)) = (wm.take::<B,1>(&format!("{p}.norm2.weight"), device),
164                                 wm.take::<B,1>(&format!("{p}.norm2.bias"), device)) {
165            set_layernorm(&mut layer.norm2, w, b);
166        }
167    }
168
169    // proj_out
170    if let (Ok(w), Ok(b)) = (wm.take::<B,2>("proj_out.0.weight", device),
171                             wm.take::<B,1>("proj_out.0.bias", device)) {
172        set_linear_wb(&mut model.proj_out, w, b);
173    }
174
175    // final_layer (Flatten → LazyLinear mapped to final_layer.1)
176    if let (Ok(w), Ok(b)) = (wm.take::<B,2>("final_layer.1.weight", device),
177                             wm.take::<B,1>("final_layer.1.bias", device)) {
178        set_linear_wb(&mut model.final_linear, w, b);
179    }
180
181    Ok(())
182}