1use crate::{Sam2, Sam2Config};
18use anyhow::{Result, anyhow, bail};
19use rlx_cli::{parse_sam_device, req};
20use std::path::PathBuf;
21
22pub fn run(args: &[String]) -> Result<()> {
23 let mut weights: Option<PathBuf> = None;
24 let mut device = "cpu".to_string();
25 let mut point: Option<(f32, f32)> = None;
26 let mut dry = false;
27 let mut i = 0;
28 while i < args.len() {
29 match args[i].as_str() {
30 "--weights" => weights = Some(req(args, &mut i)?.into()),
31 "--device" => device = req(args, &mut i)?,
32 "--point" => {
33 let v = req(args, &mut i)?;
34 let parts: Vec<&str> = v.split(',').collect();
35 if parts.len() != 2 {
36 bail!("--point expects X,Y");
37 }
38 point = Some((parts[0].trim().parse()?, parts[1].trim().parse()?));
39 }
40 "--dry" => {
41 dry = true;
42 i += 1;
43 }
44 "--help" | "-h" => {
45 eprintln!(
46 "rlx-sam2 — flags: --weights --device (cpu|metal|mlx|cuda|rocm|gpu|vulkan) --point --dry"
47 );
48 return Ok(());
49 }
50 other => bail!("unknown flag: {other}"),
51 }
52 }
53 let weights = weights.ok_or_else(|| anyhow!("--weights is required"))?;
54 let device = parse_sam_device("sam2", &device)?;
55 let path = weights.to_str().ok_or_else(|| anyhow!("non-utf8 path"))?;
56 let variant = rlx_ir::env::var("RLX_SAM2_VARIANT").unwrap_or_else(|| "tiny".to_string());
57 let cfg = match variant.as_str() {
58 "tiny" => Sam2Config::hiera_tiny(),
59 "small" => Sam2Config::hiera_small(),
60 "base_plus" => Sam2Config::hiera_base_plus(),
61 "large" => Sam2Config::hiera_large(),
62 other => bail!("RLX_SAM2_VARIANT must be tiny|small|base_plus|large, got {other}"),
63 };
64 if dry {
65 return Ok(());
66 }
67 let h_in = 1024usize;
68 let w_in = 1024usize;
69 let rgb = vec![128u8; h_in * w_in * 3];
70 let (cx, cy) = point.unwrap_or((512.0, 512.0));
71 let mut sam = Sam2::from_safetensors_on(path, cfg, device)?;
72 let pred = sam.predict_image(
73 &rgb,
74 h_in,
75 w_in,
76 Some((&[cx, cy], &[1.0f32])),
77 None,
78 None,
79 true,
80 )?;
81 eprintln!(
82 "[rlx-sam2] masks={} out={}x{} iou={:?}",
83 pred.num_masks,
84 pred.h_out,
85 pred.w_out,
86 &pred.iou_pred[..pred.iou_pred.len().min(pred.num_masks)]
87 );
88 Ok(())
89}