1use super::config::Vjepa2Config;
19use super::layers::{block_forward, gather_rows};
20use super::weights::Vjepa2PredictorWeights;
21use anyhow::{Result, ensure};
22use rlx_tensor::{layer_norm, linear};
23
24#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct Vjepa2Masks {
27 pub context: Vec<usize>,
28 pub target: Vec<usize>,
29 pub mask_index: usize,
31}
32
33pub struct Vjepa2PredictorOutput {
34 pub tokens: Vec<f32>,
35 pub num_target: usize,
36 pub hidden: usize,
37}
38
39#[derive(Debug, Clone)]
41pub struct Vjepa2PredictorLayout {
42 pub n_ctxt: usize,
43 pub n_tgt: usize,
44 pub n_combined: usize,
45 pub ctxt_idx: Vec<i64>,
47 pub sort_idx: Vec<i64>,
49 pub unsort_idx: Vec<i64>,
51 pub mask_rows: Vec<f32>,
53 pub rope_cos: Vec<f32>,
55 pub rope_sin: Vec<f32>,
56}
57
58pub fn prepare_predictor_layout(
60 cfg: &Vjepa2Config,
61 masks: &Vjepa2Masks,
62 batch: usize,
63) -> Result<Vjepa2PredictorLayout> {
64 use super::rope::build_vjepa2_rope_tables;
65
66 ensure!(!masks.context.is_empty(), "context mask must be non-empty");
67 ensure!(!masks.target.is_empty(), "target mask must be non-empty");
68
69 let pred = cfg.pred_hidden_size;
70 let pred_dh = cfg.pred_head_dim();
71 let (d_dim, h_dim, w_dim) = cfg.pred_rope_segment_dims();
72 let grid_h = cfg.grid_spatial();
73 let grid_w = cfg.grid_spatial();
74 let enc_seq = cfg.num_patches();
75
76 let n_ctxt = masks.context.len();
77 let n_tgt = masks.target.len();
78 let n_combined = n_ctxt + n_tgt;
79
80 let mut position_ids: Vec<usize> = Vec::with_capacity(n_combined);
81 position_ids.extend_from_slice(&masks.context);
82 position_ids.extend_from_slice(&masks.target);
83
84 let mut order: Vec<usize> = (0..n_combined).collect();
85 order.sort_by_key(|&i| position_ids[i]);
86
87 let mut sort_idx = vec![0i64; n_combined];
88 let mut unsort_idx = vec![0i64; n_combined];
89 for (new_i, &old_i) in order.iter().enumerate() {
90 sort_idx[new_i] = old_i as i64;
91 unsort_idx[old_i] = new_i as i64;
92 }
93
94 let sorted_pos: Vec<usize> = order.iter().map(|&i| position_ids[i]).collect();
95 let (full_cos, full_sin) =
96 build_vjepa2_rope_tables(enc_seq, pred_dh, d_dim, h_dim, w_dim, grid_h, grid_w);
97 let half = pred_dh / 2;
98 let mut rope_cos = vec![0f32; n_combined * half];
99 let mut rope_sin = vec![0f32; n_combined * half];
100 for (i, &p) in sorted_pos.iter().enumerate() {
101 rope_cos[i * half..(i + 1) * half].copy_from_slice(&full_cos[p * half..(p + 1) * half]);
102 rope_sin[i * half..(i + 1) * half].copy_from_slice(&full_sin[p * half..(p + 1) * half]);
103 }
104
105 let mut ctxt_idx = Vec::with_capacity(batch * n_ctxt);
106 let mut sort_flat = Vec::with_capacity(batch * n_combined);
107 let mut unsort_flat = Vec::with_capacity(batch * n_combined);
108 for _ in 0..batch {
109 ctxt_idx.extend(masks.context.iter().map(|&i| i as i64));
110 sort_flat.extend_from_slice(&sort_idx);
111 unsort_flat.extend_from_slice(&unsort_idx);
112 }
113
114 Ok(Vjepa2PredictorLayout {
115 n_ctxt,
116 n_tgt,
117 n_combined,
118 ctxt_idx,
119 sort_idx: sort_flat,
120 unsort_idx: unsort_flat,
121 mask_rows: vec![0f32; batch * n_tgt * pred],
122 rope_cos,
123 rope_sin,
124 })
125}
126
127pub fn predictor_mask_rows(
129 weights: &super::weights::Vjepa2PredictorWeights,
130 cfg: &Vjepa2Config,
131 masks: &Vjepa2Masks,
132 batch: usize,
133) -> Vec<f32> {
134 let pred = cfg.pred_hidden_size;
135 let n_tgt = masks.target.len();
136 let mask_idx = masks.mask_index % cfg.pred_num_mask_tokens;
137 let mask_vec = &weights.mask_tokens[mask_idx * pred..(mask_idx + 1) * pred];
138 let mut rows = Vec::with_capacity(batch * n_tgt * pred);
139 for _ in 0..batch {
140 for _ in 0..n_tgt {
141 rows.extend_from_slice(mask_vec);
142 }
143 }
144 rows
145}
146
147pub fn predict_native(
149 encoder_tokens: &[f32],
150 weights: &Vjepa2PredictorWeights,
151 cfg: &Vjepa2Config,
152 batch: usize,
153 seq: usize,
154 masks: &Vjepa2Masks,
155) -> Result<Vjepa2PredictorOutput> {
156 let enc = cfg.hidden_size;
157 let pred = cfg.pred_hidden_size;
158 let nh = cfg.pred_num_attention_heads;
159 let head_dim = cfg.pred_head_dim();
160 let (d_dim, h_dim, w_dim) = cfg.pred_rope_segment_dims();
161 let grid_t = cfg.grid_temporal();
162 let grid_h = cfg.grid_spatial();
163 let grid_w = cfg.grid_spatial();
164 let eps = cfg.layer_norm_eps as f32;
165
166 ensure!(!masks.context.is_empty(), "context mask must be non-empty");
167 ensure!(!masks.target.is_empty(), "target mask must be non-empty");
168
169 let n_ctxt = masks.context.len();
170 let n_tgt = masks.target.len();
171 let n_combined = n_ctxt + n_tgt;
172
173 let mut per_batch = Vec::with_capacity(batch * n_combined * pred);
174
175 for bi in 0..batch {
176 let enc_batch = &encoder_tokens[bi * seq * enc..(bi + 1) * seq * enc];
177 let ctxt = gather_rows(enc_batch, &masks.context, seq, enc);
178 let mut x = linear(
179 &ctxt,
180 n_ctxt,
181 enc,
182 &weights.embed_w_t,
183 pred,
184 &weights.embed_b,
185 )?;
186
187 let mask_idx = masks.mask_index % cfg.pred_num_mask_tokens;
188 let mask_vec = &weights.mask_tokens[mask_idx * pred..(mask_idx + 1) * pred];
189 let mut targets = vec![0f32; n_tgt * pred];
190 for ti in 0..n_tgt {
191 targets[ti * pred..(ti + 1) * pred].copy_from_slice(mask_vec);
192 }
193
194 x.extend_from_slice(&targets);
195
196 let mut position_ids: Vec<usize> = Vec::with_capacity(n_combined);
197 position_ids.extend_from_slice(&masks.context);
198 position_ids.extend_from_slice(&masks.target);
199
200 let mut order: Vec<usize> = (0..n_combined).collect();
202 order.sort_by_key(|&i| position_ids[i]);
203 let mut sorted_pos = vec![0usize; n_combined];
204 let mut sorted_x = vec![0f32; n_combined * pred];
205 for (new_i, &old_i) in order.iter().enumerate() {
206 sorted_pos[new_i] = position_ids[old_i];
207 sorted_x[new_i * pred..(new_i + 1) * pred]
208 .copy_from_slice(&x[old_i * pred..(old_i + 1) * pred]);
209 }
210 x = sorted_x;
211 position_ids = sorted_pos;
212
213 for block in &weights.blocks {
214 block_forward(
215 &mut x,
216 block,
217 1,
218 n_combined,
219 pred,
220 nh,
221 head_dim,
222 d_dim,
223 h_dim,
224 w_dim,
225 grid_t,
226 grid_h,
227 grid_w,
228 eps,
229 Some(&position_ids),
230 )?;
231 }
232 x = layer_norm(&x, &weights.norm_w, &weights.norm_b, pred, eps)?;
233
234 let mut unsorted = vec![0f32; n_combined * pred];
236 for (new_i, &old_i) in order.iter().enumerate() {
237 unsorted[old_i * pred..(old_i + 1) * pred]
238 .copy_from_slice(&x[new_i * pred..(new_i + 1) * pred]);
239 }
240 let target_slice = &unsorted[n_ctxt * pred..];
241 let projected = linear(
242 target_slice,
243 n_tgt,
244 pred,
245 &weights.proj_w_t,
246 enc,
247 &weights.proj_b,
248 )?;
249 per_batch.extend_from_slice(&projected);
250 }
251
252 Ok(Vjepa2PredictorOutput {
253 tokens: per_batch,
254 num_target: n_tgt,
255 hidden: enc,
256 })
257}