1use std::path::PathBuf;
29use std::str::FromStr;
30
31use anyhow::Result;
32use clap::Parser;
33use rlx_runtime::{ConfigSource, Device, LmRunnerBuilder, SampleOpts, WeightFormat};
34
35#[derive(Debug, Clone, Parser)]
37pub struct LmCliArgs {
38 #[arg(long)]
40 pub weights: PathBuf,
41
42 #[arg(long, default_value = "cpu")]
44 pub device: String,
45
46 #[arg(long, value_parser = parse_format)]
48 pub format: Option<WeightFormat>,
49
50 #[arg(long)]
52 pub config: Option<PathBuf>,
53
54 #[arg(long)]
56 pub prompt: Option<String>,
57
58 #[arg(long, value_delimiter = ',')]
60 pub prompt_ids: Option<Vec<u32>>,
61
62 #[arg(long)]
64 pub tokenizer: Option<PathBuf>,
65
66 #[arg(long, default_value_t = 32)]
68 pub max_tokens: usize,
69
70 #[arg(long, default_value_t = 128)]
72 pub max_seq: usize,
73
74 #[arg(long)]
76 pub max_memory_gb: Option<f32>,
77
78 #[arg(long)]
80 pub no_stream: bool,
81
82 #[arg(long)]
84 pub packed: bool,
85
86 #[arg(long, conflicts_with = "packed")]
88 pub no_packed: bool,
89
90 #[arg(long, default_value_t = 0.0)]
92 pub temperature: f32,
93
94 #[arg(long, default_value_t = 1.0)]
96 pub top_p: f32,
97
98 #[arg(long)]
100 pub top_k: Option<u32>,
101
102 #[arg(long, default_value_t = 1.0)]
104 pub repetition_penalty: f32,
105
106 #[arg(long, alias = "prefer")]
108 pub prefer_gguf: Option<String>,
109}
110
111fn parse_format(s: &str) -> Result<WeightFormat, String> {
112 WeightFormat::parse(s).map_err(|e| e.to_string())
113}
114
115impl LmCliArgs {
116 pub fn device(&self) -> Result<Device> {
119 Device::from_str(&self.device).map_err(|e| anyhow::anyhow!("--device {}: {e}", self.device))
120 }
121
122 pub fn sample_opts(&self) -> SampleOpts {
124 SampleOpts {
125 temperature: self.temperature,
126 top_p: self.top_p,
127 top_k: self.top_k,
128 repetition_penalty: self.repetition_penalty,
129 }
130 }
131
132 pub fn into_builder<Cfg>(self) -> Result<LmRunnerBuilder<Cfg>> {
136 let device = self.device()?;
137 let packed = if self.packed {
138 Some(true)
139 } else if self.no_packed {
140 Some(false)
141 } else {
142 None
143 };
144 let config = self
145 .config
146 .clone()
147 .map(ConfigSource::JsonFile)
148 .unwrap_or(ConfigSource::Embedded);
149
150 let mut b = LmRunnerBuilder::<Cfg>::new()
151 .weights(self.weights.clone())
152 .device(device)
153 .max_seq(self.max_seq)
154 .stream(!self.no_stream)
155 .sample(self.sample_opts())
156 .config(config);
157 b.format = self.format;
158 b.packed_weights = packed;
159 b.max_memory_gb = self.max_memory_gb;
160 b.prefer_gguf = self.prefer_gguf.clone();
161 Ok(b)
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use clap::CommandFactory;
169
170 #[test]
171 fn debug_assert_works() {
172 LmCliArgs::command().debug_assert();
173 }
174
175 #[test]
176 fn defaults() {
177 let a = LmCliArgs::try_parse_from(["x", "--weights", "/tmp/m.gguf"]).unwrap();
178 assert_eq!(a.device, "cpu");
179 assert_eq!(a.max_seq, 128);
180 assert_eq!(a.max_tokens, 32);
181 assert!(!a.no_stream);
182 assert_eq!(a.temperature, 0.0);
183 }
184
185 #[test]
186 fn packed_conflict() {
187 let r =
188 LmCliArgs::try_parse_from(["x", "--weights", "/tmp/m.gguf", "--packed", "--no-packed"]);
189 assert!(r.is_err());
190 }
191
192 #[test]
193 fn builder_propagates_packed_override() {
194 let a =
195 LmCliArgs::try_parse_from(["x", "--weights", "/tmp/m.gguf", "--no-packed"]).unwrap();
196 let b: LmRunnerBuilder<()> = a.into_builder().unwrap();
197 assert_eq!(b.packed_weights, Some(false));
198 }
199}