Skip to main content

rlx_sam/
sam.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM v1 top-level orchestrator — ties the IR-graph image encoder
17//! together with the host-side prompt encoder + mask decoder.
18//!
19//! Mirrors `candle-transformers/src/models/segment_anything/sam.rs`
20//! at the API level. The image encoder runs on the rlx-runtime
21//! `Session`; mask/prompt conv stacks use IR ops on the same device.
22
23use super::config::{SAM_EMBED_HW, SAM_IMG_SIZE, SAM_PROMPT_EMBED_DIM, SamConfig};
24use super::image_encoder::build_sam_encoder_graph;
25use super::mask_decoder::{MaskDecoderWeights, extract_mask_decoder_weights, mask_decoder_forward};
26use super::preprocess::{SamPreprocessWeights, assemble_patch_tokens, preprocess_image};
27use super::prompt_encoder::{
28    PromptEncoderOutput, PromptEncoderWeights, extract_prompt_encoder_weights,
29    prompt_encoder_forward,
30};
31use super::prompt_mask_ir::SamPromptMaskCompiled;
32use super::upscale_ir::SamMaskUpscaleCompiled;
33use anyhow::Result;
34use rlx_runtime::{CompiledGraph, Device, Session};
35use rlx_sam_ir::mask_hyper_matmul_ir::MaskHyperMatmulCompiled;
36use rlx_sam_ir::mlp_relu_ir::MlpReluCompiled;
37use std::path::Path;
38
39/// Mask channels used by the prompt encoder's mask-downscaling stack.
40/// candle's `Sam::new` hardcodes 16 across all ViT variants.
41pub const SAM_MASK_IN_CHANS: usize = 16;
42
43/// Full SAM model — owns the compiled image encoder + all decoder
44/// weights. Stateless wrt prompts: every call to [`Sam::forward`]
45/// runs the cached encoder + a fresh decoder forward.
46pub struct Sam {
47    cfg: SamConfig,
48    encoder: CompiledGraph,
49    pre: SamPreprocessWeights,
50    prompt_enc: PromptEncoderWeights,
51    mask_stack: SamPromptMaskCompiled,
52    mask_dec: MaskDecoderWeights,
53    upscale: SamMaskUpscaleCompiled,
54    hyper_matmul: MaskHyperMatmulCompiled,
55    hyper_mlps_ir: Vec<MlpReluCompiled>,
56    iou_head_ir: MlpReluCompiled,
57    tw_ir: rlx_sam_ir::twoway_transformer_ir::TwoWayTransformerCompiled,
58}
59
60impl Sam {
61    /// Load SAM ViT-B (or L/H — pass the matching config) from a
62    /// safetensors checkpoint, compiling the image encoder for the
63    /// CPU backend. For GPU/Metal/MLX, use
64    /// [`Sam::from_safetensors_on`].
65    pub fn from_safetensors(weights_path: &str, cfg: SamConfig) -> Result<Self> {
66        Self::from_safetensors_on(weights_path, cfg, Device::Cpu)
67    }
68
69    /// Same as [`Sam::from_safetensors`] but compiles the image
70    /// encoder for the given device. Requires the matching backend
71    /// feature on `rlx-models`:
72    ///
73    /// | feature   | backend           |
74    /// |-----------|-------------------|
75    /// | `metal`   | `Device::Metal`   |
76    /// | `mlx`     | `Device::Mlx`     |
77    /// | `gpu`     | `Device::Gpu`     |
78    /// | `cuda`    | `Device::Cuda`    |
79    /// | `rocm`    | `Device::Rocm`    |
80    /// | `tpu`     | `Device::Tpu`     |
81    ///
82    pub fn from_safetensors_on(weights_path: &str, cfg: SamConfig, device: Device) -> Result<Self> {
83        rlx_core::validate_sam_device("sam", device)?;
84        let mut wm = rlx_core::load_weight_map(Path::new(weights_path), rlx_core::SAM_GGUF_ARCHES)?;
85        let (graph, params, pre) = build_sam_encoder_graph(&cfg.encoder, &mut wm)?;
86        let profile = crate::profile::sam_profile_near_weights(std::path::Path::new(weights_path));
87        let opts = rlx_core::flow_bridge::compile_options_for_profile(&profile, device);
88        let mut encoder = Session::new(device).compile_with(graph, &opts);
89        for (name, data) in &params {
90            encoder.set_param(name, data);
91        }
92        let prompt_enc =
93            extract_prompt_encoder_weights(&mut wm, cfg.encoder.out_chans, SAM_MASK_IN_CHANS)?;
94        let mask_stack =
95            SamPromptMaskCompiled::compile_with_profile(&prompt_enc, device, &profile)?;
96        let mask_dec = extract_mask_decoder_weights(
97            &mut wm,
98            cfg.decoder.transformer_dim,
99            cfg.decoder.num_mask_tokens,
100            cfg.decoder.iou_head_depth,
101            cfg.decoder.iou_head_hidden_dim,
102            cfg.decoder.transformer_depth,
103            cfg.decoder.transformer_num_heads,
104            cfg.decoder.transformer_mlp_dim,
105        )?;
106        let upscale = SamMaskUpscaleCompiled::compile_with_profile(&mask_dec, device, &profile)?;
107        let hyper_matmul = MaskHyperMatmulCompiled::compile_with_profile(
108            mask_dec.num_mask_tokens,
109            cfg.decoder.transformer_dim / 8,
110            SAM_EMBED_HW,
111            device,
112            &profile,
113        )?;
114        let hyper_mlps_ir =
115            super::mlp_ir::compile_hyper_mlps_with_profile(&mask_dec.hyper_mlps, device, &profile)?;
116        let iou_head_ir =
117            super::mlp_ir::compile_iou_head_with_profile(&mask_dec.iou_head, device, &profile)?;
118        let base_q_n = 1 + mask_dec.num_mask_tokens;
119        let tw_ir = super::transformer_ir::compile_two_way_transformer_with_profile(
120            &mask_dec.transformer,
121            base_q_n,
122            SAM_EMBED_HW,
123            device,
124            &profile,
125        )?;
126        Ok(Self {
127            cfg,
128            encoder,
129            pre,
130            prompt_enc,
131            mask_stack,
132            mask_dec,
133            upscale,
134            hyper_matmul,
135            hyper_mlps_ir,
136            iou_head_ir,
137            tw_ir,
138        })
139    }
140
141    /// Encode an image into the `[256, 64, 64]` image embedding.
142    /// `image_nchw`: pre-padded `[3, 1024, 1024]` NCHW f32 tensor
143    /// (see [`super::preprocess::preprocess_image`]).
144    pub fn encode_image(&mut self, image_nchw: &[f32]) -> Vec<f32> {
145        let hidden = assemble_patch_tokens(&self.pre, image_nchw).expect("assemble_patch_tokens");
146        let outputs = self.encoder.run(&[("hidden", hidden.as_slice())]);
147        outputs.into_iter().next().expect("encoder output")
148    }
149
150    /// Run the prompt encoder + mask decoder on a pre-encoded image.
151    pub fn predict_masks(
152        &mut self,
153        image_embeddings: &[f32],
154        points: Option<(&[f32], &[f32])>,
155        boxes: Option<&[f32]>,
156        masks: Option<&[f32]>,
157        multimask_output: bool,
158    ) -> Result<MaskPrediction> {
159        let pe: PromptEncoderOutput =
160            prompt_encoder_forward(&self.prompt_enc, &mut self.mask_stack, points, boxes, masks)?;
161        let (mask_logits, iou_pred, num_masks, mask_side) = mask_decoder_forward(
162            &self.mask_dec,
163            &mut self.upscale,
164            Some(&mut self.hyper_matmul),
165            Some(&mut self.hyper_mlps_ir),
166            Some(&mut self.iou_head_ir),
167            Some(&mut self.tw_ir),
168            image_embeddings,
169            &pe.image_pe,
170            &pe.sparse_embeddings,
171            pe.num_sparse_tokens,
172            &pe.dense_embeddings,
173            multimask_output,
174        )?;
175        Ok(MaskPrediction {
176            mask_logits,
177            iou_pred,
178            num_masks,
179            mask_side,
180        })
181    }
182
183    /// End-to-end forward: image bytes → masks. `rgb` is HWC u8.
184    pub fn forward(
185        &mut self,
186        rgb: &[u8],
187        h_in: usize,
188        w_in: usize,
189        points: Option<(&[f32], &[f32])>,
190        boxes: Option<&[f32]>,
191        masks: Option<&[f32]>,
192        multimask_output: bool,
193    ) -> Result<(MaskPrediction, (usize, usize))> {
194        let (image_nchw, (resized_h, resized_w)) = preprocess_image(rgb, h_in, w_in);
195        let image_embeddings = self.encode_image(&image_nchw);
196        let pred = self.predict_masks(&image_embeddings, points, boxes, masks, multimask_output)?;
197        Ok((pred, (resized_h, resized_w)))
198    }
199
200    pub fn config(&self) -> &SamConfig {
201        &self.cfg
202    }
203
204    /// Spatial side length of the predicted mask logits (= 4 · hw = 256
205    /// for ViT-B at 1024 input).
206    pub fn mask_side(&self) -> usize {
207        4 * SAM_EMBED_HW
208    }
209
210    /// Image side that the model operates on internally.
211    pub fn input_image_size(&self) -> usize {
212        SAM_IMG_SIZE
213    }
214}
215
216/// Output of [`Sam::predict_masks`] / [`Sam::forward`].
217pub struct MaskPrediction {
218    /// `[num_masks, mask_side, mask_side]` mask logits in the encoder's
219    /// 4×-upscaled space. Threshold > 0 to get binary masks; further
220    /// upscale + crop back to the original image as needed.
221    pub mask_logits: Vec<f32>,
222    /// `[num_masks]` per-mask IoU prediction (model-self-estimated
223    /// mask quality).
224    pub iou_pred: Vec<f32>,
225    pub num_masks: usize,
226    pub mask_side: usize,
227}
228
229impl MaskPrediction {
230    /// Convenience: drop the largest predicted-IoU index. Returns
231    /// `Some((index, iou))`.
232    pub fn best_by_iou(&self) -> Option<(usize, f32)> {
233        self.iou_pred
234            .iter()
235            .enumerate()
236            .max_by(|a, b| a.1.total_cmp(b.1))
237            .map(|(i, v)| (i, *v))
238    }
239}
240
241/// Drop-in default config matching candle's `Sam::new()` for ViT-B
242/// (the `lmz/candle-sam/sam_vit_b_01ec64.safetensors` checkpoint).
243pub fn sam_vit_b_config() -> SamConfig {
244    SamConfig::vit_b()
245}
246
247#[allow(dead_code)]
248fn _silence_unused() {
249    let _ = SAM_PROMPT_EMBED_DIM;
250}