1use anyhow::{Result, anyhow, bail};
17use rlx_core::validate_sam_device;
18use rlx_runtime::Device;
19use std::path::PathBuf;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum SamArch {
24 Sam1,
25 Sam2,
26 Sam3,
27}
28
29#[derive(Debug, Clone)]
31pub struct SamRunnerBuilder {
32 arch: SamArch,
33 weights: Option<PathBuf>,
34 device: Option<Device>,
35 config_path: Option<PathBuf>,
36}
37
38impl SamRunnerBuilder {
39 pub fn weights<P: Into<PathBuf>>(mut self, p: P) -> Self {
40 self.weights = Some(p.into());
41 self
42 }
43 pub fn device(mut self, d: Device) -> Self {
44 self.device = Some(d);
45 self
46 }
47 pub fn config<P: Into<PathBuf>>(mut self, p: P) -> Self {
48 self.config_path = Some(p.into());
49 self
50 }
51
52 pub fn build(self) -> Result<SamRunner> {
56 let weights = self
57 .weights
58 .ok_or_else(|| anyhow!("weights path required"))?;
59 if !weights.exists() {
60 bail!("weights file not found: {weights:?}");
61 }
62 let device = self.device.unwrap_or(Device::Cpu);
63 validate_sam_device("sam", device)?;
64 Ok(SamRunner {
65 arch: self.arch,
66 weights,
67 device,
68 config_path: self.config_path,
69 })
70 }
71}
72
73pub struct SamRunner {
79 pub arch: SamArch,
80 pub weights: PathBuf,
81 pub device: Device,
82 pub config_path: Option<PathBuf>,
83}
84
85pub enum SamPredictionAny {
88 Sam1(rlx_sam::MaskPrediction),
89 Sam2(rlx_sam2::Sam2ImagePrediction),
90 Sam3(rlx_sam3::Sam3ImagePrediction),
91}
92
93impl SamRunner {
94 pub fn builder(arch: SamArch) -> SamRunnerBuilder {
95 SamRunnerBuilder {
96 arch,
97 weights: None,
98 device: None,
99 config_path: None,
100 }
101 }
102
103 pub fn summary(&self) -> String {
106 format!(
107 "SAM{} runner — weights={:?} device={:?} config={:?}",
108 match self.arch {
109 SamArch::Sam1 => "1",
110 SamArch::Sam2 => "2",
111 SamArch::Sam3 => "3",
112 },
113 self.weights,
114 self.device,
115 self.config_path
116 )
117 }
118
119 pub fn predict_image(
138 &self,
139 rgb: &[u8],
140 h_in: usize,
141 w_in: usize,
142 points: Option<(&[f32], &[f32])>,
143 boxes: Option<&[f32]>,
144 text_tokens: &[u32],
145 ) -> Result<SamPredictionAny> {
146 let weights_str = self
147 .weights
148 .to_str()
149 .ok_or_else(|| anyhow!("non-utf8 weights path"))?;
150 match self.arch {
151 SamArch::Sam1 => {
152 use rlx_sam::{Sam, SamConfig};
153 let cfg = match rlx_ir::env::var("RLX_SAM_VARIANT")
154 .unwrap_or_else(|| "vit_b".into())
155 .as_str()
156 {
157 "vit_b" => SamConfig::vit_b(),
158 "vit_l" => SamConfig::vit_l(),
159 "vit_h" => SamConfig::vit_h(),
160 other => bail!("RLX_SAM_VARIANT must be vit_b|vit_l|vit_h, got {other}"),
161 };
162 let mut sam = Sam::from_safetensors_on(weights_str, cfg, self.device)?;
163 let (pred, _resized) = sam.forward(
164 rgb, h_in, w_in, points, boxes, None, true,
165 )?;
166 Ok(SamPredictionAny::Sam1(pred))
167 }
168 SamArch::Sam2 => {
169 use rlx_sam2::{Sam2, Sam2Config};
170 let cfg = match rlx_ir::env::var("RLX_SAM2_VARIANT")
171 .unwrap_or_else(|| "tiny".into())
172 .as_str()
173 {
174 "tiny" => Sam2Config::hiera_tiny(),
175 "small" => Sam2Config::hiera_small(),
176 "base_plus" => Sam2Config::hiera_base_plus(),
177 "large" => Sam2Config::hiera_large(),
178 other => {
179 bail!("RLX_SAM2_VARIANT must be tiny|small|base_plus|large, got {other}")
180 }
181 };
182 let mut sam = Sam2::from_safetensors_on(weights_str, cfg, self.device)?;
183 let pred = sam.predict_image(
184 rgb, h_in, w_in, points, boxes, None, true,
185 )?;
186 Ok(SamPredictionAny::Sam2(pred))
187 }
188 SamArch::Sam3 => {
189 use rlx_sam3::{Sam3, Sam3Config};
190 let cfg = Sam3Config::base();
191 let mut sam = Sam3::from_checkpoint_on(weights_str, cfg, self.device)?;
192 if text_tokens.is_empty() {
193 bail!("SAM 3 is text-conditioned — pass non-empty text_tokens");
194 }
195 let pred = sam.predict_image_text(rgb, h_in, w_in, text_tokens)?;
196 Ok(SamPredictionAny::Sam3(pred))
197 }
198 }
199 }
200}