use crate::{Sam3, Sam3Config};
use anyhow::{Result, anyhow, bail};
use rlx_cli::{parse_sam_device, req};
use std::path::PathBuf;
pub fn run(args: &[String]) -> Result<()> {
let mut weights: Option<PathBuf> = None;
let mut device = "cpu".to_string();
let mut text_tokens: Vec<u32> = Vec::new();
let mut dry = false;
let mut i = 0;
while i < args.len() {
match args[i].as_str() {
"--weights" => weights = Some(req(args, &mut i)?.into()),
"--device" => device = req(args, &mut i)?,
"--text-tokens" => {
text_tokens = req(args, &mut i)?
.split(',')
.map(|s| s.trim().parse())
.collect::<Result<_, _>>()?;
}
"--dry" => {
dry = true;
i += 1;
}
"--help" | "-h" => {
eprintln!(
"rlx-sam3 — flags: --weights --device (cpu|metal|mlx|cuda|rocm|gpu|vulkan) --text-tokens --dry"
);
return Ok(());
}
other => bail!("unknown flag: {other}"),
}
}
let weights = weights.ok_or_else(|| anyhow!("--weights is required"))?;
let device = parse_sam_device("sam3", &device)?;
let path = weights.to_str().ok_or_else(|| anyhow!("non-utf8 path"))?;
if text_tokens.is_empty() {
text_tokens = (0..32u32).collect();
eprintln!("[rlx-sam3] using placeholder text tokens 0..32");
}
if dry {
return Ok(());
}
let h_in = 1024usize;
let w_in = 1024usize;
let rgb = vec![128u8; h_in * w_in * 3];
let mut sam = Sam3::from_checkpoint_on(path, Sam3Config::base(), device)?;
let pred = sam.predict_image_text(&rgb, h_in, w_in, &text_tokens)?;
eprintln!(
"[rlx-sam3] instances={} mask_shape={:?} scores[..5]={:?}",
pred.num_instances,
pred.mask_shape,
&pred.scores[..pred.scores.len().min(5)]
);
Ok(())
}