Skip to main content

rlx_cli/
lm_args.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Shared CLI flags for every per-family LM binary.
17//!
18//! Each `rlx-<family>/src/cli.rs` today hand-rolls the same
19//! `while i < args.len()` loop parsing `--weights / --device / --max-seq
20//! / --max-tokens / --prompt / --prompt-ids / --tokenizer / --temperature
21//! / --top-p / --format / --packed / --no-stream / --max-memory-gb`.
22//!
23//! [`LmCliArgs`] is a `clap`-derived struct that captures the shared set
24//! and provides [`LmCliArgs::into_builder`] which seeds a generic
25//! [`rlx_runtime::LmRunnerBuilder`]. Per-family CLIs can mix in their
26//! own structs with `#[command(flatten)]` for arch-specific flags.
27
28use std::path::PathBuf;
29use std::str::FromStr;
30
31use anyhow::Result;
32use clap::Parser;
33use rlx_runtime::{ConfigSource, Device, LmRunnerBuilder, SampleOpts, WeightFormat};
34
35/// Canonical LM CLI flags.
36#[derive(Debug, Clone, Parser)]
37pub struct LmCliArgs {
38    /// Weights file (`.safetensors` / `.gguf`) or directory.
39    #[arg(long)]
40    pub weights: PathBuf,
41
42    /// Inference device.
43    #[arg(long, default_value = "cpu")]
44    pub device: String,
45
46    /// Override the auto-detected weight format.
47    #[arg(long, value_parser = parse_format)]
48    pub format: Option<WeightFormat>,
49
50    /// Path to a HF `config.json` (default: sibling of `--weights`).
51    #[arg(long)]
52    pub config: Option<PathBuf>,
53
54    /// Prompt text (tokenized via `--tokenizer`).
55    #[arg(long)]
56    pub prompt: Option<String>,
57
58    /// Pre-tokenized comma-separated u32 ids.
59    #[arg(long, value_delimiter = ',')]
60    pub prompt_ids: Option<Vec<u32>>,
61
62    /// Tokenizer file (`tokenizer.json`) for `--prompt` / decode.
63    #[arg(long)]
64    pub tokenizer: Option<PathBuf>,
65
66    /// Tokens to generate.
67    #[arg(long, default_value_t = 32)]
68    pub max_tokens: usize,
69
70    /// Maximum prefill sequence length.
71    #[arg(long, default_value_t = 128)]
72    pub max_seq: usize,
73
74    /// Refuse to load if F32-dequant estimate exceeds this many GB.
75    #[arg(long)]
76    pub max_memory_gb: Option<f32>,
77
78    /// Disable streaming (print only the final string).
79    #[arg(long)]
80    pub no_stream: bool,
81
82    /// Force packed GGUF loading (`Op::DequantMatMul`).
83    #[arg(long)]
84    pub packed: bool,
85
86    /// Disable packed GGUF loading (overrides auto-detection).
87    #[arg(long, conflicts_with = "packed")]
88    pub no_packed: bool,
89
90    /// Sampling temperature. `0` = greedy.
91    #[arg(long, default_value_t = 0.0)]
92    pub temperature: f32,
93
94    /// Nucleus sampling top-p.
95    #[arg(long, default_value_t = 1.0)]
96    pub top_p: f32,
97
98    /// Top-k sampling cutoff.
99    #[arg(long)]
100    pub top_k: Option<u32>,
101
102    /// Repetition penalty.
103    #[arg(long, default_value_t = 1.0)]
104    pub repetition_penalty: f32,
105
106    /// GGUF quant preference (e.g. `Q4_K_M`) when `--weights` is a directory.
107    #[arg(long, alias = "prefer")]
108    pub prefer_gguf: Option<String>,
109}
110
111fn parse_format(s: &str) -> Result<WeightFormat, String> {
112    WeightFormat::parse(s).map_err(|e| e.to_string())
113}
114
115impl LmCliArgs {
116    /// Parse a `Device` from the `--device` string using the upstream
117    /// `FromStr for Device` impl.
118    pub fn device(&self) -> Result<Device> {
119        Device::from_str(&self.device).map_err(|e| anyhow::anyhow!("--device {}: {e}", self.device))
120    }
121
122    /// Build a sampling option set from the relevant flags.
123    pub fn sample_opts(&self) -> SampleOpts {
124        SampleOpts {
125            temperature: self.temperature,
126            top_p: self.top_p,
127            top_k: self.top_k,
128            repetition_penalty: self.repetition_penalty,
129        }
130    }
131
132    /// Construct an [`LmRunnerBuilder`] pre-populated from the flags.
133    /// Per-family runners that wrap [`LmRunnerBuilder`] can call this
134    /// once and then layer family-specific options on top.
135    pub fn into_builder<Cfg>(self) -> Result<LmRunnerBuilder<Cfg>> {
136        let device = self.device()?;
137        let packed = if self.packed {
138            Some(true)
139        } else if self.no_packed {
140            Some(false)
141        } else {
142            None
143        };
144        let config = self
145            .config
146            .clone()
147            .map(ConfigSource::JsonFile)
148            .unwrap_or(ConfigSource::Embedded);
149
150        let mut b = LmRunnerBuilder::<Cfg>::new()
151            .weights(self.weights.clone())
152            .device(device)
153            .max_seq(self.max_seq)
154            .stream(!self.no_stream)
155            .sample(self.sample_opts())
156            .config(config);
157        b.format = self.format;
158        b.packed_weights = packed;
159        b.max_memory_gb = self.max_memory_gb;
160        b.prefer_gguf = self.prefer_gguf.clone();
161        Ok(b)
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use clap::CommandFactory;
169
170    #[test]
171    fn debug_assert_works() {
172        LmCliArgs::command().debug_assert();
173    }
174
175    #[test]
176    fn defaults() {
177        let a = LmCliArgs::try_parse_from(["x", "--weights", "/tmp/m.gguf"]).unwrap();
178        assert_eq!(a.device, "cpu");
179        assert_eq!(a.max_seq, 128);
180        assert_eq!(a.max_tokens, 32);
181        assert!(!a.no_stream);
182        assert_eq!(a.temperature, 0.0);
183    }
184
185    #[test]
186    fn packed_conflict() {
187        let r =
188            LmCliArgs::try_parse_from(["x", "--weights", "/tmp/m.gguf", "--packed", "--no-packed"]);
189        assert!(r.is_err());
190    }
191
192    #[test]
193    fn builder_propagates_packed_override() {
194        let a =
195            LmCliArgs::try_parse_from(["x", "--weights", "/tmp/m.gguf", "--no-packed"]).unwrap();
196        let b: LmRunnerBuilder<()> = a.into_builder().unwrap();
197        assert_eq!(b.packed_weights, Some(false));
198    }
199}