use super::config::Vjepa2Config;
use super::layers::block_forward;
use super::preprocess::conv3d_patch_embed;
use super::weights::Vjepa2EncoderWeights;
use anyhow::Result;
use rlx_tensor::layer_norm;
pub struct Vjepa2EncoderOutput {
pub tokens: Vec<f32>,
pub seq: usize,
pub hidden: usize,
}
pub fn encode_video_native(
weights: &Vjepa2EncoderWeights,
cfg: &Vjepa2Config,
video_ncthw: &[f32],
batch: usize,
) -> Result<Vjepa2EncoderOutput> {
encode_video_native_ext(weights, cfg, video_ncthw, batch, None)
}
pub fn encode_video_native_ext(
weights: &Vjepa2EncoderWeights,
cfg: &Vjepa2Config,
video_ncthw: &[f32],
batch: usize,
stop_after_block: Option<usize>,
) -> Result<Vjepa2EncoderOutput> {
let e = cfg.hidden_size;
let frames = cfg.frames_per_clip;
let crop = cfg.crop_size;
let seq = cfg.num_patches();
let head_dim = cfg.head_dim();
let nh = cfg.num_attention_heads;
let (d_dim, h_dim, w_dim) = cfg.rope_segment_dims();
let grid_t = cfg.grid_temporal();
let grid_h = cfg.grid_spatial();
let grid_w = cfg.grid_spatial();
let eps = cfg.layer_norm_eps as f32;
let mut x = conv3d_patch_embed(&weights.patch, video_ncthw, frames, crop, crop)?;
if batch > 1 {
let per = x.len();
let mut batched = Vec::with_capacity(per * batch);
for _ in 0..batch {
batched.extend_from_slice(&x);
}
x = batched;
}
let last_block = weights.blocks.len().saturating_sub(1);
for (i, block) in weights.blocks.iter().enumerate() {
block_forward(
&mut x, block, batch, seq, e, nh, head_dim, d_dim, h_dim, w_dim, grid_t, grid_h,
grid_w, eps, None,
)?;
if stop_after_block == Some(i) {
break;
}
}
if stop_after_block.is_none() || stop_after_block == Some(last_block) {
x = layer_norm(&x, &weights.norm_w, &weights.norm_b, e, eps)?;
}
Ok(Vjepa2EncoderOutput {
tokens: x,
seq,
hidden: e,
})
}