Skip to main content

labram_rs/
weights.rs

1/// Load LaBraM weights from safetensors.
2
3use std::collections::HashMap;
4use burn::prelude::*;
5use half::bf16;
6use safetensors::SafeTensors;
7use crate::model::labram::LaBraM;
8use crate::config::ModelConfig;
9
10pub struct WeightMap {
11    pub tensors: HashMap<String, (Vec<f32>, Vec<usize>)>,
12}
13
14impl WeightMap {
15    pub fn from_file(path: &str) -> anyhow::Result<Self> {
16        let bytes = std::fs::read(path)?;
17        let st = SafeTensors::deserialize(&bytes)?;
18        let mut tensors = HashMap::with_capacity(st.len());
19        for (key, view) in st.tensors() {
20            let key = key.strip_prefix("model.").unwrap_or(&key).to_string();
21            let shape: Vec<usize> = view.shape().to_vec();
22            let data = view.data();
23            let f32s: Vec<f32> = match view.dtype() {
24                safetensors::Dtype::BF16 => data.chunks_exact(2).map(|b| bf16::from_le_bytes([b[0], b[1]]).to_f32()).collect(),
25                safetensors::Dtype::F32 => data.chunks_exact(4).map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])).collect(),
26                other => anyhow::bail!("unsupported dtype {:?}", other),
27            };
28            tensors.insert(key, (f32s, shape));
29        }
30        Ok(Self { tensors })
31    }
32    pub fn take<B: Backend, const N: usize>(&mut self, key: &str, device: &B::Device) -> anyhow::Result<Tensor<B, N>> {
33        let (data, shape) = self.tensors.remove(key).ok_or_else(|| anyhow::anyhow!("key not found: {key}"))?;
34        if shape.len() != N { anyhow::bail!("rank mismatch for {key}"); }
35        Ok(Tensor::<B, N>::from_data(TensorData::new(data, shape), device))
36    }
37    pub fn has(&self, key: &str) -> bool { self.tensors.contains_key(key) }
38}
39
40fn set_linear_wb<B: Backend>(l: &mut burn::nn::Linear<B>, w: Tensor<B, 2>, b: Tensor<B, 1>) {
41    l.weight = l.weight.clone().map(|_| w.transpose());
42    if let Some(ref bias) = l.bias { l.bias = Some(bias.clone().map(|_| b)); }
43}
44fn set_linear_w<B: Backend>(l: &mut burn::nn::Linear<B>, w: Tensor<B, 2>) {
45    l.weight = l.weight.clone().map(|_| w.transpose());
46}
47fn set_ln<B: Backend>(n: &mut burn::nn::LayerNorm<B>, w: Tensor<B, 1>, b: Tensor<B, 1>) {
48    n.gamma = n.gamma.clone().map(|_| w);
49    if let Some(ref beta) = n.beta { n.beta = Some(beta.clone().map(|_| b)); }
50}
51fn set_conv2d_wb<B: Backend>(c: &mut burn::nn::conv::Conv2d<B>, w: Tensor<B, 4>, b: Tensor<B, 1>) {
52    c.weight = c.weight.clone().map(|_| w);
53    if let Some(ref bias) = c.bias { c.bias = Some(bias.clone().map(|_| b)); }
54}
55fn set_gn<B: Backend>(g: &mut burn::nn::GroupNorm<B>, w: Tensor<B, 1>, b: Tensor<B, 1>) {
56    if let Some(ref gamma) = g.gamma { g.gamma = Some(gamma.clone().map(|_| w)); }
57    if let Some(ref beta) = g.beta { g.beta = Some(beta.clone().map(|_| b)); }
58}
59
60pub fn load_model<B: Backend>(cfg: &ModelConfig, path: &str, device: &B::Device) -> anyhow::Result<LaBraM<B>> {
61    let mut wm = WeightMap::from_file(path)?;
62    load_model_from_wm(cfg, &mut wm, device)
63}
64
65pub fn load_model_from_wm<B: Backend>(cfg: &ModelConfig, wm: &mut WeightMap, device: &B::Device) -> anyhow::Result<LaBraM<B>> {
66    let mut model = LaBraM::new(
67        cfg.n_outputs, cfg.n_chans, cfg.n_times, cfg.patch_size, cfg.embed_dim,
68        cfg.num_layers, cfg.num_heads, cfg.mlp_ratio,
69        true, Some(0.1), true, false, cfg.conv_out_channels,
70        cfg.n_pos_embeddings, 1e-6, device,
71    );
72    load_weights(wm, &mut model, cfg, device)?;
73    Ok(model)
74}
75
76fn load_weights<B: Backend>(wm: &mut WeightMap, m: &mut LaBraM<B>, cfg: &ModelConfig, dev: &B::Device) -> anyhow::Result<()> {
77    // CLS token
78    if let Ok(t) = wm.take::<B,3>("cls_token", dev) { m.cls_token = m.cls_token.clone().map(|_| t); }
79    if let Ok(t) = wm.take::<B,3>("position_embedding", dev) {
80        if let Some(ref mut pe) = m.position_embedding { *pe = pe.clone().map(|_| t); }
81    }
82    if let Ok(t) = wm.take::<B,3>("temporal_embedding", dev) { m.temporal_embedding = m.temporal_embedding.clone().map(|_| t); }
83
84    // TemporalConv
85    macro_rules! load_tc { ($c:expr, $n:expr, $ci:expr, $ni:expr) => {
86        if let (Ok(w), Ok(b)) = (wm.take::<B,4>(&format!("patch_embed.temporal_conv.{}.weight", $ci), dev),
87                                 wm.take::<B,1>(&format!("patch_embed.temporal_conv.{}.bias", $ci), dev)) { set_conv2d_wb(&mut $c, w, b); }
88        if let (Ok(w), Ok(b)) = (wm.take::<B,1>(&format!("patch_embed.temporal_conv.{}.weight", $ni), dev),
89                                 wm.take::<B,1>(&format!("patch_embed.temporal_conv.{}.bias", $ni), dev)) { set_gn(&mut $n, w, b); }
90    }; }
91    load_tc!(m.temporal_conv.conv1, m.temporal_conv.norm1, "conv1", "norm1");
92    load_tc!(m.temporal_conv.conv2, m.temporal_conv.norm2, "conv2", "norm2");
93    load_tc!(m.temporal_conv.conv3, m.temporal_conv.norm3, "conv3", "norm3");
94
95    // Blocks
96    for i in 0..cfg.num_layers {
97        let blk = &mut m.blocks[i];
98        let p = format!("blocks.{i}");
99        if let Ok(w) = wm.take::<B,1>(&format!("{p}.gamma_1"), dev) {
100            if let Some(ref mut g) = blk.gamma_1 { *g = g.clone().map(|_| w); }
101        }
102        if let Ok(w) = wm.take::<B,1>(&format!("{p}.gamma_2"), dev) {
103            if let Some(ref mut g) = blk.gamma_2 { *g = g.clone().map(|_| w); }
104        }
105        if let (Ok(w), Ok(b)) = (wm.take::<B,1>(&format!("{p}.norm1.weight"), dev), wm.take::<B,1>(&format!("{p}.norm1.bias"), dev)) { set_ln(&mut blk.norm1, w, b); }
106        // qkv (no bias)
107        if let Ok(w) = wm.take::<B,2>(&format!("{p}.attn.qkv.weight"), dev) { set_linear_w(&mut blk.attn.qkv, w); }
108        // q_norm, k_norm
109        if let Some(ref mut qn) = blk.attn.q_norm {
110            if let (Ok(w), Ok(b)) = (wm.take::<B,1>(&format!("{p}.attn.q_norm.weight"), dev), wm.take::<B,1>(&format!("{p}.attn.q_norm.bias"), dev)) { set_ln(qn, w, b); }
111        }
112        if let Some(ref mut kn) = blk.attn.k_norm {
113            if let (Ok(w), Ok(b)) = (wm.take::<B,1>(&format!("{p}.attn.k_norm.weight"), dev), wm.take::<B,1>(&format!("{p}.attn.k_norm.bias"), dev)) { set_ln(kn, w, b); }
114        }
115        // proj
116        if let (Ok(w), Ok(b)) = (wm.take::<B,2>(&format!("{p}.attn.proj.weight"), dev), wm.take::<B,1>(&format!("{p}.attn.proj.bias"), dev)) { set_linear_wb(&mut blk.attn.proj, w, b); }
117        // norm2
118        if let (Ok(w), Ok(b)) = (wm.take::<B,1>(&format!("{p}.norm2.weight"), dev), wm.take::<B,1>(&format!("{p}.norm2.bias"), dev)) { set_ln(&mut blk.norm2, w, b); }
119        // mlp: blocks.{i}.mlp.0 = fc1, blocks.{i}.mlp.2 = fc2
120        if let (Ok(w), Ok(b)) = (wm.take::<B,2>(&format!("{p}.mlp.0.weight"), dev), wm.take::<B,1>(&format!("{p}.mlp.0.bias"), dev)) { set_linear_wb(&mut blk.mlp_fc1, w, b); }
121        if let (Ok(w), Ok(b)) = (wm.take::<B,2>(&format!("{p}.mlp.2.weight"), dev), wm.take::<B,1>(&format!("{p}.mlp.2.bias"), dev)) { set_linear_wb(&mut blk.mlp_fc2, w, b); }
122    }
123
124    // Norm
125    if let (Ok(w), Ok(b)) = (wm.take::<B,1>("norm.weight", dev), wm.take::<B,1>("norm.bias", dev)) { set_ln(&mut m.norm, w, b); }
126    // Final linear
127    if let (Ok(w), Ok(b)) = (wm.take::<B,2>("final_layer.weight", dev), wm.take::<B,1>("final_layer.bias", dev)) { set_linear_wb(&mut m.final_linear, w, b); }
128
129    Ok(())
130}