Skip to main content

rlx_vjepa2/
pooler.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! V-JEPA2 attentive pooler + optional classifier (finetuned checkpoints).
17
18use super::config::Vjepa2Config;
19use super::layers::{attention_plain, cross_attention};
20use super::weights::{Vjepa2PoolerCrossWeights, Vjepa2PoolerSelfBlockWeights, Vjepa2PoolerWeights};
21use anyhow::Result;
22use rlx_tensor::{gelu_tanh, layer_norm, linear};
23
24pub struct Vjepa2PoolerOutput {
25    pub embedding: Vec<f32>,
26    pub logits: Option<Vec<f32>>,
27}
28
29/// Pool encoder tokens `[batch, seq, hidden]` → `[batch, hidden]` embedding.
30pub fn pool_native(
31    encoder_tokens: &[f32],
32    weights: &Vjepa2PoolerWeights,
33    cfg: &Vjepa2Config,
34    batch: usize,
35    seq: usize,
36) -> Result<Vjepa2PoolerOutput> {
37    let e = cfg.hidden_size;
38    let nh = cfg.num_attention_heads;
39    let head_dim = cfg.head_dim();
40    let hidden = cfg.pooler_intermediate_size();
41    let eps = cfg.layer_norm_eps as f32;
42
43    let mut per_batch = Vec::with_capacity(batch * e);
44
45    for bi in 0..batch {
46        let mut x = encoder_tokens[bi * seq * e..(bi + 1) * seq * e].to_vec();
47
48        for block in &weights.self_blocks {
49            pooler_self_block(&mut x, block, 1, seq, e, nh, head_dim, hidden, eps)?;
50        }
51
52        let mut q = weights.query_tokens.clone();
53        cross_block(
54            &mut q,
55            &x,
56            &weights.cross,
57            1,
58            1,
59            seq,
60            e,
61            nh,
62            head_dim,
63            hidden,
64            eps,
65        )?;
66        per_batch.extend_from_slice(&q[..e]);
67    }
68
69    let logits = match (&weights.classifier_w_t, &weights.classifier_b) {
70        (Some(w), Some(b)) => {
71            let nc = b.len();
72            Some(linear(&per_batch, batch, e, w, nc, b)?)
73        }
74        _ => None,
75    };
76
77    Ok(Vjepa2PoolerOutput {
78        embedding: per_batch,
79        logits,
80    })
81}
82
83#[allow(clippy::too_many_arguments)]
84fn pooler_self_block(
85    x: &mut [f32],
86    block: &Vjepa2PoolerSelfBlockWeights,
87    batch: usize,
88    seq: usize,
89    e: usize,
90    nh: usize,
91    head_dim: usize,
92    hidden: usize,
93    eps: f32,
94) -> Result<()> {
95    let rows = batch * seq;
96    let n1 = layer_norm(x, &block.norm1_w, &block.norm1_b, e, eps)?;
97    let attn = attention_plain(
98        &n1,
99        batch,
100        seq,
101        e,
102        nh,
103        head_dim,
104        &block.q_w_t,
105        &block.q_b,
106        &block.k_w_t,
107        &block.k_b,
108        &block.v_w_t,
109        &block.v_b,
110        &block.out_w_t,
111        &block.out_b,
112    )?;
113    for i in 0..x.len() {
114        x[i] += attn[i];
115    }
116
117    let n2 = layer_norm(x, &block.norm2_w, &block.norm2_b, e, eps)?;
118    let mut mlp = linear(&n2, rows, e, &block.mlp_fc1_w_t, hidden, &block.mlp_fc1_b)?;
119    gelu_tanh(&mut mlp);
120    let ffn = linear(&mlp, rows, hidden, &block.mlp_fc2_w_t, e, &block.mlp_fc2_b)?;
121    for i in 0..x.len() {
122        x[i] += ffn[i];
123    }
124    Ok(())
125}
126
127#[allow(clippy::too_many_arguments)]
128fn cross_block(
129    queries: &mut [f32],
130    context: &[f32],
131    block: &Vjepa2PoolerCrossWeights,
132    batch: usize,
133    l_q: usize,
134    l_kv: usize,
135    e: usize,
136    nh: usize,
137    head_dim: usize,
138    hidden: usize,
139    eps: f32,
140) -> Result<()> {
141    let residual = queries.to_vec();
142    let ctx_norm = layer_norm(context, &block.norm1_w, &block.norm1_b, e, eps)?;
143    let attn = cross_attention(
144        queries,
145        &ctx_norm,
146        batch,
147        l_q,
148        l_kv,
149        e,
150        nh,
151        head_dim,
152        &block.q_w_t,
153        &block.q_b,
154        &block.k_w_t,
155        &block.k_b,
156        &block.v_w_t,
157        &block.v_b,
158    )?;
159    for i in 0..queries.len() {
160        queries[i] = residual[i] + attn[i];
161    }
162
163    let n2 = layer_norm(queries, &block.norm2_w, &block.norm2_b, e, eps)?;
164    let mut mlp = linear(
165        &n2,
166        batch * l_q,
167        e,
168        &block.mlp_fc1_w_t,
169        hidden,
170        &block.mlp_fc1_b,
171    )?;
172    gelu_tanh(&mut mlp);
173    let ffn = linear(
174        &mlp,
175        batch * l_q,
176        hidden,
177        &block.mlp_fc2_w_t,
178        e,
179        &block.mlp_fc2_b,
180    )?;
181    for i in 0..queries.len() {
182        queries[i] += ffn[i];
183    }
184    Ok(())
185}