optical_embeddings/
config.rs

1//! Configuration structures for Optical Embeddings
2//!
3//! - Resolution modes (Tiny/Small/Base/Large/Gundam)
4//! - SAM encoder configuration
5//! - CLIP encoder configuration
6//! - Projector configuration
7//! - Model configuration utilities
8
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
12pub enum ResolutionMode {
13    Tiny,
14    Small,
15    Base,
16    Large,
17    Gundam,
18}
19
20impl ResolutionMode {
21    pub fn base_size(&self) -> u32 {
22        match self {
23            ResolutionMode::Tiny => 512,
24            ResolutionMode::Small => 640,
25            ResolutionMode::Base => 1024,
26            ResolutionMode::Large => 1280,
27            ResolutionMode::Gundam => 1024,
28        }
29    }
30    pub fn image_size(&self) -> u32 {
31        match self {
32            ResolutionMode::Tiny => 512,
33            ResolutionMode::Small => 640,
34            ResolutionMode::Base => 1024,
35            ResolutionMode::Large => 1280,
36            ResolutionMode::Gundam => 640,
37        }
38    }
39    pub fn crop_mode(&self) -> bool {
40        matches!(self, ResolutionMode::Gundam)
41    }
42    pub fn num_tokens(&self) -> usize {
43        match self {
44            ResolutionMode::Tiny => 64,
45            ResolutionMode::Small => 100,
46            ResolutionMode::Base => 256,
47            ResolutionMode::Large => 400,
48            ResolutionMode::Gundam => 100,
49        }
50    }
51    pub fn all_modes() -> Vec<ResolutionMode> {
52        vec![
53            ResolutionMode::Tiny,
54            ResolutionMode::Small,
55            ResolutionMode::Base,
56            ResolutionMode::Large,
57            ResolutionMode::Gundam,
58        ]
59    }
60}
61
62impl Default for ResolutionMode {
63    fn default() -> Self {
64        ResolutionMode::Base
65    }
66}
67
68impl core::fmt::Display for ResolutionMode {
69    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
70        match self {
71            ResolutionMode::Tiny => write!(f, "Tiny (512x512)"),
72            ResolutionMode::Small => write!(f, "Small (640x640)"),
73            ResolutionMode::Base => write!(f, "Base (1024x1024)"),
74            ResolutionMode::Large => write!(f, "Large (1280x1280)"),
75            ResolutionMode::Gundam => write!(f, "Gundam (1024 base, 640 image, with crops)"),
76        }
77    }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct ResolutionConfig {
82    pub base_size: u32,
83    pub image_size: u32,
84    pub crop_mode: bool,
85    pub min_crops: usize,
86    pub max_crops: usize,
87}
88
89impl ResolutionConfig {
90    pub fn new(mode: ResolutionMode) -> Self {
91        Self {
92            base_size: mode.base_size(),
93            image_size: mode.image_size(),
94            crop_mode: mode.crop_mode(),
95            min_crops: 2,
96            max_crops: 6,
97        }
98    }
99    pub fn tiny() -> Self {
100        Self::new(ResolutionMode::Tiny)
101    }
102    pub fn small() -> Self {
103        Self::new(ResolutionMode::Small)
104    }
105    pub fn base() -> Self {
106        Self::new(ResolutionMode::Base)
107    }
108    pub fn large() -> Self {
109        Self::new(ResolutionMode::Large)
110    }
111    pub fn gundam() -> Self {
112        Self::new(ResolutionMode::Gundam)
113    }
114
115    pub fn num_patches(&self, patch_size: usize) -> usize {
116        let g = self.base_size as usize / patch_size;
117        g * g
118    }
119    pub fn patch_grid_size(&self, patch_size: usize) -> (usize, usize) {
120        let g = self.base_size as usize / patch_size;
121        (g, g)
122    }
123}
124
125impl Default for ResolutionConfig {
126    fn default() -> Self {
127        Self::base()
128    }
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct SamConfig {
133    pub embed_dim: usize,
134    pub depth: usize,
135    pub num_heads: usize,
136    pub img_size: usize,
137    pub patch_size: usize,
138    pub mlp_ratio: f32,
139    pub out_chans: usize,
140    pub window_size: usize,
141    pub global_attn_indexes: Vec<usize>,
142    pub use_rel_pos: bool,
143    pub qkv_bias: bool,
144}
145
146impl SamConfig {
147    pub fn tiny() -> Self {
148        Self {
149            embed_dim: 192,
150            depth: 12,
151            num_heads: 3,
152            img_size: 1024,
153            patch_size: 16,
154            mlp_ratio: 4.0,
155            out_chans: 256,
156            window_size: 14,
157            global_attn_indexes: vec![2, 5, 8, 11],
158            use_rel_pos: true,
159            qkv_bias: true,
160        }
161    }
162    pub fn small() -> Self {
163        Self {
164            embed_dim: 384,
165            depth: 12,
166            num_heads: 6,
167            img_size: 1024,
168            patch_size: 16,
169            mlp_ratio: 4.0,
170            out_chans: 256,
171            window_size: 14,
172            global_attn_indexes: vec![2, 5, 8, 11],
173            use_rel_pos: true,
174            qkv_bias: true,
175        }
176    }
177    pub fn base() -> Self {
178        Self::default()
179    }
180    pub fn large() -> Self {
181        Self {
182            embed_dim: 1024,
183            depth: 24,
184            num_heads: 16,
185            img_size: 1024,
186            patch_size: 16,
187            mlp_ratio: 4.0,
188            out_chans: 256,
189            window_size: 14,
190            global_attn_indexes: vec![5, 11, 17, 23],
191            use_rel_pos: true,
192            qkv_bias: true,
193        }
194    }
195    pub fn with_img_size(mut self, size: usize) -> Self {
196        self.img_size = size;
197        self
198    }
199    pub fn with_window_size(mut self, size: usize) -> Self {
200        self.window_size = size;
201        self
202    }
203    pub fn num_patches(&self) -> usize {
204        (self.img_size / self.patch_size).pow(2)
205    }
206    pub fn patch_grid_size(&self) -> (usize, usize) {
207        let g = self.img_size / self.patch_size;
208        (g, g)
209    }
210}
211
212impl Default for SamConfig {
213    fn default() -> Self {
214        Self {
215            embed_dim: 768,
216            depth: 12,
217            num_heads: 12,
218            img_size: 1024,
219            patch_size: 16,
220            mlp_ratio: 4.0,
221            out_chans: 256,
222            window_size: 14,
223            global_attn_indexes: vec![2, 5, 8, 11],
224            use_rel_pos: true,
225            qkv_bias: true,
226        }
227    }
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct ClipConfig {
232    pub hidden_size: usize,
233    pub num_layers: usize,
234    pub num_attention_heads: usize,
235    pub ffn_hidden_size: usize,
236    pub image_size: usize,
237    pub patch_size: usize,
238    pub use_flash_attn: bool,
239    pub layernorm_epsilon: f32,
240    pub dropout: f32,
241}
242
243impl ClipConfig {
244    pub fn base() -> Self {
245        Self {
246            hidden_size: 768,
247            num_layers: 12,
248            num_attention_heads: 12,
249            ffn_hidden_size: 3072,
250            image_size: 224,
251            patch_size: 14,
252            use_flash_attn: false,
253            layernorm_epsilon: 1e-5,
254            dropout: 0.0,
255        }
256    }
257    pub fn large() -> Self {
258        Self::default()
259    }
260    pub fn with_image_size(mut self, size: usize) -> Self {
261        self.image_size = size;
262        self
263    }
264    pub fn with_flash_attn(mut self, enable: bool) -> Self {
265        self.use_flash_attn = enable;
266        self
267    }
268    pub fn num_patches(&self) -> usize {
269        (self.image_size / self.patch_size).pow(2)
270    }
271    pub fn num_positions(&self) -> usize {
272        self.num_patches() + 1
273    }
274}
275
276impl Default for ClipConfig {
277    fn default() -> Self {
278        Self {
279            hidden_size: 1024,
280            num_layers: 24,
281            num_attention_heads: 16,
282            ffn_hidden_size: 4096,
283            image_size: 224,
284            patch_size: 14,
285            use_flash_attn: false,
286            layernorm_epsilon: 1e-5,
287            dropout: 0.0,
288        }
289    }
290}
291
292#[derive(Debug, Clone, Serialize, Deserialize)]
293pub struct ProjectorConfig {
294    pub projector_type: String,
295    pub input_dim: usize,
296    pub n_embed: usize,
297    pub num_layers: usize,
298}
299
300impl ProjectorConfig {
301    pub fn linear(input_dim: usize, output_dim: usize) -> Self {
302        Self {
303            projector_type: "linear".to_string(),
304            input_dim,
305            n_embed: output_dim,
306            num_layers: 1,
307        }
308    }
309    pub fn mlp_gelu(input_dim: usize, output_dim: usize, num_layers: usize) -> Self {
310        Self {
311            projector_type: "mlp_gelu".to_string(),
312            input_dim,
313            n_embed: output_dim,
314            num_layers,
315        }
316    }
317    pub fn with_type(mut self, projector_type: &str) -> Self {
318        self.projector_type = projector_type.to_string();
319        self
320    }
321    pub fn with_num_layers(mut self, num_layers: usize) -> Self {
322        self.num_layers = num_layers;
323        self
324    }
325}
326
327impl Default for ProjectorConfig {
328    fn default() -> Self {
329        Self {
330            projector_type: "linear".to_string(),
331            input_dim: 2048,
332            n_embed: 1280,
333            num_layers: 1,
334        }
335    }
336}
337
338#[derive(Debug, Clone, Serialize, Deserialize)]
339pub struct ModelConfig {
340    pub sam_config: SamConfig,
341    pub clip_config: ClipConfig,
342    pub projector_config: ProjectorConfig,
343    pub resolution_config: ResolutionConfig,
344    pub patch_size: usize,
345    pub downsample_ratio: usize,
346    pub tile_tag: String,
347    pub image_mean: [f32; 3],
348    pub image_std: [f32; 3],
349}
350
351impl ModelConfig {
352    pub fn with_resolution(mut self, mode: ResolutionMode) -> Self {
353        self.resolution_config = ResolutionConfig::new(mode);
354        self.sam_config.img_size = mode.base_size() as usize;
355        self
356    }
357    pub fn with_sam_config(mut self, config: SamConfig) -> Self {
358        self.sam_config = config;
359        self
360    }
361    pub fn with_clip_config(mut self, config: ClipConfig) -> Self {
362        self.clip_config = config;
363        self
364    }
365    pub fn with_projector_config(mut self, config: ProjectorConfig) -> Self {
366        self.projector_config = config;
367        self
368    }
369    pub fn with_normalization(mut self, mean: [f32; 3], std: [f32; 3]) -> Self {
370        self.image_mean = mean;
371        self.image_std = std;
372        self
373    }
374
375    pub fn num_vision_tokens(&self) -> usize {
376        let h =
377            (self.resolution_config.base_size as usize / self.patch_size) / self.downsample_ratio;
378        let w = h;
379        h * (w + 1) + 1
380    }
381    pub fn compression_ratio(&self, text_length: usize) -> f32 {
382        text_length as f32 / self.num_vision_tokens() as f32
383    }
384    pub fn num_patches(&self) -> usize {
385        self.resolution_config.num_patches(self.patch_size)
386    }
387    pub fn patch_grid_size(&self) -> (usize, usize) {
388        self.resolution_config.patch_grid_size(self.patch_size)
389    }
390    pub fn compressed_grid_size(&self) -> (usize, usize) {
391        let (h, w) = self.patch_grid_size();
392        (h / self.downsample_ratio, w / self.downsample_ratio)
393    }
394
395    pub fn production() -> Self {
396        Self::default().with_resolution(ResolutionMode::Base)
397    }
398    pub fn evaluation() -> Self {
399        Self::default().with_resolution(ResolutionMode::Large)
400    }
401    pub fn fast() -> Self {
402        Self::default().with_resolution(ResolutionMode::Tiny)
403    }
404}
405
406impl Default for ModelConfig {
407    fn default() -> Self {
408        Self {
409            sam_config: SamConfig::default(),
410            clip_config: ClipConfig::default(),
411            projector_config: ProjectorConfig::default(),
412            resolution_config: ResolutionConfig::base(),
413            patch_size: 16,
414            downsample_ratio: 4,
415            tile_tag: "2D".to_string(),
416            image_mean: [0.5, 0.5, 0.5],
417            image_std: [0.5, 0.5, 0.5],
418        }
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425    #[test]
426    fn test_tokens_and_ratios() {
427        let cfg = ModelConfig::default();
428        assert_eq!(cfg.num_vision_tokens(), 273);
429        let r = cfg.compression_ratio(2730);
430        assert!((r - 10.0).abs() < 0.1);
431    }
432}