1use super::config::Vjepa2Config;
19use super::layers::{attention_plain, cross_attention};
20use super::weights::{Vjepa2PoolerCrossWeights, Vjepa2PoolerSelfBlockWeights, Vjepa2PoolerWeights};
21use anyhow::Result;
22use rlx_tensor::{gelu_tanh, layer_norm, linear};
23
24pub struct Vjepa2PoolerOutput {
25 pub embedding: Vec<f32>,
26 pub logits: Option<Vec<f32>>,
27}
28
29pub fn pool_native(
31 encoder_tokens: &[f32],
32 weights: &Vjepa2PoolerWeights,
33 cfg: &Vjepa2Config,
34 batch: usize,
35 seq: usize,
36) -> Result<Vjepa2PoolerOutput> {
37 let e = cfg.hidden_size;
38 let nh = cfg.num_attention_heads;
39 let head_dim = cfg.head_dim();
40 let hidden = cfg.pooler_intermediate_size();
41 let eps = cfg.layer_norm_eps as f32;
42
43 let mut per_batch = Vec::with_capacity(batch * e);
44
45 for bi in 0..batch {
46 let mut x = encoder_tokens[bi * seq * e..(bi + 1) * seq * e].to_vec();
47
48 for block in &weights.self_blocks {
49 pooler_self_block(&mut x, block, 1, seq, e, nh, head_dim, hidden, eps)?;
50 }
51
52 let mut q = weights.query_tokens.clone();
53 cross_block(
54 &mut q,
55 &x,
56 &weights.cross,
57 1,
58 1,
59 seq,
60 e,
61 nh,
62 head_dim,
63 hidden,
64 eps,
65 )?;
66 per_batch.extend_from_slice(&q[..e]);
67 }
68
69 let logits = match (&weights.classifier_w_t, &weights.classifier_b) {
70 (Some(w), Some(b)) => {
71 let nc = b.len();
72 Some(linear(&per_batch, batch, e, w, nc, b)?)
73 }
74 _ => None,
75 };
76
77 Ok(Vjepa2PoolerOutput {
78 embedding: per_batch,
79 logits,
80 })
81}
82
83#[allow(clippy::too_many_arguments)]
84fn pooler_self_block(
85 x: &mut [f32],
86 block: &Vjepa2PoolerSelfBlockWeights,
87 batch: usize,
88 seq: usize,
89 e: usize,
90 nh: usize,
91 head_dim: usize,
92 hidden: usize,
93 eps: f32,
94) -> Result<()> {
95 let rows = batch * seq;
96 let n1 = layer_norm(x, &block.norm1_w, &block.norm1_b, e, eps)?;
97 let attn = attention_plain(
98 &n1,
99 batch,
100 seq,
101 e,
102 nh,
103 head_dim,
104 &block.q_w_t,
105 &block.q_b,
106 &block.k_w_t,
107 &block.k_b,
108 &block.v_w_t,
109 &block.v_b,
110 &block.out_w_t,
111 &block.out_b,
112 )?;
113 for i in 0..x.len() {
114 x[i] += attn[i];
115 }
116
117 let n2 = layer_norm(x, &block.norm2_w, &block.norm2_b, e, eps)?;
118 let mut mlp = linear(&n2, rows, e, &block.mlp_fc1_w_t, hidden, &block.mlp_fc1_b)?;
119 gelu_tanh(&mut mlp);
120 let ffn = linear(&mlp, rows, hidden, &block.mlp_fc2_w_t, e, &block.mlp_fc2_b)?;
121 for i in 0..x.len() {
122 x[i] += ffn[i];
123 }
124 Ok(())
125}
126
127#[allow(clippy::too_many_arguments)]
128fn cross_block(
129 queries: &mut [f32],
130 context: &[f32],
131 block: &Vjepa2PoolerCrossWeights,
132 batch: usize,
133 l_q: usize,
134 l_kv: usize,
135 e: usize,
136 nh: usize,
137 head_dim: usize,
138 hidden: usize,
139 eps: f32,
140) -> Result<()> {
141 let residual = queries.to_vec();
142 let ctx_norm = layer_norm(context, &block.norm1_w, &block.norm1_b, e, eps)?;
143 let attn = cross_attention(
144 queries,
145 &ctx_norm,
146 batch,
147 l_q,
148 l_kv,
149 e,
150 nh,
151 head_dim,
152 &block.q_w_t,
153 &block.q_b,
154 &block.k_w_t,
155 &block.k_b,
156 &block.v_w_t,
157 &block.v_b,
158 )?;
159 for i in 0..queries.len() {
160 queries[i] = residual[i] + attn[i];
161 }
162
163 let n2 = layer_norm(queries, &block.norm2_w, &block.norm2_b, e, eps)?;
164 let mut mlp = linear(
165 &n2,
166 batch * l_q,
167 e,
168 &block.mlp_fc1_w_t,
169 hidden,
170 &block.mlp_fc1_b,
171 )?;
172 gelu_tanh(&mut mlp);
173 let ffn = linear(
174 &mlp,
175 batch * l_q,
176 hidden,
177 &block.mlp_fc2_w_t,
178 e,
179 &block.mlp_fc2_b,
180 )?;
181 for i in 0..queries.len() {
182 queries[i] += ffn[i];
183 }
184 Ok(())
185}