Skip to main content

rlx_sam/
lib.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM v1 — Meta's Segment Anything image-segmentation model.
17//!
18//! ## Phasing
19//!
20//! Phase 1 (this commit) lands the **image encoder** end-to-end:
21//!   - Host-side preprocessing (resize-to-1024, ImageNet pixel
22//!     normalization, zero-pad to 1024×1024, patch embedding via
23//!     Conv2d-as-matmul).
24//!   - IR graph for the 12 encoder blocks with **windowed + global**
25//!     attention, **decomposed relative position embeddings**, plain
26//!     GELU-tanh MLPs, pre-norm residual structure.
27//!   - IR neck (Conv2d 1×1 → LN2d → Conv2d 3×3 → LN2d → `[256, 64, 64]`).
28//!
29//! **Phase 1 status:** 100% numerical parity with candle's
30//! `ImageEncoderViT::forward()` on real `sam_vit_b_01ec64.safetensors`
31//! weights — `max |Δ| = 7.15e-6` on the 1×256×64×64 image embeddings
32//! (full 12-layer ViT-B at 1024×1024 input). Phase-1 bisect env vars
33//! remain in `tests/sam_parity.rs` for future debugging:
34//!   - `RLX_SAM_DEBUG_DEPTH=N` — run only the first N encoder blocks
35//!   - `RLX_SAM_DEBUG_NO_RELPOS=1` — disable decomposed relative pos
36//!   - `RLX_SAM_DEBUG_FORCE_GLOBAL=1` — force every block to use global attn
37//!   - `RLX_SAM_DEBUG_ZERO_RELH=1` / `RLX_SAM_DEBUG_ZERO_RELW=1` — zero
38//!     a single rel_pos axis (data only — the matmul + add still execute)
39//!
40//! Phase 2 (next commit) lands the **prompt encoder** + **mask decoder**:
41//!   - Random Fourier positional encoding, point/box/mask embeddings.
42//!   - Two-way transformer between prompt tokens and image embeddings.
43//!   - ConvTranspose2d upscaling (IR) + hypernetwork
44//!     MLPs for mask + IoU output.
45//!
46//! Weight key convention matches Meta / candle exactly so the
47//! `lmz/candle-sam` safetensors checkpoints load with no remapping.
48
49pub mod cli;
50pub mod config;
51pub mod flow;
52pub mod image_encoder;
53pub mod mask_decoder;
54pub mod mlp_ir;
55pub mod preprocess;
56pub mod profile;
57pub mod prompt_encoder;
58pub mod prompt_mask_ir;
59#[allow(clippy::module_inception)]
60pub mod sam;
61pub mod transformer;
62pub mod transformer_ir;
63pub mod upscale_ir;
64
65pub use config::{
66    EncoderKind, SAM_EMBED_HW, SAM_IMG_SIZE, SAM_PATCH_SIZE, SAM_PIXEL_MEAN, SAM_PIXEL_STD,
67    SAM_PROMPT_EMBED_DIM, SamConfig, SamDecoderConfig, SamEncoderConfig,
68};
69pub use flow::{SamEncoderBuilt, SamEncoderFlow, build_sam_encoder_built};
70pub use image_encoder::{
71    NeckWeights, apply_neck_host, build_sam_encoder_graph, build_sam_encoder_hir,
72};
73pub use mask_decoder::{MaskDecoderWeights, mask_decoder_forward};
74pub use preprocess::{SamPreprocessWeights, assemble_patch_tokens, preprocess_image};
75pub use profile::{
76    SAM_PROFILE_FILE, sam_profile_default, sam_profile_near_weights, sam2_profile_default,
77    sam2_profile_near_weights, sam3_profile_default, sam3_profile_near_weights,
78};
79pub use prompt_encoder::{PromptEncoderOutput, PromptEncoderWeights, prompt_encoder_forward};
80pub use sam::{MaskPrediction, SAM_MASK_IN_CHANS, Sam, sam_vit_b_config};
81
82/// Re-export `Device` so callers can construct it without depending
83/// on `rlx-runtime` themselves.
84pub use rlx_runtime::Device;
85pub use transformer::{TwoWayTransformerWeights, attention_forward, two_way_transformer_forward};
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use rlx_core::weight_map::WeightMap;
91    use std::collections::HashMap;
92
93    /// Build a synthetic ViT-B WeightMap so we can verify the encoder
94    /// graph builds without panicking. Real numerical parity needs the
95    /// safetensors checkpoint — see `tests/sam_parity.rs`.
96    fn synthetic_vit_b_weights() -> WeightMap {
97        let cfg = SamEncoderConfig::vit_b();
98        let e = cfg.embed_dim;
99        let dh = cfg.head_dim();
100        let int_dim = e * 4;
101        let hw = SAM_EMBED_HW;
102        let ws = cfg.window_size;
103        let ps = SAM_PATCH_SIZE;
104        let pd = 3 * ps * ps;
105
106        let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
107        let z = |n: usize| vec![0.0f32; n];
108
109        t.insert(
110            "image_encoder.patch_embed.proj.weight".into(),
111            (z(e * pd), vec![e, 3, ps, ps]),
112        );
113        t.insert(
114            "image_encoder.patch_embed.proj.bias".into(),
115            (z(e), vec![e]),
116        );
117        t.insert(
118            "image_encoder.pos_embed".into(),
119            (z(hw * hw * e), vec![1, hw, hw, e]),
120        );
121
122        for i in 0..cfg.depth {
123            let lp = format!("image_encoder.blocks.{i}");
124            let is_global = cfg.global_attn_indexes.contains(&i);
125            let rel_size = if is_global { hw } else { ws };
126
127            t.insert(format!("{lp}.norm1.weight"), (z(e), vec![e]));
128            t.insert(format!("{lp}.norm1.bias"), (z(e), vec![e]));
129            t.insert(
130                format!("{lp}.attn.qkv.weight"),
131                (z(3 * e * e), vec![3 * e, e]),
132            );
133            t.insert(format!("{lp}.attn.qkv.bias"), (z(3 * e), vec![3 * e]));
134            t.insert(format!("{lp}.attn.proj.weight"), (z(e * e), vec![e, e]));
135            t.insert(format!("{lp}.attn.proj.bias"), (z(e), vec![e]));
136            t.insert(
137                format!("{lp}.attn.rel_pos_h"),
138                (z((2 * rel_size - 1) * dh), vec![2 * rel_size - 1, dh]),
139            );
140            t.insert(
141                format!("{lp}.attn.rel_pos_w"),
142                (z((2 * rel_size - 1) * dh), vec![2 * rel_size - 1, dh]),
143            );
144            t.insert(format!("{lp}.norm2.weight"), (z(e), vec![e]));
145            t.insert(format!("{lp}.norm2.bias"), (z(e), vec![e]));
146            t.insert(
147                format!("{lp}.mlp.lin1.weight"),
148                (z(int_dim * e), vec![int_dim, e]),
149            );
150            t.insert(format!("{lp}.mlp.lin1.bias"), (z(int_dim), vec![int_dim]));
151            t.insert(
152                format!("{lp}.mlp.lin2.weight"),
153                (z(e * int_dim), vec![e, int_dim]),
154            );
155            t.insert(format!("{lp}.mlp.lin2.bias"), (z(e), vec![e]));
156        }
157        // Neck
158        t.insert(
159            "image_encoder.neck.0.weight".into(),
160            (z(cfg.out_chans * e), vec![cfg.out_chans, e, 1, 1]),
161        );
162        t.insert(
163            "image_encoder.neck.0.bias".into(),
164            (z(cfg.out_chans), vec![cfg.out_chans]),
165        );
166        t.insert(
167            "image_encoder.neck.1.weight".into(),
168            (z(cfg.out_chans), vec![cfg.out_chans]),
169        );
170        t.insert(
171            "image_encoder.neck.1.bias".into(),
172            (z(cfg.out_chans), vec![cfg.out_chans]),
173        );
174        t.insert(
175            "image_encoder.neck.2.weight".into(),
176            (
177                z(cfg.out_chans * cfg.out_chans * 9),
178                vec![cfg.out_chans, cfg.out_chans, 3, 3],
179            ),
180        );
181        t.insert(
182            "image_encoder.neck.3.weight".into(),
183            (z(cfg.out_chans), vec![cfg.out_chans]),
184        );
185        t.insert(
186            "image_encoder.neck.3.bias".into(),
187            (z(cfg.out_chans), vec![cfg.out_chans]),
188        );
189
190        WeightMap::from_tensors(t)
191    }
192
193    #[test]
194    fn encoder_graph_builds() {
195        let cfg = SamEncoderConfig::vit_b();
196        let mut wm = synthetic_vit_b_weights();
197        let (g, _params, _pre) = build_sam_encoder_graph(&cfg, &mut wm).unwrap();
198        assert_eq!(g.outputs.len(), 1);
199        // [1, out_chans, hw, hw] NCHW image embeddings
200        let s = g.shape(g.outputs[0]);
201        let dims: Vec<usize> = s.dims().iter().map(|d| d.unwrap_static()).collect();
202        assert_eq!(dims, vec![1, cfg.out_chans, SAM_EMBED_HW, SAM_EMBED_HW]);
203        // All non-preprocess weights must be drained.
204        let leftovers: Vec<&str> = wm.keys().collect();
205        assert!(leftovers.is_empty(), "leftover weights: {leftovers:?}");
206    }
207
208    #[test]
209    fn sam_rlx_toml_profile_loads() {
210        use rlx_flow::CompileProfile;
211        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("src/sam.rlx.toml");
212        let p = CompileProfile::from_toml_path(&path).unwrap();
213        assert_eq!(p.fusion.policy, rlx_flow::FusionPolicyKind::Direct);
214    }
215
216    #[test]
217    fn preprocess_round_trip_shapes() {
218        // 100×80 RGB image → padded to 1024×1024 NCHW; new_h, new_w
219        // preserve aspect ratio with long side = 1024.
220        let img = vec![128u8; 100 * 80 * 3];
221        let (nchw, (h, w)) = preprocess_image(&img, 100, 80);
222        assert_eq!(nchw.len(), 3 * 1024 * 1024);
223        assert_eq!(h, 1024);
224        assert_eq!(w, (80.0_f32 * (1024.0 / 100.0)).round() as usize);
225    }
226}