1use serde::Deserialize;
19use std::path::Path;
20
21pub const IMAGENET_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
23pub const IMAGENET_STD: [f32; 3] = [0.229, 0.224, 0.225];
24
25#[derive(Debug, Clone, Deserialize)]
26pub struct Vjepa2Config {
27 pub hidden_size: usize,
28 pub num_hidden_layers: usize,
29 pub num_attention_heads: usize,
30 #[serde(alias = "image_size")]
31 pub crop_size: usize,
32 pub patch_size: usize,
33 pub tubelet_size: usize,
34 pub frames_per_clip: usize,
35 #[serde(default = "default_mlp_ratio")]
36 pub mlp_ratio: f64,
37 #[serde(default = "default_ln_eps")]
38 pub layer_norm_eps: f64,
39 #[serde(default = "default_in_chans")]
40 pub in_chans: usize,
41 #[serde(default = "default_pred_hidden")]
43 pub pred_hidden_size: usize,
44 #[serde(default = "default_pred_heads")]
45 pub pred_num_attention_heads: usize,
46 #[serde(default = "default_pred_layers")]
47 pub pred_num_hidden_layers: usize,
48 #[serde(default = "default_pred_mlp_ratio")]
49 pub pred_mlp_ratio: f64,
50 #[serde(default = "default_pred_mask_tokens")]
51 pub pred_num_mask_tokens: usize,
52 #[serde(default = "default_true")]
53 pub pred_zero_init_mask_tokens: bool,
54 #[serde(default = "default_pooler_layers")]
56 pub num_pooler_layers: usize,
57 #[serde(default)]
58 pub num_classes: usize,
59}
60
61fn default_mlp_ratio() -> f64 {
62 48.0 / 11.0
63}
64fn default_ln_eps() -> f64 {
65 1e-6
66}
67fn default_in_chans() -> usize {
68 3
69}
70fn default_pred_hidden() -> usize {
71 384
72}
73fn default_pred_heads() -> usize {
74 12
75}
76fn default_pred_layers() -> usize {
77 12
78}
79fn default_pred_mlp_ratio() -> f64 {
80 4.0
81}
82fn default_pred_mask_tokens() -> usize {
83 10
84}
85fn default_true() -> bool {
86 true
87}
88fn default_pooler_layers() -> usize {
89 3
90}
91
92impl Vjepa2Config {
93 pub fn from_file(path: &Path) -> anyhow::Result<Self> {
94 let data = std::fs::read_to_string(path)?;
95 Ok(serde_json::from_str(&data)?)
96 }
97
98 pub fn vit_g_384() -> Self {
100 Self {
101 hidden_size: 1408,
102 num_hidden_layers: 40,
103 num_attention_heads: 22,
104 crop_size: 384,
105 patch_size: 16,
106 tubelet_size: 2,
107 frames_per_clip: 64,
108 mlp_ratio: 48.0 / 11.0,
109 layer_norm_eps: 1e-6,
110 in_chans: 3,
111 pred_hidden_size: 384,
112 pred_num_attention_heads: 12,
113 pred_num_hidden_layers: 12,
114 pred_mlp_ratio: 4.0,
115 pred_num_mask_tokens: 10,
116 pred_zero_init_mask_tokens: true,
117 num_pooler_layers: 3,
118 num_classes: 0,
119 }
120 }
121
122 pub fn head_dim(&self) -> usize {
123 self.hidden_size / self.num_attention_heads
124 }
125
126 pub fn pred_head_dim(&self) -> usize {
127 self.pred_hidden_size / self.pred_num_attention_heads
128 }
129
130 pub fn intermediate_size(&self) -> usize {
131 (self.hidden_size as f64 * self.mlp_ratio) as usize
132 }
133
134 pub fn pred_intermediate_size(&self) -> usize {
135 (self.pred_hidden_size as f64 * self.pred_mlp_ratio) as usize
136 }
137
138 pub fn pooler_intermediate_size(&self) -> usize {
139 (self.hidden_size as f64 * self.mlp_ratio) as usize
140 }
141
142 pub fn grid_spatial(&self) -> usize {
143 self.crop_size / self.patch_size
144 }
145
146 pub fn grid_temporal(&self) -> usize {
147 self.frames_per_clip / self.tubelet_size
148 }
149
150 pub fn num_patches(&self) -> usize {
151 self.grid_temporal() * self.grid_spatial() * self.grid_spatial()
152 }
153
154 pub fn rope_segment_dims(&self) -> (usize, usize, usize) {
156 rope_segment_dims(self.head_dim())
157 }
158
159 pub fn pred_rope_segment_dims(&self) -> (usize, usize, usize) {
160 rope_segment_dims(self.pred_head_dim())
161 }
162}
163
164pub fn rope_segment_dims(head_dim: usize) -> (usize, usize, usize) {
165 let third = head_dim / 3;
166 let seg = 2 * (third / 2);
167 (seg, seg, seg)
168}