1use serde::Deserialize;
23
24pub const SAM3_PIXEL_MEAN: [f32; 3] = [0.5, 0.5, 0.5];
26pub const SAM3_PIXEL_STD: [f32; 3] = [0.5, 0.5, 0.5];
27
28pub const SAM3_IMG_SIZE: usize = 1008;
30pub const SAM3_PATCH_SIZE: usize = 14;
31pub const SAM3_PATCH_GRID: usize = SAM3_IMG_SIZE / SAM3_PATCH_SIZE; pub const SAM3_VISION_DIM: usize = 1024;
33pub const SAM3_DET_DIM: usize = 256;
34
35#[derive(Debug, Clone, Deserialize)]
36pub struct Sam3VitConfig {
37 pub img_size: usize,
38 pub pretrain_img_size: usize,
39 pub patch_size: usize,
40 pub embed_dim: usize,
41 pub depth: usize,
42 pub num_heads: usize,
43 pub mlp_ratio: f64,
44 pub qkv_bias: bool,
45 pub bias_patch_embed: bool,
46 pub use_abs_pos: bool,
47 pub tile_abs_pos: bool,
48 pub use_rope: bool,
49 pub use_interp_rope: bool,
50 pub window_size: usize,
51 pub global_att_blocks: Vec<usize>,
52 pub layer_norm_eps: f64,
53}
54
55impl Sam3VitConfig {
56 pub fn base() -> Self {
57 Self {
58 img_size: SAM3_IMG_SIZE,
59 pretrain_img_size: 336,
60 patch_size: SAM3_PATCH_SIZE,
61 embed_dim: SAM3_VISION_DIM,
62 depth: 32,
63 num_heads: 16,
64 mlp_ratio: 4.625,
65 qkv_bias: true,
66 bias_patch_embed: false,
67 use_abs_pos: true,
68 tile_abs_pos: true,
69 use_rope: true,
70 use_interp_rope: true,
71 window_size: 24,
72 global_att_blocks: vec![7, 15, 23, 31],
73 layer_norm_eps: 1e-6,
74 }
75 }
76
77 pub fn patch_grid(&self) -> usize {
78 self.img_size / self.patch_size
79 }
80}
81
82#[derive(Debug, Clone, Deserialize)]
83pub struct Sam3TextConfig {
84 pub d_model: usize,
85 pub width: usize,
86 pub heads: usize,
87 pub layers: usize,
88}
89
90impl Default for Sam3TextConfig {
91 fn default() -> Self {
92 Self {
93 d_model: SAM3_DET_DIM,
94 width: 1024,
95 heads: 16,
96 layers: 24,
97 }
98 }
99}
100
101#[derive(Debug, Clone, Deserialize)]
102pub struct Sam3DetectorConfig {
103 pub d_model: usize,
104 pub num_queries: usize,
105 pub encoder_layers: usize,
106 pub decoder_layers: usize,
107 pub transformer_heads: usize,
108 pub dim_feedforward: usize,
109 pub presence_token: bool,
110 pub num_feature_levels: usize,
111}
112
113impl Default for Sam3DetectorConfig {
114 fn default() -> Self {
115 Self {
116 d_model: SAM3_DET_DIM,
117 num_queries: 200,
118 encoder_layers: 6,
119 decoder_layers: 6,
120 transformer_heads: 8,
121 dim_feedforward: 2048,
122 presence_token: true,
123 num_feature_levels: 1,
124 }
125 }
126}
127
128#[derive(Debug, Clone, Deserialize)]
129pub struct Sam3TrackerConfig {
130 pub image_size: usize,
131 pub backbone_stride: usize,
132 pub num_maskmem: usize,
133 pub max_cond_frames_in_attn: usize,
134 pub memory_dim: usize,
135 pub transformer_dim: usize,
136 pub transformer_layers: usize,
137 pub feat_hw: usize,
138}
139
140impl Default for Sam3TrackerConfig {
141 fn default() -> Self {
142 Self {
143 image_size: SAM3_IMG_SIZE,
144 backbone_stride: SAM3_PATCH_SIZE,
145 num_maskmem: 7,
146 max_cond_frames_in_attn: 4,
147 memory_dim: 64,
148 transformer_dim: SAM3_DET_DIM,
149 transformer_layers: 4,
150 feat_hw: SAM3_PATCH_GRID,
151 }
152 }
153}
154
155#[derive(Debug, Clone, Deserialize)]
156pub struct Sam3Config {
157 pub vit: Sam3VitConfig,
158 pub text: Sam3TextConfig,
159 pub detector: Sam3DetectorConfig,
160 pub tracker: Sam3TrackerConfig,
161 pub enable_inst_interactivity: bool,
162 pub enable_video: bool,
163}
164
165impl Sam3Config {
166 pub fn base() -> Self {
167 Self {
168 vit: Sam3VitConfig::base(),
169 text: Sam3TextConfig::default(),
170 detector: Sam3DetectorConfig::default(),
171 tracker: Sam3TrackerConfig::default(),
172 enable_inst_interactivity: false,
173 enable_video: true,
174 }
175 }
176}