1use crate::{Sam, SamConfig};
18use anyhow::{Result, anyhow, bail};
19use rlx_cli::{parse_sam_device, req};
20use std::path::PathBuf;
21
22pub fn run_sam1(args: &[String]) -> Result<()> {
23 run(args)
24}
25
26pub fn run(args: &[String]) -> Result<()> {
27 let mut weights: Option<PathBuf> = None;
28 let mut device = "cpu".to_string();
29 let mut point: Option<(f32, f32)> = None;
30 let mut dry = false;
31 let mut i = 0;
32 while i < args.len() {
33 match args[i].as_str() {
34 "--weights" => weights = Some(req(args, &mut i)?.into()),
35 "--device" => device = req(args, &mut i)?,
36 "--point" => {
37 let v = req(args, &mut i)?;
38 let parts: Vec<&str> = v.split(',').collect();
39 if parts.len() != 2 {
40 bail!("--point expects X,Y, got {v}");
41 }
42 point = Some((
43 parts[0].trim().parse().map_err(|_| anyhow!("--point X"))?,
44 parts[1].trim().parse().map_err(|_| anyhow!("--point Y"))?,
45 ));
46 }
47 "--dry" => {
48 dry = true;
49 i += 1;
50 }
51 "--help" | "-h" => {
52 eprintln!(
53 "rlx-sam1 — SAM v1; flags: --weights --device (cpu|metal|mlx|cuda|rocm|gpu|vulkan|tpu) --point --dry"
54 );
55 return Ok(());
56 }
57 other => bail!("unknown flag: {other}"),
58 }
59 }
60 let weights = weights.ok_or_else(|| anyhow!("--weights is required"))?;
61 let device = parse_sam_device("sam", &device)?;
62 let weights_str = weights.to_str().ok_or_else(|| anyhow!("non-utf8 path"))?;
63 let variant = rlx_ir::env::var("RLX_SAM_VARIANT").unwrap_or_else(|| "vit_b".to_string());
64 let cfg = match variant.as_str() {
65 "vit_b" => SamConfig::vit_b(),
66 "vit_l" => SamConfig::vit_l(),
67 "vit_h" => SamConfig::vit_h(),
68 other => bail!("RLX_SAM_VARIANT must be vit_b|vit_l|vit_h, got {other}"),
69 };
70 eprintln!("[rlx-sam1] weights={weights:?} device={device:?}");
71 if dry {
72 return Ok(());
73 }
74 let h_in = 1024usize;
75 let w_in = 1024usize;
76 let mut rgb = vec![0u8; h_in * w_in * 3];
77 for y in 0..h_in {
78 for x in 0..w_in {
79 let base = (y * w_in + x) * 3;
80 rgb[base] = (x * 255 / w_in) as u8;
81 rgb[base + 1] = (y * 255 / h_in) as u8;
82 rgb[base + 2] = ((x + y) * 127 / (h_in + w_in)) as u8;
83 }
84 }
85 let (cx, cy) = point.unwrap_or((w_in as f32 / 2.0, h_in as f32 / 2.0));
86 let mut sam = Sam::from_safetensors_on(weights_str, cfg, device)?;
87 let (pred, _) = sam.forward(
88 &rgb,
89 h_in,
90 w_in,
91 Some((&[cx, cy], &[1.0f32])),
92 None,
93 None,
94 true,
95 )?;
96 eprintln!(
97 "[rlx-sam1] masks={} side={} iou={:?}",
98 pred.num_masks,
99 pred.mask_side,
100 &pred.iou_pred[..pred.iou_pred.len().min(pred.num_masks)]
101 );
102 Ok(())
103}
104
105pub fn run_sam2(_args: &[String]) -> Result<()> {
107 bail!("use `rlx-sam2` (or `rlx-run sam2`)")
108}
109pub fn run_sam3(_args: &[String]) -> Result<()> {
110 bail!("use `rlx-sam3` (or `rlx-run sam3`)")
111}