use anyhow::{Result, anyhow, bail};
use rlx_core::validate_sam_device;
use rlx_runtime::Device;
use std::path::PathBuf;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SamArch {
Sam1,
Sam2,
Sam3,
}
#[derive(Debug, Clone)]
pub struct SamRunnerBuilder {
arch: SamArch,
weights: Option<PathBuf>,
device: Option<Device>,
config_path: Option<PathBuf>,
}
impl SamRunnerBuilder {
pub fn weights<P: Into<PathBuf>>(mut self, p: P) -> Self {
self.weights = Some(p.into());
self
}
pub fn device(mut self, d: Device) -> Self {
self.device = Some(d);
self
}
pub fn config<P: Into<PathBuf>>(mut self, p: P) -> Self {
self.config_path = Some(p.into());
self
}
pub fn build(self) -> Result<SamRunner> {
let weights = self
.weights
.ok_or_else(|| anyhow!("weights path required"))?;
if !weights.exists() {
bail!("weights file not found: {weights:?}");
}
let device = self.device.unwrap_or(Device::Cpu);
validate_sam_device("sam", device)?;
Ok(SamRunner {
arch: self.arch,
weights,
device,
config_path: self.config_path,
})
}
}
pub struct SamRunner {
pub arch: SamArch,
pub weights: PathBuf,
pub device: Device,
pub config_path: Option<PathBuf>,
}
pub enum SamPredictionAny {
Sam1(rlx_sam::MaskPrediction),
Sam2(rlx_sam2::Sam2ImagePrediction),
Sam3(rlx_sam3::Sam3ImagePrediction),
}
impl SamRunner {
pub fn builder(arch: SamArch) -> SamRunnerBuilder {
SamRunnerBuilder {
arch,
weights: None,
device: None,
config_path: None,
}
}
pub fn summary(&self) -> String {
format!(
"SAM{} runner — weights={:?} device={:?} config={:?}",
match self.arch {
SamArch::Sam1 => "1",
SamArch::Sam2 => "2",
SamArch::Sam3 => "3",
},
self.weights,
self.device,
self.config_path
)
}
pub fn predict_image(
&self,
rgb: &[u8],
h_in: usize,
w_in: usize,
points: Option<(&[f32], &[f32])>,
boxes: Option<&[f32]>,
text_tokens: &[u32],
) -> Result<SamPredictionAny> {
let weights_str = self
.weights
.to_str()
.ok_or_else(|| anyhow!("non-utf8 weights path"))?;
match self.arch {
SamArch::Sam1 => {
use rlx_sam::{Sam, SamConfig};
let cfg = match rlx_ir::env::var("RLX_SAM_VARIANT")
.unwrap_or_else(|| "vit_b".into())
.as_str()
{
"vit_b" => SamConfig::vit_b(),
"vit_l" => SamConfig::vit_l(),
"vit_h" => SamConfig::vit_h(),
other => bail!("RLX_SAM_VARIANT must be vit_b|vit_l|vit_h, got {other}"),
};
let mut sam = Sam::from_safetensors_on(weights_str, cfg, self.device)?;
let (pred, _resized) = sam.forward(
rgb, h_in, w_in, points, boxes, None, true,
)?;
Ok(SamPredictionAny::Sam1(pred))
}
SamArch::Sam2 => {
use rlx_sam2::{Sam2, Sam2Config};
let cfg = match rlx_ir::env::var("RLX_SAM2_VARIANT")
.unwrap_or_else(|| "tiny".into())
.as_str()
{
"tiny" => Sam2Config::hiera_tiny(),
"small" => Sam2Config::hiera_small(),
"base_plus" => Sam2Config::hiera_base_plus(),
"large" => Sam2Config::hiera_large(),
other => {
bail!("RLX_SAM2_VARIANT must be tiny|small|base_plus|large, got {other}")
}
};
let mut sam = Sam2::from_safetensors_on(weights_str, cfg, self.device)?;
let pred = sam.predict_image(
rgb, h_in, w_in, points, boxes, None, true,
)?;
Ok(SamPredictionAny::Sam2(pred))
}
SamArch::Sam3 => {
use rlx_sam3::{Sam3, Sam3Config};
let cfg = Sam3Config::base();
let mut sam = Sam3::from_checkpoint_on(weights_str, cfg, self.device)?;
if text_tokens.is_empty() {
bail!("SAM 3 is text-conditioned — pass non-empty text_tokens");
}
let pred = sam.predict_image_text(rgb, h_in, w_in, text_tokens)?;
Ok(SamPredictionAny::Sam3(pred))
}
}
}
}