Skip to main content

rlx_vjepa2/
encoder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Native CPU forward for the V-JEPA2 video encoder.
17
18use super::config::Vjepa2Config;
19use super::layers::block_forward;
20use super::preprocess::conv3d_patch_embed;
21use super::weights::Vjepa2EncoderWeights;
22use anyhow::Result;
23use rlx_tensor::layer_norm;
24
25pub struct Vjepa2EncoderOutput {
26    pub tokens: Vec<f32>,
27    pub seq: usize,
28    pub hidden: usize,
29}
30
31/// Encode a pre-normalized video tensor `[C, T, H, W]`.
32pub fn encode_video_native(
33    weights: &Vjepa2EncoderWeights,
34    cfg: &Vjepa2Config,
35    video_ncthw: &[f32],
36    batch: usize,
37) -> Result<Vjepa2EncoderOutput> {
38    encode_video_native_ext(weights, cfg, video_ncthw, batch, None)
39}
40
41/// Like [`encode_video_native`], but stop after transformer block `stop_after_block`
42/// (inclusive). Skips final layer norm unless stopping at the last block.
43pub fn encode_video_native_ext(
44    weights: &Vjepa2EncoderWeights,
45    cfg: &Vjepa2Config,
46    video_ncthw: &[f32],
47    batch: usize,
48    stop_after_block: Option<usize>,
49) -> Result<Vjepa2EncoderOutput> {
50    let e = cfg.hidden_size;
51    let frames = cfg.frames_per_clip;
52    let crop = cfg.crop_size;
53    let seq = cfg.num_patches();
54    let head_dim = cfg.head_dim();
55    let nh = cfg.num_attention_heads;
56    let (d_dim, h_dim, w_dim) = cfg.rope_segment_dims();
57    let grid_t = cfg.grid_temporal();
58    let grid_h = cfg.grid_spatial();
59    let grid_w = cfg.grid_spatial();
60    let eps = cfg.layer_norm_eps as f32;
61
62    let mut x = conv3d_patch_embed(&weights.patch, video_ncthw, frames, crop, crop)?;
63    if batch > 1 {
64        let per = x.len();
65        let mut batched = Vec::with_capacity(per * batch);
66        for _ in 0..batch {
67            batched.extend_from_slice(&x);
68        }
69        x = batched;
70    }
71
72    let last_block = weights.blocks.len().saturating_sub(1);
73    for (i, block) in weights.blocks.iter().enumerate() {
74        block_forward(
75            &mut x, block, batch, seq, e, nh, head_dim, d_dim, h_dim, w_dim, grid_t, grid_h,
76            grid_w, eps, None,
77        )?;
78        if stop_after_block == Some(i) {
79            break;
80        }
81    }
82
83    if stop_after_block.is_none() || stop_after_block == Some(last_block) {
84        x = layer_norm(&x, &weights.norm_w, &weights.norm_b, e, eps)?;
85    }
86
87    Ok(Vjepa2EncoderOutput {
88        tokens: x,
89        seq,
90        hidden: e,
91    })
92}