Skip to main content

rlx_vjepa2/
weights.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Weight extraction for V-JEPA2 checkpoints (HF + Meta key layouts).
17
18use super::config::Vjepa2Config;
19use super::preprocess::{Vjepa2PatchEmbedWeights, extract_patch_embed_weights};
20use anyhow::{Result, ensure};
21use rlx_core::weight_map::WeightMap;
22
23#[derive(Clone)]
24pub struct Vjepa2BlockWeights {
25    pub norm1_w: Vec<f32>,
26    pub norm1_b: Vec<f32>,
27    pub q_w_t: Vec<f32>,
28    pub q_b: Vec<f32>,
29    pub k_w_t: Vec<f32>,
30    pub k_b: Vec<f32>,
31    pub v_w_t: Vec<f32>,
32    pub v_b: Vec<f32>,
33    pub proj_w_t: Vec<f32>,
34    pub proj_b: Vec<f32>,
35    pub norm2_w: Vec<f32>,
36    pub norm2_b: Vec<f32>,
37    pub mlp_fc1_w_t: Vec<f32>,
38    pub mlp_fc1_b: Vec<f32>,
39    pub mlp_fc2_w_t: Vec<f32>,
40    pub mlp_fc2_b: Vec<f32>,
41}
42
43#[derive(Clone)]
44pub struct Vjepa2EncoderWeights {
45    pub patch: Vjepa2PatchEmbedWeights,
46    pub blocks: Vec<Vjepa2BlockWeights>,
47    pub norm_w: Vec<f32>,
48    pub norm_b: Vec<f32>,
49}
50
51#[derive(Clone)]
52pub struct Vjepa2PredictorWeights {
53    pub embed_w_t: Vec<f32>,
54    pub embed_b: Vec<f32>,
55    pub mask_tokens: Vec<f32>,
56    pub blocks: Vec<Vjepa2BlockWeights>,
57    pub norm_w: Vec<f32>,
58    pub norm_b: Vec<f32>,
59    pub proj_w_t: Vec<f32>,
60    pub proj_b: Vec<f32>,
61}
62
63#[derive(Clone)]
64pub struct Vjepa2PoolerSelfBlockWeights {
65    pub norm1_w: Vec<f32>,
66    pub norm1_b: Vec<f32>,
67    pub q_w_t: Vec<f32>,
68    pub q_b: Vec<f32>,
69    pub k_w_t: Vec<f32>,
70    pub k_b: Vec<f32>,
71    pub v_w_t: Vec<f32>,
72    pub v_b: Vec<f32>,
73    pub out_w_t: Vec<f32>,
74    pub out_b: Vec<f32>,
75    pub norm2_w: Vec<f32>,
76    pub norm2_b: Vec<f32>,
77    pub mlp_fc1_w_t: Vec<f32>,
78    pub mlp_fc1_b: Vec<f32>,
79    pub mlp_fc2_w_t: Vec<f32>,
80    pub mlp_fc2_b: Vec<f32>,
81}
82
83#[derive(Clone)]
84pub struct Vjepa2PoolerCrossWeights {
85    pub norm1_w: Vec<f32>,
86    pub norm1_b: Vec<f32>,
87    pub q_w_t: Vec<f32>,
88    pub q_b: Vec<f32>,
89    pub k_w_t: Vec<f32>,
90    pub k_b: Vec<f32>,
91    pub v_w_t: Vec<f32>,
92    pub v_b: Vec<f32>,
93    pub norm2_w: Vec<f32>,
94    pub norm2_b: Vec<f32>,
95    pub mlp_fc1_w_t: Vec<f32>,
96    pub mlp_fc1_b: Vec<f32>,
97    pub mlp_fc2_w_t: Vec<f32>,
98    pub mlp_fc2_b: Vec<f32>,
99}
100
101#[derive(Clone)]
102pub struct Vjepa2PoolerWeights {
103    pub query_tokens: Vec<f32>,
104    pub self_blocks: Vec<Vjepa2PoolerSelfBlockWeights>,
105    pub cross: Vjepa2PoolerCrossWeights,
106    pub classifier_w_t: Option<Vec<f32>>,
107    pub classifier_b: Option<Vec<f32>>,
108}
109
110#[derive(Clone)]
111pub struct Vjepa2ModelWeights {
112    pub encoder: Vjepa2EncoderWeights,
113    pub predictor: Option<Vjepa2PredictorWeights>,
114    pub pooler: Option<Vjepa2PoolerWeights>,
115}
116
117pub fn extract_encoder_weights(
118    weights: &mut WeightMap,
119    cfg: &Vjepa2Config,
120) -> Result<Vjepa2EncoderWeights> {
121    let patch = extract_patch_embed_weights(weights, cfg)?;
122    let e = cfg.hidden_size;
123    let hidden = cfg.intermediate_size();
124    let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
125
126    for i in 0..cfg.num_hidden_layers {
127        let hf = format!("encoder.layer.{i}");
128        let meta = format!("blocks.{i}");
129        blocks.push(extract_transformer_block(
130            weights,
131            &[hf, meta],
132            e,
133            hidden,
134            "attention",
135            "attn",
136        )?);
137    }
138
139    let norm_w = take_first_vec(
140        weights,
141        &["encoder.layernorm.weight", "norm.weight"],
142        vec![e],
143    )?;
144    let norm_b = take_first_vec(weights, &["encoder.layernorm.bias", "norm.bias"], vec![e])?;
145
146    Ok(Vjepa2EncoderWeights {
147        patch,
148        blocks,
149        norm_w,
150        norm_b,
151    })
152}
153
154pub fn extract_predictor_weights(
155    weights: &mut WeightMap,
156    cfg: &Vjepa2Config,
157) -> Result<Vjepa2PredictorWeights> {
158    let enc = cfg.hidden_size;
159    let pred = cfg.pred_hidden_size;
160    let hidden = cfg.pred_intermediate_size();
161
162    let embed_key = pick_key(
163        weights,
164        &[
165            "predictor.embeddings.predictor_embeddings.weight",
166            "predictor_embed.weight",
167        ],
168    )?;
169    let embed_w_t = take_linear_w_key(weights, &embed_key, enc, pred)?;
170    let embed_b = take_first_vec(
171        weights,
172        &[
173            "predictor.embeddings.predictor_embeddings.bias",
174            "predictor_embed.bias",
175        ],
176        vec![pred],
177    )?;
178
179    let n_masks = cfg.pred_num_mask_tokens;
180    let mask_tokens = take_first_vec(
181        weights,
182        &["predictor.embeddings.mask_tokens", "mask_tokens"],
183        vec![n_masks, 1, 1, pred],
184    )?;
185
186    let mut blocks = Vec::with_capacity(cfg.pred_num_hidden_layers);
187    for i in 0..cfg.pred_num_hidden_layers {
188        let hf = format!("predictor.layer.{i}");
189        let meta = format!("predictor_blocks.{i}");
190        blocks.push(extract_transformer_block(
191            weights,
192            &[hf, meta],
193            pred,
194            hidden,
195            "attention",
196            "attn",
197        )?);
198    }
199
200    let norm_w = take_first_vec(
201        weights,
202        &["predictor.layernorm.weight", "predictor_norm.weight"],
203        vec![pred],
204    )?;
205    let norm_b = take_first_vec(
206        weights,
207        &["predictor.layernorm.bias", "predictor_norm.bias"],
208        vec![pred],
209    )?;
210    let proj_key = pick_key(weights, &["predictor.proj.weight", "predictor_proj.weight"])?;
211    let proj_w_t = take_linear_w_key(weights, &proj_key, pred, enc)?;
212    let proj_b = take_first_vec(
213        weights,
214        &["predictor.proj.bias", "predictor_proj.bias"],
215        vec![enc],
216    )?;
217
218    Ok(Vjepa2PredictorWeights {
219        embed_w_t,
220        embed_b,
221        mask_tokens,
222        blocks,
223        norm_w,
224        norm_b,
225        proj_w_t,
226        proj_b,
227    })
228}
229
230pub fn extract_pooler_weights(
231    weights: &mut WeightMap,
232    cfg: &Vjepa2Config,
233) -> Result<Vjepa2PoolerWeights> {
234    let e = cfg.hidden_size;
235    let hidden = cfg.pooler_intermediate_size();
236
237    let query_tokens = take_first_vec(weights, &["pooler.query_tokens"], vec![1, 1, e])?;
238
239    let mut self_blocks = Vec::with_capacity(cfg.num_pooler_layers);
240    for i in 0..cfg.num_pooler_layers {
241        let p = format!("pooler.self_attention_layers.{i}");
242        self_blocks.push(Vjepa2PoolerSelfBlockWeights {
243            norm1_w: take_ln_w(weights, &[&p], "layer_norm1", e)?,
244            norm1_b: take_ln_b(weights, &[&p], "layer_norm1", e)?,
245            q_w_t: take_linear_w_key(weights, &format!("{p}.self_attn.q_proj.weight"), e, e)?,
246            q_b: take_first_vec(weights, &[&format!("{p}.self_attn.q_proj.bias")], vec![e])?,
247            k_w_t: take_linear_w_key(weights, &format!("{p}.self_attn.k_proj.weight"), e, e)?,
248            k_b: take_first_vec(weights, &[&format!("{p}.self_attn.k_proj.bias")], vec![e])?,
249            v_w_t: take_linear_w_key(weights, &format!("{p}.self_attn.v_proj.weight"), e, e)?,
250            v_b: take_first_vec(weights, &[&format!("{p}.self_attn.v_proj.bias")], vec![e])?,
251            out_w_t: take_linear_w_key(weights, &format!("{p}.self_attn.out_proj.weight"), e, e)?,
252            out_b: take_first_vec(weights, &[&format!("{p}.self_attn.out_proj.bias")], vec![e])?,
253            norm2_w: take_ln_w(weights, &[&p], "layer_norm2", e)?,
254            norm2_b: take_ln_b(weights, &[&p], "layer_norm2", e)?,
255            mlp_fc1_w_t: take_linear_w_key(weights, &format!("{p}.mlp.fc1.weight"), e, hidden)?,
256            mlp_fc1_b: take_first_vec(weights, &[&format!("{p}.mlp.fc1.bias")], vec![hidden])?,
257            mlp_fc2_w_t: take_linear_w_key(weights, &format!("{p}.mlp.fc2.weight"), hidden, e)?,
258            mlp_fc2_b: take_first_vec(weights, &[&format!("{p}.mlp.fc2.bias")], vec![e])?,
259        });
260    }
261
262    let cp = "pooler.cross_attention_layer";
263    let cross = Vjepa2PoolerCrossWeights {
264        norm1_w: take_ln_w(weights, &[cp], "layer_norm1", e)?,
265        norm1_b: take_ln_b(weights, &[cp], "layer_norm1", e)?,
266        q_w_t: take_linear_w_key(weights, &format!("{cp}.cross_attn.q_proj.weight"), e, e)?,
267        q_b: take_first_vec(weights, &[&format!("{cp}.cross_attn.q_proj.bias")], vec![e])?,
268        k_w_t: take_linear_w_key(weights, &format!("{cp}.cross_attn.k_proj.weight"), e, e)?,
269        k_b: take_first_vec(weights, &[&format!("{cp}.cross_attn.k_proj.bias")], vec![e])?,
270        v_w_t: take_linear_w_key(weights, &format!("{cp}.cross_attn.v_proj.weight"), e, e)?,
271        v_b: take_first_vec(weights, &[&format!("{cp}.cross_attn.v_proj.bias")], vec![e])?,
272        norm2_w: take_ln_w(weights, &[cp], "layer_norm2", e)?,
273        norm2_b: take_ln_b(weights, &[cp], "layer_norm2", e)?,
274        mlp_fc1_w_t: take_linear_w_key(weights, &format!("{cp}.mlp.fc1.weight"), e, hidden)?,
275        mlp_fc1_b: take_first_vec(weights, &[&format!("{cp}.mlp.fc1.bias")], vec![hidden])?,
276        mlp_fc2_w_t: take_linear_w_key(weights, &format!("{cp}.mlp.fc2.weight"), hidden, e)?,
277        mlp_fc2_b: take_first_vec(weights, &[&format!("{cp}.mlp.fc2.bias")], vec![e])?,
278    };
279
280    let classifier_w_t = if weights.has("classifier.weight") {
281        let (data, shape) = weights.take_transposed("classifier.weight")?;
282        ensure!(shape[1] == e, "classifier weight second dim must be {e}");
283        Some(data)
284    } else {
285        None
286    };
287    let classifier_b = if weights.has("classifier.bias") {
288        let (data, shape) = weights.take("classifier.bias")?;
289        ensure!(shape.len() == 1, "classifier bias must be 1d");
290        Some(data)
291    } else {
292        None
293    };
294
295    Ok(Vjepa2PoolerWeights {
296        query_tokens,
297        self_blocks,
298        cross,
299        classifier_w_t,
300        classifier_b,
301    })
302}
303
304pub fn extract_model_weights(
305    weights: &mut WeightMap,
306    cfg: &Vjepa2Config,
307) -> Result<Vjepa2ModelWeights> {
308    let encoder = extract_encoder_weights(weights, cfg)?;
309    let predictor = if weights.has("predictor.layer.0.attention.query.weight")
310        || weights.has("predictor_blocks.0.attn.qkv.weight")
311    {
312        Some(extract_predictor_weights(weights, cfg)?)
313    } else {
314        None
315    };
316    let pooler = if weights.has("pooler.query_tokens") {
317        Some(extract_pooler_weights(weights, cfg)?)
318    } else {
319        None
320    };
321    Ok(Vjepa2ModelWeights {
322        encoder,
323        predictor,
324        pooler,
325    })
326}
327
328pub(crate) fn extract_transformer_block(
329    weights: &mut WeightMap,
330    prefixes: &[String],
331    embed: usize,
332    hidden: usize,
333    attn_hf: &str,
334    attn_meta: &str,
335) -> Result<Vjepa2BlockWeights> {
336    let pref_refs: Vec<&str> = prefixes.iter().map(String::as_str).collect();
337    Ok(Vjepa2BlockWeights {
338        norm1_w: take_ln_w(weights, &pref_refs, "norm1", embed)?,
339        norm1_b: take_ln_b(weights, &pref_refs, "norm1", embed)?,
340        q_w_t: take_linear_w(
341            weights, &pref_refs, "query", embed, embed, attn_hf, attn_meta,
342        )?,
343        q_b: take_linear_b(weights, &pref_refs, "query", embed, attn_hf, attn_meta)?,
344        k_w_t: take_linear_w(weights, &pref_refs, "key", embed, embed, attn_hf, attn_meta)?,
345        k_b: take_linear_b(weights, &pref_refs, "key", embed, attn_hf, attn_meta)?,
346        v_w_t: take_linear_w(
347            weights, &pref_refs, "value", embed, embed, attn_hf, attn_meta,
348        )?,
349        v_b: take_linear_b(weights, &pref_refs, "value", embed, attn_hf, attn_meta)?,
350        proj_w_t: take_attn_proj_w(weights, &pref_refs, embed, attn_hf, attn_meta)?,
351        proj_b: take_attn_proj_b(weights, &pref_refs, embed, attn_hf, attn_meta)?,
352        norm2_w: take_ln_w(weights, &pref_refs, "norm2", embed)?,
353        norm2_b: take_ln_b(weights, &pref_refs, "norm2", embed)?,
354        mlp_fc1_w_t: take_mlp_w(weights, &pref_refs, "fc1", embed, hidden)?,
355        mlp_fc1_b: take_mlp_b(weights, &pref_refs, "fc1", hidden)?,
356        mlp_fc2_w_t: take_mlp_w(weights, &pref_refs, "fc2", hidden, embed)?,
357        mlp_fc2_b: take_mlp_b(weights, &pref_refs, "fc2", embed)?,
358    })
359}
360
361fn pick_key(weights: &WeightMap, keys: &[&str]) -> Result<String> {
362    for k in keys {
363        if weights.has(k) {
364            return Ok((*k).to_string());
365        }
366    }
367    anyhow::bail!("none of keys found: {keys:?}")
368}
369
370fn take_attn_proj_w(
371    weights: &mut WeightMap,
372    prefixes: &[&str],
373    e: usize,
374    attn_hf: &str,
375    attn_meta: &str,
376) -> Result<Vec<f32>> {
377    for p in prefixes {
378        let hf = format!("{p}.{attn_hf}.proj.weight");
379        if weights.has(&hf) {
380            return take_linear_w_key(weights, &hf, e, e);
381        }
382        let meta = format!("{p}.{attn_meta}.proj.weight");
383        if weights.has(&meta) {
384            return take_linear_w_key(weights, &meta, e, e);
385        }
386    }
387    anyhow::bail!("attention proj weight not found for {prefixes:?}")
388}
389
390fn take_attn_proj_b(
391    weights: &mut WeightMap,
392    prefixes: &[&str],
393    e: usize,
394    attn_hf: &str,
395    attn_meta: &str,
396) -> Result<Vec<f32>> {
397    for p in prefixes {
398        for suffix in [
399            format!("{attn_hf}.proj.bias"),
400            format!("{attn_meta}.proj.bias"),
401        ] {
402            let key = format!("{p}.{suffix}");
403            if weights.has(&key) {
404                let (data, shape) = weights.take(&key)?;
405                ensure!(shape == vec![e]);
406                return Ok(data);
407            }
408        }
409    }
410    anyhow::bail!("attention proj bias not found")
411}
412
413fn take_linear_w(
414    weights: &mut WeightMap,
415    prefixes: &[&str],
416    name: &str,
417    in_dim: usize,
418    out_dim: usize,
419    attn_hf: &str,
420    attn_meta: &str,
421) -> Result<Vec<f32>> {
422    for p in prefixes {
423        let hf = format!("{p}.{attn_hf}.{name}.weight");
424        if weights.has(&hf) {
425            return take_linear_w_key(weights, &hf, in_dim, out_dim);
426        }
427    }
428    for p in prefixes {
429        if !p.starts_with("blocks.") && !p.starts_with("predictor_blocks.") {
430            continue;
431        }
432        let key = format!("{p}.{attn_meta}.qkv.weight");
433        if weights.has(&key) {
434            let (data, shape) = weights.take_transposed(&key)?;
435            ensure!(shape == vec![in_dim, 3 * out_dim]);
436            return Ok(split_qkv_w(&data, in_dim, out_dim, name));
437        }
438    }
439    anyhow::bail!("linear weight {name} not found for {prefixes:?}")
440}
441
442fn take_linear_b(
443    weights: &mut WeightMap,
444    prefixes: &[&str],
445    name: &str,
446    dim: usize,
447    attn_hf: &str,
448    attn_meta: &str,
449) -> Result<Vec<f32>> {
450    for p in prefixes {
451        let hf = format!("{p}.{attn_hf}.{name}.bias");
452        if weights.has(&hf) {
453            let (data, shape) = weights.take(&hf)?;
454            ensure!(shape == vec![dim]);
455            return Ok(data);
456        }
457    }
458    for p in prefixes {
459        if !p.starts_with("blocks.") && !p.starts_with("predictor_blocks.") {
460            continue;
461        }
462        let key = format!("{p}.{attn_meta}.qkv.bias");
463        if weights.has(&key) {
464            let (data, shape) = weights.take(&key)?;
465            ensure!(shape == vec![3 * dim]);
466            return Ok(split_qkv_b(&data, dim, name));
467        }
468    }
469    anyhow::bail!("linear bias {name} not found")
470}
471
472fn split_qkv_w(data: &[f32], in_dim: usize, out_dim: usize, which: &str) -> Vec<f32> {
473    let off = match which {
474        "query" => 0,
475        "key" => out_dim,
476        "value" => 2 * out_dim,
477        _ => panic!("bad qkv split {which}"),
478    };
479    let mut out = vec![0f32; in_dim * out_dim];
480    for i in 0..in_dim {
481        for j in 0..out_dim {
482            out[i * out_dim + j] = data[i * 3 * out_dim + off + j];
483        }
484    }
485    out
486}
487
488fn split_qkv_b(data: &[f32], dim: usize, which: &str) -> Vec<f32> {
489    let off = match which {
490        "query" => 0,
491        "key" => dim,
492        "value" => 2 * dim,
493        _ => panic!("bad qkv split {which}"),
494    };
495    data[off..off + dim].to_vec()
496}
497
498fn take_mlp_w(
499    weights: &mut WeightMap,
500    prefixes: &[&str],
501    fc: &str,
502    in_dim: usize,
503    out_dim: usize,
504) -> Result<Vec<f32>> {
505    for p in prefixes {
506        let key = format!("{p}.mlp.{fc}.weight");
507        if weights.has(&key) {
508            return take_linear_w_key(weights, &key, in_dim, out_dim);
509        }
510    }
511    anyhow::bail!("mlp {fc} weight not found")
512}
513
514fn take_mlp_b(
515    weights: &mut WeightMap,
516    prefixes: &[&str],
517    fc: &str,
518    dim: usize,
519) -> Result<Vec<f32>> {
520    for p in prefixes {
521        let key = format!("{p}.mlp.{fc}.bias");
522        if weights.has(&key) {
523            let (data, shape) = weights.take(&key)?;
524            ensure!(shape == vec![dim]);
525            return Ok(data);
526        }
527    }
528    anyhow::bail!("mlp {fc} bias not found")
529}
530
531fn take_ln_w(
532    weights: &mut WeightMap,
533    prefixes: &[&str],
534    norm: &str,
535    dim: usize,
536) -> Result<Vec<f32>> {
537    for p in prefixes {
538        let key = format!("{p}.{norm}.weight");
539        if weights.has(&key) {
540            let (data, shape) = weights.take(&key)?;
541            ensure!(shape == vec![dim]);
542            return Ok(data);
543        }
544    }
545    anyhow::bail!("{norm} weight not found")
546}
547
548fn take_ln_b(
549    weights: &mut WeightMap,
550    prefixes: &[&str],
551    norm: &str,
552    dim: usize,
553) -> Result<Vec<f32>> {
554    for p in prefixes {
555        let key = format!("{p}.{norm}.bias");
556        if weights.has(&key) {
557            let (data, shape) = weights.take(&key)?;
558            ensure!(shape == vec![dim]);
559            return Ok(data);
560        }
561    }
562    anyhow::bail!("{norm} bias not found")
563}
564
565fn take_linear_w_key(
566    weights: &mut WeightMap,
567    key: &str,
568    in_dim: usize,
569    out_dim: usize,
570) -> Result<Vec<f32>> {
571    let (data, shape) = weights.take_transposed(key)?;
572    ensure!(
573        shape == vec![in_dim, out_dim],
574        "{key} expected [{in_dim}, {out_dim}], got {shape:?}"
575    );
576    Ok(data)
577}
578
579fn take_first_vec(
580    weights: &mut WeightMap,
581    keys: &[&str],
582    expected: Vec<usize>,
583) -> Result<Vec<f32>> {
584    for key in keys {
585        if weights.has(key) {
586            let (data, shape) = weights.take(key)?;
587            ensure!(
588                shape == expected,
589                "{key} shape mismatch: {shape:?} vs {expected:?}"
590            );
591            return Ok(data);
592        }
593    }
594    anyhow::bail!("keys not found: {keys:?}")
595}