1use anyhow::{Result, anyhow};
17use rlx_core::validate_standard_device;
18use rlx_flow::CompileProfile;
19use rlx_runtime::Device;
20use std::path::PathBuf;
21
22#[derive(Debug, Clone)]
24pub struct Vjepa2Output {
25 pub per_batch: Vec<Vec<f32>>,
26 pub seq: usize,
27 pub hidden: usize,
28}
29
30#[derive(Debug, Clone)]
32pub struct Vjepa2PredictOutput {
33 pub per_batch: Vec<Vec<f32>>,
34 pub num_target: usize,
35 pub hidden: usize,
36}
37
38#[derive(Debug, Clone)]
40pub struct Vjepa2PoolOutput {
41 pub embedding: Vec<f32>,
42 pub logits: Option<Vec<f32>>,
43}
44
45#[derive(Debug, Clone, Default)]
46pub struct Vjepa2RunnerBuilder {
47 weights: Option<PathBuf>,
48 config: Option<crate::Vjepa2Config>,
49 config_path: Option<PathBuf>,
50 batch: Option<usize>,
51 device: Option<Device>,
52 predictor_masks: Option<crate::Vjepa2Masks>,
54}
55
56impl Vjepa2RunnerBuilder {
57 pub fn weights<P: Into<PathBuf>>(mut self, p: P) -> Self {
58 self.weights = Some(p.into());
59 self
60 }
61 pub fn config(mut self, cfg: crate::Vjepa2Config) -> Self {
62 self.config = Some(cfg);
63 self
64 }
65 pub fn config_path<P: Into<PathBuf>>(mut self, p: P) -> Self {
66 self.config_path = Some(p.into());
67 self
68 }
69 pub fn batch(mut self, n: usize) -> Self {
70 self.batch = Some(n);
71 self
72 }
73 pub fn device(mut self, d: Device) -> Self {
75 self.device = Some(d);
76 self
77 }
78 pub fn predictor_masks(mut self, masks: crate::Vjepa2Masks) -> Self {
80 self.predictor_masks = Some(masks);
81 self
82 }
83
84 pub fn build(self) -> Result<Vjepa2Runner> {
85 use crate::{
86 Vjepa2Config, Vjepa2GraphParams, build_vjepa2_encoder_graph_sized,
87 build_vjepa2_pooler_graph_sized, build_vjepa2_predictor_graph_sized,
88 extract_model_weights, predictor_mask_rows, prepare_predictor_layout,
89 };
90 use rlx_runtime::Session;
91
92 let weights_path = self
93 .weights
94 .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?;
95 let cfg = match (self.config, self.config_path) {
96 (Some(c), _) => c,
97 (_, Some(p)) => Vjepa2Config::from_file(&p)?,
98 _ => Vjepa2Config::vit_g_384(),
99 };
100 let device = self.device.unwrap_or(Device::Cpu);
101 validate_standard_device("vjepa2", device)?;
102 let batch = self.batch.unwrap_or(1);
103
104 let mut wm = rlx_core::load_weight_map(&weights_path, rlx_core::VJEPA2_GGUF_ARCHES)?;
105 let model = extract_model_weights(&mut wm, &cfg)?;
106
107 let compiled = if self.device.is_some() {
108 let (graph, params, _pre) =
109 build_vjepa2_encoder_graph_sized(&cfg, &model.encoder, batch)?;
110 let opts = rlx_core::flow_bridge::compile_options_for_profile(
111 &CompileProfile::encoder(),
112 device,
113 );
114 let mut compiled = Session::new(device).compile_with(graph, &opts);
115 Vjepa2GraphParams::from_f32(params).load(&mut compiled);
116 Some(compiled)
117 } else {
118 None
119 };
120
121 let compiled_predictor = if self.device.is_some() {
122 if let (Some(pred), Some(masks)) = (&model.predictor, &self.predictor_masks) {
123 let layout = prepare_predictor_layout(&cfg, masks, batch)?;
124 let mask_rows = predictor_mask_rows(pred, &cfg, masks, batch);
125 let (graph, params) =
126 build_vjepa2_predictor_graph_sized(&cfg, pred, &layout, &mask_rows, batch)?;
127 let opts = rlx_core::flow_bridge::compile_options_for_profile(
128 &CompileProfile::encoder(),
129 device,
130 );
131 let mut compiled = Session::new(device).compile_with(graph, &opts);
132 params.load(&mut compiled);
133 Some((compiled, masks.clone()))
134 } else {
135 None
136 }
137 } else {
138 None
139 };
140
141 let compiled_pooler = if self.device.is_some() {
142 if let Some(pooler) = &model.pooler {
143 let (graph, params) = build_vjepa2_pooler_graph_sized(&cfg, pooler, batch)?;
144 let opts = rlx_core::flow_bridge::compile_options_for_profile(
145 &CompileProfile::encoder(),
146 device,
147 );
148 let mut compiled = Session::new(device).compile_with(graph, &opts);
149 params.load(&mut compiled);
150 Some(compiled)
151 } else {
152 None
153 }
154 } else {
155 None
156 };
157
158 Ok(Vjepa2Runner {
159 model,
160 cfg,
161 batch,
162 device,
163 compiled,
164 compiled_predictor,
165 compiled_pooler,
166 })
167 }
168}
169
170pub struct Vjepa2Runner {
172 model: crate::Vjepa2ModelWeights,
173 cfg: crate::Vjepa2Config,
174 batch: usize,
175 device: Device,
176 compiled: Option<rlx_runtime::CompiledGraph>,
177 compiled_predictor: Option<(rlx_runtime::CompiledGraph, crate::Vjepa2Masks)>,
178 compiled_pooler: Option<rlx_runtime::CompiledGraph>,
179}
180
181impl Vjepa2Runner {
182 pub fn builder() -> Vjepa2RunnerBuilder {
183 Vjepa2RunnerBuilder::default()
184 }
185 pub fn config(&self) -> &crate::Vjepa2Config {
186 &self.cfg
187 }
188 pub fn device(&self) -> Device {
189 self.device
190 }
191 pub fn has_predictor(&self) -> bool {
192 self.model.predictor.is_some()
193 }
194 pub fn has_pooler(&self) -> bool {
195 self.model.pooler.is_some()
196 }
197
198 fn encode_tokens_inner(&mut self, video_ncthw: &[f32]) -> Result<Vjepa2Output> {
199 use crate::{conv3d_patch_embed, encode_video_native};
200
201 let crop = self.cfg.crop_size;
202 let frames = self.cfg.frames_per_clip;
203 let expected = 3 * frames * crop * crop;
204 anyhow::ensure!(
205 video_ncthw.len() == expected,
206 "expected {expected} f32 values for NCTHW video, got {}",
207 video_ncthw.len()
208 );
209
210 let out = if let Some(compiled) = self.compiled.as_mut() {
211 let patch = &self.model.encoder.patch;
212 let mut hidden = conv3d_patch_embed(patch, video_ncthw, frames, crop, crop)?;
213 if self.batch > 1 {
214 let per = hidden.len();
215 let mut batched = Vec::with_capacity(per * self.batch);
216 for _ in 0..self.batch {
217 batched.extend_from_slice(&hidden);
218 }
219 hidden = batched;
220 }
221 let flat = compiled
222 .run(&[("hidden", hidden.as_slice())])
223 .into_iter()
224 .next()
225 .ok_or_else(|| anyhow!("vjepa2 graph forward returned no output"))?;
226 crate::Vjepa2EncoderOutput {
227 tokens: flat,
228 seq: self.cfg.num_patches(),
229 hidden: self.cfg.hidden_size,
230 }
231 } else {
232 encode_video_native(&self.model.encoder, &self.cfg, video_ncthw, self.batch)?
233 };
234
235 let per = out.seq * out.hidden;
236 let mut per_batch = Vec::with_capacity(self.batch);
237 for b in 0..self.batch {
238 per_batch.push(out.tokens[b * per..(b + 1) * per].to_vec());
239 }
240 Ok(Vjepa2Output {
241 per_batch,
242 seq: out.seq,
243 hidden: out.hidden,
244 })
245 }
246
247 pub fn encode_video(&mut self, video_ncthw: &[f32]) -> Result<Vjepa2Output> {
249 self.encode_tokens_inner(video_ncthw)
250 }
251
252 pub fn encode_video_hwc(&mut self, frames: &[u8]) -> Result<Vjepa2Output> {
254 use crate::normalize_video_hwc;
255
256 let crop = self.cfg.crop_size;
257 let nframes = self.cfg.frames_per_clip;
258 let expected = nframes * crop * crop * 3;
259 anyhow::ensure!(
260 frames.len() == expected,
261 "expected {expected} u8 pixels HWC, got {}",
262 frames.len()
263 );
264 let ncthw = normalize_video_hwc(frames, nframes, crop);
265 self.encode_video(&ncthw)
266 }
267
268 pub fn predict(
270 &mut self,
271 enc: &Vjepa2Output,
272 masks: &crate::Vjepa2Masks,
273 ) -> Result<Vjepa2PredictOutput> {
274 use crate::predict_native;
275
276 let pred = self
277 .model
278 .predictor
279 .as_ref()
280 .ok_or_else(|| anyhow!("checkpoint has no predictor weights"))?;
281 let mut flat = Vec::with_capacity(enc.per_batch.len() * enc.seq * enc.hidden);
282 for batch in &enc.per_batch {
283 flat.extend_from_slice(batch);
284 }
285
286 let out = if let Some((compiled, cached_masks)) = self.compiled_predictor.as_mut() {
287 if cached_masks == masks {
288 let mut outputs = compiled.run(&[("encoder", flat.as_slice())]);
289 let tokens = outputs
290 .pop()
291 .ok_or_else(|| anyhow!("vjepa2 predictor graph returned no output"))?;
292 let num_target = masks.target.len();
293 crate::Vjepa2PredictorOutput {
294 tokens,
295 num_target,
296 hidden: enc.hidden,
297 }
298 } else {
299 predict_native(&flat, pred, &self.cfg, self.batch, enc.seq, masks)?
300 }
301 } else {
302 predict_native(&flat, pred, &self.cfg, self.batch, enc.seq, masks)?
303 };
304 let per = out.num_target * out.hidden;
305 let mut per_batch = Vec::with_capacity(self.batch);
306 for b in 0..self.batch {
307 per_batch.push(out.tokens[b * per..(b + 1) * per].to_vec());
308 }
309 Ok(Vjepa2PredictOutput {
310 per_batch,
311 num_target: out.num_target,
312 hidden: out.hidden,
313 })
314 }
315
316 pub fn pool(&self, enc: &Vjepa2Output) -> Result<Vjepa2PoolOutput> {
318 use crate::pool_native;
319
320 let pooler = self
321 .model
322 .pooler
323 .as_ref()
324 .ok_or_else(|| anyhow!("checkpoint has no pooler weights"))?;
325 let mut flat = Vec::with_capacity(enc.per_batch.len() * enc.seq * enc.hidden);
326 for batch in &enc.per_batch {
327 flat.extend_from_slice(batch);
328 }
329
330 let out = if let Some(compiled) = &self.compiled_pooler {
331 let mut compiled = compiled.clone();
332 let mut outputs = compiled.run(&[("encoder", flat.as_slice())]);
333 anyhow::ensure!(
334 !outputs.is_empty(),
335 "vjepa2 pooler graph returned no embedding"
336 );
337 let embedding = outputs.remove(0);
338 let logits = outputs.pop();
339 crate::Vjepa2PoolerOutput { embedding, logits }
340 } else {
341 pool_native(&flat, pooler, &self.cfg, self.batch, enc.seq)?
342 };
343 Ok(Vjepa2PoolOutput {
344 embedding: out.embedding,
345 logits: out.logits,
346 })
347 }
348}