Skip to main content

rlx_models/
sam_runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use anyhow::{Result, anyhow, bail};
17use rlx_core::validate_sam_device;
18use rlx_runtime::Device;
19use std::path::PathBuf;
20
21/// Which SAM generation.
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum SamArch {
24    Sam1,
25    Sam2,
26    Sam3,
27}
28
29/// Builder for the SAM family.
30#[derive(Debug, Clone)]
31pub struct SamRunnerBuilder {
32    arch: SamArch,
33    weights: Option<PathBuf>,
34    device: Option<Device>,
35    config_path: Option<PathBuf>,
36}
37
38impl SamRunnerBuilder {
39    pub fn weights<P: Into<PathBuf>>(mut self, p: P) -> Self {
40        self.weights = Some(p.into());
41        self
42    }
43    pub fn device(mut self, d: Device) -> Self {
44        self.device = Some(d);
45        self
46    }
47    pub fn config<P: Into<PathBuf>>(mut self, p: P) -> Self {
48        self.config_path = Some(p.into());
49        self
50    }
51
52    /// Build (validates inputs, but does not load weights — SAM
53    /// loaders today take ownership of the file path and load on
54    /// demand to keep memory peaks lower).
55    pub fn build(self) -> Result<SamRunner> {
56        let weights = self
57            .weights
58            .ok_or_else(|| anyhow!("weights path required"))?;
59        if !weights.exists() {
60            bail!("weights file not found: {weights:?}");
61        }
62        let device = self.device.unwrap_or(Device::Cpu);
63        validate_sam_device("sam", device)?;
64        Ok(SamRunner {
65            arch: self.arch,
66            weights,
67            device,
68            config_path: self.config_path,
69        })
70    }
71}
72
73/// SAM runner — owns the resolved config and dispatches the
74/// per-arch forward pass. SAM 1 / 2 / 3 differ enough in their
75/// prompting that we keep the heavy result type
76/// (`SamPredictionAny`) as a discriminated union the caller
77/// matches on.
78pub struct SamRunner {
79    pub arch: SamArch,
80    pub weights: PathBuf,
81    pub device: Device,
82    pub config_path: Option<PathBuf>,
83}
84
85/// Union of per-arch image-prediction outputs. Caller matches on
86/// the arch they asked for.
87pub enum SamPredictionAny {
88    Sam1(rlx_sam::MaskPrediction),
89    Sam2(rlx_sam2::Sam2ImagePrediction),
90    Sam3(rlx_sam3::Sam3ImagePrediction),
91}
92
93impl SamRunner {
94    pub fn builder(arch: SamArch) -> SamRunnerBuilder {
95        SamRunnerBuilder {
96            arch,
97            weights: None,
98            device: None,
99            config_path: None,
100        }
101    }
102
103    /// Print a human-readable summary — what the CLI prints before
104    /// any per-arch image processing.
105    pub fn summary(&self) -> String {
106        format!(
107            "SAM{} runner — weights={:?} device={:?} config={:?}",
108            match self.arch {
109                SamArch::Sam1 => "1",
110                SamArch::Sam2 => "2",
111                SamArch::Sam3 => "3",
112            },
113            self.weights,
114            self.device,
115            self.config_path
116        )
117    }
118
119    /// End-to-end forward: image bytes → masks. Dispatches to the
120    /// right per-arch entrypoint:
121    ///   * SAM 1 → `Sam::forward` (multimask = true)
122    ///   * SAM 2 → `Sam2::predict_image` (multimask = true)
123    ///   * SAM 3 → `Sam3::predict_image_text` with the supplied
124    ///     `text_tokens` (required for SAM 3 — its decoder is
125    ///     text-conditioned). Pass an empty slice for arches that
126    ///     don't use it.
127    ///
128    /// `rgb` is HWC u8; `points` is `(xy_pairs, labels)` with one
129    /// label per (x, y) pair (1 = foreground, 0 = background).
130    ///
131    /// SAM-arch-specific defaults applied:
132    ///   * `cfg` derived from environment variables (`RLX_SAM_VARIANT`
133    ///     for v1: vit_b/l/h; `RLX_SAM2_VARIANT` for v2: tiny/small/
134    ///     base_plus/large); falls back to the smallest variant.
135    ///   * `multimask_output = true` for v1 + v2.
136    ///   * SAM 3 vit defaults to `base`.
137    pub fn predict_image(
138        &self,
139        rgb: &[u8],
140        h_in: usize,
141        w_in: usize,
142        points: Option<(&[f32], &[f32])>,
143        boxes: Option<&[f32]>,
144        text_tokens: &[u32],
145    ) -> Result<SamPredictionAny> {
146        let weights_str = self
147            .weights
148            .to_str()
149            .ok_or_else(|| anyhow!("non-utf8 weights path"))?;
150        match self.arch {
151            SamArch::Sam1 => {
152                use rlx_sam::{Sam, SamConfig};
153                let cfg = match rlx_ir::env::var("RLX_SAM_VARIANT")
154                    .unwrap_or_else(|| "vit_b".into())
155                    .as_str()
156                {
157                    "vit_b" => SamConfig::vit_b(),
158                    "vit_l" => SamConfig::vit_l(),
159                    "vit_h" => SamConfig::vit_h(),
160                    other => bail!("RLX_SAM_VARIANT must be vit_b|vit_l|vit_h, got {other}"),
161                };
162                let mut sam = Sam::from_safetensors_on(weights_str, cfg, self.device)?;
163                let (pred, _resized) = sam.forward(
164                    rgb, h_in, w_in, points, boxes, None, /*multimask*/ true,
165                )?;
166                Ok(SamPredictionAny::Sam1(pred))
167            }
168            SamArch::Sam2 => {
169                use rlx_sam2::{Sam2, Sam2Config};
170                let cfg = match rlx_ir::env::var("RLX_SAM2_VARIANT")
171                    .unwrap_or_else(|| "tiny".into())
172                    .as_str()
173                {
174                    "tiny" => Sam2Config::hiera_tiny(),
175                    "small" => Sam2Config::hiera_small(),
176                    "base_plus" => Sam2Config::hiera_base_plus(),
177                    "large" => Sam2Config::hiera_large(),
178                    other => {
179                        bail!("RLX_SAM2_VARIANT must be tiny|small|base_plus|large, got {other}")
180                    }
181                };
182                let mut sam = Sam2::from_safetensors_on(weights_str, cfg, self.device)?;
183                let pred = sam.predict_image(
184                    rgb, h_in, w_in, points, boxes, None, /*multimask*/ true,
185                )?;
186                Ok(SamPredictionAny::Sam2(pred))
187            }
188            SamArch::Sam3 => {
189                use rlx_sam3::{Sam3, Sam3Config};
190                let cfg = Sam3Config::base();
191                let mut sam = Sam3::from_checkpoint_on(weights_str, cfg, self.device)?;
192                if text_tokens.is_empty() {
193                    bail!("SAM 3 is text-conditioned — pass non-empty text_tokens");
194                }
195                let pred = sam.predict_image_text(rgb, h_in, w_in, text_tokens)?;
196                Ok(SamPredictionAny::Sam3(pred))
197            }
198        }
199    }
200}