Skip to main content

reve_rs/
weights.rs

1/// Load pretrained REVE weights from a safetensors file.
2///
3/// Weight key patterns (from Python state_dict):
4///
5///   to_patch_embedding.0.weight                   [embed_dim, patch_size]
6///   to_patch_embedding.0.bias                     [embed_dim]
7///   mlp4d.0.weight                                [embed_dim, 4]
8///   mlp4d.2.weight                                [embed_dim]
9///   mlp4d.2.bias                                  [embed_dim]
10///   ln.weight                                     [embed_dim]
11///   ln.bias                                       [embed_dim]
12///   transformer.layers.{i}.0.norm.weight          [embed_dim]          (RMSNorm)
13///   transformer.layers.{i}.0.to_qkv.weight        [3*inner, embed_dim] (no bias)
14///   transformer.layers.{i}.0.to_out.weight         [embed_dim, inner]   (no bias)
15///   transformer.layers.{i}.1.net.0.weight          [embed_dim]          (RMSNorm)
16///   transformer.layers.{i}.1.net.1.weight          [geglu_dim, embed_dim] (no bias)
17///   transformer.layers.{i}.1.net.3.weight          [embed_dim, mlp_dim]   (no bias)
18///   final_layer.1.weight                           [final_dim]           (LayerNorm)
19///   final_layer.1.bias                             [final_dim]
20///   final_layer.2.weight                           [n_outputs, final_dim]
21///   final_layer.2.bias                             [n_outputs]
22///
23///   For attention_pooling:
24///   cls_query_token                                [1, 1, embed_dim]
25///   final_layer.0.weight                           [embed_dim]
26///   final_layer.0.bias                             [embed_dim]
27///   final_layer.1.weight                           [n_outputs, embed_dim]
28///   final_layer.1.bias                             [n_outputs]
29
30use std::collections::HashMap;
31use burn::prelude::*;
32use half::bf16;
33use safetensors::SafeTensors;
34
35use crate::model::reve::Reve;
36use crate::config::ModelConfig;
37
38// ── WeightMap ─────────────────────────────────────────────────────────────────
39
40pub struct WeightMap {
41    pub tensors: HashMap<String, (Vec<f32>, Vec<usize>)>,
42}
43
44impl WeightMap {
45    pub fn from_file(path: &str) -> anyhow::Result<Self> {
46        let bytes = std::fs::read(path)?;
47        let st = SafeTensors::deserialize(&bytes)?;
48        let mut tensors = HashMap::with_capacity(st.len());
49
50        for (raw_key, view) in st.tensors() {
51            let key = raw_key
52                .strip_prefix("model.")
53                .unwrap_or(raw_key.as_str())
54                .to_string();
55
56            let shape: Vec<usize> = view.shape().to_vec();
57            let data = view.data();
58
59            let f32s: Vec<f32> = match view.dtype() {
60                safetensors::Dtype::BF16 => data
61                    .chunks_exact(2)
62                    .map(|b| bf16::from_le_bytes([b[0], b[1]]).to_f32())
63                    .collect(),
64                safetensors::Dtype::F32 => data
65                    .chunks_exact(4)
66                    .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
67                    .collect(),
68                safetensors::Dtype::F16 => data
69                    .chunks_exact(2)
70                    .map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32())
71                    .collect(),
72                other => anyhow::bail!("unsupported dtype {:?} for key {key}", other),
73            };
74
75            tensors.insert(key, (f32s, shape));
76        }
77
78        Ok(Self { tensors })
79    }
80
81    pub fn take<B: Backend, const N: usize>(
82        &mut self,
83        key: &str,
84        device: &B::Device,
85    ) -> anyhow::Result<Tensor<B, N>> {
86        let (data, shape) = self.tensors.remove(key)
87            .ok_or_else(|| anyhow::anyhow!("weight key not found: {key}"))?;
88        if shape.len() != N {
89            anyhow::bail!("rank mismatch for {key}: expected {N}, got {}", shape.len());
90        }
91        Ok(Tensor::<B, N>::from_data(TensorData::new(data, shape), device))
92    }
93
94    pub fn has(&self, key: &str) -> bool {
95        self.tensors.contains_key(key)
96    }
97
98    pub fn print_keys(&self) {
99        let mut keys: Vec<&str> = self.tensors.keys().map(String::as_str).collect();
100        keys.sort();
101        for k in keys {
102            let (_, s) = &self.tensors[k];
103            println!("  {k:80}  {s:?}");
104        }
105    }
106}
107
108// ── Weight assignment helpers ─────────────────────────────────────────────────
109
110/// PyTorch [out, in] → burn [in, out]
111fn set_linear_w<B: Backend>(linear: &mut burn::nn::Linear<B>, w: Tensor<B, 2>) {
112    linear.weight = linear.weight.clone().map(|_| w.transpose());
113}
114
115fn set_linear_wb<B: Backend>(linear: &mut burn::nn::Linear<B>, w: Tensor<B, 2>, b: Tensor<B, 1>) {
116    linear.weight = linear.weight.clone().map(|_| w.transpose());
117    if let Some(ref bias) = linear.bias {
118        linear.bias = Some(bias.clone().map(|_| b));
119    }
120}
121
122fn set_layernorm<B: Backend>(norm: &mut burn::nn::LayerNorm<B>, w: Tensor<B, 1>, b: Tensor<B, 1>) {
123    norm.gamma = norm.gamma.clone().map(|_| w);
124    if let Some(ref beta) = norm.beta {
125        norm.beta = Some(beta.clone().map(|_| b));
126    }
127}
128
129fn set_rmsnorm<B: Backend>(norm: &mut crate::model::rms_norm::RmsNorm<B>, w: Tensor<B, 1>) {
130    norm.weight = norm.weight.clone().map(|_| w);
131}
132
133// ── Full model loader ─────────────────────────────────────────────────────────
134
135/// Load a REVE model from a safetensors file.
136pub fn load_model<B: Backend>(
137    cfg: &ModelConfig,
138    weights_path: &str,
139    device: &B::Device,
140) -> anyhow::Result<Reve<B>> {
141    let mut wm = WeightMap::from_file(weights_path)?;
142    eprintln!("Loading {} weight tensors...", wm.tensors.len());
143    load_model_from_wm(cfg, &mut wm, device)
144}
145
146pub fn load_model_from_wm<B: Backend>(
147    cfg: &ModelConfig,
148    wm: &mut WeightMap,
149    device: &B::Device,
150) -> anyhow::Result<Reve<B>> {
151    let mut model = Reve::new(
152        cfg.n_outputs,
153        cfg.n_chans,
154        cfg.n_times,
155        cfg.embed_dim,
156        cfg.depth,
157        cfg.heads,
158        cfg.head_dim,
159        cfg.mlp_dim_ratio,
160        cfg.use_geglu,
161        cfg.freqs,
162        cfg.patch_size,
163        cfg.patch_overlap,
164        cfg.attention_pooling,
165        device,
166    );
167
168    load_reve_weights(wm, &mut model, cfg, device)?;
169    Ok(model)
170}
171
172fn load_reve_weights<B: Backend>(
173    wm: &mut WeightMap,
174    model: &mut Reve<B>,
175    cfg: &ModelConfig,
176    device: &B::Device,
177) -> anyhow::Result<()> {
178    // ── Patch embedding ─────────────────────────────────────────────────────
179    // to_patch_embedding.0.weight [embed_dim, patch_size]
180    // to_patch_embedding.0.bias   [embed_dim]
181    if let (Ok(w), Ok(b)) = (
182        wm.take::<B, 2>("to_patch_embedding.0.weight", device),
183        wm.take::<B, 1>("to_patch_embedding.0.bias", device),
184    ) {
185        set_linear_wb(&mut model.patch_embed, w, b);
186    }
187
188    // ── MLP4D (positional encoding MLP) ─────────────────────────────────────
189    // mlp4d.0.weight [embed_dim, 4] (no bias)
190    if let Ok(w) = wm.take::<B, 2>("mlp4d.0.weight", device) {
191        set_linear_w(&mut model.mlp4d_linear, w);
192    }
193    // mlp4d.2.weight [embed_dim], mlp4d.2.bias [embed_dim]  (LayerNorm)
194    if let (Ok(w), Ok(b)) = (
195        wm.take::<B, 1>("mlp4d.2.weight", device),
196        wm.take::<B, 1>("mlp4d.2.bias", device),
197    ) {
198        set_layernorm(&mut model.mlp4d_ln, w, b);
199    }
200
201    // ── 4DPE output LayerNorm ───────────────────────────────────────────────
202    // ln.weight [embed_dim], ln.bias [embed_dim]
203    if let (Ok(w), Ok(b)) = (
204        wm.take::<B, 1>("ln.weight", device),
205        wm.take::<B, 1>("ln.bias", device),
206    ) {
207        set_layernorm(&mut model.pos_ln, w, b);
208    }
209
210    // ── Transformer layers ──────────────────────────────────────────────────
211    for i in 0..cfg.depth {
212        let block = &mut model.transformer.layers[i];
213
214        // Attention: layers.{i}.0
215        // layers.{i}.0.norm.weight  (RMSNorm, no bias)
216        if let Ok(w) = wm.take::<B, 1>(
217            &format!("transformer.layers.{i}.0.norm.weight"), device,
218        ) {
219            set_rmsnorm(&mut block.attn.norm, w);
220        }
221        // layers.{i}.0.to_qkv.weight [3*inner_dim, embed_dim] (no bias)
222        if let Ok(w) = wm.take::<B, 2>(
223            &format!("transformer.layers.{i}.0.to_qkv.weight"), device,
224        ) {
225            set_linear_w(&mut block.attn.to_qkv, w);
226        }
227        // layers.{i}.0.to_out.weight [embed_dim, inner_dim] (no bias)
228        if let Ok(w) = wm.take::<B, 2>(
229            &format!("transformer.layers.{i}.0.to_out.weight"), device,
230        ) {
231            set_linear_w(&mut block.attn.to_out, w);
232        }
233
234        // FeedForward: layers.{i}.1
235        // layers.{i}.1.net.0.weight  (RMSNorm, no bias)
236        if let Ok(w) = wm.take::<B, 1>(
237            &format!("transformer.layers.{i}.1.net.0.weight"), device,
238        ) {
239            set_rmsnorm(&mut block.ff.norm, w);
240        }
241        // layers.{i}.1.net.1.weight [geglu_dim, embed_dim] (no bias)
242        if let Ok(w) = wm.take::<B, 2>(
243            &format!("transformer.layers.{i}.1.net.1.weight"), device,
244        ) {
245            set_linear_w(&mut block.ff.linear1, w);
246        }
247        // layers.{i}.1.net.3.weight [embed_dim, mlp_dim] (no bias)
248        if let Ok(w) = wm.take::<B, 2>(
249            &format!("transformer.layers.{i}.1.net.3.weight"), device,
250        ) {
251            set_linear_w(&mut block.ff.linear2, w);
252        }
253    }
254
255    // ── Classification head ─────────────────────────────────────────────────
256    if cfg.attention_pooling {
257        // cls_query_token [1, 1, embed_dim]
258        if let Ok(t) = wm.take::<B, 3>("cls_query_token", device) {
259            if let Some(ref mut q) = model.cls_query_token {
260                *q = q.clone().map(|_| t);
261            }
262        }
263        // final_layer.0 = LayerNorm, final_layer.1 = Linear
264        if let (Ok(w), Ok(b)) = (
265            wm.take::<B, 1>("final_layer.0.weight", device),
266            wm.take::<B, 1>("final_layer.0.bias", device),
267        ) {
268            set_layernorm(&mut model.final_ln, w, b);
269        }
270        if let (Ok(w), Ok(b)) = (
271            wm.take::<B, 2>("final_layer.1.weight", device),
272            wm.take::<B, 1>("final_layer.1.bias", device),
273        ) {
274            set_linear_wb(&mut model.final_linear, w, b);
275        }
276    } else {
277        // final_layer: Flatten, LayerNorm, Linear
278        // final_layer.1 = LayerNorm, final_layer.2 = Linear
279        if let (Ok(w), Ok(b)) = (
280            wm.take::<B, 1>("final_layer.1.weight", device),
281            wm.take::<B, 1>("final_layer.1.bias", device),
282        ) {
283            set_layernorm(&mut model.final_ln, w, b);
284        }
285        if let (Ok(w), Ok(b)) = (
286            wm.take::<B, 2>("final_layer.2.weight", device),
287            wm.take::<B, 1>("final_layer.2.bias", device),
288        ) {
289            set_linear_wb(&mut model.final_linear, w, b);
290        }
291    }
292
293    Ok(())
294}