Skip to main content

rlx_vjepa2/
config.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! V-JEPA2 configuration — mirrors Meta / HuggingFace `config.json`.
17
18use serde::Deserialize;
19use std::path::Path;
20
21/// ImageNet-style mean/std (same as DINOv2 / HF VJEPA2VideoProcessor).
22pub const IMAGENET_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
23pub const IMAGENET_STD: [f32; 3] = [0.229, 0.224, 0.225];
24
25#[derive(Debug, Clone, Deserialize)]
26pub struct Vjepa2Config {
27    pub hidden_size: usize,
28    pub num_hidden_layers: usize,
29    pub num_attention_heads: usize,
30    #[serde(alias = "image_size")]
31    pub crop_size: usize,
32    pub patch_size: usize,
33    pub tubelet_size: usize,
34    pub frames_per_clip: usize,
35    #[serde(default = "default_mlp_ratio")]
36    pub mlp_ratio: f64,
37    #[serde(default = "default_ln_eps")]
38    pub layer_norm_eps: f64,
39    #[serde(default = "default_in_chans")]
40    pub in_chans: usize,
41    // Predictor
42    #[serde(default = "default_pred_hidden")]
43    pub pred_hidden_size: usize,
44    #[serde(default = "default_pred_heads")]
45    pub pred_num_attention_heads: usize,
46    #[serde(default = "default_pred_layers")]
47    pub pred_num_hidden_layers: usize,
48    #[serde(default = "default_pred_mlp_ratio")]
49    pub pred_mlp_ratio: f64,
50    #[serde(default = "default_pred_mask_tokens")]
51    pub pred_num_mask_tokens: usize,
52    #[serde(default = "default_true")]
53    pub pred_zero_init_mask_tokens: bool,
54    // Attentive pooler (finetuned checkpoints)
55    #[serde(default = "default_pooler_layers")]
56    pub num_pooler_layers: usize,
57    #[serde(default)]
58    pub num_classes: usize,
59}
60
61fn default_mlp_ratio() -> f64 {
62    48.0 / 11.0
63}
64fn default_ln_eps() -> f64 {
65    1e-6
66}
67fn default_in_chans() -> usize {
68    3
69}
70fn default_pred_hidden() -> usize {
71    384
72}
73fn default_pred_heads() -> usize {
74    12
75}
76fn default_pred_layers() -> usize {
77    12
78}
79fn default_pred_mlp_ratio() -> f64 {
80    4.0
81}
82fn default_pred_mask_tokens() -> usize {
83    10
84}
85fn default_true() -> bool {
86    true
87}
88fn default_pooler_layers() -> usize {
89    3
90}
91
92impl Vjepa2Config {
93    pub fn from_file(path: &Path) -> anyhow::Result<Self> {
94        let data = std::fs::read_to_string(path)?;
95        Ok(serde_json::from_str(&data)?)
96    }
97
98    /// `facebook/vjepa2-vitg-fpc64-384` — ViT-G, 64 frames, 384².
99    pub fn vit_g_384() -> Self {
100        Self {
101            hidden_size: 1408,
102            num_hidden_layers: 40,
103            num_attention_heads: 22,
104            crop_size: 384,
105            patch_size: 16,
106            tubelet_size: 2,
107            frames_per_clip: 64,
108            mlp_ratio: 48.0 / 11.0,
109            layer_norm_eps: 1e-6,
110            in_chans: 3,
111            pred_hidden_size: 384,
112            pred_num_attention_heads: 12,
113            pred_num_hidden_layers: 12,
114            pred_mlp_ratio: 4.0,
115            pred_num_mask_tokens: 10,
116            pred_zero_init_mask_tokens: true,
117            num_pooler_layers: 3,
118            num_classes: 0,
119        }
120    }
121
122    pub fn head_dim(&self) -> usize {
123        self.hidden_size / self.num_attention_heads
124    }
125
126    pub fn pred_head_dim(&self) -> usize {
127        self.pred_hidden_size / self.pred_num_attention_heads
128    }
129
130    pub fn intermediate_size(&self) -> usize {
131        (self.hidden_size as f64 * self.mlp_ratio) as usize
132    }
133
134    pub fn pred_intermediate_size(&self) -> usize {
135        (self.pred_hidden_size as f64 * self.pred_mlp_ratio) as usize
136    }
137
138    pub fn pooler_intermediate_size(&self) -> usize {
139        (self.hidden_size as f64 * self.mlp_ratio) as usize
140    }
141
142    pub fn grid_spatial(&self) -> usize {
143        self.crop_size / self.patch_size
144    }
145
146    pub fn grid_temporal(&self) -> usize {
147        self.frames_per_clip / self.tubelet_size
148    }
149
150    pub fn num_patches(&self) -> usize {
151        self.grid_temporal() * self.grid_spatial() * self.grid_spatial()
152    }
153
154    /// Per-axis RoPE segment sizes (d, h, w). Matches Meta `RoPEAttention`.
155    pub fn rope_segment_dims(&self) -> (usize, usize, usize) {
156        rope_segment_dims(self.head_dim())
157    }
158
159    pub fn pred_rope_segment_dims(&self) -> (usize, usize, usize) {
160        rope_segment_dims(self.pred_head_dim())
161    }
162}
163
164pub fn rope_segment_dims(head_dim: usize) -> (usize, usize, usize) {
165    let third = head_dim / 3;
166    let seg = 2 * (third / 2);
167    (seg, seg, seg)
168}