Skip to main content

rlx_sam2/
cli.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16// RLX CLI for SAM 2
17use crate::{Sam2, Sam2Config};
18use anyhow::{Result, anyhow, bail};
19use rlx_cli::{parse_sam_device, req};
20use std::path::PathBuf;
21
22pub fn run(args: &[String]) -> Result<()> {
23    let mut weights: Option<PathBuf> = None;
24    let mut device = "cpu".to_string();
25    let mut point: Option<(f32, f32)> = None;
26    let mut dry = false;
27    let mut i = 0;
28    while i < args.len() {
29        match args[i].as_str() {
30            "--weights" => weights = Some(req(args, &mut i)?.into()),
31            "--device" => device = req(args, &mut i)?,
32            "--point" => {
33                let v = req(args, &mut i)?;
34                let parts: Vec<&str> = v.split(',').collect();
35                if parts.len() != 2 {
36                    bail!("--point expects X,Y");
37                }
38                point = Some((parts[0].trim().parse()?, parts[1].trim().parse()?));
39            }
40            "--dry" => {
41                dry = true;
42                i += 1;
43            }
44            "--help" | "-h" => {
45                eprintln!(
46                    "rlx-sam2 — flags: --weights --device (cpu|metal|mlx|cuda|rocm|gpu|vulkan) --point --dry"
47                );
48                return Ok(());
49            }
50            other => bail!("unknown flag: {other}"),
51        }
52    }
53    let weights = weights.ok_or_else(|| anyhow!("--weights is required"))?;
54    let device = parse_sam_device("sam2", &device)?;
55    let path = weights.to_str().ok_or_else(|| anyhow!("non-utf8 path"))?;
56    let variant = rlx_ir::env::var("RLX_SAM2_VARIANT").unwrap_or_else(|| "tiny".to_string());
57    let cfg = match variant.as_str() {
58        "tiny" => Sam2Config::hiera_tiny(),
59        "small" => Sam2Config::hiera_small(),
60        "base_plus" => Sam2Config::hiera_base_plus(),
61        "large" => Sam2Config::hiera_large(),
62        other => bail!("RLX_SAM2_VARIANT must be tiny|small|base_plus|large, got {other}"),
63    };
64    if dry {
65        return Ok(());
66    }
67    let h_in = 1024usize;
68    let w_in = 1024usize;
69    let rgb = vec![128u8; h_in * w_in * 3];
70    let (cx, cy) = point.unwrap_or((512.0, 512.0));
71    let mut sam = Sam2::from_safetensors_on(path, cfg, device)?;
72    let pred = sam.predict_image(
73        &rgb,
74        h_in,
75        w_in,
76        Some((&[cx, cy], &[1.0f32])),
77        None,
78        None,
79        true,
80    )?;
81    eprintln!(
82        "[rlx-sam2] masks={} out={}x{} iou={:?}",
83        pred.num_masks,
84        pred.h_out,
85        pred.w_out,
86        &pred.iou_pred[..pred.iou_pred.len().min(pred.num_masks)]
87    );
88    Ok(())
89}