1use std::collections::HashMap;
31use burn::prelude::*;
32use half::bf16;
33use safetensors::SafeTensors;
34
35use crate::model::reve::Reve;
36use crate::config::ModelConfig;
37
38pub struct WeightMap {
41 pub tensors: HashMap<String, (Vec<f32>, Vec<usize>)>,
42}
43
44impl WeightMap {
45 pub fn from_file(path: &str) -> anyhow::Result<Self> {
46 let bytes = std::fs::read(path)?;
47 let st = SafeTensors::deserialize(&bytes)?;
48 let mut tensors = HashMap::with_capacity(st.len());
49
50 for (raw_key, view) in st.tensors() {
51 let key = raw_key
52 .strip_prefix("model.")
53 .unwrap_or(raw_key.as_str())
54 .to_string();
55
56 let shape: Vec<usize> = view.shape().to_vec();
57 let data = view.data();
58
59 let f32s: Vec<f32> = match view.dtype() {
60 safetensors::Dtype::BF16 => data
61 .chunks_exact(2)
62 .map(|b| bf16::from_le_bytes([b[0], b[1]]).to_f32())
63 .collect(),
64 safetensors::Dtype::F32 => data
65 .chunks_exact(4)
66 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
67 .collect(),
68 safetensors::Dtype::F16 => data
69 .chunks_exact(2)
70 .map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32())
71 .collect(),
72 other => anyhow::bail!("unsupported dtype {:?} for key {key}", other),
73 };
74
75 tensors.insert(key, (f32s, shape));
76 }
77
78 Ok(Self { tensors })
79 }
80
81 pub fn take<B: Backend, const N: usize>(
82 &mut self,
83 key: &str,
84 device: &B::Device,
85 ) -> anyhow::Result<Tensor<B, N>> {
86 let (data, shape) = self.tensors.remove(key)
87 .ok_or_else(|| anyhow::anyhow!("weight key not found: {key}"))?;
88 if shape.len() != N {
89 anyhow::bail!("rank mismatch for {key}: expected {N}, got {}", shape.len());
90 }
91 Ok(Tensor::<B, N>::from_data(TensorData::new(data, shape), device))
92 }
93
94 pub fn has(&self, key: &str) -> bool {
95 self.tensors.contains_key(key)
96 }
97
98 pub fn print_keys(&self) {
99 let mut keys: Vec<&str> = self.tensors.keys().map(String::as_str).collect();
100 keys.sort();
101 for k in keys {
102 let (_, s) = &self.tensors[k];
103 println!(" {k:80} {s:?}");
104 }
105 }
106}
107
108fn set_linear_w<B: Backend>(linear: &mut burn::nn::Linear<B>, w: Tensor<B, 2>) {
112 linear.weight = linear.weight.clone().map(|_| w.transpose());
113}
114
115fn set_linear_wb<B: Backend>(linear: &mut burn::nn::Linear<B>, w: Tensor<B, 2>, b: Tensor<B, 1>) {
116 linear.weight = linear.weight.clone().map(|_| w.transpose());
117 if let Some(ref bias) = linear.bias {
118 linear.bias = Some(bias.clone().map(|_| b));
119 }
120}
121
122fn set_layernorm<B: Backend>(norm: &mut burn::nn::LayerNorm<B>, w: Tensor<B, 1>, b: Tensor<B, 1>) {
123 norm.gamma = norm.gamma.clone().map(|_| w);
124 if let Some(ref beta) = norm.beta {
125 norm.beta = Some(beta.clone().map(|_| b));
126 }
127}
128
129fn set_rmsnorm<B: Backend>(norm: &mut crate::model::rms_norm::RmsNorm<B>, w: Tensor<B, 1>) {
130 norm.weight = norm.weight.clone().map(|_| w);
131}
132
133pub fn load_model<B: Backend>(
137 cfg: &ModelConfig,
138 weights_path: &str,
139 device: &B::Device,
140) -> anyhow::Result<Reve<B>> {
141 let mut wm = WeightMap::from_file(weights_path)?;
142 eprintln!("Loading {} weight tensors...", wm.tensors.len());
143 load_model_from_wm(cfg, &mut wm, device)
144}
145
146pub fn load_model_from_wm<B: Backend>(
147 cfg: &ModelConfig,
148 wm: &mut WeightMap,
149 device: &B::Device,
150) -> anyhow::Result<Reve<B>> {
151 let mut model = Reve::new(
152 cfg.n_outputs,
153 cfg.n_chans,
154 cfg.n_times,
155 cfg.embed_dim,
156 cfg.depth,
157 cfg.heads,
158 cfg.head_dim,
159 cfg.mlp_dim_ratio,
160 cfg.use_geglu,
161 cfg.freqs,
162 cfg.patch_size,
163 cfg.patch_overlap,
164 cfg.attention_pooling,
165 device,
166 );
167
168 load_reve_weights(wm, &mut model, cfg, device)?;
169 Ok(model)
170}
171
172fn load_reve_weights<B: Backend>(
173 wm: &mut WeightMap,
174 model: &mut Reve<B>,
175 cfg: &ModelConfig,
176 device: &B::Device,
177) -> anyhow::Result<()> {
178 if let (Ok(w), Ok(b)) = (
182 wm.take::<B, 2>("to_patch_embedding.0.weight", device),
183 wm.take::<B, 1>("to_patch_embedding.0.bias", device),
184 ) {
185 set_linear_wb(&mut model.patch_embed, w, b);
186 }
187
188 if let Ok(w) = wm.take::<B, 2>("mlp4d.0.weight", device) {
191 set_linear_w(&mut model.mlp4d_linear, w);
192 }
193 if let (Ok(w), Ok(b)) = (
195 wm.take::<B, 1>("mlp4d.2.weight", device),
196 wm.take::<B, 1>("mlp4d.2.bias", device),
197 ) {
198 set_layernorm(&mut model.mlp4d_ln, w, b);
199 }
200
201 if let (Ok(w), Ok(b)) = (
204 wm.take::<B, 1>("ln.weight", device),
205 wm.take::<B, 1>("ln.bias", device),
206 ) {
207 set_layernorm(&mut model.pos_ln, w, b);
208 }
209
210 for i in 0..cfg.depth {
212 let block = &mut model.transformer.layers[i];
213
214 if let Ok(w) = wm.take::<B, 1>(
217 &format!("transformer.layers.{i}.0.norm.weight"), device,
218 ) {
219 set_rmsnorm(&mut block.attn.norm, w);
220 }
221 if let Ok(w) = wm.take::<B, 2>(
223 &format!("transformer.layers.{i}.0.to_qkv.weight"), device,
224 ) {
225 set_linear_w(&mut block.attn.to_qkv, w);
226 }
227 if let Ok(w) = wm.take::<B, 2>(
229 &format!("transformer.layers.{i}.0.to_out.weight"), device,
230 ) {
231 set_linear_w(&mut block.attn.to_out, w);
232 }
233
234 if let Ok(w) = wm.take::<B, 1>(
237 &format!("transformer.layers.{i}.1.net.0.weight"), device,
238 ) {
239 set_rmsnorm(&mut block.ff.norm, w);
240 }
241 if let Ok(w) = wm.take::<B, 2>(
243 &format!("transformer.layers.{i}.1.net.1.weight"), device,
244 ) {
245 set_linear_w(&mut block.ff.linear1, w);
246 }
247 if let Ok(w) = wm.take::<B, 2>(
249 &format!("transformer.layers.{i}.1.net.3.weight"), device,
250 ) {
251 set_linear_w(&mut block.ff.linear2, w);
252 }
253 }
254
255 if cfg.attention_pooling {
257 if let Ok(t) = wm.take::<B, 3>("cls_query_token", device) {
259 if let Some(ref mut q) = model.cls_query_token {
260 *q = q.clone().map(|_| t);
261 }
262 }
263 if let (Ok(w), Ok(b)) = (
265 wm.take::<B, 1>("final_layer.0.weight", device),
266 wm.take::<B, 1>("final_layer.0.bias", device),
267 ) {
268 set_layernorm(&mut model.final_ln, w, b);
269 }
270 if let (Ok(w), Ok(b)) = (
271 wm.take::<B, 2>("final_layer.1.weight", device),
272 wm.take::<B, 1>("final_layer.1.bias", device),
273 ) {
274 set_linear_wb(&mut model.final_linear, w, b);
275 }
276 } else {
277 if let (Ok(w), Ok(b)) = (
280 wm.take::<B, 1>("final_layer.1.weight", device),
281 wm.take::<B, 1>("final_layer.1.bias", device),
282 ) {
283 set_layernorm(&mut model.final_ln, w, b);
284 }
285 if let (Ok(w), Ok(b)) = (
286 wm.take::<B, 2>("final_layer.2.weight", device),
287 wm.take::<B, 1>("final_layer.2.bias", device),
288 ) {
289 set_linear_wb(&mut model.final_linear, w, b);
290 }
291 }
292
293 Ok(())
294}