Skip to main content

rlx_vjepa2/
runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use anyhow::{Result, anyhow};
17use rlx_core::validate_standard_device;
18use rlx_flow::CompileProfile;
19use rlx_runtime::Device;
20use std::path::PathBuf;
21
22/// Encoder token output from [`Vjepa2Runner::encode_video`].
23#[derive(Debug, Clone)]
24pub struct Vjepa2Output {
25    pub per_batch: Vec<Vec<f32>>,
26    pub seq: usize,
27    pub hidden: usize,
28}
29
30/// Predictor output (projected target tokens).
31#[derive(Debug, Clone)]
32pub struct Vjepa2PredictOutput {
33    pub per_batch: Vec<Vec<f32>>,
34    pub num_target: usize,
35    pub hidden: usize,
36}
37
38/// Attentive pooler output (+ optional classifier logits).
39#[derive(Debug, Clone)]
40pub struct Vjepa2PoolOutput {
41    pub embedding: Vec<f32>,
42    pub logits: Option<Vec<f32>>,
43}
44
45#[derive(Debug, Clone, Default)]
46pub struct Vjepa2RunnerBuilder {
47    weights: Option<PathBuf>,
48    config: Option<crate::Vjepa2Config>,
49    config_path: Option<PathBuf>,
50    batch: Option<usize>,
51    device: Option<Device>,
52    /// Fixed context/target masks for compiled predictor graphs.
53    predictor_masks: Option<crate::Vjepa2Masks>,
54}
55
56impl Vjepa2RunnerBuilder {
57    pub fn weights<P: Into<PathBuf>>(mut self, p: P) -> Self {
58        self.weights = Some(p.into());
59        self
60    }
61    pub fn config(mut self, cfg: crate::Vjepa2Config) -> Self {
62        self.config = Some(cfg);
63        self
64    }
65    pub fn config_path<P: Into<PathBuf>>(mut self, p: P) -> Self {
66        self.config_path = Some(p.into());
67        self
68    }
69    pub fn batch(mut self, n: usize) -> Self {
70        self.batch = Some(n);
71        self
72    }
73    /// When set, the encoder trunk runs via compiled IR (CPU / Metal / …).
74    pub fn device(mut self, d: Device) -> Self {
75        self.device = Some(d);
76        self
77    }
78    /// Context/target masks baked into the compiled predictor graph.
79    pub fn predictor_masks(mut self, masks: crate::Vjepa2Masks) -> Self {
80        self.predictor_masks = Some(masks);
81        self
82    }
83
84    pub fn build(self) -> Result<Vjepa2Runner> {
85        use crate::{
86            Vjepa2Config, Vjepa2GraphParams, build_vjepa2_encoder_graph_sized,
87            build_vjepa2_pooler_graph_sized, build_vjepa2_predictor_graph_sized,
88            extract_model_weights, predictor_mask_rows, prepare_predictor_layout,
89        };
90        use rlx_runtime::Session;
91
92        let weights_path = self
93            .weights
94            .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?;
95        let cfg = match (self.config, self.config_path) {
96            (Some(c), _) => c,
97            (_, Some(p)) => Vjepa2Config::from_file(&p)?,
98            _ => Vjepa2Config::vit_g_384(),
99        };
100        let device = self.device.unwrap_or(Device::Cpu);
101        validate_standard_device("vjepa2", device)?;
102        let batch = self.batch.unwrap_or(1);
103
104        let mut wm = rlx_core::load_weight_map(&weights_path, rlx_core::VJEPA2_GGUF_ARCHES)?;
105        let model = extract_model_weights(&mut wm, &cfg)?;
106
107        let compiled = if self.device.is_some() {
108            let (graph, params, _pre) =
109                build_vjepa2_encoder_graph_sized(&cfg, &model.encoder, batch)?;
110            let opts = rlx_core::flow_bridge::compile_options_for_profile(
111                &CompileProfile::encoder(),
112                device,
113            );
114            let mut compiled = Session::new(device).compile_with(graph, &opts);
115            Vjepa2GraphParams::from_f32(params).load(&mut compiled);
116            Some(compiled)
117        } else {
118            None
119        };
120
121        let compiled_predictor = if self.device.is_some() {
122            if let (Some(pred), Some(masks)) = (&model.predictor, &self.predictor_masks) {
123                let layout = prepare_predictor_layout(&cfg, masks, batch)?;
124                let mask_rows = predictor_mask_rows(pred, &cfg, masks, batch);
125                let (graph, params) =
126                    build_vjepa2_predictor_graph_sized(&cfg, pred, &layout, &mask_rows, batch)?;
127                let opts = rlx_core::flow_bridge::compile_options_for_profile(
128                    &CompileProfile::encoder(),
129                    device,
130                );
131                let mut compiled = Session::new(device).compile_with(graph, &opts);
132                params.load(&mut compiled);
133                Some((compiled, masks.clone()))
134            } else {
135                None
136            }
137        } else {
138            None
139        };
140
141        let compiled_pooler = if self.device.is_some() {
142            if let Some(pooler) = &model.pooler {
143                let (graph, params) = build_vjepa2_pooler_graph_sized(&cfg, pooler, batch)?;
144                let opts = rlx_core::flow_bridge::compile_options_for_profile(
145                    &CompileProfile::encoder(),
146                    device,
147                );
148                let mut compiled = Session::new(device).compile_with(graph, &opts);
149                params.load(&mut compiled);
150                Some(compiled)
151            } else {
152                None
153            }
154        } else {
155            None
156        };
157
158        Ok(Vjepa2Runner {
159            model,
160            cfg,
161            batch,
162            device,
163            compiled,
164            compiled_predictor,
165            compiled_pooler,
166        })
167    }
168}
169
170/// V-JEPA2 runner — encoder (+ optional predictor / pooler).
171pub struct Vjepa2Runner {
172    model: crate::Vjepa2ModelWeights,
173    cfg: crate::Vjepa2Config,
174    batch: usize,
175    device: Device,
176    compiled: Option<rlx_runtime::CompiledGraph>,
177    compiled_predictor: Option<(rlx_runtime::CompiledGraph, crate::Vjepa2Masks)>,
178    compiled_pooler: Option<rlx_runtime::CompiledGraph>,
179}
180
181impl Vjepa2Runner {
182    pub fn builder() -> Vjepa2RunnerBuilder {
183        Vjepa2RunnerBuilder::default()
184    }
185    pub fn config(&self) -> &crate::Vjepa2Config {
186        &self.cfg
187    }
188    pub fn device(&self) -> Device {
189        self.device
190    }
191    pub fn has_predictor(&self) -> bool {
192        self.model.predictor.is_some()
193    }
194    pub fn has_pooler(&self) -> bool {
195        self.model.pooler.is_some()
196    }
197
198    fn encode_tokens_inner(&mut self, video_ncthw: &[f32]) -> Result<Vjepa2Output> {
199        use crate::{conv3d_patch_embed, encode_video_native};
200
201        let crop = self.cfg.crop_size;
202        let frames = self.cfg.frames_per_clip;
203        let expected = 3 * frames * crop * crop;
204        anyhow::ensure!(
205            video_ncthw.len() == expected,
206            "expected {expected} f32 values for NCTHW video, got {}",
207            video_ncthw.len()
208        );
209
210        let out = if let Some(compiled) = self.compiled.as_mut() {
211            let patch = &self.model.encoder.patch;
212            let mut hidden = conv3d_patch_embed(patch, video_ncthw, frames, crop, crop)?;
213            if self.batch > 1 {
214                let per = hidden.len();
215                let mut batched = Vec::with_capacity(per * self.batch);
216                for _ in 0..self.batch {
217                    batched.extend_from_slice(&hidden);
218                }
219                hidden = batched;
220            }
221            let flat = compiled
222                .run(&[("hidden", hidden.as_slice())])
223                .into_iter()
224                .next()
225                .ok_or_else(|| anyhow!("vjepa2 graph forward returned no output"))?;
226            crate::Vjepa2EncoderOutput {
227                tokens: flat,
228                seq: self.cfg.num_patches(),
229                hidden: self.cfg.hidden_size,
230            }
231        } else {
232            encode_video_native(&self.model.encoder, &self.cfg, video_ncthw, self.batch)?
233        };
234
235        let per = out.seq * out.hidden;
236        let mut per_batch = Vec::with_capacity(self.batch);
237        for b in 0..self.batch {
238            per_batch.push(out.tokens[b * per..(b + 1) * per].to_vec());
239        }
240        Ok(Vjepa2Output {
241            per_batch,
242            seq: out.seq,
243            hidden: out.hidden,
244        })
245    }
246
247    /// Encode a pre-normalized video tensor `[C, T, H, W]` (NCTHW f32).
248    pub fn encode_video(&mut self, video_ncthw: &[f32]) -> Result<Vjepa2Output> {
249        self.encode_tokens_inner(video_ncthw)
250    }
251
252    /// Convenience: u8 HWC frames `[num_frames, crop, crop, 3]` → encode.
253    pub fn encode_video_hwc(&mut self, frames: &[u8]) -> Result<Vjepa2Output> {
254        use crate::normalize_video_hwc;
255
256        let crop = self.cfg.crop_size;
257        let nframes = self.cfg.frames_per_clip;
258        let expected = nframes * crop * crop * 3;
259        anyhow::ensure!(
260            frames.len() == expected,
261            "expected {expected} u8 pixels HWC, got {}",
262            frames.len()
263        );
264        let ncthw = normalize_video_hwc(frames, nframes, crop);
265        self.encode_video(&ncthw)
266    }
267
268    /// Run the JEPA predictor on encoder outputs with context/target masks.
269    pub fn predict(
270        &mut self,
271        enc: &Vjepa2Output,
272        masks: &crate::Vjepa2Masks,
273    ) -> Result<Vjepa2PredictOutput> {
274        use crate::predict_native;
275
276        let pred = self
277            .model
278            .predictor
279            .as_ref()
280            .ok_or_else(|| anyhow!("checkpoint has no predictor weights"))?;
281        let mut flat = Vec::with_capacity(enc.per_batch.len() * enc.seq * enc.hidden);
282        for batch in &enc.per_batch {
283            flat.extend_from_slice(batch);
284        }
285
286        let out = if let Some((compiled, cached_masks)) = self.compiled_predictor.as_mut() {
287            if cached_masks == masks {
288                let mut outputs = compiled.run(&[("encoder", flat.as_slice())]);
289                let tokens = outputs
290                    .pop()
291                    .ok_or_else(|| anyhow!("vjepa2 predictor graph returned no output"))?;
292                let num_target = masks.target.len();
293                crate::Vjepa2PredictorOutput {
294                    tokens,
295                    num_target,
296                    hidden: enc.hidden,
297                }
298            } else {
299                predict_native(&flat, pred, &self.cfg, self.batch, enc.seq, masks)?
300            }
301        } else {
302            predict_native(&flat, pred, &self.cfg, self.batch, enc.seq, masks)?
303        };
304        let per = out.num_target * out.hidden;
305        let mut per_batch = Vec::with_capacity(self.batch);
306        for b in 0..self.batch {
307            per_batch.push(out.tokens[b * per..(b + 1) * per].to_vec());
308        }
309        Ok(Vjepa2PredictOutput {
310            per_batch,
311            num_target: out.num_target,
312            hidden: out.hidden,
313        })
314    }
315
316    /// Attentive pooler (+ classifier when present) on encoder tokens.
317    pub fn pool(&self, enc: &Vjepa2Output) -> Result<Vjepa2PoolOutput> {
318        use crate::pool_native;
319
320        let pooler = self
321            .model
322            .pooler
323            .as_ref()
324            .ok_or_else(|| anyhow!("checkpoint has no pooler weights"))?;
325        let mut flat = Vec::with_capacity(enc.per_batch.len() * enc.seq * enc.hidden);
326        for batch in &enc.per_batch {
327            flat.extend_from_slice(batch);
328        }
329
330        let out = if let Some(compiled) = &self.compiled_pooler {
331            let mut compiled = compiled.clone();
332            let mut outputs = compiled.run(&[("encoder", flat.as_slice())]);
333            anyhow::ensure!(
334                !outputs.is_empty(),
335                "vjepa2 pooler graph returned no embedding"
336            );
337            let embedding = outputs.remove(0);
338            let logits = outputs.pop();
339            crate::Vjepa2PoolerOutput { embedding, logits }
340        } else {
341            pool_native(&flat, pooler, &self.cfg, self.batch, enc.seq)?
342        };
343        Ok(Vjepa2PoolOutput {
344            embedding: out.embedding,
345            logits: out.logits,
346        })
347    }
348}