Skip to main content

rlx_vjepa2/
predictor.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! V-JEPA2 predictor — masked token prediction head.
17
18use super::config::Vjepa2Config;
19use super::layers::{block_forward, gather_rows};
20use super::weights::Vjepa2PredictorWeights;
21use anyhow::{Result, ensure};
22use rlx_tensor::{layer_norm, linear};
23
24/// Context / target patch indices for one batch element.
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct Vjepa2Masks {
27    pub context: Vec<usize>,
28    pub target: Vec<usize>,
29    /// Which learned mask token vector to use (`0..pred_num_mask_tokens`).
30    pub mask_index: usize,
31}
32
33pub struct Vjepa2PredictorOutput {
34    pub tokens: Vec<f32>,
35    pub num_target: usize,
36    pub hidden: usize,
37}
38
39/// Baked gather / RoPE layout for a compiled predictor graph.
40#[derive(Debug, Clone)]
41pub struct Vjepa2PredictorLayout {
42    pub n_ctxt: usize,
43    pub n_tgt: usize,
44    pub n_combined: usize,
45    /// Flat `[batch, n_ctxt]` patch indices into encoder sequence.
46    pub ctxt_idx: Vec<i64>,
47    /// Flat `[batch, n_combined]` row gather for sort-by-position.
48    pub sort_idx: Vec<i64>,
49    /// Flat `[batch, n_combined]` inverse gather before target slice.
50    pub unsort_idx: Vec<i64>,
51    /// Flat `[batch, n_tgt, pred_hidden]` mask token rows.
52    pub mask_rows: Vec<f32>,
53    /// `[n_combined, pred_head_dim/2]`
54    pub rope_cos: Vec<f32>,
55    pub rope_sin: Vec<f32>,
56}
57
58/// Precompute indices and RoPE tables for [`super::builder::build_vjepa2_predictor_graph_sized`].
59pub fn prepare_predictor_layout(
60    cfg: &Vjepa2Config,
61    masks: &Vjepa2Masks,
62    batch: usize,
63) -> Result<Vjepa2PredictorLayout> {
64    use super::rope::build_vjepa2_rope_tables;
65
66    ensure!(!masks.context.is_empty(), "context mask must be non-empty");
67    ensure!(!masks.target.is_empty(), "target mask must be non-empty");
68
69    let pred = cfg.pred_hidden_size;
70    let pred_dh = cfg.pred_head_dim();
71    let (d_dim, h_dim, w_dim) = cfg.pred_rope_segment_dims();
72    let grid_h = cfg.grid_spatial();
73    let grid_w = cfg.grid_spatial();
74    let enc_seq = cfg.num_patches();
75
76    let n_ctxt = masks.context.len();
77    let n_tgt = masks.target.len();
78    let n_combined = n_ctxt + n_tgt;
79
80    let mut position_ids: Vec<usize> = Vec::with_capacity(n_combined);
81    position_ids.extend_from_slice(&masks.context);
82    position_ids.extend_from_slice(&masks.target);
83
84    let mut order: Vec<usize> = (0..n_combined).collect();
85    order.sort_by_key(|&i| position_ids[i]);
86
87    let mut sort_idx = vec![0i64; n_combined];
88    let mut unsort_idx = vec![0i64; n_combined];
89    for (new_i, &old_i) in order.iter().enumerate() {
90        sort_idx[new_i] = old_i as i64;
91        unsort_idx[old_i] = new_i as i64;
92    }
93
94    let sorted_pos: Vec<usize> = order.iter().map(|&i| position_ids[i]).collect();
95    let (full_cos, full_sin) =
96        build_vjepa2_rope_tables(enc_seq, pred_dh, d_dim, h_dim, w_dim, grid_h, grid_w);
97    let half = pred_dh / 2;
98    let mut rope_cos = vec![0f32; n_combined * half];
99    let mut rope_sin = vec![0f32; n_combined * half];
100    for (i, &p) in sorted_pos.iter().enumerate() {
101        rope_cos[i * half..(i + 1) * half].copy_from_slice(&full_cos[p * half..(p + 1) * half]);
102        rope_sin[i * half..(i + 1) * half].copy_from_slice(&full_sin[p * half..(p + 1) * half]);
103    }
104
105    let mut ctxt_idx = Vec::with_capacity(batch * n_ctxt);
106    let mut sort_flat = Vec::with_capacity(batch * n_combined);
107    let mut unsort_flat = Vec::with_capacity(batch * n_combined);
108    for _ in 0..batch {
109        ctxt_idx.extend(masks.context.iter().map(|&i| i as i64));
110        sort_flat.extend_from_slice(&sort_idx);
111        unsort_flat.extend_from_slice(&unsort_idx);
112    }
113
114    Ok(Vjepa2PredictorLayout {
115        n_ctxt,
116        n_tgt,
117        n_combined,
118        ctxt_idx,
119        sort_idx: sort_flat,
120        unsort_idx: unsort_flat,
121        mask_rows: vec![0f32; batch * n_tgt * pred],
122        rope_cos,
123        rope_sin,
124    })
125}
126
127/// Tile the selected mask token into `[batch, n_tgt, pred_hidden]` row-major.
128pub fn predictor_mask_rows(
129    weights: &super::weights::Vjepa2PredictorWeights,
130    cfg: &Vjepa2Config,
131    masks: &Vjepa2Masks,
132    batch: usize,
133) -> Vec<f32> {
134    let pred = cfg.pred_hidden_size;
135    let n_tgt = masks.target.len();
136    let mask_idx = masks.mask_index % cfg.pred_num_mask_tokens;
137    let mask_vec = &weights.mask_tokens[mask_idx * pred..(mask_idx + 1) * pred];
138    let mut rows = Vec::with_capacity(batch * n_tgt * pred);
139    for _ in 0..batch {
140        for _ in 0..n_tgt {
141            rows.extend_from_slice(mask_vec);
142        }
143    }
144    rows
145}
146
147/// Run the predictor on encoder outputs `[batch, seq, enc_dim]` flat.
148pub fn predict_native(
149    encoder_tokens: &[f32],
150    weights: &Vjepa2PredictorWeights,
151    cfg: &Vjepa2Config,
152    batch: usize,
153    seq: usize,
154    masks: &Vjepa2Masks,
155) -> Result<Vjepa2PredictorOutput> {
156    let enc = cfg.hidden_size;
157    let pred = cfg.pred_hidden_size;
158    let nh = cfg.pred_num_attention_heads;
159    let head_dim = cfg.pred_head_dim();
160    let (d_dim, h_dim, w_dim) = cfg.pred_rope_segment_dims();
161    let grid_t = cfg.grid_temporal();
162    let grid_h = cfg.grid_spatial();
163    let grid_w = cfg.grid_spatial();
164    let eps = cfg.layer_norm_eps as f32;
165
166    ensure!(!masks.context.is_empty(), "context mask must be non-empty");
167    ensure!(!masks.target.is_empty(), "target mask must be non-empty");
168
169    let n_ctxt = masks.context.len();
170    let n_tgt = masks.target.len();
171    let n_combined = n_ctxt + n_tgt;
172
173    let mut per_batch = Vec::with_capacity(batch * n_combined * pred);
174
175    for bi in 0..batch {
176        let enc_batch = &encoder_tokens[bi * seq * enc..(bi + 1) * seq * enc];
177        let ctxt = gather_rows(enc_batch, &masks.context, seq, enc);
178        let mut x = linear(
179            &ctxt,
180            n_ctxt,
181            enc,
182            &weights.embed_w_t,
183            pred,
184            &weights.embed_b,
185        )?;
186
187        let mask_idx = masks.mask_index % cfg.pred_num_mask_tokens;
188        let mask_vec = &weights.mask_tokens[mask_idx * pred..(mask_idx + 1) * pred];
189        let mut targets = vec![0f32; n_tgt * pred];
190        for ti in 0..n_tgt {
191            targets[ti * pred..(ti + 1) * pred].copy_from_slice(mask_vec);
192        }
193
194        x.extend_from_slice(&targets);
195
196        let mut position_ids: Vec<usize> = Vec::with_capacity(n_combined);
197        position_ids.extend_from_slice(&masks.context);
198        position_ids.extend_from_slice(&masks.target);
199
200        // Sort by patch index (argsort of position_ids).
201        let mut order: Vec<usize> = (0..n_combined).collect();
202        order.sort_by_key(|&i| position_ids[i]);
203        let mut sorted_pos = vec![0usize; n_combined];
204        let mut sorted_x = vec![0f32; n_combined * pred];
205        for (new_i, &old_i) in order.iter().enumerate() {
206            sorted_pos[new_i] = position_ids[old_i];
207            sorted_x[new_i * pred..(new_i + 1) * pred]
208                .copy_from_slice(&x[old_i * pred..(old_i + 1) * pred]);
209        }
210        x = sorted_x;
211        position_ids = sorted_pos;
212
213        for block in &weights.blocks {
214            block_forward(
215                &mut x,
216                block,
217                1,
218                n_combined,
219                pred,
220                nh,
221                head_dim,
222                d_dim,
223                h_dim,
224                w_dim,
225                grid_t,
226                grid_h,
227                grid_w,
228                eps,
229                Some(&position_ids),
230            )?;
231        }
232        x = layer_norm(&x, &weights.norm_w, &weights.norm_b, pred, eps)?;
233
234        // Unsort and take target slice.
235        let mut unsorted = vec![0f32; n_combined * pred];
236        for (new_i, &old_i) in order.iter().enumerate() {
237            unsorted[old_i * pred..(old_i + 1) * pred]
238                .copy_from_slice(&x[new_i * pred..(new_i + 1) * pred]);
239        }
240        let target_slice = &unsorted[n_ctxt * pred..];
241        let projected = linear(
242            target_slice,
243            n_tgt,
244            pred,
245            &weights.proj_w_t,
246            enc,
247            &weights.proj_b,
248        )?;
249        per_batch.extend_from_slice(&projected);
250    }
251
252    Ok(Vjepa2PredictorOutput {
253        tokens: per_batch,
254        num_target: n_tgt,
255        hidden: enc,
256    })
257}