1use super::config::Vjepa2Config;
19use super::layers::block_forward;
20use super::preprocess::conv3d_patch_embed;
21use super::weights::Vjepa2EncoderWeights;
22use anyhow::Result;
23use rlx_tensor::layer_norm;
24
25pub struct Vjepa2EncoderOutput {
26 pub tokens: Vec<f32>,
27 pub seq: usize,
28 pub hidden: usize,
29}
30
31pub fn encode_video_native(
33 weights: &Vjepa2EncoderWeights,
34 cfg: &Vjepa2Config,
35 video_ncthw: &[f32],
36 batch: usize,
37) -> Result<Vjepa2EncoderOutput> {
38 encode_video_native_ext(weights, cfg, video_ncthw, batch, None)
39}
40
41pub fn encode_video_native_ext(
44 weights: &Vjepa2EncoderWeights,
45 cfg: &Vjepa2Config,
46 video_ncthw: &[f32],
47 batch: usize,
48 stop_after_block: Option<usize>,
49) -> Result<Vjepa2EncoderOutput> {
50 let e = cfg.hidden_size;
51 let frames = cfg.frames_per_clip;
52 let crop = cfg.crop_size;
53 let seq = cfg.num_patches();
54 let head_dim = cfg.head_dim();
55 let nh = cfg.num_attention_heads;
56 let (d_dim, h_dim, w_dim) = cfg.rope_segment_dims();
57 let grid_t = cfg.grid_temporal();
58 let grid_h = cfg.grid_spatial();
59 let grid_w = cfg.grid_spatial();
60 let eps = cfg.layer_norm_eps as f32;
61
62 let mut x = conv3d_patch_embed(&weights.patch, video_ncthw, frames, crop, crop)?;
63 if batch > 1 {
64 let per = x.len();
65 let mut batched = Vec::with_capacity(per * batch);
66 for _ in 0..batch {
67 batched.extend_from_slice(&x);
68 }
69 x = batched;
70 }
71
72 let last_block = weights.blocks.len().saturating_sub(1);
73 for (i, block) in weights.blocks.iter().enumerate() {
74 block_forward(
75 &mut x, block, batch, seq, e, nh, head_dim, d_dim, h_dim, w_dim, grid_t, grid_h,
76 grid_w, eps, None,
77 )?;
78 if stop_after_block == Some(i) {
79 break;
80 }
81 }
82
83 if stop_after_block.is_none() || stop_after_block == Some(last_block) {
84 x = layer_norm(&x, &weights.norm_w, &weights.norm_b, e, eps)?;
85 }
86
87 Ok(Vjepa2EncoderOutput {
88 tokens: x,
89 seq,
90 hidden: e,
91 })
92}