1use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
12pub enum ResolutionMode {
13 Tiny,
14 Small,
15 Base,
16 Large,
17 Gundam,
18}
19
20impl ResolutionMode {
21 pub fn base_size(&self) -> u32 {
22 match self {
23 ResolutionMode::Tiny => 512,
24 ResolutionMode::Small => 640,
25 ResolutionMode::Base => 1024,
26 ResolutionMode::Large => 1280,
27 ResolutionMode::Gundam => 1024,
28 }
29 }
30 pub fn image_size(&self) -> u32 {
31 match self {
32 ResolutionMode::Tiny => 512,
33 ResolutionMode::Small => 640,
34 ResolutionMode::Base => 1024,
35 ResolutionMode::Large => 1280,
36 ResolutionMode::Gundam => 640,
37 }
38 }
39 pub fn crop_mode(&self) -> bool {
40 matches!(self, ResolutionMode::Gundam)
41 }
42 pub fn num_tokens(&self) -> usize {
43 match self {
44 ResolutionMode::Tiny => 64,
45 ResolutionMode::Small => 100,
46 ResolutionMode::Base => 256,
47 ResolutionMode::Large => 400,
48 ResolutionMode::Gundam => 100,
49 }
50 }
51 pub fn all_modes() -> Vec<ResolutionMode> {
52 vec![
53 ResolutionMode::Tiny,
54 ResolutionMode::Small,
55 ResolutionMode::Base,
56 ResolutionMode::Large,
57 ResolutionMode::Gundam,
58 ]
59 }
60}
61
62impl Default for ResolutionMode {
63 fn default() -> Self {
64 ResolutionMode::Base
65 }
66}
67
68impl core::fmt::Display for ResolutionMode {
69 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
70 match self {
71 ResolutionMode::Tiny => write!(f, "Tiny (512x512)"),
72 ResolutionMode::Small => write!(f, "Small (640x640)"),
73 ResolutionMode::Base => write!(f, "Base (1024x1024)"),
74 ResolutionMode::Large => write!(f, "Large (1280x1280)"),
75 ResolutionMode::Gundam => write!(f, "Gundam (1024 base, 640 image, with crops)"),
76 }
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct ResolutionConfig {
82 pub base_size: u32,
83 pub image_size: u32,
84 pub crop_mode: bool,
85 pub min_crops: usize,
86 pub max_crops: usize,
87}
88
89impl ResolutionConfig {
90 pub fn new(mode: ResolutionMode) -> Self {
91 Self {
92 base_size: mode.base_size(),
93 image_size: mode.image_size(),
94 crop_mode: mode.crop_mode(),
95 min_crops: 2,
96 max_crops: 6,
97 }
98 }
99 pub fn tiny() -> Self {
100 Self::new(ResolutionMode::Tiny)
101 }
102 pub fn small() -> Self {
103 Self::new(ResolutionMode::Small)
104 }
105 pub fn base() -> Self {
106 Self::new(ResolutionMode::Base)
107 }
108 pub fn large() -> Self {
109 Self::new(ResolutionMode::Large)
110 }
111 pub fn gundam() -> Self {
112 Self::new(ResolutionMode::Gundam)
113 }
114
115 pub fn num_patches(&self, patch_size: usize) -> usize {
116 let g = self.base_size as usize / patch_size;
117 g * g
118 }
119 pub fn patch_grid_size(&self, patch_size: usize) -> (usize, usize) {
120 let g = self.base_size as usize / patch_size;
121 (g, g)
122 }
123}
124
125impl Default for ResolutionConfig {
126 fn default() -> Self {
127 Self::base()
128 }
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct SamConfig {
133 pub embed_dim: usize,
134 pub depth: usize,
135 pub num_heads: usize,
136 pub img_size: usize,
137 pub patch_size: usize,
138 pub mlp_ratio: f32,
139 pub out_chans: usize,
140 pub window_size: usize,
141 pub global_attn_indexes: Vec<usize>,
142 pub use_rel_pos: bool,
143 pub qkv_bias: bool,
144}
145
146impl SamConfig {
147 pub fn tiny() -> Self {
148 Self {
149 embed_dim: 192,
150 depth: 12,
151 num_heads: 3,
152 img_size: 1024,
153 patch_size: 16,
154 mlp_ratio: 4.0,
155 out_chans: 256,
156 window_size: 14,
157 global_attn_indexes: vec![2, 5, 8, 11],
158 use_rel_pos: true,
159 qkv_bias: true,
160 }
161 }
162 pub fn small() -> Self {
163 Self {
164 embed_dim: 384,
165 depth: 12,
166 num_heads: 6,
167 img_size: 1024,
168 patch_size: 16,
169 mlp_ratio: 4.0,
170 out_chans: 256,
171 window_size: 14,
172 global_attn_indexes: vec![2, 5, 8, 11],
173 use_rel_pos: true,
174 qkv_bias: true,
175 }
176 }
177 pub fn base() -> Self {
178 Self::default()
179 }
180 pub fn large() -> Self {
181 Self {
182 embed_dim: 1024,
183 depth: 24,
184 num_heads: 16,
185 img_size: 1024,
186 patch_size: 16,
187 mlp_ratio: 4.0,
188 out_chans: 256,
189 window_size: 14,
190 global_attn_indexes: vec![5, 11, 17, 23],
191 use_rel_pos: true,
192 qkv_bias: true,
193 }
194 }
195 pub fn with_img_size(mut self, size: usize) -> Self {
196 self.img_size = size;
197 self
198 }
199 pub fn with_window_size(mut self, size: usize) -> Self {
200 self.window_size = size;
201 self
202 }
203 pub fn num_patches(&self) -> usize {
204 (self.img_size / self.patch_size).pow(2)
205 }
206 pub fn patch_grid_size(&self) -> (usize, usize) {
207 let g = self.img_size / self.patch_size;
208 (g, g)
209 }
210}
211
212impl Default for SamConfig {
213 fn default() -> Self {
214 Self {
215 embed_dim: 768,
216 depth: 12,
217 num_heads: 12,
218 img_size: 1024,
219 patch_size: 16,
220 mlp_ratio: 4.0,
221 out_chans: 256,
222 window_size: 14,
223 global_attn_indexes: vec![2, 5, 8, 11],
224 use_rel_pos: true,
225 qkv_bias: true,
226 }
227 }
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct ClipConfig {
232 pub hidden_size: usize,
233 pub num_layers: usize,
234 pub num_attention_heads: usize,
235 pub ffn_hidden_size: usize,
236 pub image_size: usize,
237 pub patch_size: usize,
238 pub use_flash_attn: bool,
239 pub layernorm_epsilon: f32,
240 pub dropout: f32,
241}
242
243impl ClipConfig {
244 pub fn base() -> Self {
245 Self {
246 hidden_size: 768,
247 num_layers: 12,
248 num_attention_heads: 12,
249 ffn_hidden_size: 3072,
250 image_size: 224,
251 patch_size: 14,
252 use_flash_attn: false,
253 layernorm_epsilon: 1e-5,
254 dropout: 0.0,
255 }
256 }
257 pub fn large() -> Self {
258 Self::default()
259 }
260 pub fn with_image_size(mut self, size: usize) -> Self {
261 self.image_size = size;
262 self
263 }
264 pub fn with_flash_attn(mut self, enable: bool) -> Self {
265 self.use_flash_attn = enable;
266 self
267 }
268 pub fn num_patches(&self) -> usize {
269 (self.image_size / self.patch_size).pow(2)
270 }
271 pub fn num_positions(&self) -> usize {
272 self.num_patches() + 1
273 }
274}
275
276impl Default for ClipConfig {
277 fn default() -> Self {
278 Self {
279 hidden_size: 1024,
280 num_layers: 24,
281 num_attention_heads: 16,
282 ffn_hidden_size: 4096,
283 image_size: 224,
284 patch_size: 14,
285 use_flash_attn: false,
286 layernorm_epsilon: 1e-5,
287 dropout: 0.0,
288 }
289 }
290}
291
292#[derive(Debug, Clone, Serialize, Deserialize)]
293pub struct ProjectorConfig {
294 pub projector_type: String,
295 pub input_dim: usize,
296 pub n_embed: usize,
297 pub num_layers: usize,
298}
299
300impl ProjectorConfig {
301 pub fn linear(input_dim: usize, output_dim: usize) -> Self {
302 Self {
303 projector_type: "linear".to_string(),
304 input_dim,
305 n_embed: output_dim,
306 num_layers: 1,
307 }
308 }
309 pub fn mlp_gelu(input_dim: usize, output_dim: usize, num_layers: usize) -> Self {
310 Self {
311 projector_type: "mlp_gelu".to_string(),
312 input_dim,
313 n_embed: output_dim,
314 num_layers,
315 }
316 }
317 pub fn with_type(mut self, projector_type: &str) -> Self {
318 self.projector_type = projector_type.to_string();
319 self
320 }
321 pub fn with_num_layers(mut self, num_layers: usize) -> Self {
322 self.num_layers = num_layers;
323 self
324 }
325}
326
327impl Default for ProjectorConfig {
328 fn default() -> Self {
329 Self {
330 projector_type: "linear".to_string(),
331 input_dim: 2048,
332 n_embed: 1280,
333 num_layers: 1,
334 }
335 }
336}
337
338#[derive(Debug, Clone, Serialize, Deserialize)]
339pub struct ModelConfig {
340 pub sam_config: SamConfig,
341 pub clip_config: ClipConfig,
342 pub projector_config: ProjectorConfig,
343 pub resolution_config: ResolutionConfig,
344 pub patch_size: usize,
345 pub downsample_ratio: usize,
346 pub tile_tag: String,
347 pub image_mean: [f32; 3],
348 pub image_std: [f32; 3],
349}
350
351impl ModelConfig {
352 pub fn with_resolution(mut self, mode: ResolutionMode) -> Self {
353 self.resolution_config = ResolutionConfig::new(mode);
354 self.sam_config.img_size = mode.base_size() as usize;
355 self
356 }
357 pub fn with_sam_config(mut self, config: SamConfig) -> Self {
358 self.sam_config = config;
359 self
360 }
361 pub fn with_clip_config(mut self, config: ClipConfig) -> Self {
362 self.clip_config = config;
363 self
364 }
365 pub fn with_projector_config(mut self, config: ProjectorConfig) -> Self {
366 self.projector_config = config;
367 self
368 }
369 pub fn with_normalization(mut self, mean: [f32; 3], std: [f32; 3]) -> Self {
370 self.image_mean = mean;
371 self.image_std = std;
372 self
373 }
374
375 pub fn num_vision_tokens(&self) -> usize {
376 let h =
377 (self.resolution_config.base_size as usize / self.patch_size) / self.downsample_ratio;
378 let w = h;
379 h * (w + 1) + 1
380 }
381 pub fn compression_ratio(&self, text_length: usize) -> f32 {
382 text_length as f32 / self.num_vision_tokens() as f32
383 }
384 pub fn num_patches(&self) -> usize {
385 self.resolution_config.num_patches(self.patch_size)
386 }
387 pub fn patch_grid_size(&self) -> (usize, usize) {
388 self.resolution_config.patch_grid_size(self.patch_size)
389 }
390 pub fn compressed_grid_size(&self) -> (usize, usize) {
391 let (h, w) = self.patch_grid_size();
392 (h / self.downsample_ratio, w / self.downsample_ratio)
393 }
394
395 pub fn production() -> Self {
396 Self::default().with_resolution(ResolutionMode::Base)
397 }
398 pub fn evaluation() -> Self {
399 Self::default().with_resolution(ResolutionMode::Large)
400 }
401 pub fn fast() -> Self {
402 Self::default().with_resolution(ResolutionMode::Tiny)
403 }
404}
405
406impl Default for ModelConfig {
407 fn default() -> Self {
408 Self {
409 sam_config: SamConfig::default(),
410 clip_config: ClipConfig::default(),
411 projector_config: ProjectorConfig::default(),
412 resolution_config: ResolutionConfig::base(),
413 patch_size: 16,
414 downsample_ratio: 4,
415 tile_tag: "2D".to_string(),
416 image_mean: [0.5, 0.5, 0.5],
417 image_std: [0.5, 0.5, 0.5],
418 }
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425 #[test]
426 fn test_tokens_and_ratios() {
427 let cfg = ModelConfig::default();
428 assert_eq!(cfg.num_vision_tokens(), 273);
429 let r = cfg.compression_ratio(2730);
430 assert!((r - 10.0).abs() < 0.1);
431 }
432}