rlx_sam2/config.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM 2 model configuration. Mirrors Meta's `segment-anything-2` (a.k.a.
17//! `facebookresearch/sam2`) reference exactly, so the published
18//! `sam2_hiera_{t,s,b+,l}.pt` checkpoints can load without remapping.
19//!
20//! The image encoder is `Hiera` (Ryali et al. 2023) — a hierarchical,
21//! multi-scale ViT with mask-unit attention and Q-pooling — wrapped by
22//! an FPN-style neck that emits feature maps at strides 4/8/16/32. The
23//! prompt encoder + mask decoder are similar in spirit to SAM v1 but
24//! add an object-pointer token, an object-score head, and a high-res
25//! mask path that consumes the FPN's stride-4 / stride-8 features.
26//!
27//! Phase split (mirrors `crate::sam`):
28//! - **Phase 1 (this commit)**: Hiera image encoder graph + FpnNeck.
29//! - Phase 2: prompt encoder + mask decoder + IoU/object-score heads.
30//! - Phase 3: memory attention + memory encoder for video tracking.
31
32use serde::Deserialize;
33
34/// SAM 2 normalises pixels with ImageNet stats *after* /255 scaling
35/// (note: this differs from SAM v1, which normalises raw 0..255
36/// pixel values directly). Matches `sam2/utils/transforms.py`.
37pub const SAM2_PIXEL_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
38pub const SAM2_PIXEL_STD: [f32; 3] = [0.229, 0.224, 0.225];
39
40/// Target image side after preprocessing. SAM 2 always operates at
41/// 1024×1024 internally (same as SAM v1).
42pub const SAM2_IMG_SIZE: usize = 1024;
43
44/// Hiera patch embedding parameters — Conv2d(in=3, out=embed_dim,
45/// k=7, s=4, p=3). The /4 stride gives a 256×256 grid at stage 0.
46pub const SAM2_PATCH_KERNEL: usize = 7;
47pub const SAM2_PATCH_STRIDE: usize = 4;
48pub const SAM2_PATCH_PADDING: usize = 3;
49
50/// Spatial resolution emitted by the patch embedding (stage 0 input).
51pub const SAM2_PATCH_GRID: usize = SAM2_IMG_SIZE / SAM2_PATCH_STRIDE; // 256
52
53/// Number of Q-pooling stages — fixed at 3 in the reference for every
54/// Hiera variant. After each Q-pool the spatial sequence is downsampled
55/// 2× along each spatial axis (4× area reduction) and the channel
56/// dimension + head count each double.
57pub const SAM2_Q_POOL_COUNT: usize = 3;
58pub const SAM2_Q_STRIDE: usize = 2;
59
60/// Channel count of the embeddings emitted by the FPN neck and consumed
61/// by the prompt encoder + mask decoder.
62pub const SAM2_PROMPT_EMBED_DIM: usize = 256;
63
64/// Hiera image-encoder configuration — Tiny, Small, Base+ or Large.
65///
66/// Field names mirror Hiera's Python kwargs so the values map 1:1 to
67/// the published checkpoints.
68#[derive(Debug, Clone, Deserialize)]
69pub struct Sam2HieraConfig {
70 /// Stage-0 embedding dimension. Doubles after each Q-pool.
71 pub embed_dim: usize,
72 /// Stage-0 head count. Doubles after each Q-pool.
73 pub num_heads: usize,
74 /// Number of blocks per stage. `stages.len() == 4` for every
75 /// published Hiera variant (Tiny/Small/Base+/Large).
76 pub stages: Vec<usize>,
77 /// Indices (in the *flattened* block enumeration across stages) of
78 /// blocks that use *global* attention rather than mask-unit
79 /// (windowed) attention. Always exactly 3 entries in published
80 /// configs.
81 pub global_att_blocks: Vec<usize>,
82 /// Background-pos-embed spatial size: `[Ph, Pw]` for the
83 /// learned-pos table that gets bilinear-interpolated to the
84 /// current grid each forward.
85 pub window_pos_embed_bkg_spatial_size: [usize; 2],
86 /// Mask-unit window size per stage (in units of post-Q-pool tokens
87 /// at that stage). Always 4 entries, one per stage.
88 pub window_spec: [usize; 4],
89 /// LayerNorm eps used throughout the encoder.
90 pub layer_norm_eps: f64,
91 /// MLP expansion ratio (FFN hidden = `mlp_ratio · embed_dim`).
92 pub mlp_ratio: f64,
93 /// QKV linear bias toggle (always true in published configs).
94 pub qkv_bias: bool,
95 /// Output channels per FPN level (256 for every published config).
96 pub fpn_out_chans: usize,
97}
98
99impl Sam2HieraConfig {
100 /// `sam2_hiera_tiny` — ~30 M params.
101 pub fn tiny() -> Self {
102 Self {
103 embed_dim: 96,
104 num_heads: 1,
105 stages: vec![1, 2, 7, 2],
106 global_att_blocks: vec![5, 7, 9],
107 window_pos_embed_bkg_spatial_size: [7, 7],
108 window_spec: [8, 4, 14, 7],
109 layer_norm_eps: 1e-6,
110 mlp_ratio: 4.0,
111 qkv_bias: true,
112 fpn_out_chans: SAM2_PROMPT_EMBED_DIM,
113 }
114 }
115 /// `sam2_hiera_small` — ~46 M params.
116 pub fn small() -> Self {
117 Self {
118 stages: vec![1, 2, 11, 2],
119 global_att_blocks: vec![7, 10, 13],
120 ..Self::tiny()
121 }
122 }
123 /// `sam2_hiera_base_plus` — ~80 M params, the recommended default.
124 pub fn base_plus() -> Self {
125 Self {
126 embed_dim: 112,
127 num_heads: 2,
128 stages: vec![2, 3, 16, 3],
129 global_att_blocks: vec![12, 16, 20],
130 window_pos_embed_bkg_spatial_size: [14, 14],
131 window_spec: [8, 4, 14, 7],
132 layer_norm_eps: 1e-6,
133 mlp_ratio: 4.0,
134 qkv_bias: true,
135 fpn_out_chans: SAM2_PROMPT_EMBED_DIM,
136 }
137 }
138 /// `sam2_hiera_large` — ~224 M params. The YAML overrides
139 /// `window_pos_embed_bkg_spatial_size` to `[7, 7]` (vs `[14, 14]`
140 /// for base+).
141 pub fn large() -> Self {
142 Self {
143 embed_dim: 144,
144 num_heads: 2,
145 stages: vec![2, 6, 36, 4],
146 global_att_blocks: vec![23, 33, 43],
147 window_pos_embed_bkg_spatial_size: [7, 7],
148 window_spec: [8, 4, 16, 8],
149 ..Self::base_plus()
150 }
151 }
152
153 /// Total number of transformer blocks across all stages.
154 pub fn total_blocks(&self) -> usize {
155 self.stages.iter().sum()
156 }
157
158 /// Indices (in the flattened block enumeration) where a Q-pool
159 /// happens — at the *first* block of every stage after stage 0.
160 ///
161 /// Reference: in `Hiera.__init__` the Q-pool boundaries are
162 /// `cumulative_sum(stages)[:-1]`, i.e. `[s0, s0+s1, s0+s1+s2]`.
163 pub fn q_pool_block_indices(&self) -> Vec<usize> {
164 let mut acc = 0usize;
165 let mut out = Vec::with_capacity(SAM2_Q_POOL_COUNT);
166 for &n in &self.stages[..self.stages.len() - 1] {
167 acc += n;
168 out.push(acc);
169 }
170 out
171 }
172
173 /// Stage index for the i-th flattened block.
174 pub fn stage_of_block(&self, block_idx: usize) -> usize {
175 let mut acc = 0usize;
176 for (si, &n) in self.stages.iter().enumerate() {
177 acc += n;
178 if block_idx < acc {
179 return si;
180 }
181 }
182 self.stages.len() - 1
183 }
184
185 /// Embedding dimension at stage `s` (doubles per Q-pool).
186 pub fn embed_dim_at_stage(&self, s: usize) -> usize {
187 self.embed_dim * (1 << s)
188 }
189 /// Number of heads at stage `s` (doubles per Q-pool).
190 pub fn num_heads_at_stage(&self, s: usize) -> usize {
191 self.num_heads * (1 << s)
192 }
193 /// Mask-unit window size at stage `s`.
194 pub fn window_size_at_stage(&self, s: usize) -> usize {
195 self.window_spec[s]
196 }
197 /// Per-axis spatial size of the token grid at stage `s` (before any
198 /// Q-pool *inside* the stage — i.e. the size at the stage's first
199 /// post-Q-pool block, for s>0, or stage-0 patch grid for s=0).
200 pub fn grid_size_at_stage(&self, s: usize) -> usize {
201 SAM2_PATCH_GRID / (1 << s)
202 }
203}
204
205/// FPN neck configuration. Mirrors `FpnNeck` in the reference.
206///
207/// SAM 2's neck takes the per-stage outputs from Hiera (finest →
208/// coarsest, i.e. stage 0..3) and runs a top-down pyramid:
209/// 1. Each level gets a 1×1 lateral conv to `d_model=256`.
210/// 2. Going coarse → fine, levels listed in `fpn_top_down_levels`
211/// *also* receive a nearest-neighbour ×2 upsample of the next-
212/// coarser level summed in.
213///
214/// The published `_b+` / `_l` configs use `fpn_top_down_levels=[2, 3]`,
215/// meaning only the two coarsest levels actually fuse with their
216/// neighbours; the two finest levels are emitted as plain laterals.
217///
218/// Note on indexing: `backbone_channel_list` and `fpn_top_down_levels`
219/// are stored **coarse-to-fine** (i.e. `[stage3_dim, …, stage0_dim]`)
220/// to match the reference YAML and let conv weight keys
221/// (`image_encoder.neck.convs.{n-i}…`) line up 1:1 with the
222/// checkpoint.
223#[derive(Debug, Clone)]
224pub struct Sam2FpnConfig {
225 pub d_model: usize,
226 /// Per-stage Hiera output channels, **coarse → fine** order:
227 /// `[stage3_dim, stage2_dim, stage1_dim, stage0_dim]`.
228 pub backbone_channel_list: Vec<usize>,
229 /// Backbone stage indices (in the same coarse-to-fine ordering as
230 /// `backbone_channel_list`) that participate in the top-down sum.
231 /// `[2, 3]` in every published config, i.e. only the two coarsest
232 /// levels (note: indices into the *reversed* level enumeration the
233 /// reference uses — kept as-is here for checkpoint compatibility).
234 pub fpn_top_down_levels: Vec<usize>,
235 /// Interpolation mode for the top-down upsample. Reference uses
236 /// `"nearest"`, which we lower as cheap host-side replicate.
237 pub interpolation_nearest: bool,
238}
239
240impl Sam2FpnConfig {
241 pub fn for_hiera(cfg: &Sam2HieraConfig) -> Self {
242 // Coarsest stage first → finest last, matching the reference's
243 // YAML ordering and the `image_encoder.neck.convs.{i}` key layout.
244 let channels: Vec<usize> = (0..cfg.stages.len())
245 .rev()
246 .map(|s| cfg.embed_dim_at_stage(s))
247 .collect();
248 // Preflight: B+ should give [896, 448, 224, 112].
249 debug_assert!(
250 channels.first().copied().unwrap_or(0) >= channels.last().copied().unwrap_or(0),
251 "backbone_channel_list must be coarse → fine"
252 );
253 Self {
254 d_model: cfg.fpn_out_chans,
255 backbone_channel_list: channels,
256 fpn_top_down_levels: vec![2, 3],
257 interpolation_nearest: true,
258 }
259 }
260}
261
262/// Mask decoder configuration. Field names + defaults mirror
263/// `sam2/modeling/sam/mask_decoder.py::MaskDecoder.__init__` and the
264/// published `sam2_hiera_*.yaml` `model.sam_mask_decoder_extra_args`.
265#[derive(Debug, Clone)]
266pub struct Sam2DecoderConfig {
267 pub transformer_dim: usize,
268 pub transformer_depth: usize,
269 pub transformer_num_heads: usize,
270 pub transformer_mlp_dim: usize,
271 /// 4 = 1 best-mask token + 3 multimask tokens (`num_multimask_outputs=3`
272 /// in the YAML; total tokens = `num_multimask_outputs + 1`).
273 pub num_mask_tokens: usize,
274 pub iou_head_depth: usize,
275 pub iou_head_hidden_dim: usize,
276 /// `iou_prediction_use_sigmoid` flag (true in the published YAML).
277 pub iou_prediction_use_sigmoid: bool,
278 /// SAM 2 emits an additional object-pointer token. Always true for
279 /// the published video configs.
280 pub use_object_pointer: bool,
281 /// If true, `obj_ptr_proj` is a 3-layer MLP; else a plain Linear.
282 /// True in every published config.
283 pub use_mlp_for_obj_ptr_proj: bool,
284 /// Predict an object-score logit (whether an object is present).
285 /// True in every published config.
286 pub pred_obj_scores: bool,
287 /// If true, `pred_obj_score_head` is a 3-layer MLP; else a Linear.
288 /// True in every published config.
289 pub pred_obj_scores_mlp: bool,
290 /// When multimask is selected, use the three multimask tokens
291 /// (rather than the best-mask token) for the object pointer.
292 /// True in every published config.
293 pub use_multimask_token_for_obj_ptr: bool,
294 /// Use the FpnNeck's stride-4 + stride-8 features to refine the
295 /// upscaled mask. True in every published config.
296 pub use_high_res_features: bool,
297 /// Fall back to the best multimask output when the single-mask
298 /// token's stability score is below threshold. True in every
299 /// published video config (`dynamic_multimask_via_stability=True`).
300 pub dynamic_multimask_via_stability: bool,
301 pub dynamic_multimask_stability_delta: f32,
302 pub dynamic_multimask_stability_thresh: f32,
303 pub layer_norm_eps: f64,
304}
305
306impl Default for Sam2DecoderConfig {
307 fn default() -> Self {
308 Self {
309 transformer_dim: SAM2_PROMPT_EMBED_DIM,
310 transformer_depth: 2,
311 transformer_num_heads: 8,
312 transformer_mlp_dim: 2048,
313 num_mask_tokens: 4,
314 iou_head_depth: 3,
315 iou_head_hidden_dim: SAM2_PROMPT_EMBED_DIM,
316 iou_prediction_use_sigmoid: true,
317 use_object_pointer: true,
318 use_mlp_for_obj_ptr_proj: true,
319 pred_obj_scores: true,
320 pred_obj_scores_mlp: true,
321 use_multimask_token_for_obj_ptr: true,
322 use_high_res_features: true,
323 dynamic_multimask_via_stability: true,
324 dynamic_multimask_stability_delta: 0.05,
325 dynamic_multimask_stability_thresh: 0.98,
326 layer_norm_eps: 1e-6,
327 }
328 }
329}
330
331/// Memory-encoder configuration. Mirrors
332/// `sam2/modeling/memory_encoder.py::MemoryEncoder` + its
333/// `MaskDownSampler` and `Fuser`. Defaults match every published
334/// `sam2_hiera_*.yaml` `memory_encoder:` block.
335#[derive(Debug, Clone)]
336pub struct Sam2MemoryEncoderConfig {
337 /// Input feature dim from the FpnNeck stride-16 level.
338 pub in_dim: usize,
339 /// Output memory token dim (smaller than `in_dim` in published
340 /// configs: 64 vs 256, so memory bank tokens are cheap to store).
341 pub out_dim: usize,
342 /// MaskDownSampler: per-step kernel/stride/padding + total stride.
343 /// total_stride must be a power of `stride`; reference uses
344 /// kernel=3 stride=2 padding=1 total=16 → 4 down-sampling levels.
345 pub mask_downsampler_kernel: usize,
346 pub mask_downsampler_stride: usize,
347 pub mask_downsampler_padding: usize,
348 pub mask_downsampler_total_stride: usize,
349 /// Fuser: `num_layers` × CXBlock(dim, kernel, padding, ls_init).
350 pub fuser_num_layers: usize,
351 pub fuser_dim: usize,
352 pub fuser_kernel: usize,
353 pub fuser_padding: usize,
354 pub fuser_layer_scale_init_value: f32,
355 pub fuser_use_dwconv: bool,
356 pub fuser_input_projection: bool,
357 /// Memory-encoder output PE: `num_pos_feats * 2` is the channel
358 /// count of the emitted position encoding.
359 pub pe_num_pos_feats: usize,
360 pub pe_temperature: f32,
361}
362
363impl Default for Sam2MemoryEncoderConfig {
364 fn default() -> Self {
365 Self {
366 in_dim: SAM2_PROMPT_EMBED_DIM,
367 out_dim: 64,
368 mask_downsampler_kernel: 3,
369 mask_downsampler_stride: 2,
370 mask_downsampler_padding: 1,
371 mask_downsampler_total_stride: 16,
372 fuser_num_layers: 2,
373 fuser_dim: SAM2_PROMPT_EMBED_DIM,
374 fuser_kernel: 7,
375 fuser_padding: 3,
376 fuser_layer_scale_init_value: 1e-6,
377 fuser_use_dwconv: true,
378 fuser_input_projection: false,
379 // num_pos_feats * 2 == out_dim so PE shape matches memory
380 // tokens for `memory + pos` addition in MemoryAttention.
381 pe_num_pos_feats: 32,
382 pe_temperature: 10000.0,
383 }
384 }
385}
386
387/// Memory-attention configuration (video path).
388///
389/// `d_model` (256) is the working dim of memory-attention layers;
390/// memory bank tokens come in at `kv_in_dim=out_dim` of the memory
391/// encoder (64 in published configs) and are projected up by the
392/// cross-attention's k/v linear layers.
393#[derive(Debug, Clone)]
394pub struct Sam2MemoryConfig {
395 pub d_model: usize,
396 pub num_layers: usize,
397 pub num_heads: usize,
398 pub dim_feedforward: usize,
399 pub layer_norm_eps: f64,
400 /// Memory bank kv channel dim — matches memory-encoder out_dim.
401 pub kv_in_dim: usize,
402 pub rope_theta: f32,
403 /// Spatial size used to build the RoPE table for the image
404 /// features (32×32 for the stride-32 path in published configs).
405 pub rope_feat_size: [usize; 2],
406 pub rope_k_repeat: bool,
407 /// Whether to add the input PE to the queries at layer input
408 /// (`pos_enc_at_input` in the reference YAML; true).
409 pub pos_enc_at_input: bool,
410 pub pos_enc_at_attn: bool,
411 pub pos_enc_at_cross_attn_keys: bool,
412 pub pos_enc_at_cross_attn_queries: bool,
413 /// Maximum number of object pointers preserved across frames.
414 pub max_obj_ptrs_in_encoder: usize,
415 /// Fuse each memory-attention layer into one graph with [`Op::AxialRope2d`]
416 /// (faster than the default five-graph + host-RoPE path). Requires axial RoPE
417 /// on the compile device (CPU today; Metal uses host fallback when enabled).
418 pub mem_attn_in_graph_rope: bool,
419 /// Number of *temporal-position* embeddings packed with each
420 /// object-pointer token. Reference YAML: 64.
421 pub mem_dim: usize,
422}
423
424impl Default for Sam2MemoryConfig {
425 fn default() -> Self {
426 Self {
427 d_model: SAM2_PROMPT_EMBED_DIM,
428 num_layers: 4,
429 num_heads: 1,
430 dim_feedforward: 2048,
431 layer_norm_eps: 1e-5,
432 kv_in_dim: 64,
433 rope_theta: 10000.0,
434 // Published YAMLs use feat_sizes: [64, 64] for memory
435 // attention RoPE — matches the stride-16 feature grid.
436 rope_feat_size: [64, 64],
437 rope_k_repeat: true,
438 pos_enc_at_input: true,
439 pos_enc_at_attn: false,
440 pos_enc_at_cross_attn_keys: true,
441 pos_enc_at_cross_attn_queries: false,
442 max_obj_ptrs_in_encoder: 16,
443 mem_dim: 64,
444 mem_attn_in_graph_rope: false,
445 }
446 }
447}
448
449/// Top-level SAM 2 configuration — Hiera + FPN + decoder + memory
450/// (encoder + attention) for the video path. Mirrors `SAM2Base` in
451/// the reference.
452#[derive(Debug, Clone)]
453pub struct Sam2Config {
454 pub hiera: Sam2HieraConfig,
455 pub fpn: Sam2FpnConfig,
456 pub decoder: Sam2DecoderConfig,
457 pub memory: Sam2MemoryConfig,
458 pub memory_encoder: Sam2MemoryEncoderConfig,
459}
460
461impl Sam2Config {
462 pub fn hiera_tiny() -> Self {
463 let hiera = Sam2HieraConfig::tiny();
464 let fpn = Sam2FpnConfig::for_hiera(&hiera);
465 Self {
466 hiera,
467 fpn,
468 decoder: Sam2DecoderConfig::default(),
469 memory: Sam2MemoryConfig::default(),
470 memory_encoder: Sam2MemoryEncoderConfig::default(),
471 }
472 }
473 pub fn hiera_small() -> Self {
474 let hiera = Sam2HieraConfig::small();
475 let fpn = Sam2FpnConfig::for_hiera(&hiera);
476 Self {
477 hiera,
478 fpn,
479 decoder: Sam2DecoderConfig::default(),
480 memory: Sam2MemoryConfig::default(),
481 memory_encoder: Sam2MemoryEncoderConfig::default(),
482 }
483 }
484 pub fn hiera_base_plus() -> Self {
485 let hiera = Sam2HieraConfig::base_plus();
486 let fpn = Sam2FpnConfig::for_hiera(&hiera);
487 Self {
488 hiera,
489 fpn,
490 decoder: Sam2DecoderConfig::default(),
491 memory: Sam2MemoryConfig::default(),
492 memory_encoder: Sam2MemoryEncoderConfig::default(),
493 }
494 }
495 pub fn hiera_large() -> Self {
496 let hiera = Sam2HieraConfig::large();
497 let fpn = Sam2FpnConfig::for_hiera(&hiera);
498 Self {
499 hiera,
500 fpn,
501 decoder: Sam2DecoderConfig::default(),
502 memory: Sam2MemoryConfig::default(),
503 memory_encoder: Sam2MemoryEncoderConfig::default(),
504 }
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 use super::*;
511
512 #[test]
513 fn q_pool_indices_match_reference() {
514 // Reference (sam2/modeling/backbones/hieradet.py): q_pool
515 // happens at `cumulative_sum(stages)[:-1]`.
516 assert_eq!(
517 Sam2HieraConfig::tiny().q_pool_block_indices(),
518 vec![1, 3, 10]
519 );
520 assert_eq!(
521 Sam2HieraConfig::small().q_pool_block_indices(),
522 vec![1, 3, 14]
523 );
524 assert_eq!(
525 Sam2HieraConfig::base_plus().q_pool_block_indices(),
526 vec![2, 5, 21]
527 );
528 assert_eq!(
529 Sam2HieraConfig::large().q_pool_block_indices(),
530 vec![2, 8, 44]
531 );
532 }
533
534 #[test]
535 fn stage_dim_and_head_doubling() {
536 let cfg = Sam2HieraConfig::base_plus();
537 assert_eq!(cfg.embed_dim_at_stage(0), 112);
538 assert_eq!(cfg.embed_dim_at_stage(1), 224);
539 assert_eq!(cfg.embed_dim_at_stage(2), 448);
540 assert_eq!(cfg.embed_dim_at_stage(3), 896);
541 assert_eq!(cfg.num_heads_at_stage(3), 16);
542 }
543
544 #[test]
545 fn grid_halves_per_stage() {
546 let cfg = Sam2HieraConfig::base_plus();
547 assert_eq!(cfg.grid_size_at_stage(0), 256);
548 assert_eq!(cfg.grid_size_at_stage(1), 128);
549 assert_eq!(cfg.grid_size_at_stage(2), 64);
550 assert_eq!(cfg.grid_size_at_stage(3), 32);
551 }
552}