Skip to main content

rlx_sam/
config.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM v1 model configuration. Mirrors Meta's `segment-anything` Python
17//! reference and candle's `segment_anything` module.
18//!
19//! Three ViT image-encoder variants (B/L/H) and one MobileSAM TinyViT
20//! variant. Decoder + prompt-encoder hyperparameters are fixed across
21//! all variants.
22
23use serde::Deserialize;
24
25/// ImageNet mean/std applied to raw 0..255 pixel values *before* the
26/// /255 scaling — SAM uses unnormalized pixel values directly, unlike
27/// most ViTs. Match `sam.rs::preprocess()` in candle exactly.
28pub const SAM_PIXEL_MEAN: [f32; 3] = [123.675, 116.28, 103.53];
29pub const SAM_PIXEL_STD: [f32; 3] = [58.395, 57.12, 57.375];
30
31/// Target image side after preprocessing. SAM always operates at
32/// 1024×1024 internally; smaller inputs are resized + zero-padded.
33pub const SAM_IMG_SIZE: usize = 1024;
34pub const SAM_PATCH_SIZE: usize = 16;
35/// Spatial resolution of image embeddings produced by the encoder.
36pub const SAM_EMBED_HW: usize = SAM_IMG_SIZE / SAM_PATCH_SIZE; // 64
37
38/// Channel count of the embeddings emitted by the encoder neck and
39/// consumed by the prompt encoder + mask decoder.
40pub const SAM_PROMPT_EMBED_DIM: usize = 256;
41
42/// Encoder configuration — ViT-B/L/H or TinyViT variants.
43#[derive(Debug, Clone, Deserialize)]
44pub struct SamEncoderConfig {
45    pub encoder_kind: EncoderKind,
46    pub embed_dim: usize,
47    pub depth: usize,
48    pub num_heads: usize,
49    /// Per-block flag: blocks listed here use global attention
50    /// (no windowing); all others use windowed attention with
51    /// `window_size`.
52    pub global_attn_indexes: Vec<usize>,
53    pub window_size: usize,
54    pub use_rel_pos: bool,
55    pub use_abs_pos: bool,
56    pub qkv_bias: bool,
57    /// LayerNorm eps used throughout the encoder.
58    pub layer_norm_eps: f64,
59    /// Channel count of the final image embeddings (after the neck).
60    pub out_chans: usize,
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
64pub enum EncoderKind {
65    ViT,
66    TinyViT,
67}
68
69impl SamEncoderConfig {
70    /// SAM ViT-B (default, ~91 M params).
71    pub fn vit_b() -> Self {
72        Self {
73            encoder_kind: EncoderKind::ViT,
74            embed_dim: 768,
75            depth: 12,
76            num_heads: 12,
77            global_attn_indexes: vec![2, 5, 8, 11],
78            window_size: 14,
79            use_rel_pos: true,
80            use_abs_pos: true,
81            qkv_bias: true,
82            layer_norm_eps: 1e-6,
83            out_chans: SAM_PROMPT_EMBED_DIM,
84        }
85    }
86    /// SAM ViT-L (~308 M params).
87    pub fn vit_l() -> Self {
88        Self {
89            embed_dim: 1024,
90            depth: 24,
91            num_heads: 16,
92            global_attn_indexes: vec![5, 11, 17, 23],
93            ..Self::vit_b()
94        }
95    }
96    /// SAM ViT-H (~632 M params).
97    pub fn vit_h() -> Self {
98        Self {
99            embed_dim: 1280,
100            depth: 32,
101            num_heads: 16,
102            global_attn_indexes: vec![7, 15, 23, 31],
103            ..Self::vit_b()
104        }
105    }
106
107    pub fn head_dim(&self) -> usize {
108        self.embed_dim / self.num_heads
109    }
110    pub fn num_patches_per_side(&self) -> usize {
111        SAM_EMBED_HW
112    }
113}
114
115/// Mask decoder configuration. Same across SAM variants.
116#[derive(Debug, Clone)]
117pub struct SamDecoderConfig {
118    pub transformer_dim: usize,
119    pub transformer_depth: usize,
120    pub transformer_num_heads: usize,
121    pub transformer_mlp_dim: usize,
122    /// 4 = 1 IoU token + 3 mask tokens; downstream code picks one or
123    /// all three depending on `multimask_output`.
124    pub num_mask_tokens: usize,
125    pub iou_head_depth: usize,
126    pub iou_head_hidden_dim: usize,
127    pub layer_norm_eps: f64,
128}
129
130impl Default for SamDecoderConfig {
131    fn default() -> Self {
132        Self {
133            transformer_dim: SAM_PROMPT_EMBED_DIM,
134            transformer_depth: 2,
135            transformer_num_heads: 8,
136            transformer_mlp_dim: 2048,
137            num_mask_tokens: 4,
138            iou_head_depth: 3,
139            iou_head_hidden_dim: SAM_PROMPT_EMBED_DIM,
140            layer_norm_eps: 1e-6,
141        }
142    }
143}
144
145/// Top-level SAM configuration — encoder + decoder + a few constants
146/// shared between them.
147#[derive(Debug, Clone)]
148pub struct SamConfig {
149    pub encoder: SamEncoderConfig,
150    pub decoder: SamDecoderConfig,
151}
152
153impl SamConfig {
154    pub fn vit_b() -> Self {
155        Self {
156            encoder: SamEncoderConfig::vit_b(),
157            decoder: SamDecoderConfig::default(),
158        }
159    }
160    pub fn vit_l() -> Self {
161        Self {
162            encoder: SamEncoderConfig::vit_l(),
163            decoder: SamDecoderConfig::default(),
164        }
165    }
166    pub fn vit_h() -> Self {
167        Self {
168            encoder: SamEncoderConfig::vit_h(),
169            decoder: SamDecoderConfig::default(),
170        }
171    }
172}