Skip to main content

rlx_vjepa2/
preprocess.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Host-side 3-D patch embedding and video normalization for V-JEPA2.
17
18use super::config::{IMAGENET_MEAN, IMAGENET_STD, Vjepa2Config};
19use anyhow::{Result, ensure};
20use rlx_core::weight_map::WeightMap;
21
22#[derive(Clone)]
23pub struct Vjepa2PatchEmbedWeights {
24    /// Conv3d kernel `[embed_dim * in_chans * t * ph * pw]` in PyTorch order
25    /// `[oc, ic, kt, kh, kw]` flattened.
26    pub proj_w: Vec<f32>,
27    pub proj_b: Vec<f32>,
28    pub embed_dim: usize,
29    pub in_chans: usize,
30    pub tubelet_size: usize,
31    pub patch_size: usize,
32}
33
34pub fn extract_patch_embed_weights(
35    weights: &mut WeightMap,
36    cfg: &Vjepa2Config,
37) -> Result<Vjepa2PatchEmbedWeights> {
38    let e = cfg.hidden_size;
39    let c = cfg.in_chans;
40    let ts = cfg.tubelet_size;
41    let ps = cfg.patch_size;
42    let expected = vec![e, c, ts, ps, ps];
43
44    let w_keys = [
45        "encoder.embeddings.patch_embeddings.proj.weight",
46        "patch_embed.proj.weight",
47    ];
48    let b_keys = [
49        "encoder.embeddings.patch_embeddings.proj.bias",
50        "patch_embed.proj.bias",
51    ];
52
53    let (proj_w, shape) = take_first(weights, &w_keys)?;
54    ensure!(
55        shape == expected,
56        "patch embed weight expected {expected:?}, got {shape:?}"
57    );
58    let (proj_b, bshape) = take_first(weights, &b_keys)?;
59    ensure!(bshape == vec![e], "patch embed bias expected [{e}]");
60
61    Ok(Vjepa2PatchEmbedWeights {
62        proj_w,
63        proj_b,
64        embed_dim: e,
65        in_chans: c,
66        tubelet_size: ts,
67        patch_size: ps,
68    })
69}
70
71/// Normalize RGB u8 frames to NCTHW f32 in `[0,1]` then ImageNet stats.
72/// `frames` is `[num_frames, crop, crop, 3]` HWC u8 row-major.
73pub fn normalize_video_hwc(frames: &[u8], num_frames: usize, crop: usize) -> Vec<f32> {
74    let plane = crop * crop;
75    let mut out = vec![0f32; 3 * num_frames * plane];
76    for t in 0..num_frames {
77        for y in 0..crop {
78            for x in 0..crop {
79                let src = (t * plane + y * crop + x) * 3;
80                for c in 0..3 {
81                    let v = frames[src + c] as f32 / 255.0;
82                    let norm = (v - IMAGENET_MEAN[c]) / IMAGENET_STD[c];
83                    out[c * num_frames * plane + t * plane + y * crop + x] = norm;
84                }
85            }
86        }
87    }
88    out
89}
90
91/// 3-D conv patch embedding: input `[C, T, H, W]` → tokens `[seq, embed_dim]`.
92pub fn conv3d_patch_embed(
93    patch: &Vjepa2PatchEmbedWeights,
94    video_ncthw: &[f32],
95    frames: usize,
96    height: usize,
97    width: usize,
98) -> Result<Vec<f32>> {
99    let c = patch.in_chans;
100    let ts = patch.tubelet_size;
101    let ps = patch.patch_size;
102    let e = patch.embed_dim;
103    ensure!(
104        video_ncthw.len() == c * frames * height * width,
105        "video tensor size mismatch"
106    );
107    ensure!(frames.is_multiple_of(ts) && height.is_multiple_of(ps) && width.is_multiple_of(ps));
108
109    let t_out = frames / ts;
110    let h_out = height / ps;
111    let w_out = width / ps;
112    let seq = t_out * h_out * w_out;
113    let mut tokens = vec![0f32; seq * e];
114
115    let plane = height * width;
116    let vol = frames * plane;
117
118    for ot in 0..t_out {
119        for oh in 0..h_out {
120            for ow in 0..w_out {
121                let tok = (ot * h_out + oh) * w_out + ow;
122                for oc in 0..e {
123                    let mut acc = patch.proj_b[oc];
124                    for ic in 0..c {
125                        for kt in 0..ts {
126                            for kh in 0..ps {
127                                for kw in 0..ps {
128                                    let it = ot * ts + kt;
129                                    let ih = oh * ps + kh;
130                                    let iw = ow * ps + kw;
131                                    let in_idx = ic * vol + it * plane + ih * width + iw;
132                                    let w_idx = oc * (c * ts * ps * ps)
133                                        + ic * (ts * ps * ps)
134                                        + kt * (ps * ps)
135                                        + kh * ps
136                                        + kw;
137                                    acc += patch.proj_w[w_idx] * video_ncthw[in_idx];
138                                }
139                            }
140                        }
141                    }
142                    tokens[tok * e + oc] = acc;
143                }
144            }
145        }
146    }
147    Ok(tokens)
148}
149
150fn take_first(weights: &mut WeightMap, keys: &[&str]) -> Result<(Vec<f32>, Vec<usize>)> {
151    for key in keys {
152        if weights.has(key) {
153            return weights.take(key);
154        }
155    }
156    anyhow::bail!("none of the patch-embed keys found: {keys:?}")
157}