1use super::config::{IMAGENET_MEAN, IMAGENET_STD, Vjepa2Config};
19use anyhow::{Result, ensure};
20use rlx_core::weight_map::WeightMap;
21
22#[derive(Clone)]
23pub struct Vjepa2PatchEmbedWeights {
24 pub proj_w: Vec<f32>,
27 pub proj_b: Vec<f32>,
28 pub embed_dim: usize,
29 pub in_chans: usize,
30 pub tubelet_size: usize,
31 pub patch_size: usize,
32}
33
34pub fn extract_patch_embed_weights(
35 weights: &mut WeightMap,
36 cfg: &Vjepa2Config,
37) -> Result<Vjepa2PatchEmbedWeights> {
38 let e = cfg.hidden_size;
39 let c = cfg.in_chans;
40 let ts = cfg.tubelet_size;
41 let ps = cfg.patch_size;
42 let expected = vec![e, c, ts, ps, ps];
43
44 let w_keys = [
45 "encoder.embeddings.patch_embeddings.proj.weight",
46 "patch_embed.proj.weight",
47 ];
48 let b_keys = [
49 "encoder.embeddings.patch_embeddings.proj.bias",
50 "patch_embed.proj.bias",
51 ];
52
53 let (proj_w, shape) = take_first(weights, &w_keys)?;
54 ensure!(
55 shape == expected,
56 "patch embed weight expected {expected:?}, got {shape:?}"
57 );
58 let (proj_b, bshape) = take_first(weights, &b_keys)?;
59 ensure!(bshape == vec![e], "patch embed bias expected [{e}]");
60
61 Ok(Vjepa2PatchEmbedWeights {
62 proj_w,
63 proj_b,
64 embed_dim: e,
65 in_chans: c,
66 tubelet_size: ts,
67 patch_size: ps,
68 })
69}
70
71pub fn normalize_video_hwc(frames: &[u8], num_frames: usize, crop: usize) -> Vec<f32> {
74 let plane = crop * crop;
75 let mut out = vec![0f32; 3 * num_frames * plane];
76 for t in 0..num_frames {
77 for y in 0..crop {
78 for x in 0..crop {
79 let src = (t * plane + y * crop + x) * 3;
80 for c in 0..3 {
81 let v = frames[src + c] as f32 / 255.0;
82 let norm = (v - IMAGENET_MEAN[c]) / IMAGENET_STD[c];
83 out[c * num_frames * plane + t * plane + y * crop + x] = norm;
84 }
85 }
86 }
87 }
88 out
89}
90
91pub fn conv3d_patch_embed(
93 patch: &Vjepa2PatchEmbedWeights,
94 video_ncthw: &[f32],
95 frames: usize,
96 height: usize,
97 width: usize,
98) -> Result<Vec<f32>> {
99 let c = patch.in_chans;
100 let ts = patch.tubelet_size;
101 let ps = patch.patch_size;
102 let e = patch.embed_dim;
103 ensure!(
104 video_ncthw.len() == c * frames * height * width,
105 "video tensor size mismatch"
106 );
107 ensure!(frames.is_multiple_of(ts) && height.is_multiple_of(ps) && width.is_multiple_of(ps));
108
109 let t_out = frames / ts;
110 let h_out = height / ps;
111 let w_out = width / ps;
112 let seq = t_out * h_out * w_out;
113 let mut tokens = vec![0f32; seq * e];
114
115 let plane = height * width;
116 let vol = frames * plane;
117
118 for ot in 0..t_out {
119 for oh in 0..h_out {
120 for ow in 0..w_out {
121 let tok = (ot * h_out + oh) * w_out + ow;
122 for oc in 0..e {
123 let mut acc = patch.proj_b[oc];
124 for ic in 0..c {
125 for kt in 0..ts {
126 for kh in 0..ps {
127 for kw in 0..ps {
128 let it = ot * ts + kt;
129 let ih = oh * ps + kh;
130 let iw = ow * ps + kw;
131 let in_idx = ic * vol + it * plane + ih * width + iw;
132 let w_idx = oc * (c * ts * ps * ps)
133 + ic * (ts * ps * ps)
134 + kt * (ps * ps)
135 + kh * ps
136 + kw;
137 acc += patch.proj_w[w_idx] * video_ncthw[in_idx];
138 }
139 }
140 }
141 }
142 tokens[tok * e + oc] = acc;
143 }
144 }
145 }
146 }
147 Ok(tokens)
148}
149
150fn take_first(weights: &mut WeightMap, keys: &[&str]) -> Result<(Vec<f32>, Vec<usize>)> {
151 for key in keys {
152 if weights.has(key) {
153 return weights.take(key);
154 }
155 }
156 anyhow::bail!("none of the patch-embed keys found: {keys:?}")
157}