rlx_sam/config.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM v1 model configuration. Mirrors Meta's `segment-anything` Python
17//! reference and candle's `segment_anything` module.
18//!
19//! Three ViT image-encoder variants (B/L/H) and one MobileSAM TinyViT
20//! variant. Decoder + prompt-encoder hyperparameters are fixed across
21//! all variants.
22
23use serde::Deserialize;
24
25/// ImageNet mean/std applied to raw 0..255 pixel values *before* the
26/// /255 scaling — SAM uses unnormalized pixel values directly, unlike
27/// most ViTs. Match `sam.rs::preprocess()` in candle exactly.
28pub const SAM_PIXEL_MEAN: [f32; 3] = [123.675, 116.28, 103.53];
29pub const SAM_PIXEL_STD: [f32; 3] = [58.395, 57.12, 57.375];
30
31/// Target image side after preprocessing. SAM always operates at
32/// 1024×1024 internally; smaller inputs are resized + zero-padded.
33pub const SAM_IMG_SIZE: usize = 1024;
34pub const SAM_PATCH_SIZE: usize = 16;
35/// Spatial resolution of image embeddings produced by the encoder.
36pub const SAM_EMBED_HW: usize = SAM_IMG_SIZE / SAM_PATCH_SIZE; // 64
37
38/// Channel count of the embeddings emitted by the encoder neck and
39/// consumed by the prompt encoder + mask decoder.
40pub const SAM_PROMPT_EMBED_DIM: usize = 256;
41
42/// Encoder configuration — ViT-B/L/H or TinyViT variants.
43#[derive(Debug, Clone, Deserialize)]
44pub struct SamEncoderConfig {
45 pub encoder_kind: EncoderKind,
46 pub embed_dim: usize,
47 pub depth: usize,
48 pub num_heads: usize,
49 /// Per-block flag: blocks listed here use global attention
50 /// (no windowing); all others use windowed attention with
51 /// `window_size`.
52 pub global_attn_indexes: Vec<usize>,
53 pub window_size: usize,
54 pub use_rel_pos: bool,
55 pub use_abs_pos: bool,
56 pub qkv_bias: bool,
57 /// LayerNorm eps used throughout the encoder.
58 pub layer_norm_eps: f64,
59 /// Channel count of the final image embeddings (after the neck).
60 pub out_chans: usize,
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
64pub enum EncoderKind {
65 ViT,
66 TinyViT,
67}
68
69impl SamEncoderConfig {
70 /// SAM ViT-B (default, ~91 M params).
71 pub fn vit_b() -> Self {
72 Self {
73 encoder_kind: EncoderKind::ViT,
74 embed_dim: 768,
75 depth: 12,
76 num_heads: 12,
77 global_attn_indexes: vec![2, 5, 8, 11],
78 window_size: 14,
79 use_rel_pos: true,
80 use_abs_pos: true,
81 qkv_bias: true,
82 layer_norm_eps: 1e-6,
83 out_chans: SAM_PROMPT_EMBED_DIM,
84 }
85 }
86 /// SAM ViT-L (~308 M params).
87 pub fn vit_l() -> Self {
88 Self {
89 embed_dim: 1024,
90 depth: 24,
91 num_heads: 16,
92 global_attn_indexes: vec![5, 11, 17, 23],
93 ..Self::vit_b()
94 }
95 }
96 /// SAM ViT-H (~632 M params).
97 pub fn vit_h() -> Self {
98 Self {
99 embed_dim: 1280,
100 depth: 32,
101 num_heads: 16,
102 global_attn_indexes: vec![7, 15, 23, 31],
103 ..Self::vit_b()
104 }
105 }
106
107 pub fn head_dim(&self) -> usize {
108 self.embed_dim / self.num_heads
109 }
110 pub fn num_patches_per_side(&self) -> usize {
111 SAM_EMBED_HW
112 }
113}
114
115/// Mask decoder configuration. Same across SAM variants.
116#[derive(Debug, Clone)]
117pub struct SamDecoderConfig {
118 pub transformer_dim: usize,
119 pub transformer_depth: usize,
120 pub transformer_num_heads: usize,
121 pub transformer_mlp_dim: usize,
122 /// 4 = 1 IoU token + 3 mask tokens; downstream code picks one or
123 /// all three depending on `multimask_output`.
124 pub num_mask_tokens: usize,
125 pub iou_head_depth: usize,
126 pub iou_head_hidden_dim: usize,
127 pub layer_norm_eps: f64,
128}
129
130impl Default for SamDecoderConfig {
131 fn default() -> Self {
132 Self {
133 transformer_dim: SAM_PROMPT_EMBED_DIM,
134 transformer_depth: 2,
135 transformer_num_heads: 8,
136 transformer_mlp_dim: 2048,
137 num_mask_tokens: 4,
138 iou_head_depth: 3,
139 iou_head_hidden_dim: SAM_PROMPT_EMBED_DIM,
140 layer_norm_eps: 1e-6,
141 }
142 }
143}
144
145/// Top-level SAM configuration — encoder + decoder + a few constants
146/// shared between them.
147#[derive(Debug, Clone)]
148pub struct SamConfig {
149 pub encoder: SamEncoderConfig,
150 pub decoder: SamDecoderConfig,
151}
152
153impl SamConfig {
154 pub fn vit_b() -> Self {
155 Self {
156 encoder: SamEncoderConfig::vit_b(),
157 decoder: SamDecoderConfig::default(),
158 }
159 }
160 pub fn vit_l() -> Self {
161 Self {
162 encoder: SamEncoderConfig::vit_l(),
163 decoder: SamDecoderConfig::default(),
164 }
165 }
166 pub fn vit_h() -> Self {
167 Self {
168 encoder: SamEncoderConfig::vit_h(),
169 decoder: SamDecoderConfig::default(),
170 }
171 }
172}