Skip to main content

rlx_vjepa2/
builder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! HIR-first graph builders for the V-JEPA2 encoder, predictor, and pooler.
17//!
18//! Builders emit [`HirModule`] with fused linear/attention blocks and lower to
19//! legacy [`Graph`] (MIR) for `Session::compile` / `Session::compile_hir`.
20
21use super::config::Vjepa2Config;
22use super::predictor::Vjepa2PredictorLayout;
23use super::preprocess::Vjepa2PatchEmbedWeights;
24use super::rope::build_vjepa2_rope_tables;
25use super::weights::{
26    Vjepa2BlockWeights, Vjepa2EncoderWeights, Vjepa2PoolerCrossWeights,
27    Vjepa2PoolerSelfBlockWeights, Vjepa2PoolerWeights, Vjepa2PredictorWeights,
28};
29use anyhow::Result;
30use rlx_ir::hir::{FusionPolicy, HirModule, HirNodeId};
31use rlx_ir::op::{Activation, BinaryOp, MaskKind};
32use rlx_ir::{DType, Graph, Op, Shape};
33use std::collections::HashMap;
34
35/// Host-side patch-embed weights used with the compiled graph.
36pub struct Vjepa2GraphPreprocess {
37    pub patch: Vjepa2PatchEmbedWeights,
38}
39
40/// F32 params returned by graph builders (including gather indices stored as f32).
41pub struct Vjepa2GraphParams {
42    pub f32: HashMap<String, Vec<f32>>,
43}
44
45impl Vjepa2GraphParams {
46    pub fn from_f32(map: HashMap<String, Vec<f32>>) -> Self {
47        Self { f32: map }
48    }
49
50    pub fn load(&self, compiled: &mut rlx_runtime::CompiledGraph) {
51        for (name, data) in &self.f32 {
52            compiled.set_param(name, data);
53        }
54    }
55}
56
57#[allow(dead_code)]
58fn lower_hir(hir: HirModule) -> Result<Graph> {
59    Ok(hir
60        .lower_to_mir()
61        .map_err(|e| anyhow::anyhow!("{e}"))?
62        .into_graph())
63}
64
65struct VjepaBuilder {
66    hir: HirModule,
67    params: HashMap<String, Vec<f32>>,
68    f: DType,
69}
70
71impl VjepaBuilder {
72    fn new(name: &str) -> Self {
73        Self {
74            hir: HirModule::new(name).with_fusion_policy(FusionPolicy::Direct),
75            params: HashMap::new(),
76            f: DType::F32,
77        }
78    }
79
80    #[allow(dead_code)]
81    fn finish(self) -> Result<Graph> {
82        lower_hir(self.hir)
83    }
84
85    fn shape3(&self, batch: usize, seq: usize, h: usize) -> Shape {
86        Shape::new(&[batch, seq, h], self.f)
87    }
88
89    fn node_shape(&self, id: HirNodeId) -> Shape {
90        self.hir.node(id).shape.clone()
91    }
92
93    fn layer_norm(
94        &mut self,
95        x: HirNodeId,
96        gamma: HirNodeId,
97        beta: HirNodeId,
98        eps: f32,
99        shape: Shape,
100    ) -> HirNodeId {
101        self.hir
102            .mir(Op::LayerNorm { axis: -1, eps }, vec![x, gamma, beta], shape)
103    }
104
105    fn reshape(&mut self, x: HirNodeId, new_shape: Vec<i64>) -> HirNodeId {
106        let in_shape = self.hir.node(x).shape.clone();
107        let static_dims: Vec<usize> = new_shape.iter().map(|&d| d as usize).collect();
108        let out = Shape::new(&static_dims, in_shape.dtype());
109        self.hir.mir(Op::Reshape { new_shape }, vec![x], out)
110    }
111
112    fn narrow(
113        &mut self,
114        x: HirNodeId,
115        axis: usize,
116        start: usize,
117        len: usize,
118        shape: Shape,
119    ) -> HirNodeId {
120        self.hir
121            .mir(Op::Narrow { axis, start, len }, vec![x], shape)
122    }
123
124    fn concat(&mut self, inputs: Vec<HirNodeId>, axis: usize, shape: Shape) -> HirNodeId {
125        self.hir.mir(Op::Concat { axis }, inputs, shape)
126    }
127
128    fn gather(&mut self, table: HirNodeId, indices: HirNodeId, axis: usize) -> HirNodeId {
129        let out = rlx_ir::shape::gather_shape(
130            &self.hir.node(table).shape,
131            &self.hir.node(indices).shape,
132            axis,
133        )
134        .expect("gather shape");
135        self.hir.mir(Op::Gather { axis }, vec![table, indices], out)
136    }
137
138    fn add(&mut self, a: HirNodeId, b: HirNodeId, shape: Shape) -> HirNodeId {
139        self.hir.mir(Op::Binary(BinaryOp::Add), vec![a, b], shape)
140    }
141
142    fn mm(&mut self, lhs: HirNodeId, rhs: HirNodeId) -> HirNodeId {
143        let out = rlx_ir::shape::matmul_shape(&self.hir.node(lhs).shape, &self.hir.node(rhs).shape)
144            .expect("matmul shape");
145        self.hir.mir(Op::MatMul, vec![lhs, rhs], out)
146    }
147
148    fn rope_n(
149        &mut self,
150        x: HirNodeId,
151        cos: HirNodeId,
152        sin: HirNodeId,
153        head_dim: usize,
154        n_rot: usize,
155    ) -> HirNodeId {
156        let shape = self.hir.node(x).shape.clone();
157        self.hir
158            .mir(Op::Rope { head_dim, n_rot }, vec![x, cos, sin], shape)
159    }
160
161    #[allow(dead_code)]
162    fn gelu_approx(&mut self, x: HirNodeId, shape: Shape) -> HirNodeId {
163        self.hir
164            .mir(Op::Activation(Activation::GeluApprox), vec![x], shape)
165    }
166
167    fn attention_custom(
168        &mut self,
169        q: HirNodeId,
170        k: HirNodeId,
171        v: HirNodeId,
172        mask: HirNodeId,
173        nh: usize,
174        dh: usize,
175    ) -> HirNodeId {
176        let out = rlx_ir::shape::attention_shape(&self.hir.node(q).shape);
177        self.hir
178            .attention(q, k, v, Some(mask), nh, dh, MaskKind::Custom, out)
179    }
180
181    fn attention_none(
182        &mut self,
183        q: HirNodeId,
184        k: HirNodeId,
185        v: HirNodeId,
186        nh: usize,
187        dh: usize,
188    ) -> HirNodeId {
189        let out = rlx_ir::shape::attention_shape(&self.hir.node(q).shape);
190        self.hir
191            .attention(q, k, v, None, nh, dh, MaskKind::None, out)
192    }
193
194    fn bind_vec(&mut self, name: &str, data: &[f32]) -> HirNodeId {
195        let id = self.hir.param(name, Shape::new(&[data.len()], self.f));
196        self.params.insert(name.to_string(), data.to_vec());
197        id
198    }
199
200    fn bind_mat(&mut self, name: &str, w_t: &[f32], in_dim: usize, out_dim: usize) -> HirNodeId {
201        let id = self.hir.param(name, Shape::new(&[in_dim, out_dim], self.f));
202        self.params.insert(name.to_string(), w_t.to_vec());
203        id
204    }
205
206    fn bind_indices(&mut self, name: &str, data: &[i64], shape: &[usize]) -> HirNodeId {
207        let f32_data: Vec<f32> = data.iter().map(|&v| v as f32).collect();
208        let id = self.hir.param(name, Shape::new(shape, self.f));
209        self.params.insert(name.to_string(), f32_data);
210        id
211    }
212
213    fn linear_named(
214        &mut self,
215        name: &str,
216        input: HirNodeId,
217        in_dim: usize,
218        w_t: &[f32],
219        b: &[f32],
220    ) -> HirNodeId {
221        let out_dim = b.len();
222        let w = self.bind_mat(&format!("{name}.weight"), w_t, in_dim, out_dim);
223        let bias = self.bind_vec(&format!("{name}.bias"), b);
224        let out_shape =
225            rlx_ir::shape::matmul_shape(&self.hir.node(input).shape, &self.hir.node(w).shape)
226                .expect("linear matmul shape");
227        self.hir.linear_fused(input, w, bias, None, out_shape)
228    }
229
230    fn mlp_block(
231        &mut self,
232        lp: &str,
233        x: HirNodeId,
234        embed: usize,
235        fc1_w_t: &[f32],
236        fc1_b: &[f32],
237        fc2_w_t: &[f32],
238        fc2_b: &[f32],
239        residual: HirNodeId,
240        out_shape: Shape,
241    ) -> HirNodeId {
242        let hidden = fc1_b.len();
243        let fc1_w = self.bind_mat(&format!("{lp}.mlp.fc1.weight"), fc1_w_t, embed, hidden);
244        let fc1_bias = self.bind_vec(&format!("{lp}.mlp.fc1.bias"), fc1_b);
245        let fc1_shape =
246            rlx_ir::shape::matmul_shape(&self.hir.node(x).shape, &self.hir.node(fc1_w).shape)
247                .expect("fc1 shape");
248        let up = self
249            .hir
250            .linear_fused(x, fc1_w, fc1_bias, Some(Activation::GeluApprox), fc1_shape);
251
252        let fc2_w = self.bind_mat(&format!("{lp}.mlp.fc2.weight"), fc2_w_t, hidden, embed);
253        let fc2_bias = self.bind_vec(&format!("{lp}.mlp.fc2.bias"), fc2_b);
254        let fc2_shape =
255            rlx_ir::shape::matmul_shape(&self.hir.node(up).shape, &self.hir.node(fc2_w).shape)
256                .expect("fc2 shape");
257        let ffn = self.hir.linear_fused(up, fc2_w, fc2_bias, None, fc2_shape);
258        self.add(residual, ffn, out_shape)
259    }
260}
261
262/// Build the V-JEPA2 encoder HIR module from extracted weights.
263pub fn build_vjepa2_encoder_hir_sized(
264    cfg: &Vjepa2Config,
265    enc: &Vjepa2EncoderWeights,
266    batch: usize,
267) -> Result<(HirModule, HashMap<String, Vec<f32>>, Vjepa2GraphPreprocess)> {
268    let mut b = VjepaBuilder::new("vjepa2_encoder");
269
270    let h = cfg.hidden_size;
271    let nh = cfg.num_attention_heads;
272    let dh = cfg.head_dim();
273    let eps = cfg.layer_norm_eps as f32;
274    let seq = cfg.num_patches();
275    let (d_dim, hd_dim, w_dim) = cfg.rope_segment_dims();
276    let grid_h = cfg.grid_spatial();
277    let grid_w = cfg.grid_spatial();
278    let n_rot = d_dim + hd_dim + w_dim;
279
280    let preprocess = Vjepa2GraphPreprocess {
281        patch: enc.patch.clone(),
282    };
283
284    let (cos_data, sin_data) =
285        build_vjepa2_rope_tables(seq, dh, d_dim, hd_dim, w_dim, grid_h, grid_w);
286    let half = dh / 2;
287    let cos_id = b.bind_mat("rope_cos", &cos_data, seq, half);
288    let sin_id = b.bind_mat("rope_sin", &sin_data, seq, half);
289
290    let mask_data = vec![1.0f32; batch * seq];
291    let mask_id = b.hir.param("attn_mask", Shape::new(&[batch, seq], b.f));
292    b.params.insert("attn_mask".into(), mask_data);
293
294    let hidden_input = b.hir.input("hidden", b.shape3(batch, seq, h));
295    let mut x = hidden_input;
296    let enc_shape = b.shape3(batch, seq, h);
297
298    for (layer_idx, block) in enc.blocks.iter().enumerate() {
299        let lp = format!("blocks.{layer_idx}");
300        x = append_rope_block(
301            &mut b,
302            x,
303            block,
304            &lp,
305            h,
306            nh,
307            dh,
308            n_rot,
309            cos_id,
310            sin_id,
311            Some(mask_id),
312            eps,
313            true,
314            enc_shape.clone(),
315        );
316    }
317
318    let fn_g = b.bind_vec("norm.weight", &enc.norm_w);
319    let fn_b = b.bind_vec("norm.bias", &enc.norm_b);
320    let encoded = b.layer_norm(x, fn_g, fn_b, eps, enc_shape);
321    b.hir.outputs = vec![encoded];
322
323    Ok((b.hir, b.params, preprocess))
324}
325
326/// Build the V-JEPA2 encoder IR graph from extracted weights (via [`super::flow::Vjepa2EncoderFlow`]).
327pub fn build_vjepa2_encoder_graph_sized(
328    cfg: &Vjepa2Config,
329    enc: &Vjepa2EncoderWeights,
330    batch: usize,
331) -> Result<(Graph, HashMap<String, Vec<f32>>, Vjepa2GraphPreprocess)> {
332    let built = super::flow::Vjepa2EncoderFlow::new(cfg, enc, batch).build()?;
333    let (graph, params) = rlx_core::flow_util::graph_from_built(built.model)?;
334    Ok((graph, params, built.preprocess))
335}
336
337/// Build the JEPA predictor HIR module for fixed context/target masks.
338pub fn build_vjepa2_predictor_hir_sized(
339    cfg: &Vjepa2Config,
340    pred: &Vjepa2PredictorWeights,
341    layout: &Vjepa2PredictorLayout,
342    mask_rows: &[f32],
343    batch: usize,
344) -> Result<(HirModule, Vjepa2GraphParams)> {
345    let mut b = VjepaBuilder::new("vjepa2_predictor");
346
347    let enc = cfg.hidden_size;
348    let pred_h = cfg.pred_hidden_size;
349    let nh = cfg.pred_num_attention_heads;
350    let dh = cfg.pred_head_dim();
351    let eps = cfg.layer_norm_eps as f32;
352    let enc_seq = cfg.num_patches();
353    let (d_dim, hd_dim, w_dim) = cfg.pred_rope_segment_dims();
354    let n_rot = d_dim + hd_dim + w_dim;
355    let n_ctxt = layout.n_ctxt;
356    let n_tgt = layout.n_tgt;
357    let n_combined = layout.n_combined;
358    let half = dh / 2;
359
360    let encoder = b.hir.input("encoder", b.shape3(batch, enc_seq, enc));
361
362    let ctxt_idx_id = b.bind_indices("ctxt_idx", &layout.ctxt_idx, &[batch, n_ctxt]);
363    let ctxt = b.gather(encoder, ctxt_idx_id, 1);
364    let ctxt = b.reshape(ctxt, vec![batch as i64, n_ctxt as i64, enc as i64]);
365
366    let embed_w = b.bind_mat("embed.weight", &pred.embed_w_t, enc, pred_h);
367    let embed_b = b.bind_vec("embed.bias", &pred.embed_b);
368    let mm_embed = b.mm(ctxt, embed_w);
369    let ctxt_up = b.add(mm_embed, embed_b, b.shape3(batch, n_ctxt, pred_h));
370    let ctxt_embed = b.reshape(ctxt_up, vec![batch as i64, n_ctxt as i64, pred_h as i64]);
371
372    let mask_id = b
373        .hir
374        .param("mask_rows", Shape::new(&[batch, n_tgt, pred_h], b.f));
375    b.params.insert("mask_rows".into(), mask_rows.to_vec());
376    let mut x = b.concat(
377        vec![ctxt_embed, mask_id],
378        1,
379        b.shape3(batch, n_combined, pred_h),
380    );
381    x = b.reshape(x, vec![batch as i64, n_combined as i64, pred_h as i64]);
382
383    let sort_idx_id = b.bind_indices("sort_idx", &layout.sort_idx, &[batch, n_combined]);
384    x = b.gather(x, sort_idx_id, 1);
385    x = b.reshape(x, vec![batch as i64, n_combined as i64, pred_h as i64]);
386
387    let cos_id = b.bind_mat("rope_cos", &layout.rope_cos, n_combined, half);
388    let sin_id = b.bind_mat("rope_sin", &layout.rope_sin, n_combined, half);
389    let pred_shape = b.shape3(batch, n_combined, pred_h);
390
391    for (layer_idx, block) in pred.blocks.iter().enumerate() {
392        let lp = format!("blocks.{layer_idx}");
393        x = append_rope_block(
394            &mut b,
395            x,
396            block,
397            &lp,
398            pred_h,
399            nh,
400            dh,
401            n_rot,
402            cos_id,
403            sin_id,
404            None,
405            eps,
406            false,
407            pred_shape.clone(),
408        );
409    }
410
411    let fn_g = b.bind_vec("norm.weight", &pred.norm_w);
412    let fn_b = b.bind_vec("norm.bias", &pred.norm_b);
413    x = b.layer_norm(x, fn_g, fn_b, eps, pred_shape.clone());
414
415    let unsort_idx_id = b.bind_indices("unsort_idx", &layout.unsort_idx, &[batch, n_combined]);
416    x = b.gather(x, unsort_idx_id, 1);
417    x = b.reshape(x, vec![batch as i64, n_combined as i64, pred_h as i64]);
418    x = b.narrow(x, 1, n_ctxt, n_tgt, b.shape3(batch, n_tgt, pred_h));
419    x = b.reshape(x, vec![batch as i64, n_tgt as i64, pred_h as i64]);
420
421    let proj_w = b.bind_mat("proj.weight", &pred.proj_w_t, pred_h, enc);
422    let proj_b = b.bind_vec("proj.bias", &pred.proj_b);
423    let mm_proj = b.mm(x, proj_w);
424    let out = b.add(mm_proj, proj_b, b.shape3(batch, n_tgt, enc));
425    b.hir.outputs = vec![out];
426
427    Ok((b.hir, Vjepa2GraphParams { f32: b.params }))
428}
429
430/// Build the JEPA predictor IR graph for fixed context/target masks (via [`super::flow::Vjepa2PredictorFlow`]).
431pub fn build_vjepa2_predictor_graph_sized(
432    cfg: &Vjepa2Config,
433    pred: &Vjepa2PredictorWeights,
434    layout: &Vjepa2PredictorLayout,
435    mask_rows: &[f32],
436    batch: usize,
437) -> Result<(Graph, Vjepa2GraphParams)> {
438    let built =
439        super::flow::Vjepa2PredictorFlow::new(cfg, pred, layout, mask_rows, batch).build()?;
440    let (graph, params) = rlx_core::flow_util::graph_from_built(built)?;
441    Ok((graph, Vjepa2GraphParams { f32: params }))
442}
443
444/// Build the attentive pooler HIR module (+ optional classifier head).
445pub fn build_vjepa2_pooler_hir_sized(
446    cfg: &Vjepa2Config,
447    pooler: &Vjepa2PoolerWeights,
448    batch: usize,
449) -> Result<(HirModule, Vjepa2GraphParams)> {
450    let mut b = VjepaBuilder::new("vjepa2_pooler");
451
452    let e = cfg.hidden_size;
453    let nh = cfg.num_attention_heads;
454    let dh = cfg.head_dim();
455    let hidden = cfg.pooler_intermediate_size();
456    let eps = cfg.layer_norm_eps as f32;
457    let seq = cfg.num_patches();
458
459    let encoder = b.hir.input("encoder", b.shape3(batch, seq, e));
460    let mut ctx = encoder;
461    let ctx_shape = b.shape3(batch, seq, e);
462
463    for (layer_idx, block) in pooler.self_blocks.iter().enumerate() {
464        let lp = format!("self.{layer_idx}");
465        ctx = append_pooler_self_block(
466            &mut b,
467            ctx,
468            block,
469            &lp,
470            e,
471            nh,
472            dh,
473            hidden,
474            eps,
475            ctx_shape.clone(),
476        );
477    }
478
479    let mut query_data = Vec::with_capacity(batch * e);
480    for _ in 0..batch {
481        query_data.extend_from_slice(&pooler.query_tokens);
482    }
483    let query_id = b.bind_vec("query_tokens", &query_data);
484    let mut queries = b.reshape(query_id, vec![batch as i64, 1, e as i64]);
485    let query_shape = b.shape3(batch, 1, e);
486
487    queries = append_pooler_cross_block(
488        &mut b,
489        queries,
490        ctx,
491        &pooler.cross,
492        "cross",
493        e,
494        nh,
495        dh,
496        hidden,
497        eps,
498        query_shape.clone(),
499    );
500
501    queries = b.narrow(queries, 1, 0, 1, query_shape.clone());
502    let embedding = b.reshape(queries, vec![batch as i64, e as i64]);
503
504    let mut outputs = vec![embedding];
505    if let (Some(w_t), Some(bias)) = (&pooler.classifier_w_t, &pooler.classifier_b) {
506        let nc = bias.len();
507        let cls_w = b.bind_mat("classifier.weight", w_t, e, nc);
508        let cls_b = b.bind_vec("classifier.bias", bias);
509        let mm = b.mm(embedding, cls_w);
510        let logits = b.add(mm, cls_b, Shape::new(&[batch, nc], b.f));
511        outputs.push(logits);
512    }
513    b.hir.outputs = outputs;
514
515    Ok((b.hir, Vjepa2GraphParams { f32: b.params }))
516}
517
518/// Build the attentive pooler IR graph (+ optional classifier head) (via [`super::flow::Vjepa2PoolerFlow`]).
519pub fn build_vjepa2_pooler_graph_sized(
520    cfg: &Vjepa2Config,
521    pooler: &Vjepa2PoolerWeights,
522    batch: usize,
523) -> Result<(Graph, Vjepa2GraphParams)> {
524    let built = super::flow::Vjepa2PoolerFlow::new(cfg, pooler, batch).build()?;
525    let (graph, params) = rlx_core::flow_util::graph_from_built(built)?;
526    Ok((graph, Vjepa2GraphParams { f32: params }))
527}
528
529/// Compile encoder HIR on the given device (HIR → MIR → LIR).
530pub fn compile_vjepa2_encoder(
531    cfg: &Vjepa2Config,
532    enc: &Vjepa2EncoderWeights,
533    batch: usize,
534    device: rlx_runtime::Device,
535) -> Result<(
536    rlx_runtime::CompiledGraph,
537    HashMap<String, Vec<f32>>,
538    Vjepa2GraphPreprocess,
539)> {
540    use rlx_runtime::Session;
541
542    let (hir, params, preprocess) = build_vjepa2_encoder_hir_sized(cfg, enc, batch)?;
543    let opts = rlx_core::flow_bridge::compile_options_for_profile(
544        &rlx_flow::CompileProfile::encoder(),
545        device,
546    );
547    let mut compiled = Session::new(device).compile_hir_with(hir, &opts)?;
548    for (name, data) in &params {
549        compiled.set_param(name, data);
550    }
551    Ok((compiled, params, preprocess))
552}
553
554#[allow(clippy::too_many_arguments)]
555fn append_rope_block(
556    b: &mut VjepaBuilder,
557    x: HirNodeId,
558    block: &Vjepa2BlockWeights,
559    lp: &str,
560    embed: usize,
561    nh: usize,
562    dh: usize,
563    n_rot: usize,
564    cos_id: HirNodeId,
565    sin_id: HirNodeId,
566    mask_id: Option<HirNodeId>,
567    eps: f32,
568    use_mask: bool,
569    block_shape: Shape,
570) -> HirNodeId {
571    let n1_g = b.bind_vec(&format!("{lp}.norm1.weight"), &block.norm1_w);
572    let n1_b = b.bind_vec(&format!("{lp}.norm1.bias"), &block.norm1_b);
573    let normed1 = b.layer_norm(x, n1_g, n1_b, eps, block_shape.clone());
574
575    let q = b.linear_named(
576        &format!("{lp}.attn.q"),
577        normed1,
578        embed,
579        &block.q_w_t,
580        &block.q_b,
581    );
582    let k = b.linear_named(
583        &format!("{lp}.attn.k"),
584        normed1,
585        embed,
586        &block.k_w_t,
587        &block.k_b,
588    );
589    let v = b.linear_named(
590        &format!("{lp}.attn.v"),
591        normed1,
592        embed,
593        &block.v_w_t,
594        &block.v_b,
595    );
596
597    let q_rot = b.rope_n(q, cos_id, sin_id, dh, n_rot);
598    let k_rot = b.rope_n(k, cos_id, sin_id, dh, n_rot);
599    let attn = if use_mask {
600        let mask = mask_id.expect("rope block with use_mask requires attn mask");
601        b.attention_custom(q_rot, k_rot, v, mask, nh, dh)
602    } else {
603        b.attention_none(q_rot, k_rot, v, nh, dh)
604    };
605
606    let p_w = b.bind_mat(
607        &format!("{lp}.attn.proj.weight"),
608        &block.proj_w_t,
609        embed,
610        embed,
611    );
612    let p_b = b.bind_vec(&format!("{lp}.attn.proj.bias"), &block.proj_b);
613    let mm_proj = b.mm(attn, p_w);
614    let proj = b.add(mm_proj, p_b, block_shape.clone());
615    let x = b.add(x, proj, block_shape.clone());
616
617    let n2_g = b.bind_vec(&format!("{lp}.norm2.weight"), &block.norm2_w);
618    let n2_b = b.bind_vec(&format!("{lp}.norm2.bias"), &block.norm2_b);
619    let normed2 = b.layer_norm(x, n2_g, n2_b, eps, block_shape.clone());
620
621    b.mlp_block(
622        lp,
623        normed2,
624        embed,
625        &block.mlp_fc1_w_t,
626        &block.mlp_fc1_b,
627        &block.mlp_fc2_w_t,
628        &block.mlp_fc2_b,
629        x,
630        block_shape,
631    )
632}
633
634#[allow(clippy::too_many_arguments)]
635fn append_pooler_self_block(
636    b: &mut VjepaBuilder,
637    x: HirNodeId,
638    block: &Vjepa2PoolerSelfBlockWeights,
639    lp: &str,
640    embed: usize,
641    nh: usize,
642    dh: usize,
643    _hidden: usize,
644    eps: f32,
645    block_shape: Shape,
646) -> HirNodeId {
647    let n1_g = b.bind_vec(&format!("{lp}.norm1.weight"), &block.norm1_w);
648    let n1_b = b.bind_vec(&format!("{lp}.norm1.bias"), &block.norm1_b);
649    let normed1 = b.layer_norm(x, n1_g, n1_b, eps, block_shape.clone());
650
651    let q = b.linear_named(&format!("{lp}.q"), normed1, embed, &block.q_w_t, &block.q_b);
652    let k = b.linear_named(&format!("{lp}.k"), normed1, embed, &block.k_w_t, &block.k_b);
653    let v = b.linear_named(&format!("{lp}.v"), normed1, embed, &block.v_w_t, &block.v_b);
654    let attn = b.attention_none(q, k, v, nh, dh);
655
656    let out_w = b.bind_mat(&format!("{lp}.out.weight"), &block.out_w_t, embed, embed);
657    let out_b = b.bind_vec(&format!("{lp}.out.bias"), &block.out_b);
658    let mm_out = b.mm(attn, out_w);
659    let proj = b.add(mm_out, out_b, block_shape.clone());
660    let x = b.add(x, proj, block_shape.clone());
661
662    let n2_g = b.bind_vec(&format!("{lp}.norm2.weight"), &block.norm2_w);
663    let n2_b = b.bind_vec(&format!("{lp}.norm2.bias"), &block.norm2_b);
664    let normed2 = b.layer_norm(x, n2_g, n2_b, eps, block_shape.clone());
665
666    b.mlp_block(
667        lp,
668        normed2,
669        embed,
670        &block.mlp_fc1_w_t,
671        &block.mlp_fc1_b,
672        &block.mlp_fc2_w_t,
673        &block.mlp_fc2_b,
674        x,
675        block_shape,
676    )
677}
678
679#[allow(clippy::too_many_arguments)]
680fn append_pooler_cross_block(
681    b: &mut VjepaBuilder,
682    queries: HirNodeId,
683    context: HirNodeId,
684    block: &Vjepa2PoolerCrossWeights,
685    lp: &str,
686    embed: usize,
687    nh: usize,
688    dh: usize,
689    _hidden: usize,
690    eps: f32,
691    query_shape: Shape,
692) -> HirNodeId {
693    let ctx_shape = b.node_shape(context);
694    let residual = queries;
695
696    let n1_g = b.bind_vec(&format!("{lp}.norm1.weight"), &block.norm1_w);
697    let n1_b = b.bind_vec(&format!("{lp}.norm1.bias"), &block.norm1_b);
698    let ctx_norm = b.layer_norm(context, n1_g, n1_b, eps, ctx_shape);
699
700    let q = b.linear_named(&format!("{lp}.q"), queries, embed, &block.q_w_t, &block.q_b);
701    let k = b.linear_named(
702        &format!("{lp}.k"),
703        ctx_norm,
704        embed,
705        &block.k_w_t,
706        &block.k_b,
707    );
708    let v = b.linear_named(
709        &format!("{lp}.v"),
710        ctx_norm,
711        embed,
712        &block.v_w_t,
713        &block.v_b,
714    );
715    let attn = b.attention_none(q, k, v, nh, dh);
716    let queries = b.add(residual, attn, query_shape.clone());
717
718    let n2_g = b.bind_vec(&format!("{lp}.norm2.weight"), &block.norm2_w);
719    let n2_b = b.bind_vec(&format!("{lp}.norm2.bias"), &block.norm2_b);
720    let normed2 = b.layer_norm(queries, n2_g, n2_b, eps, query_shape.clone());
721
722    b.mlp_block(
723        lp,
724        normed2,
725        embed,
726        &block.mlp_fc1_w_t,
727        &block.mlp_fc1_b,
728        &block.mlp_fc2_w_t,
729        &block.mlp_fc2_b,
730        queries,
731        query_shape,
732    )
733}