Skip to main content

rlx_sam2/
config.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM 2 model configuration. Mirrors Meta's `segment-anything-2` (a.k.a.
17//! `facebookresearch/sam2`) reference exactly, so the published
18//! `sam2_hiera_{t,s,b+,l}.pt` checkpoints can load without remapping.
19//!
20//! The image encoder is `Hiera` (Ryali et al. 2023) — a hierarchical,
21//! multi-scale ViT with mask-unit attention and Q-pooling — wrapped by
22//! an FPN-style neck that emits feature maps at strides 4/8/16/32. The
23//! prompt encoder + mask decoder are similar in spirit to SAM v1 but
24//! add an object-pointer token, an object-score head, and a high-res
25//! mask path that consumes the FPN's stride-4 / stride-8 features.
26//!
27//! Phase split (mirrors `crate::sam`):
28//!   - **Phase 1 (this commit)**: Hiera image encoder graph + FpnNeck.
29//!   - Phase 2: prompt encoder + mask decoder + IoU/object-score heads.
30//!   - Phase 3: memory attention + memory encoder for video tracking.
31
32use serde::Deserialize;
33
34/// SAM 2 normalises pixels with ImageNet stats *after* /255 scaling
35/// (note: this differs from SAM v1, which normalises raw 0..255
36/// pixel values directly). Matches `sam2/utils/transforms.py`.
37pub const SAM2_PIXEL_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
38pub const SAM2_PIXEL_STD: [f32; 3] = [0.229, 0.224, 0.225];
39
40/// Target image side after preprocessing. SAM 2 always operates at
41/// 1024×1024 internally (same as SAM v1).
42pub const SAM2_IMG_SIZE: usize = 1024;
43
44/// Hiera patch embedding parameters — Conv2d(in=3, out=embed_dim,
45/// k=7, s=4, p=3). The /4 stride gives a 256×256 grid at stage 0.
46pub const SAM2_PATCH_KERNEL: usize = 7;
47pub const SAM2_PATCH_STRIDE: usize = 4;
48pub const SAM2_PATCH_PADDING: usize = 3;
49
50/// Spatial resolution emitted by the patch embedding (stage 0 input).
51pub const SAM2_PATCH_GRID: usize = SAM2_IMG_SIZE / SAM2_PATCH_STRIDE; // 256
52
53/// Number of Q-pooling stages — fixed at 3 in the reference for every
54/// Hiera variant. After each Q-pool the spatial sequence is downsampled
55/// 2× along each spatial axis (4× area reduction) and the channel
56/// dimension + head count each double.
57pub const SAM2_Q_POOL_COUNT: usize = 3;
58pub const SAM2_Q_STRIDE: usize = 2;
59
60/// Channel count of the embeddings emitted by the FPN neck and consumed
61/// by the prompt encoder + mask decoder.
62pub const SAM2_PROMPT_EMBED_DIM: usize = 256;
63
64/// Hiera image-encoder configuration — Tiny, Small, Base+ or Large.
65///
66/// Field names mirror Hiera's Python kwargs so the values map 1:1 to
67/// the published checkpoints.
68#[derive(Debug, Clone, Deserialize)]
69pub struct Sam2HieraConfig {
70    /// Stage-0 embedding dimension. Doubles after each Q-pool.
71    pub embed_dim: usize,
72    /// Stage-0 head count. Doubles after each Q-pool.
73    pub num_heads: usize,
74    /// Number of blocks per stage. `stages.len() == 4` for every
75    /// published Hiera variant (Tiny/Small/Base+/Large).
76    pub stages: Vec<usize>,
77    /// Indices (in the *flattened* block enumeration across stages) of
78    /// blocks that use *global* attention rather than mask-unit
79    /// (windowed) attention. Always exactly 3 entries in published
80    /// configs.
81    pub global_att_blocks: Vec<usize>,
82    /// Background-pos-embed spatial size: `[Ph, Pw]` for the
83    /// learned-pos table that gets bilinear-interpolated to the
84    /// current grid each forward.
85    pub window_pos_embed_bkg_spatial_size: [usize; 2],
86    /// Mask-unit window size per stage (in units of post-Q-pool tokens
87    /// at that stage). Always 4 entries, one per stage.
88    pub window_spec: [usize; 4],
89    /// LayerNorm eps used throughout the encoder.
90    pub layer_norm_eps: f64,
91    /// MLP expansion ratio (FFN hidden = `mlp_ratio · embed_dim`).
92    pub mlp_ratio: f64,
93    /// QKV linear bias toggle (always true in published configs).
94    pub qkv_bias: bool,
95    /// Output channels per FPN level (256 for every published config).
96    pub fpn_out_chans: usize,
97}
98
99impl Sam2HieraConfig {
100    /// `sam2_hiera_tiny` — ~30 M params.
101    pub fn tiny() -> Self {
102        Self {
103            embed_dim: 96,
104            num_heads: 1,
105            stages: vec![1, 2, 7, 2],
106            global_att_blocks: vec![5, 7, 9],
107            window_pos_embed_bkg_spatial_size: [7, 7],
108            window_spec: [8, 4, 14, 7],
109            layer_norm_eps: 1e-6,
110            mlp_ratio: 4.0,
111            qkv_bias: true,
112            fpn_out_chans: SAM2_PROMPT_EMBED_DIM,
113        }
114    }
115    /// `sam2_hiera_small` — ~46 M params.
116    pub fn small() -> Self {
117        Self {
118            stages: vec![1, 2, 11, 2],
119            global_att_blocks: vec![7, 10, 13],
120            ..Self::tiny()
121        }
122    }
123    /// `sam2_hiera_base_plus` — ~80 M params, the recommended default.
124    pub fn base_plus() -> Self {
125        Self {
126            embed_dim: 112,
127            num_heads: 2,
128            stages: vec![2, 3, 16, 3],
129            global_att_blocks: vec![12, 16, 20],
130            window_pos_embed_bkg_spatial_size: [14, 14],
131            window_spec: [8, 4, 14, 7],
132            layer_norm_eps: 1e-6,
133            mlp_ratio: 4.0,
134            qkv_bias: true,
135            fpn_out_chans: SAM2_PROMPT_EMBED_DIM,
136        }
137    }
138    /// `sam2_hiera_large` — ~224 M params. The YAML overrides
139    /// `window_pos_embed_bkg_spatial_size` to `[7, 7]` (vs `[14, 14]`
140    /// for base+).
141    pub fn large() -> Self {
142        Self {
143            embed_dim: 144,
144            num_heads: 2,
145            stages: vec![2, 6, 36, 4],
146            global_att_blocks: vec![23, 33, 43],
147            window_pos_embed_bkg_spatial_size: [7, 7],
148            window_spec: [8, 4, 16, 8],
149            ..Self::base_plus()
150        }
151    }
152
153    /// Total number of transformer blocks across all stages.
154    pub fn total_blocks(&self) -> usize {
155        self.stages.iter().sum()
156    }
157
158    /// Indices (in the flattened block enumeration) where a Q-pool
159    /// happens — at the *first* block of every stage after stage 0.
160    ///
161    /// Reference: in `Hiera.__init__` the Q-pool boundaries are
162    /// `cumulative_sum(stages)[:-1]`, i.e. `[s0, s0+s1, s0+s1+s2]`.
163    pub fn q_pool_block_indices(&self) -> Vec<usize> {
164        let mut acc = 0usize;
165        let mut out = Vec::with_capacity(SAM2_Q_POOL_COUNT);
166        for &n in &self.stages[..self.stages.len() - 1] {
167            acc += n;
168            out.push(acc);
169        }
170        out
171    }
172
173    /// Stage index for the i-th flattened block.
174    pub fn stage_of_block(&self, block_idx: usize) -> usize {
175        let mut acc = 0usize;
176        for (si, &n) in self.stages.iter().enumerate() {
177            acc += n;
178            if block_idx < acc {
179                return si;
180            }
181        }
182        self.stages.len() - 1
183    }
184
185    /// Embedding dimension at stage `s` (doubles per Q-pool).
186    pub fn embed_dim_at_stage(&self, s: usize) -> usize {
187        self.embed_dim * (1 << s)
188    }
189    /// Number of heads at stage `s` (doubles per Q-pool).
190    pub fn num_heads_at_stage(&self, s: usize) -> usize {
191        self.num_heads * (1 << s)
192    }
193    /// Mask-unit window size at stage `s`.
194    pub fn window_size_at_stage(&self, s: usize) -> usize {
195        self.window_spec[s]
196    }
197    /// Per-axis spatial size of the token grid at stage `s` (before any
198    /// Q-pool *inside* the stage — i.e. the size at the stage's first
199    /// post-Q-pool block, for s>0, or stage-0 patch grid for s=0).
200    pub fn grid_size_at_stage(&self, s: usize) -> usize {
201        SAM2_PATCH_GRID / (1 << s)
202    }
203}
204
205/// FPN neck configuration. Mirrors `FpnNeck` in the reference.
206///
207/// SAM 2's neck takes the per-stage outputs from Hiera (finest →
208/// coarsest, i.e. stage 0..3) and runs a top-down pyramid:
209///   1. Each level gets a 1×1 lateral conv to `d_model=256`.
210///   2. Going coarse → fine, levels listed in `fpn_top_down_levels`
211///      *also* receive a nearest-neighbour ×2 upsample of the next-
212///      coarser level summed in.
213///
214/// The published `_b+` / `_l` configs use `fpn_top_down_levels=[2, 3]`,
215/// meaning only the two coarsest levels actually fuse with their
216/// neighbours; the two finest levels are emitted as plain laterals.
217///
218/// Note on indexing: `backbone_channel_list` and `fpn_top_down_levels`
219/// are stored **coarse-to-fine** (i.e. `[stage3_dim, …, stage0_dim]`)
220/// to match the reference YAML and let conv weight keys
221/// (`image_encoder.neck.convs.{n-i}…`) line up 1:1 with the
222/// checkpoint.
223#[derive(Debug, Clone)]
224pub struct Sam2FpnConfig {
225    pub d_model: usize,
226    /// Per-stage Hiera output channels, **coarse → fine** order:
227    /// `[stage3_dim, stage2_dim, stage1_dim, stage0_dim]`.
228    pub backbone_channel_list: Vec<usize>,
229    /// Backbone stage indices (in the same coarse-to-fine ordering as
230    /// `backbone_channel_list`) that participate in the top-down sum.
231    /// `[2, 3]` in every published config, i.e. only the two coarsest
232    /// levels (note: indices into the *reversed* level enumeration the
233    /// reference uses — kept as-is here for checkpoint compatibility).
234    pub fpn_top_down_levels: Vec<usize>,
235    /// Interpolation mode for the top-down upsample. Reference uses
236    /// `"nearest"`, which we lower as cheap host-side replicate.
237    pub interpolation_nearest: bool,
238}
239
240impl Sam2FpnConfig {
241    pub fn for_hiera(cfg: &Sam2HieraConfig) -> Self {
242        // Coarsest stage first → finest last, matching the reference's
243        // YAML ordering and the `image_encoder.neck.convs.{i}` key layout.
244        let channels: Vec<usize> = (0..cfg.stages.len())
245            .rev()
246            .map(|s| cfg.embed_dim_at_stage(s))
247            .collect();
248        // Preflight: B+ should give [896, 448, 224, 112].
249        debug_assert!(
250            channels.first().copied().unwrap_or(0) >= channels.last().copied().unwrap_or(0),
251            "backbone_channel_list must be coarse → fine"
252        );
253        Self {
254            d_model: cfg.fpn_out_chans,
255            backbone_channel_list: channels,
256            fpn_top_down_levels: vec![2, 3],
257            interpolation_nearest: true,
258        }
259    }
260}
261
262/// Mask decoder configuration. Field names + defaults mirror
263/// `sam2/modeling/sam/mask_decoder.py::MaskDecoder.__init__` and the
264/// published `sam2_hiera_*.yaml` `model.sam_mask_decoder_extra_args`.
265#[derive(Debug, Clone)]
266pub struct Sam2DecoderConfig {
267    pub transformer_dim: usize,
268    pub transformer_depth: usize,
269    pub transformer_num_heads: usize,
270    pub transformer_mlp_dim: usize,
271    /// 4 = 1 best-mask token + 3 multimask tokens (`num_multimask_outputs=3`
272    /// in the YAML; total tokens = `num_multimask_outputs + 1`).
273    pub num_mask_tokens: usize,
274    pub iou_head_depth: usize,
275    pub iou_head_hidden_dim: usize,
276    /// `iou_prediction_use_sigmoid` flag (true in the published YAML).
277    pub iou_prediction_use_sigmoid: bool,
278    /// SAM 2 emits an additional object-pointer token. Always true for
279    /// the published video configs.
280    pub use_object_pointer: bool,
281    /// If true, `obj_ptr_proj` is a 3-layer MLP; else a plain Linear.
282    /// True in every published config.
283    pub use_mlp_for_obj_ptr_proj: bool,
284    /// Predict an object-score logit (whether an object is present).
285    /// True in every published config.
286    pub pred_obj_scores: bool,
287    /// If true, `pred_obj_score_head` is a 3-layer MLP; else a Linear.
288    /// True in every published config.
289    pub pred_obj_scores_mlp: bool,
290    /// When multimask is selected, use the three multimask tokens
291    /// (rather than the best-mask token) for the object pointer.
292    /// True in every published config.
293    pub use_multimask_token_for_obj_ptr: bool,
294    /// Use the FpnNeck's stride-4 + stride-8 features to refine the
295    /// upscaled mask. True in every published config.
296    pub use_high_res_features: bool,
297    /// Fall back to the best multimask output when the single-mask
298    /// token's stability score is below threshold. True in every
299    /// published video config (`dynamic_multimask_via_stability=True`).
300    pub dynamic_multimask_via_stability: bool,
301    pub dynamic_multimask_stability_delta: f32,
302    pub dynamic_multimask_stability_thresh: f32,
303    pub layer_norm_eps: f64,
304}
305
306impl Default for Sam2DecoderConfig {
307    fn default() -> Self {
308        Self {
309            transformer_dim: SAM2_PROMPT_EMBED_DIM,
310            transformer_depth: 2,
311            transformer_num_heads: 8,
312            transformer_mlp_dim: 2048,
313            num_mask_tokens: 4,
314            iou_head_depth: 3,
315            iou_head_hidden_dim: SAM2_PROMPT_EMBED_DIM,
316            iou_prediction_use_sigmoid: true,
317            use_object_pointer: true,
318            use_mlp_for_obj_ptr_proj: true,
319            pred_obj_scores: true,
320            pred_obj_scores_mlp: true,
321            use_multimask_token_for_obj_ptr: true,
322            use_high_res_features: true,
323            dynamic_multimask_via_stability: true,
324            dynamic_multimask_stability_delta: 0.05,
325            dynamic_multimask_stability_thresh: 0.98,
326            layer_norm_eps: 1e-6,
327        }
328    }
329}
330
331/// Memory-encoder configuration. Mirrors
332/// `sam2/modeling/memory_encoder.py::MemoryEncoder` + its
333/// `MaskDownSampler` and `Fuser`. Defaults match every published
334/// `sam2_hiera_*.yaml` `memory_encoder:` block.
335#[derive(Debug, Clone)]
336pub struct Sam2MemoryEncoderConfig {
337    /// Input feature dim from the FpnNeck stride-16 level.
338    pub in_dim: usize,
339    /// Output memory token dim (smaller than `in_dim` in published
340    /// configs: 64 vs 256, so memory bank tokens are cheap to store).
341    pub out_dim: usize,
342    /// MaskDownSampler: per-step kernel/stride/padding + total stride.
343    /// total_stride must be a power of `stride`; reference uses
344    /// kernel=3 stride=2 padding=1 total=16 → 4 down-sampling levels.
345    pub mask_downsampler_kernel: usize,
346    pub mask_downsampler_stride: usize,
347    pub mask_downsampler_padding: usize,
348    pub mask_downsampler_total_stride: usize,
349    /// Fuser: `num_layers` × CXBlock(dim, kernel, padding, ls_init).
350    pub fuser_num_layers: usize,
351    pub fuser_dim: usize,
352    pub fuser_kernel: usize,
353    pub fuser_padding: usize,
354    pub fuser_layer_scale_init_value: f32,
355    pub fuser_use_dwconv: bool,
356    pub fuser_input_projection: bool,
357    /// Memory-encoder output PE: `num_pos_feats * 2` is the channel
358    /// count of the emitted position encoding.
359    pub pe_num_pos_feats: usize,
360    pub pe_temperature: f32,
361}
362
363impl Default for Sam2MemoryEncoderConfig {
364    fn default() -> Self {
365        Self {
366            in_dim: SAM2_PROMPT_EMBED_DIM,
367            out_dim: 64,
368            mask_downsampler_kernel: 3,
369            mask_downsampler_stride: 2,
370            mask_downsampler_padding: 1,
371            mask_downsampler_total_stride: 16,
372            fuser_num_layers: 2,
373            fuser_dim: SAM2_PROMPT_EMBED_DIM,
374            fuser_kernel: 7,
375            fuser_padding: 3,
376            fuser_layer_scale_init_value: 1e-6,
377            fuser_use_dwconv: true,
378            fuser_input_projection: false,
379            // num_pos_feats * 2 == out_dim so PE shape matches memory
380            // tokens for `memory + pos` addition in MemoryAttention.
381            pe_num_pos_feats: 32,
382            pe_temperature: 10000.0,
383        }
384    }
385}
386
387/// Memory-attention configuration (video path).
388///
389/// `d_model` (256) is the working dim of memory-attention layers;
390/// memory bank tokens come in at `kv_in_dim=out_dim` of the memory
391/// encoder (64 in published configs) and are projected up by the
392/// cross-attention's k/v linear layers.
393#[derive(Debug, Clone)]
394pub struct Sam2MemoryConfig {
395    pub d_model: usize,
396    pub num_layers: usize,
397    pub num_heads: usize,
398    pub dim_feedforward: usize,
399    pub layer_norm_eps: f64,
400    /// Memory bank kv channel dim — matches memory-encoder out_dim.
401    pub kv_in_dim: usize,
402    pub rope_theta: f32,
403    /// Spatial size used to build the RoPE table for the image
404    /// features (32×32 for the stride-32 path in published configs).
405    pub rope_feat_size: [usize; 2],
406    pub rope_k_repeat: bool,
407    /// Whether to add the input PE to the queries at layer input
408    /// (`pos_enc_at_input` in the reference YAML; true).
409    pub pos_enc_at_input: bool,
410    pub pos_enc_at_attn: bool,
411    pub pos_enc_at_cross_attn_keys: bool,
412    pub pos_enc_at_cross_attn_queries: bool,
413    /// Maximum number of object pointers preserved across frames.
414    pub max_obj_ptrs_in_encoder: usize,
415    /// Fuse each memory-attention layer into one graph with [`Op::AxialRope2d`]
416    /// (faster than the default five-graph + host-RoPE path). Requires axial RoPE
417    /// on the compile device (CPU today; Metal uses host fallback when enabled).
418    pub mem_attn_in_graph_rope: bool,
419    /// Number of *temporal-position* embeddings packed with each
420    /// object-pointer token. Reference YAML: 64.
421    pub mem_dim: usize,
422}
423
424impl Default for Sam2MemoryConfig {
425    fn default() -> Self {
426        Self {
427            d_model: SAM2_PROMPT_EMBED_DIM,
428            num_layers: 4,
429            num_heads: 1,
430            dim_feedforward: 2048,
431            layer_norm_eps: 1e-5,
432            kv_in_dim: 64,
433            rope_theta: 10000.0,
434            // Published YAMLs use feat_sizes: [64, 64] for memory
435            // attention RoPE — matches the stride-16 feature grid.
436            rope_feat_size: [64, 64],
437            rope_k_repeat: true,
438            pos_enc_at_input: true,
439            pos_enc_at_attn: false,
440            pos_enc_at_cross_attn_keys: true,
441            pos_enc_at_cross_attn_queries: false,
442            max_obj_ptrs_in_encoder: 16,
443            mem_dim: 64,
444            mem_attn_in_graph_rope: false,
445        }
446    }
447}
448
449/// Top-level SAM 2 configuration — Hiera + FPN + decoder + memory
450/// (encoder + attention) for the video path. Mirrors `SAM2Base` in
451/// the reference.
452#[derive(Debug, Clone)]
453pub struct Sam2Config {
454    pub hiera: Sam2HieraConfig,
455    pub fpn: Sam2FpnConfig,
456    pub decoder: Sam2DecoderConfig,
457    pub memory: Sam2MemoryConfig,
458    pub memory_encoder: Sam2MemoryEncoderConfig,
459}
460
461impl Sam2Config {
462    pub fn hiera_tiny() -> Self {
463        let hiera = Sam2HieraConfig::tiny();
464        let fpn = Sam2FpnConfig::for_hiera(&hiera);
465        Self {
466            hiera,
467            fpn,
468            decoder: Sam2DecoderConfig::default(),
469            memory: Sam2MemoryConfig::default(),
470            memory_encoder: Sam2MemoryEncoderConfig::default(),
471        }
472    }
473    pub fn hiera_small() -> Self {
474        let hiera = Sam2HieraConfig::small();
475        let fpn = Sam2FpnConfig::for_hiera(&hiera);
476        Self {
477            hiera,
478            fpn,
479            decoder: Sam2DecoderConfig::default(),
480            memory: Sam2MemoryConfig::default(),
481            memory_encoder: Sam2MemoryEncoderConfig::default(),
482        }
483    }
484    pub fn hiera_base_plus() -> Self {
485        let hiera = Sam2HieraConfig::base_plus();
486        let fpn = Sam2FpnConfig::for_hiera(&hiera);
487        Self {
488            hiera,
489            fpn,
490            decoder: Sam2DecoderConfig::default(),
491            memory: Sam2MemoryConfig::default(),
492            memory_encoder: Sam2MemoryEncoderConfig::default(),
493        }
494    }
495    pub fn hiera_large() -> Self {
496        let hiera = Sam2HieraConfig::large();
497        let fpn = Sam2FpnConfig::for_hiera(&hiera);
498        Self {
499            hiera,
500            fpn,
501            decoder: Sam2DecoderConfig::default(),
502            memory: Sam2MemoryConfig::default(),
503            memory_encoder: Sam2MemoryEncoderConfig::default(),
504        }
505    }
506}
507
508#[cfg(test)]
509mod tests {
510    use super::*;
511
512    #[test]
513    fn q_pool_indices_match_reference() {
514        // Reference (sam2/modeling/backbones/hieradet.py): q_pool
515        // happens at `cumulative_sum(stages)[:-1]`.
516        assert_eq!(
517            Sam2HieraConfig::tiny().q_pool_block_indices(),
518            vec![1, 3, 10]
519        );
520        assert_eq!(
521            Sam2HieraConfig::small().q_pool_block_indices(),
522            vec![1, 3, 14]
523        );
524        assert_eq!(
525            Sam2HieraConfig::base_plus().q_pool_block_indices(),
526            vec![2, 5, 21]
527        );
528        assert_eq!(
529            Sam2HieraConfig::large().q_pool_block_indices(),
530            vec![2, 8, 44]
531        );
532    }
533
534    #[test]
535    fn stage_dim_and_head_doubling() {
536        let cfg = Sam2HieraConfig::base_plus();
537        assert_eq!(cfg.embed_dim_at_stage(0), 112);
538        assert_eq!(cfg.embed_dim_at_stage(1), 224);
539        assert_eq!(cfg.embed_dim_at_stage(2), 448);
540        assert_eq!(cfg.embed_dim_at_stage(3), 896);
541        assert_eq!(cfg.num_heads_at_stage(3), 16);
542    }
543
544    #[test]
545    fn grid_halves_per_stage() {
546        let cfg = Sam2HieraConfig::base_plus();
547        assert_eq!(cfg.grid_size_at_stage(0), 256);
548        assert_eq!(cfg.grid_size_at_stage(1), 128);
549        assert_eq!(cfg.grid_size_at_stage(2), 64);
550        assert_eq!(cfg.grid_size_at_stage(3), 32);
551    }
552}