1use super::config::Vjepa2Config;
22use super::predictor::Vjepa2PredictorLayout;
23use super::preprocess::Vjepa2PatchEmbedWeights;
24use super::rope::build_vjepa2_rope_tables;
25use super::weights::{
26 Vjepa2BlockWeights, Vjepa2EncoderWeights, Vjepa2PoolerCrossWeights,
27 Vjepa2PoolerSelfBlockWeights, Vjepa2PoolerWeights, Vjepa2PredictorWeights,
28};
29use anyhow::Result;
30use rlx_ir::hir::{FusionPolicy, HirModule, HirNodeId};
31use rlx_ir::op::{Activation, BinaryOp, MaskKind};
32use rlx_ir::{DType, Graph, Op, Shape};
33use std::collections::HashMap;
34
35pub struct Vjepa2GraphPreprocess {
37 pub patch: Vjepa2PatchEmbedWeights,
38}
39
40pub struct Vjepa2GraphParams {
42 pub f32: HashMap<String, Vec<f32>>,
43}
44
45impl Vjepa2GraphParams {
46 pub fn from_f32(map: HashMap<String, Vec<f32>>) -> Self {
47 Self { f32: map }
48 }
49
50 pub fn load(&self, compiled: &mut rlx_runtime::CompiledGraph) {
51 for (name, data) in &self.f32 {
52 compiled.set_param(name, data);
53 }
54 }
55}
56
57#[allow(dead_code)]
58fn lower_hir(hir: HirModule) -> Result<Graph> {
59 Ok(hir
60 .lower_to_mir()
61 .map_err(|e| anyhow::anyhow!("{e}"))?
62 .into_graph())
63}
64
65struct VjepaBuilder {
66 hir: HirModule,
67 params: HashMap<String, Vec<f32>>,
68 f: DType,
69}
70
71impl VjepaBuilder {
72 fn new(name: &str) -> Self {
73 Self {
74 hir: HirModule::new(name).with_fusion_policy(FusionPolicy::Direct),
75 params: HashMap::new(),
76 f: DType::F32,
77 }
78 }
79
80 #[allow(dead_code)]
81 fn finish(self) -> Result<Graph> {
82 lower_hir(self.hir)
83 }
84
85 fn shape3(&self, batch: usize, seq: usize, h: usize) -> Shape {
86 Shape::new(&[batch, seq, h], self.f)
87 }
88
89 fn node_shape(&self, id: HirNodeId) -> Shape {
90 self.hir.node(id).shape.clone()
91 }
92
93 fn layer_norm(
94 &mut self,
95 x: HirNodeId,
96 gamma: HirNodeId,
97 beta: HirNodeId,
98 eps: f32,
99 shape: Shape,
100 ) -> HirNodeId {
101 self.hir
102 .mir(Op::LayerNorm { axis: -1, eps }, vec![x, gamma, beta], shape)
103 }
104
105 fn reshape(&mut self, x: HirNodeId, new_shape: Vec<i64>) -> HirNodeId {
106 let in_shape = self.hir.node(x).shape.clone();
107 let static_dims: Vec<usize> = new_shape.iter().map(|&d| d as usize).collect();
108 let out = Shape::new(&static_dims, in_shape.dtype());
109 self.hir.mir(Op::Reshape { new_shape }, vec![x], out)
110 }
111
112 fn narrow(
113 &mut self,
114 x: HirNodeId,
115 axis: usize,
116 start: usize,
117 len: usize,
118 shape: Shape,
119 ) -> HirNodeId {
120 self.hir
121 .mir(Op::Narrow { axis, start, len }, vec![x], shape)
122 }
123
124 fn concat(&mut self, inputs: Vec<HirNodeId>, axis: usize, shape: Shape) -> HirNodeId {
125 self.hir.mir(Op::Concat { axis }, inputs, shape)
126 }
127
128 fn gather(&mut self, table: HirNodeId, indices: HirNodeId, axis: usize) -> HirNodeId {
129 let out = rlx_ir::shape::gather_shape(
130 &self.hir.node(table).shape,
131 &self.hir.node(indices).shape,
132 axis,
133 )
134 .expect("gather shape");
135 self.hir.mir(Op::Gather { axis }, vec![table, indices], out)
136 }
137
138 fn add(&mut self, a: HirNodeId, b: HirNodeId, shape: Shape) -> HirNodeId {
139 self.hir.mir(Op::Binary(BinaryOp::Add), vec![a, b], shape)
140 }
141
142 fn mm(&mut self, lhs: HirNodeId, rhs: HirNodeId) -> HirNodeId {
143 let out = rlx_ir::shape::matmul_shape(&self.hir.node(lhs).shape, &self.hir.node(rhs).shape)
144 .expect("matmul shape");
145 self.hir.mir(Op::MatMul, vec![lhs, rhs], out)
146 }
147
148 fn rope_n(
149 &mut self,
150 x: HirNodeId,
151 cos: HirNodeId,
152 sin: HirNodeId,
153 head_dim: usize,
154 n_rot: usize,
155 ) -> HirNodeId {
156 let shape = self.hir.node(x).shape.clone();
157 self.hir
158 .mir(Op::Rope { head_dim, n_rot }, vec![x, cos, sin], shape)
159 }
160
161 #[allow(dead_code)]
162 fn gelu_approx(&mut self, x: HirNodeId, shape: Shape) -> HirNodeId {
163 self.hir
164 .mir(Op::Activation(Activation::GeluApprox), vec![x], shape)
165 }
166
167 fn attention_custom(
168 &mut self,
169 q: HirNodeId,
170 k: HirNodeId,
171 v: HirNodeId,
172 mask: HirNodeId,
173 nh: usize,
174 dh: usize,
175 ) -> HirNodeId {
176 let out = rlx_ir::shape::attention_shape(&self.hir.node(q).shape);
177 self.hir
178 .attention(q, k, v, Some(mask), nh, dh, MaskKind::Custom, out)
179 }
180
181 fn attention_none(
182 &mut self,
183 q: HirNodeId,
184 k: HirNodeId,
185 v: HirNodeId,
186 nh: usize,
187 dh: usize,
188 ) -> HirNodeId {
189 let out = rlx_ir::shape::attention_shape(&self.hir.node(q).shape);
190 self.hir
191 .attention(q, k, v, None, nh, dh, MaskKind::None, out)
192 }
193
194 fn bind_vec(&mut self, name: &str, data: &[f32]) -> HirNodeId {
195 let id = self.hir.param(name, Shape::new(&[data.len()], self.f));
196 self.params.insert(name.to_string(), data.to_vec());
197 id
198 }
199
200 fn bind_mat(&mut self, name: &str, w_t: &[f32], in_dim: usize, out_dim: usize) -> HirNodeId {
201 let id = self.hir.param(name, Shape::new(&[in_dim, out_dim], self.f));
202 self.params.insert(name.to_string(), w_t.to_vec());
203 id
204 }
205
206 fn bind_indices(&mut self, name: &str, data: &[i64], shape: &[usize]) -> HirNodeId {
207 let f32_data: Vec<f32> = data.iter().map(|&v| v as f32).collect();
208 let id = self.hir.param(name, Shape::new(shape, self.f));
209 self.params.insert(name.to_string(), f32_data);
210 id
211 }
212
213 fn linear_named(
214 &mut self,
215 name: &str,
216 input: HirNodeId,
217 in_dim: usize,
218 w_t: &[f32],
219 b: &[f32],
220 ) -> HirNodeId {
221 let out_dim = b.len();
222 let w = self.bind_mat(&format!("{name}.weight"), w_t, in_dim, out_dim);
223 let bias = self.bind_vec(&format!("{name}.bias"), b);
224 let out_shape =
225 rlx_ir::shape::matmul_shape(&self.hir.node(input).shape, &self.hir.node(w).shape)
226 .expect("linear matmul shape");
227 self.hir.linear_fused(input, w, bias, None, out_shape)
228 }
229
230 fn mlp_block(
231 &mut self,
232 lp: &str,
233 x: HirNodeId,
234 embed: usize,
235 fc1_w_t: &[f32],
236 fc1_b: &[f32],
237 fc2_w_t: &[f32],
238 fc2_b: &[f32],
239 residual: HirNodeId,
240 out_shape: Shape,
241 ) -> HirNodeId {
242 let hidden = fc1_b.len();
243 let fc1_w = self.bind_mat(&format!("{lp}.mlp.fc1.weight"), fc1_w_t, embed, hidden);
244 let fc1_bias = self.bind_vec(&format!("{lp}.mlp.fc1.bias"), fc1_b);
245 let fc1_shape =
246 rlx_ir::shape::matmul_shape(&self.hir.node(x).shape, &self.hir.node(fc1_w).shape)
247 .expect("fc1 shape");
248 let up = self
249 .hir
250 .linear_fused(x, fc1_w, fc1_bias, Some(Activation::GeluApprox), fc1_shape);
251
252 let fc2_w = self.bind_mat(&format!("{lp}.mlp.fc2.weight"), fc2_w_t, hidden, embed);
253 let fc2_bias = self.bind_vec(&format!("{lp}.mlp.fc2.bias"), fc2_b);
254 let fc2_shape =
255 rlx_ir::shape::matmul_shape(&self.hir.node(up).shape, &self.hir.node(fc2_w).shape)
256 .expect("fc2 shape");
257 let ffn = self.hir.linear_fused(up, fc2_w, fc2_bias, None, fc2_shape);
258 self.add(residual, ffn, out_shape)
259 }
260}
261
262pub fn build_vjepa2_encoder_hir_sized(
264 cfg: &Vjepa2Config,
265 enc: &Vjepa2EncoderWeights,
266 batch: usize,
267) -> Result<(HirModule, HashMap<String, Vec<f32>>, Vjepa2GraphPreprocess)> {
268 let mut b = VjepaBuilder::new("vjepa2_encoder");
269
270 let h = cfg.hidden_size;
271 let nh = cfg.num_attention_heads;
272 let dh = cfg.head_dim();
273 let eps = cfg.layer_norm_eps as f32;
274 let seq = cfg.num_patches();
275 let (d_dim, hd_dim, w_dim) = cfg.rope_segment_dims();
276 let grid_h = cfg.grid_spatial();
277 let grid_w = cfg.grid_spatial();
278 let n_rot = d_dim + hd_dim + w_dim;
279
280 let preprocess = Vjepa2GraphPreprocess {
281 patch: enc.patch.clone(),
282 };
283
284 let (cos_data, sin_data) =
285 build_vjepa2_rope_tables(seq, dh, d_dim, hd_dim, w_dim, grid_h, grid_w);
286 let half = dh / 2;
287 let cos_id = b.bind_mat("rope_cos", &cos_data, seq, half);
288 let sin_id = b.bind_mat("rope_sin", &sin_data, seq, half);
289
290 let mask_data = vec![1.0f32; batch * seq];
291 let mask_id = b.hir.param("attn_mask", Shape::new(&[batch, seq], b.f));
292 b.params.insert("attn_mask".into(), mask_data);
293
294 let hidden_input = b.hir.input("hidden", b.shape3(batch, seq, h));
295 let mut x = hidden_input;
296 let enc_shape = b.shape3(batch, seq, h);
297
298 for (layer_idx, block) in enc.blocks.iter().enumerate() {
299 let lp = format!("blocks.{layer_idx}");
300 x = append_rope_block(
301 &mut b,
302 x,
303 block,
304 &lp,
305 h,
306 nh,
307 dh,
308 n_rot,
309 cos_id,
310 sin_id,
311 Some(mask_id),
312 eps,
313 true,
314 enc_shape.clone(),
315 );
316 }
317
318 let fn_g = b.bind_vec("norm.weight", &enc.norm_w);
319 let fn_b = b.bind_vec("norm.bias", &enc.norm_b);
320 let encoded = b.layer_norm(x, fn_g, fn_b, eps, enc_shape);
321 b.hir.outputs = vec![encoded];
322
323 Ok((b.hir, b.params, preprocess))
324}
325
326pub fn build_vjepa2_encoder_graph_sized(
328 cfg: &Vjepa2Config,
329 enc: &Vjepa2EncoderWeights,
330 batch: usize,
331) -> Result<(Graph, HashMap<String, Vec<f32>>, Vjepa2GraphPreprocess)> {
332 let built = super::flow::Vjepa2EncoderFlow::new(cfg, enc, batch).build()?;
333 let (graph, params) = rlx_core::flow_util::graph_from_built(built.model)?;
334 Ok((graph, params, built.preprocess))
335}
336
337pub fn build_vjepa2_predictor_hir_sized(
339 cfg: &Vjepa2Config,
340 pred: &Vjepa2PredictorWeights,
341 layout: &Vjepa2PredictorLayout,
342 mask_rows: &[f32],
343 batch: usize,
344) -> Result<(HirModule, Vjepa2GraphParams)> {
345 let mut b = VjepaBuilder::new("vjepa2_predictor");
346
347 let enc = cfg.hidden_size;
348 let pred_h = cfg.pred_hidden_size;
349 let nh = cfg.pred_num_attention_heads;
350 let dh = cfg.pred_head_dim();
351 let eps = cfg.layer_norm_eps as f32;
352 let enc_seq = cfg.num_patches();
353 let (d_dim, hd_dim, w_dim) = cfg.pred_rope_segment_dims();
354 let n_rot = d_dim + hd_dim + w_dim;
355 let n_ctxt = layout.n_ctxt;
356 let n_tgt = layout.n_tgt;
357 let n_combined = layout.n_combined;
358 let half = dh / 2;
359
360 let encoder = b.hir.input("encoder", b.shape3(batch, enc_seq, enc));
361
362 let ctxt_idx_id = b.bind_indices("ctxt_idx", &layout.ctxt_idx, &[batch, n_ctxt]);
363 let ctxt = b.gather(encoder, ctxt_idx_id, 1);
364 let ctxt = b.reshape(ctxt, vec![batch as i64, n_ctxt as i64, enc as i64]);
365
366 let embed_w = b.bind_mat("embed.weight", &pred.embed_w_t, enc, pred_h);
367 let embed_b = b.bind_vec("embed.bias", &pred.embed_b);
368 let mm_embed = b.mm(ctxt, embed_w);
369 let ctxt_up = b.add(mm_embed, embed_b, b.shape3(batch, n_ctxt, pred_h));
370 let ctxt_embed = b.reshape(ctxt_up, vec![batch as i64, n_ctxt as i64, pred_h as i64]);
371
372 let mask_id = b
373 .hir
374 .param("mask_rows", Shape::new(&[batch, n_tgt, pred_h], b.f));
375 b.params.insert("mask_rows".into(), mask_rows.to_vec());
376 let mut x = b.concat(
377 vec![ctxt_embed, mask_id],
378 1,
379 b.shape3(batch, n_combined, pred_h),
380 );
381 x = b.reshape(x, vec![batch as i64, n_combined as i64, pred_h as i64]);
382
383 let sort_idx_id = b.bind_indices("sort_idx", &layout.sort_idx, &[batch, n_combined]);
384 x = b.gather(x, sort_idx_id, 1);
385 x = b.reshape(x, vec![batch as i64, n_combined as i64, pred_h as i64]);
386
387 let cos_id = b.bind_mat("rope_cos", &layout.rope_cos, n_combined, half);
388 let sin_id = b.bind_mat("rope_sin", &layout.rope_sin, n_combined, half);
389 let pred_shape = b.shape3(batch, n_combined, pred_h);
390
391 for (layer_idx, block) in pred.blocks.iter().enumerate() {
392 let lp = format!("blocks.{layer_idx}");
393 x = append_rope_block(
394 &mut b,
395 x,
396 block,
397 &lp,
398 pred_h,
399 nh,
400 dh,
401 n_rot,
402 cos_id,
403 sin_id,
404 None,
405 eps,
406 false,
407 pred_shape.clone(),
408 );
409 }
410
411 let fn_g = b.bind_vec("norm.weight", &pred.norm_w);
412 let fn_b = b.bind_vec("norm.bias", &pred.norm_b);
413 x = b.layer_norm(x, fn_g, fn_b, eps, pred_shape.clone());
414
415 let unsort_idx_id = b.bind_indices("unsort_idx", &layout.unsort_idx, &[batch, n_combined]);
416 x = b.gather(x, unsort_idx_id, 1);
417 x = b.reshape(x, vec![batch as i64, n_combined as i64, pred_h as i64]);
418 x = b.narrow(x, 1, n_ctxt, n_tgt, b.shape3(batch, n_tgt, pred_h));
419 x = b.reshape(x, vec![batch as i64, n_tgt as i64, pred_h as i64]);
420
421 let proj_w = b.bind_mat("proj.weight", &pred.proj_w_t, pred_h, enc);
422 let proj_b = b.bind_vec("proj.bias", &pred.proj_b);
423 let mm_proj = b.mm(x, proj_w);
424 let out = b.add(mm_proj, proj_b, b.shape3(batch, n_tgt, enc));
425 b.hir.outputs = vec![out];
426
427 Ok((b.hir, Vjepa2GraphParams { f32: b.params }))
428}
429
430pub fn build_vjepa2_predictor_graph_sized(
432 cfg: &Vjepa2Config,
433 pred: &Vjepa2PredictorWeights,
434 layout: &Vjepa2PredictorLayout,
435 mask_rows: &[f32],
436 batch: usize,
437) -> Result<(Graph, Vjepa2GraphParams)> {
438 let built =
439 super::flow::Vjepa2PredictorFlow::new(cfg, pred, layout, mask_rows, batch).build()?;
440 let (graph, params) = rlx_core::flow_util::graph_from_built(built)?;
441 Ok((graph, Vjepa2GraphParams { f32: params }))
442}
443
444pub fn build_vjepa2_pooler_hir_sized(
446 cfg: &Vjepa2Config,
447 pooler: &Vjepa2PoolerWeights,
448 batch: usize,
449) -> Result<(HirModule, Vjepa2GraphParams)> {
450 let mut b = VjepaBuilder::new("vjepa2_pooler");
451
452 let e = cfg.hidden_size;
453 let nh = cfg.num_attention_heads;
454 let dh = cfg.head_dim();
455 let hidden = cfg.pooler_intermediate_size();
456 let eps = cfg.layer_norm_eps as f32;
457 let seq = cfg.num_patches();
458
459 let encoder = b.hir.input("encoder", b.shape3(batch, seq, e));
460 let mut ctx = encoder;
461 let ctx_shape = b.shape3(batch, seq, e);
462
463 for (layer_idx, block) in pooler.self_blocks.iter().enumerate() {
464 let lp = format!("self.{layer_idx}");
465 ctx = append_pooler_self_block(
466 &mut b,
467 ctx,
468 block,
469 &lp,
470 e,
471 nh,
472 dh,
473 hidden,
474 eps,
475 ctx_shape.clone(),
476 );
477 }
478
479 let mut query_data = Vec::with_capacity(batch * e);
480 for _ in 0..batch {
481 query_data.extend_from_slice(&pooler.query_tokens);
482 }
483 let query_id = b.bind_vec("query_tokens", &query_data);
484 let mut queries = b.reshape(query_id, vec![batch as i64, 1, e as i64]);
485 let query_shape = b.shape3(batch, 1, e);
486
487 queries = append_pooler_cross_block(
488 &mut b,
489 queries,
490 ctx,
491 &pooler.cross,
492 "cross",
493 e,
494 nh,
495 dh,
496 hidden,
497 eps,
498 query_shape.clone(),
499 );
500
501 queries = b.narrow(queries, 1, 0, 1, query_shape.clone());
502 let embedding = b.reshape(queries, vec![batch as i64, e as i64]);
503
504 let mut outputs = vec![embedding];
505 if let (Some(w_t), Some(bias)) = (&pooler.classifier_w_t, &pooler.classifier_b) {
506 let nc = bias.len();
507 let cls_w = b.bind_mat("classifier.weight", w_t, e, nc);
508 let cls_b = b.bind_vec("classifier.bias", bias);
509 let mm = b.mm(embedding, cls_w);
510 let logits = b.add(mm, cls_b, Shape::new(&[batch, nc], b.f));
511 outputs.push(logits);
512 }
513 b.hir.outputs = outputs;
514
515 Ok((b.hir, Vjepa2GraphParams { f32: b.params }))
516}
517
518pub fn build_vjepa2_pooler_graph_sized(
520 cfg: &Vjepa2Config,
521 pooler: &Vjepa2PoolerWeights,
522 batch: usize,
523) -> Result<(Graph, Vjepa2GraphParams)> {
524 let built = super::flow::Vjepa2PoolerFlow::new(cfg, pooler, batch).build()?;
525 let (graph, params) = rlx_core::flow_util::graph_from_built(built)?;
526 Ok((graph, Vjepa2GraphParams { f32: params }))
527}
528
529pub fn compile_vjepa2_encoder(
531 cfg: &Vjepa2Config,
532 enc: &Vjepa2EncoderWeights,
533 batch: usize,
534 device: rlx_runtime::Device,
535) -> Result<(
536 rlx_runtime::CompiledGraph,
537 HashMap<String, Vec<f32>>,
538 Vjepa2GraphPreprocess,
539)> {
540 use rlx_runtime::Session;
541
542 let (hir, params, preprocess) = build_vjepa2_encoder_hir_sized(cfg, enc, batch)?;
543 let opts = rlx_core::flow_bridge::compile_options_for_profile(
544 &rlx_flow::CompileProfile::encoder(),
545 device,
546 );
547 let mut compiled = Session::new(device).compile_hir_with(hir, &opts)?;
548 for (name, data) in ¶ms {
549 compiled.set_param(name, data);
550 }
551 Ok((compiled, params, preprocess))
552}
553
554#[allow(clippy::too_many_arguments)]
555fn append_rope_block(
556 b: &mut VjepaBuilder,
557 x: HirNodeId,
558 block: &Vjepa2BlockWeights,
559 lp: &str,
560 embed: usize,
561 nh: usize,
562 dh: usize,
563 n_rot: usize,
564 cos_id: HirNodeId,
565 sin_id: HirNodeId,
566 mask_id: Option<HirNodeId>,
567 eps: f32,
568 use_mask: bool,
569 block_shape: Shape,
570) -> HirNodeId {
571 let n1_g = b.bind_vec(&format!("{lp}.norm1.weight"), &block.norm1_w);
572 let n1_b = b.bind_vec(&format!("{lp}.norm1.bias"), &block.norm1_b);
573 let normed1 = b.layer_norm(x, n1_g, n1_b, eps, block_shape.clone());
574
575 let q = b.linear_named(
576 &format!("{lp}.attn.q"),
577 normed1,
578 embed,
579 &block.q_w_t,
580 &block.q_b,
581 );
582 let k = b.linear_named(
583 &format!("{lp}.attn.k"),
584 normed1,
585 embed,
586 &block.k_w_t,
587 &block.k_b,
588 );
589 let v = b.linear_named(
590 &format!("{lp}.attn.v"),
591 normed1,
592 embed,
593 &block.v_w_t,
594 &block.v_b,
595 );
596
597 let q_rot = b.rope_n(q, cos_id, sin_id, dh, n_rot);
598 let k_rot = b.rope_n(k, cos_id, sin_id, dh, n_rot);
599 let attn = if use_mask {
600 let mask = mask_id.expect("rope block with use_mask requires attn mask");
601 b.attention_custom(q_rot, k_rot, v, mask, nh, dh)
602 } else {
603 b.attention_none(q_rot, k_rot, v, nh, dh)
604 };
605
606 let p_w = b.bind_mat(
607 &format!("{lp}.attn.proj.weight"),
608 &block.proj_w_t,
609 embed,
610 embed,
611 );
612 let p_b = b.bind_vec(&format!("{lp}.attn.proj.bias"), &block.proj_b);
613 let mm_proj = b.mm(attn, p_w);
614 let proj = b.add(mm_proj, p_b, block_shape.clone());
615 let x = b.add(x, proj, block_shape.clone());
616
617 let n2_g = b.bind_vec(&format!("{lp}.norm2.weight"), &block.norm2_w);
618 let n2_b = b.bind_vec(&format!("{lp}.norm2.bias"), &block.norm2_b);
619 let normed2 = b.layer_norm(x, n2_g, n2_b, eps, block_shape.clone());
620
621 b.mlp_block(
622 lp,
623 normed2,
624 embed,
625 &block.mlp_fc1_w_t,
626 &block.mlp_fc1_b,
627 &block.mlp_fc2_w_t,
628 &block.mlp_fc2_b,
629 x,
630 block_shape,
631 )
632}
633
634#[allow(clippy::too_many_arguments)]
635fn append_pooler_self_block(
636 b: &mut VjepaBuilder,
637 x: HirNodeId,
638 block: &Vjepa2PoolerSelfBlockWeights,
639 lp: &str,
640 embed: usize,
641 nh: usize,
642 dh: usize,
643 _hidden: usize,
644 eps: f32,
645 block_shape: Shape,
646) -> HirNodeId {
647 let n1_g = b.bind_vec(&format!("{lp}.norm1.weight"), &block.norm1_w);
648 let n1_b = b.bind_vec(&format!("{lp}.norm1.bias"), &block.norm1_b);
649 let normed1 = b.layer_norm(x, n1_g, n1_b, eps, block_shape.clone());
650
651 let q = b.linear_named(&format!("{lp}.q"), normed1, embed, &block.q_w_t, &block.q_b);
652 let k = b.linear_named(&format!("{lp}.k"), normed1, embed, &block.k_w_t, &block.k_b);
653 let v = b.linear_named(&format!("{lp}.v"), normed1, embed, &block.v_w_t, &block.v_b);
654 let attn = b.attention_none(q, k, v, nh, dh);
655
656 let out_w = b.bind_mat(&format!("{lp}.out.weight"), &block.out_w_t, embed, embed);
657 let out_b = b.bind_vec(&format!("{lp}.out.bias"), &block.out_b);
658 let mm_out = b.mm(attn, out_w);
659 let proj = b.add(mm_out, out_b, block_shape.clone());
660 let x = b.add(x, proj, block_shape.clone());
661
662 let n2_g = b.bind_vec(&format!("{lp}.norm2.weight"), &block.norm2_w);
663 let n2_b = b.bind_vec(&format!("{lp}.norm2.bias"), &block.norm2_b);
664 let normed2 = b.layer_norm(x, n2_g, n2_b, eps, block_shape.clone());
665
666 b.mlp_block(
667 lp,
668 normed2,
669 embed,
670 &block.mlp_fc1_w_t,
671 &block.mlp_fc1_b,
672 &block.mlp_fc2_w_t,
673 &block.mlp_fc2_b,
674 x,
675 block_shape,
676 )
677}
678
679#[allow(clippy::too_many_arguments)]
680fn append_pooler_cross_block(
681 b: &mut VjepaBuilder,
682 queries: HirNodeId,
683 context: HirNodeId,
684 block: &Vjepa2PoolerCrossWeights,
685 lp: &str,
686 embed: usize,
687 nh: usize,
688 dh: usize,
689 _hidden: usize,
690 eps: f32,
691 query_shape: Shape,
692) -> HirNodeId {
693 let ctx_shape = b.node_shape(context);
694 let residual = queries;
695
696 let n1_g = b.bind_vec(&format!("{lp}.norm1.weight"), &block.norm1_w);
697 let n1_b = b.bind_vec(&format!("{lp}.norm1.bias"), &block.norm1_b);
698 let ctx_norm = b.layer_norm(context, n1_g, n1_b, eps, ctx_shape);
699
700 let q = b.linear_named(&format!("{lp}.q"), queries, embed, &block.q_w_t, &block.q_b);
701 let k = b.linear_named(
702 &format!("{lp}.k"),
703 ctx_norm,
704 embed,
705 &block.k_w_t,
706 &block.k_b,
707 );
708 let v = b.linear_named(
709 &format!("{lp}.v"),
710 ctx_norm,
711 embed,
712 &block.v_w_t,
713 &block.v_b,
714 );
715 let attn = b.attention_none(q, k, v, nh, dh);
716 let queries = b.add(residual, attn, query_shape.clone());
717
718 let n2_g = b.bind_vec(&format!("{lp}.norm2.weight"), &block.norm2_w);
719 let n2_b = b.bind_vec(&format!("{lp}.norm2.bias"), &block.norm2_b);
720 let normed2 = b.layer_norm(queries, n2_g, n2_b, eps, query_shape.clone());
721
722 b.mlp_block(
723 lp,
724 normed2,
725 embed,
726 &block.mlp_fc1_w_t,
727 &block.mlp_fc1_b,
728 &block.mlp_fc2_w_t,
729 &block.mlp_fc2_b,
730 queries,
731 query_shape,
732 )
733}