use crate::{Sam, SamConfig};
use anyhow::{Result, anyhow, bail};
use rlx_cli::{parse_sam_device, req};
use std::path::PathBuf;
pub fn run_sam1(args: &[String]) -> Result<()> {
run(args)
}
pub fn run(args: &[String]) -> Result<()> {
let mut weights: Option<PathBuf> = None;
let mut device = "cpu".to_string();
let mut point: Option<(f32, f32)> = None;
let mut dry = false;
let mut i = 0;
while i < args.len() {
match args[i].as_str() {
"--weights" => weights = Some(req(args, &mut i)?.into()),
"--device" => device = req(args, &mut i)?,
"--point" => {
let v = req(args, &mut i)?;
let parts: Vec<&str> = v.split(',').collect();
if parts.len() != 2 {
bail!("--point expects X,Y, got {v}");
}
point = Some((
parts[0].trim().parse().map_err(|_| anyhow!("--point X"))?,
parts[1].trim().parse().map_err(|_| anyhow!("--point Y"))?,
));
}
"--dry" => {
dry = true;
i += 1;
}
"--help" | "-h" => {
eprintln!(
"rlx-sam1 — SAM v1; flags: --weights --device (cpu|metal|mlx|cuda|rocm|gpu|vulkan|tpu) --point --dry"
);
return Ok(());
}
other => bail!("unknown flag: {other}"),
}
}
let weights = weights.ok_or_else(|| anyhow!("--weights is required"))?;
let device = parse_sam_device("sam", &device)?;
let weights_str = weights.to_str().ok_or_else(|| anyhow!("non-utf8 path"))?;
let variant = rlx_ir::env::var("RLX_SAM_VARIANT").unwrap_or_else(|| "vit_b".to_string());
let cfg = match variant.as_str() {
"vit_b" => SamConfig::vit_b(),
"vit_l" => SamConfig::vit_l(),
"vit_h" => SamConfig::vit_h(),
other => bail!("RLX_SAM_VARIANT must be vit_b|vit_l|vit_h, got {other}"),
};
eprintln!("[rlx-sam1] weights={weights:?} device={device:?}");
if dry {
return Ok(());
}
let h_in = 1024usize;
let w_in = 1024usize;
let mut rgb = vec![0u8; h_in * w_in * 3];
for y in 0..h_in {
for x in 0..w_in {
let base = (y * w_in + x) * 3;
rgb[base] = (x * 255 / w_in) as u8;
rgb[base + 1] = (y * 255 / h_in) as u8;
rgb[base + 2] = ((x + y) * 127 / (h_in + w_in)) as u8;
}
}
let (cx, cy) = point.unwrap_or((w_in as f32 / 2.0, h_in as f32 / 2.0));
let mut sam = Sam::from_safetensors_on(weights_str, cfg, device)?;
let (pred, _) = sam.forward(
&rgb,
h_in,
w_in,
Some((&[cx, cy], &[1.0f32])),
None,
None,
true,
)?;
eprintln!(
"[rlx-sam1] masks={} side={} iou={:?}",
pred.num_masks,
pred.mask_side,
&pred.iou_pred[..pred.iou_pred.len().min(pred.num_masks)]
);
Ok(())
}
pub fn run_sam2(_args: &[String]) -> Result<()> {
bail!("use `rlx-sam2` (or `rlx-run sam2`)")
}
pub fn run_sam3(_args: &[String]) -> Result<()> {
bail!("use `rlx-sam3` (or `rlx-run sam3`)")
}