1use super::config::Vjepa2Config;
19use super::preprocess::{Vjepa2PatchEmbedWeights, extract_patch_embed_weights};
20use anyhow::{Result, ensure};
21use rlx_core::weight_map::WeightMap;
22
23#[derive(Clone)]
24pub struct Vjepa2BlockWeights {
25 pub norm1_w: Vec<f32>,
26 pub norm1_b: Vec<f32>,
27 pub q_w_t: Vec<f32>,
28 pub q_b: Vec<f32>,
29 pub k_w_t: Vec<f32>,
30 pub k_b: Vec<f32>,
31 pub v_w_t: Vec<f32>,
32 pub v_b: Vec<f32>,
33 pub proj_w_t: Vec<f32>,
34 pub proj_b: Vec<f32>,
35 pub norm2_w: Vec<f32>,
36 pub norm2_b: Vec<f32>,
37 pub mlp_fc1_w_t: Vec<f32>,
38 pub mlp_fc1_b: Vec<f32>,
39 pub mlp_fc2_w_t: Vec<f32>,
40 pub mlp_fc2_b: Vec<f32>,
41}
42
43#[derive(Clone)]
44pub struct Vjepa2EncoderWeights {
45 pub patch: Vjepa2PatchEmbedWeights,
46 pub blocks: Vec<Vjepa2BlockWeights>,
47 pub norm_w: Vec<f32>,
48 pub norm_b: Vec<f32>,
49}
50
51#[derive(Clone)]
52pub struct Vjepa2PredictorWeights {
53 pub embed_w_t: Vec<f32>,
54 pub embed_b: Vec<f32>,
55 pub mask_tokens: Vec<f32>,
56 pub blocks: Vec<Vjepa2BlockWeights>,
57 pub norm_w: Vec<f32>,
58 pub norm_b: Vec<f32>,
59 pub proj_w_t: Vec<f32>,
60 pub proj_b: Vec<f32>,
61}
62
63#[derive(Clone)]
64pub struct Vjepa2PoolerSelfBlockWeights {
65 pub norm1_w: Vec<f32>,
66 pub norm1_b: Vec<f32>,
67 pub q_w_t: Vec<f32>,
68 pub q_b: Vec<f32>,
69 pub k_w_t: Vec<f32>,
70 pub k_b: Vec<f32>,
71 pub v_w_t: Vec<f32>,
72 pub v_b: Vec<f32>,
73 pub out_w_t: Vec<f32>,
74 pub out_b: Vec<f32>,
75 pub norm2_w: Vec<f32>,
76 pub norm2_b: Vec<f32>,
77 pub mlp_fc1_w_t: Vec<f32>,
78 pub mlp_fc1_b: Vec<f32>,
79 pub mlp_fc2_w_t: Vec<f32>,
80 pub mlp_fc2_b: Vec<f32>,
81}
82
83#[derive(Clone)]
84pub struct Vjepa2PoolerCrossWeights {
85 pub norm1_w: Vec<f32>,
86 pub norm1_b: Vec<f32>,
87 pub q_w_t: Vec<f32>,
88 pub q_b: Vec<f32>,
89 pub k_w_t: Vec<f32>,
90 pub k_b: Vec<f32>,
91 pub v_w_t: Vec<f32>,
92 pub v_b: Vec<f32>,
93 pub norm2_w: Vec<f32>,
94 pub norm2_b: Vec<f32>,
95 pub mlp_fc1_w_t: Vec<f32>,
96 pub mlp_fc1_b: Vec<f32>,
97 pub mlp_fc2_w_t: Vec<f32>,
98 pub mlp_fc2_b: Vec<f32>,
99}
100
101#[derive(Clone)]
102pub struct Vjepa2PoolerWeights {
103 pub query_tokens: Vec<f32>,
104 pub self_blocks: Vec<Vjepa2PoolerSelfBlockWeights>,
105 pub cross: Vjepa2PoolerCrossWeights,
106 pub classifier_w_t: Option<Vec<f32>>,
107 pub classifier_b: Option<Vec<f32>>,
108}
109
110#[derive(Clone)]
111pub struct Vjepa2ModelWeights {
112 pub encoder: Vjepa2EncoderWeights,
113 pub predictor: Option<Vjepa2PredictorWeights>,
114 pub pooler: Option<Vjepa2PoolerWeights>,
115}
116
117pub fn extract_encoder_weights(
118 weights: &mut WeightMap,
119 cfg: &Vjepa2Config,
120) -> Result<Vjepa2EncoderWeights> {
121 let patch = extract_patch_embed_weights(weights, cfg)?;
122 let e = cfg.hidden_size;
123 let hidden = cfg.intermediate_size();
124 let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
125
126 for i in 0..cfg.num_hidden_layers {
127 let hf = format!("encoder.layer.{i}");
128 let meta = format!("blocks.{i}");
129 blocks.push(extract_transformer_block(
130 weights,
131 &[hf, meta],
132 e,
133 hidden,
134 "attention",
135 "attn",
136 )?);
137 }
138
139 let norm_w = take_first_vec(
140 weights,
141 &["encoder.layernorm.weight", "norm.weight"],
142 vec![e],
143 )?;
144 let norm_b = take_first_vec(weights, &["encoder.layernorm.bias", "norm.bias"], vec![e])?;
145
146 Ok(Vjepa2EncoderWeights {
147 patch,
148 blocks,
149 norm_w,
150 norm_b,
151 })
152}
153
154pub fn extract_predictor_weights(
155 weights: &mut WeightMap,
156 cfg: &Vjepa2Config,
157) -> Result<Vjepa2PredictorWeights> {
158 let enc = cfg.hidden_size;
159 let pred = cfg.pred_hidden_size;
160 let hidden = cfg.pred_intermediate_size();
161
162 let embed_key = pick_key(
163 weights,
164 &[
165 "predictor.embeddings.predictor_embeddings.weight",
166 "predictor_embed.weight",
167 ],
168 )?;
169 let embed_w_t = take_linear_w_key(weights, &embed_key, enc, pred)?;
170 let embed_b = take_first_vec(
171 weights,
172 &[
173 "predictor.embeddings.predictor_embeddings.bias",
174 "predictor_embed.bias",
175 ],
176 vec![pred],
177 )?;
178
179 let n_masks = cfg.pred_num_mask_tokens;
180 let mask_tokens = take_first_vec(
181 weights,
182 &["predictor.embeddings.mask_tokens", "mask_tokens"],
183 vec![n_masks, 1, 1, pred],
184 )?;
185
186 let mut blocks = Vec::with_capacity(cfg.pred_num_hidden_layers);
187 for i in 0..cfg.pred_num_hidden_layers {
188 let hf = format!("predictor.layer.{i}");
189 let meta = format!("predictor_blocks.{i}");
190 blocks.push(extract_transformer_block(
191 weights,
192 &[hf, meta],
193 pred,
194 hidden,
195 "attention",
196 "attn",
197 )?);
198 }
199
200 let norm_w = take_first_vec(
201 weights,
202 &["predictor.layernorm.weight", "predictor_norm.weight"],
203 vec![pred],
204 )?;
205 let norm_b = take_first_vec(
206 weights,
207 &["predictor.layernorm.bias", "predictor_norm.bias"],
208 vec![pred],
209 )?;
210 let proj_key = pick_key(weights, &["predictor.proj.weight", "predictor_proj.weight"])?;
211 let proj_w_t = take_linear_w_key(weights, &proj_key, pred, enc)?;
212 let proj_b = take_first_vec(
213 weights,
214 &["predictor.proj.bias", "predictor_proj.bias"],
215 vec![enc],
216 )?;
217
218 Ok(Vjepa2PredictorWeights {
219 embed_w_t,
220 embed_b,
221 mask_tokens,
222 blocks,
223 norm_w,
224 norm_b,
225 proj_w_t,
226 proj_b,
227 })
228}
229
230pub fn extract_pooler_weights(
231 weights: &mut WeightMap,
232 cfg: &Vjepa2Config,
233) -> Result<Vjepa2PoolerWeights> {
234 let e = cfg.hidden_size;
235 let hidden = cfg.pooler_intermediate_size();
236
237 let query_tokens = take_first_vec(weights, &["pooler.query_tokens"], vec![1, 1, e])?;
238
239 let mut self_blocks = Vec::with_capacity(cfg.num_pooler_layers);
240 for i in 0..cfg.num_pooler_layers {
241 let p = format!("pooler.self_attention_layers.{i}");
242 self_blocks.push(Vjepa2PoolerSelfBlockWeights {
243 norm1_w: take_ln_w(weights, &[&p], "layer_norm1", e)?,
244 norm1_b: take_ln_b(weights, &[&p], "layer_norm1", e)?,
245 q_w_t: take_linear_w_key(weights, &format!("{p}.self_attn.q_proj.weight"), e, e)?,
246 q_b: take_first_vec(weights, &[&format!("{p}.self_attn.q_proj.bias")], vec![e])?,
247 k_w_t: take_linear_w_key(weights, &format!("{p}.self_attn.k_proj.weight"), e, e)?,
248 k_b: take_first_vec(weights, &[&format!("{p}.self_attn.k_proj.bias")], vec![e])?,
249 v_w_t: take_linear_w_key(weights, &format!("{p}.self_attn.v_proj.weight"), e, e)?,
250 v_b: take_first_vec(weights, &[&format!("{p}.self_attn.v_proj.bias")], vec![e])?,
251 out_w_t: take_linear_w_key(weights, &format!("{p}.self_attn.out_proj.weight"), e, e)?,
252 out_b: take_first_vec(weights, &[&format!("{p}.self_attn.out_proj.bias")], vec![e])?,
253 norm2_w: take_ln_w(weights, &[&p], "layer_norm2", e)?,
254 norm2_b: take_ln_b(weights, &[&p], "layer_norm2", e)?,
255 mlp_fc1_w_t: take_linear_w_key(weights, &format!("{p}.mlp.fc1.weight"), e, hidden)?,
256 mlp_fc1_b: take_first_vec(weights, &[&format!("{p}.mlp.fc1.bias")], vec![hidden])?,
257 mlp_fc2_w_t: take_linear_w_key(weights, &format!("{p}.mlp.fc2.weight"), hidden, e)?,
258 mlp_fc2_b: take_first_vec(weights, &[&format!("{p}.mlp.fc2.bias")], vec![e])?,
259 });
260 }
261
262 let cp = "pooler.cross_attention_layer";
263 let cross = Vjepa2PoolerCrossWeights {
264 norm1_w: take_ln_w(weights, &[cp], "layer_norm1", e)?,
265 norm1_b: take_ln_b(weights, &[cp], "layer_norm1", e)?,
266 q_w_t: take_linear_w_key(weights, &format!("{cp}.cross_attn.q_proj.weight"), e, e)?,
267 q_b: take_first_vec(weights, &[&format!("{cp}.cross_attn.q_proj.bias")], vec![e])?,
268 k_w_t: take_linear_w_key(weights, &format!("{cp}.cross_attn.k_proj.weight"), e, e)?,
269 k_b: take_first_vec(weights, &[&format!("{cp}.cross_attn.k_proj.bias")], vec![e])?,
270 v_w_t: take_linear_w_key(weights, &format!("{cp}.cross_attn.v_proj.weight"), e, e)?,
271 v_b: take_first_vec(weights, &[&format!("{cp}.cross_attn.v_proj.bias")], vec![e])?,
272 norm2_w: take_ln_w(weights, &[cp], "layer_norm2", e)?,
273 norm2_b: take_ln_b(weights, &[cp], "layer_norm2", e)?,
274 mlp_fc1_w_t: take_linear_w_key(weights, &format!("{cp}.mlp.fc1.weight"), e, hidden)?,
275 mlp_fc1_b: take_first_vec(weights, &[&format!("{cp}.mlp.fc1.bias")], vec![hidden])?,
276 mlp_fc2_w_t: take_linear_w_key(weights, &format!("{cp}.mlp.fc2.weight"), hidden, e)?,
277 mlp_fc2_b: take_first_vec(weights, &[&format!("{cp}.mlp.fc2.bias")], vec![e])?,
278 };
279
280 let classifier_w_t = if weights.has("classifier.weight") {
281 let (data, shape) = weights.take_transposed("classifier.weight")?;
282 ensure!(shape[1] == e, "classifier weight second dim must be {e}");
283 Some(data)
284 } else {
285 None
286 };
287 let classifier_b = if weights.has("classifier.bias") {
288 let (data, shape) = weights.take("classifier.bias")?;
289 ensure!(shape.len() == 1, "classifier bias must be 1d");
290 Some(data)
291 } else {
292 None
293 };
294
295 Ok(Vjepa2PoolerWeights {
296 query_tokens,
297 self_blocks,
298 cross,
299 classifier_w_t,
300 classifier_b,
301 })
302}
303
304pub fn extract_model_weights(
305 weights: &mut WeightMap,
306 cfg: &Vjepa2Config,
307) -> Result<Vjepa2ModelWeights> {
308 let encoder = extract_encoder_weights(weights, cfg)?;
309 let predictor = if weights.has("predictor.layer.0.attention.query.weight")
310 || weights.has("predictor_blocks.0.attn.qkv.weight")
311 {
312 Some(extract_predictor_weights(weights, cfg)?)
313 } else {
314 None
315 };
316 let pooler = if weights.has("pooler.query_tokens") {
317 Some(extract_pooler_weights(weights, cfg)?)
318 } else {
319 None
320 };
321 Ok(Vjepa2ModelWeights {
322 encoder,
323 predictor,
324 pooler,
325 })
326}
327
328pub(crate) fn extract_transformer_block(
329 weights: &mut WeightMap,
330 prefixes: &[String],
331 embed: usize,
332 hidden: usize,
333 attn_hf: &str,
334 attn_meta: &str,
335) -> Result<Vjepa2BlockWeights> {
336 let pref_refs: Vec<&str> = prefixes.iter().map(String::as_str).collect();
337 Ok(Vjepa2BlockWeights {
338 norm1_w: take_ln_w(weights, &pref_refs, "norm1", embed)?,
339 norm1_b: take_ln_b(weights, &pref_refs, "norm1", embed)?,
340 q_w_t: take_linear_w(
341 weights, &pref_refs, "query", embed, embed, attn_hf, attn_meta,
342 )?,
343 q_b: take_linear_b(weights, &pref_refs, "query", embed, attn_hf, attn_meta)?,
344 k_w_t: take_linear_w(weights, &pref_refs, "key", embed, embed, attn_hf, attn_meta)?,
345 k_b: take_linear_b(weights, &pref_refs, "key", embed, attn_hf, attn_meta)?,
346 v_w_t: take_linear_w(
347 weights, &pref_refs, "value", embed, embed, attn_hf, attn_meta,
348 )?,
349 v_b: take_linear_b(weights, &pref_refs, "value", embed, attn_hf, attn_meta)?,
350 proj_w_t: take_attn_proj_w(weights, &pref_refs, embed, attn_hf, attn_meta)?,
351 proj_b: take_attn_proj_b(weights, &pref_refs, embed, attn_hf, attn_meta)?,
352 norm2_w: take_ln_w(weights, &pref_refs, "norm2", embed)?,
353 norm2_b: take_ln_b(weights, &pref_refs, "norm2", embed)?,
354 mlp_fc1_w_t: take_mlp_w(weights, &pref_refs, "fc1", embed, hidden)?,
355 mlp_fc1_b: take_mlp_b(weights, &pref_refs, "fc1", hidden)?,
356 mlp_fc2_w_t: take_mlp_w(weights, &pref_refs, "fc2", hidden, embed)?,
357 mlp_fc2_b: take_mlp_b(weights, &pref_refs, "fc2", embed)?,
358 })
359}
360
361fn pick_key(weights: &WeightMap, keys: &[&str]) -> Result<String> {
362 for k in keys {
363 if weights.has(k) {
364 return Ok((*k).to_string());
365 }
366 }
367 anyhow::bail!("none of keys found: {keys:?}")
368}
369
370fn take_attn_proj_w(
371 weights: &mut WeightMap,
372 prefixes: &[&str],
373 e: usize,
374 attn_hf: &str,
375 attn_meta: &str,
376) -> Result<Vec<f32>> {
377 for p in prefixes {
378 let hf = format!("{p}.{attn_hf}.proj.weight");
379 if weights.has(&hf) {
380 return take_linear_w_key(weights, &hf, e, e);
381 }
382 let meta = format!("{p}.{attn_meta}.proj.weight");
383 if weights.has(&meta) {
384 return take_linear_w_key(weights, &meta, e, e);
385 }
386 }
387 anyhow::bail!("attention proj weight not found for {prefixes:?}")
388}
389
390fn take_attn_proj_b(
391 weights: &mut WeightMap,
392 prefixes: &[&str],
393 e: usize,
394 attn_hf: &str,
395 attn_meta: &str,
396) -> Result<Vec<f32>> {
397 for p in prefixes {
398 for suffix in [
399 format!("{attn_hf}.proj.bias"),
400 format!("{attn_meta}.proj.bias"),
401 ] {
402 let key = format!("{p}.{suffix}");
403 if weights.has(&key) {
404 let (data, shape) = weights.take(&key)?;
405 ensure!(shape == vec![e]);
406 return Ok(data);
407 }
408 }
409 }
410 anyhow::bail!("attention proj bias not found")
411}
412
413fn take_linear_w(
414 weights: &mut WeightMap,
415 prefixes: &[&str],
416 name: &str,
417 in_dim: usize,
418 out_dim: usize,
419 attn_hf: &str,
420 attn_meta: &str,
421) -> Result<Vec<f32>> {
422 for p in prefixes {
423 let hf = format!("{p}.{attn_hf}.{name}.weight");
424 if weights.has(&hf) {
425 return take_linear_w_key(weights, &hf, in_dim, out_dim);
426 }
427 }
428 for p in prefixes {
429 if !p.starts_with("blocks.") && !p.starts_with("predictor_blocks.") {
430 continue;
431 }
432 let key = format!("{p}.{attn_meta}.qkv.weight");
433 if weights.has(&key) {
434 let (data, shape) = weights.take_transposed(&key)?;
435 ensure!(shape == vec![in_dim, 3 * out_dim]);
436 return Ok(split_qkv_w(&data, in_dim, out_dim, name));
437 }
438 }
439 anyhow::bail!("linear weight {name} not found for {prefixes:?}")
440}
441
442fn take_linear_b(
443 weights: &mut WeightMap,
444 prefixes: &[&str],
445 name: &str,
446 dim: usize,
447 attn_hf: &str,
448 attn_meta: &str,
449) -> Result<Vec<f32>> {
450 for p in prefixes {
451 let hf = format!("{p}.{attn_hf}.{name}.bias");
452 if weights.has(&hf) {
453 let (data, shape) = weights.take(&hf)?;
454 ensure!(shape == vec![dim]);
455 return Ok(data);
456 }
457 }
458 for p in prefixes {
459 if !p.starts_with("blocks.") && !p.starts_with("predictor_blocks.") {
460 continue;
461 }
462 let key = format!("{p}.{attn_meta}.qkv.bias");
463 if weights.has(&key) {
464 let (data, shape) = weights.take(&key)?;
465 ensure!(shape == vec![3 * dim]);
466 return Ok(split_qkv_b(&data, dim, name));
467 }
468 }
469 anyhow::bail!("linear bias {name} not found")
470}
471
472fn split_qkv_w(data: &[f32], in_dim: usize, out_dim: usize, which: &str) -> Vec<f32> {
473 let off = match which {
474 "query" => 0,
475 "key" => out_dim,
476 "value" => 2 * out_dim,
477 _ => panic!("bad qkv split {which}"),
478 };
479 let mut out = vec![0f32; in_dim * out_dim];
480 for i in 0..in_dim {
481 for j in 0..out_dim {
482 out[i * out_dim + j] = data[i * 3 * out_dim + off + j];
483 }
484 }
485 out
486}
487
488fn split_qkv_b(data: &[f32], dim: usize, which: &str) -> Vec<f32> {
489 let off = match which {
490 "query" => 0,
491 "key" => dim,
492 "value" => 2 * dim,
493 _ => panic!("bad qkv split {which}"),
494 };
495 data[off..off + dim].to_vec()
496}
497
498fn take_mlp_w(
499 weights: &mut WeightMap,
500 prefixes: &[&str],
501 fc: &str,
502 in_dim: usize,
503 out_dim: usize,
504) -> Result<Vec<f32>> {
505 for p in prefixes {
506 let key = format!("{p}.mlp.{fc}.weight");
507 if weights.has(&key) {
508 return take_linear_w_key(weights, &key, in_dim, out_dim);
509 }
510 }
511 anyhow::bail!("mlp {fc} weight not found")
512}
513
514fn take_mlp_b(
515 weights: &mut WeightMap,
516 prefixes: &[&str],
517 fc: &str,
518 dim: usize,
519) -> Result<Vec<f32>> {
520 for p in prefixes {
521 let key = format!("{p}.mlp.{fc}.bias");
522 if weights.has(&key) {
523 let (data, shape) = weights.take(&key)?;
524 ensure!(shape == vec![dim]);
525 return Ok(data);
526 }
527 }
528 anyhow::bail!("mlp {fc} bias not found")
529}
530
531fn take_ln_w(
532 weights: &mut WeightMap,
533 prefixes: &[&str],
534 norm: &str,
535 dim: usize,
536) -> Result<Vec<f32>> {
537 for p in prefixes {
538 let key = format!("{p}.{norm}.weight");
539 if weights.has(&key) {
540 let (data, shape) = weights.take(&key)?;
541 ensure!(shape == vec![dim]);
542 return Ok(data);
543 }
544 }
545 anyhow::bail!("{norm} weight not found")
546}
547
548fn take_ln_b(
549 weights: &mut WeightMap,
550 prefixes: &[&str],
551 norm: &str,
552 dim: usize,
553) -> Result<Vec<f32>> {
554 for p in prefixes {
555 let key = format!("{p}.{norm}.bias");
556 if weights.has(&key) {
557 let (data, shape) = weights.take(&key)?;
558 ensure!(shape == vec![dim]);
559 return Ok(data);
560 }
561 }
562 anyhow::bail!("{norm} bias not found")
563}
564
565fn take_linear_w_key(
566 weights: &mut WeightMap,
567 key: &str,
568 in_dim: usize,
569 out_dim: usize,
570) -> Result<Vec<f32>> {
571 let (data, shape) = weights.take_transposed(key)?;
572 ensure!(
573 shape == vec![in_dim, out_dim],
574 "{key} expected [{in_dim}, {out_dim}], got {shape:?}"
575 );
576 Ok(data)
577}
578
579fn take_first_vec(
580 weights: &mut WeightMap,
581 keys: &[&str],
582 expected: Vec<usize>,
583) -> Result<Vec<f32>> {
584 for key in keys {
585 if weights.has(key) {
586 let (data, shape) = weights.take(key)?;
587 ensure!(
588 shape == expected,
589 "{key} shape mismatch: {shape:?} vs {expected:?}"
590 );
591 return Ok(data);
592 }
593 }
594 anyhow::bail!("keys not found: {keys:?}")
595}