Skip to main content

rlx_sam3/
config.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM 3 configuration.
17//!
18//! The defaults mirror `facebookresearch/sam3::model_builder` for the
19//! base SAM3 release. SAM3.1 multiplex is a distinct architecture and is
20//! intentionally not represented by this config.
21
22use serde::Deserialize;
23
24/// SAM3 normalizes RGB values after scaling to `[0, 1]`.
25pub const SAM3_PIXEL_MEAN: [f32; 3] = [0.5, 0.5, 0.5];
26pub const SAM3_PIXEL_STD: [f32; 3] = [0.5, 0.5, 0.5];
27
28/// Base SAM3 image side used by the public model builder.
29pub const SAM3_IMG_SIZE: usize = 1008;
30pub const SAM3_PATCH_SIZE: usize = 14;
31pub const SAM3_PATCH_GRID: usize = SAM3_IMG_SIZE / SAM3_PATCH_SIZE; // 72
32pub const SAM3_VISION_DIM: usize = 1024;
33pub const SAM3_DET_DIM: usize = 256;
34
35#[derive(Debug, Clone, Deserialize)]
36pub struct Sam3VitConfig {
37    pub img_size: usize,
38    pub pretrain_img_size: usize,
39    pub patch_size: usize,
40    pub embed_dim: usize,
41    pub depth: usize,
42    pub num_heads: usize,
43    pub mlp_ratio: f64,
44    pub qkv_bias: bool,
45    pub bias_patch_embed: bool,
46    pub use_abs_pos: bool,
47    pub tile_abs_pos: bool,
48    pub use_rope: bool,
49    pub use_interp_rope: bool,
50    pub window_size: usize,
51    pub global_att_blocks: Vec<usize>,
52    pub layer_norm_eps: f64,
53}
54
55impl Sam3VitConfig {
56    pub fn base() -> Self {
57        Self {
58            img_size: SAM3_IMG_SIZE,
59            pretrain_img_size: 336,
60            patch_size: SAM3_PATCH_SIZE,
61            embed_dim: SAM3_VISION_DIM,
62            depth: 32,
63            num_heads: 16,
64            mlp_ratio: 4.625,
65            qkv_bias: true,
66            bias_patch_embed: false,
67            use_abs_pos: true,
68            tile_abs_pos: true,
69            use_rope: true,
70            use_interp_rope: true,
71            window_size: 24,
72            global_att_blocks: vec![7, 15, 23, 31],
73            layer_norm_eps: 1e-6,
74        }
75    }
76
77    pub fn patch_grid(&self) -> usize {
78        self.img_size / self.patch_size
79    }
80}
81
82#[derive(Debug, Clone, Deserialize)]
83pub struct Sam3TextConfig {
84    pub d_model: usize,
85    pub width: usize,
86    pub heads: usize,
87    pub layers: usize,
88}
89
90impl Default for Sam3TextConfig {
91    fn default() -> Self {
92        Self {
93            d_model: SAM3_DET_DIM,
94            width: 1024,
95            heads: 16,
96            layers: 24,
97        }
98    }
99}
100
101#[derive(Debug, Clone, Deserialize)]
102pub struct Sam3DetectorConfig {
103    pub d_model: usize,
104    pub num_queries: usize,
105    pub encoder_layers: usize,
106    pub decoder_layers: usize,
107    pub transformer_heads: usize,
108    pub dim_feedforward: usize,
109    pub presence_token: bool,
110    pub num_feature_levels: usize,
111}
112
113impl Default for Sam3DetectorConfig {
114    fn default() -> Self {
115        Self {
116            d_model: SAM3_DET_DIM,
117            num_queries: 200,
118            encoder_layers: 6,
119            decoder_layers: 6,
120            transformer_heads: 8,
121            dim_feedforward: 2048,
122            presence_token: true,
123            num_feature_levels: 1,
124        }
125    }
126}
127
128#[derive(Debug, Clone, Deserialize)]
129pub struct Sam3TrackerConfig {
130    pub image_size: usize,
131    pub backbone_stride: usize,
132    pub num_maskmem: usize,
133    pub max_cond_frames_in_attn: usize,
134    pub memory_dim: usize,
135    pub transformer_dim: usize,
136    pub transformer_layers: usize,
137    pub feat_hw: usize,
138}
139
140impl Default for Sam3TrackerConfig {
141    fn default() -> Self {
142        Self {
143            image_size: SAM3_IMG_SIZE,
144            backbone_stride: SAM3_PATCH_SIZE,
145            num_maskmem: 7,
146            max_cond_frames_in_attn: 4,
147            memory_dim: 64,
148            transformer_dim: SAM3_DET_DIM,
149            transformer_layers: 4,
150            feat_hw: SAM3_PATCH_GRID,
151        }
152    }
153}
154
155#[derive(Debug, Clone, Deserialize)]
156pub struct Sam3Config {
157    pub vit: Sam3VitConfig,
158    pub text: Sam3TextConfig,
159    pub detector: Sam3DetectorConfig,
160    pub tracker: Sam3TrackerConfig,
161    pub enable_inst_interactivity: bool,
162    pub enable_video: bool,
163}
164
165impl Sam3Config {
166    pub fn base() -> Self {
167        Self {
168            vit: Sam3VitConfig::base(),
169            text: Sam3TextConfig::default(),
170            detector: Sam3DetectorConfig::default(),
171            tracker: Sam3TrackerConfig::default(),
172            enable_inst_interactivity: false,
173            enable_video: true,
174        }
175    }
176}