Skip to main content

rlx_sam/
cli.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16// RLX CLI for SAM 1
17use crate::{Sam, SamConfig};
18use anyhow::{Result, anyhow, bail};
19use rlx_cli::{parse_sam_device, req};
20use std::path::PathBuf;
21
22pub fn run_sam1(args: &[String]) -> Result<()> {
23    run(args)
24}
25
26pub fn run(args: &[String]) -> Result<()> {
27    let mut weights: Option<PathBuf> = None;
28    let mut device = "cpu".to_string();
29    let mut point: Option<(f32, f32)> = None;
30    let mut dry = false;
31    let mut i = 0;
32    while i < args.len() {
33        match args[i].as_str() {
34            "--weights" => weights = Some(req(args, &mut i)?.into()),
35            "--device" => device = req(args, &mut i)?,
36            "--point" => {
37                let v = req(args, &mut i)?;
38                let parts: Vec<&str> = v.split(',').collect();
39                if parts.len() != 2 {
40                    bail!("--point expects X,Y, got {v}");
41                }
42                point = Some((
43                    parts[0].trim().parse().map_err(|_| anyhow!("--point X"))?,
44                    parts[1].trim().parse().map_err(|_| anyhow!("--point Y"))?,
45                ));
46            }
47            "--dry" => {
48                dry = true;
49                i += 1;
50            }
51            "--help" | "-h" => {
52                eprintln!(
53                    "rlx-sam1 — SAM v1; flags: --weights --device (cpu|metal|mlx|cuda|rocm|gpu|vulkan|tpu) --point --dry"
54                );
55                return Ok(());
56            }
57            other => bail!("unknown flag: {other}"),
58        }
59    }
60    let weights = weights.ok_or_else(|| anyhow!("--weights is required"))?;
61    let device = parse_sam_device("sam", &device)?;
62    let weights_str = weights.to_str().ok_or_else(|| anyhow!("non-utf8 path"))?;
63    let variant = rlx_ir::env::var("RLX_SAM_VARIANT").unwrap_or_else(|| "vit_b".to_string());
64    let cfg = match variant.as_str() {
65        "vit_b" => SamConfig::vit_b(),
66        "vit_l" => SamConfig::vit_l(),
67        "vit_h" => SamConfig::vit_h(),
68        other => bail!("RLX_SAM_VARIANT must be vit_b|vit_l|vit_h, got {other}"),
69    };
70    eprintln!("[rlx-sam1] weights={weights:?} device={device:?}");
71    if dry {
72        return Ok(());
73    }
74    let h_in = 1024usize;
75    let w_in = 1024usize;
76    let mut rgb = vec![0u8; h_in * w_in * 3];
77    for y in 0..h_in {
78        for x in 0..w_in {
79            let base = (y * w_in + x) * 3;
80            rgb[base] = (x * 255 / w_in) as u8;
81            rgb[base + 1] = (y * 255 / h_in) as u8;
82            rgb[base + 2] = ((x + y) * 127 / (h_in + w_in)) as u8;
83        }
84    }
85    let (cx, cy) = point.unwrap_or((w_in as f32 / 2.0, h_in as f32 / 2.0));
86    let mut sam = Sam::from_safetensors_on(weights_str, cfg, device)?;
87    let (pred, _) = sam.forward(
88        &rgb,
89        h_in,
90        w_in,
91        Some((&[cx, cy], &[1.0f32])),
92        None,
93        None,
94        true,
95    )?;
96    eprintln!(
97        "[rlx-sam1] masks={} side={} iou={:?}",
98        pred.num_masks,
99        pred.mask_side,
100        &pred.iou_pred[..pred.iou_pred.len().min(pred.num_masks)]
101    );
102    Ok(())
103}
104
105// Stubs for multiplexer — use `rlx-sam2` / `rlx-sam3` binaries.
106pub fn run_sam2(_args: &[String]) -> Result<()> {
107    bail!("use `rlx-sam2` (or `rlx-run sam2`)")
108}
109pub fn run_sam3(_args: &[String]) -> Result<()> {
110    bail!("use `rlx-sam3` (or `rlx-run sam3`)")
111}