use std::path::PathBuf;
use std::str::FromStr;
use anyhow::Result;
use clap::Parser;
use rlx_runtime::{ConfigSource, Device, LmRunnerBuilder, SampleOpts, WeightFormat};
#[derive(Debug, Clone, Parser)]
pub struct LmCliArgs {
#[arg(long)]
pub weights: PathBuf,
#[arg(long, default_value = "cpu")]
pub device: String,
#[arg(long, value_parser = parse_format)]
pub format: Option<WeightFormat>,
#[arg(long)]
pub config: Option<PathBuf>,
#[arg(long)]
pub prompt: Option<String>,
#[arg(long, value_delimiter = ',')]
pub prompt_ids: Option<Vec<u32>>,
#[arg(long)]
pub tokenizer: Option<PathBuf>,
#[arg(long, default_value_t = 32)]
pub max_tokens: usize,
#[arg(long, default_value_t = 128)]
pub max_seq: usize,
#[arg(long)]
pub max_memory_gb: Option<f32>,
#[arg(long)]
pub no_stream: bool,
#[arg(long)]
pub packed: bool,
#[arg(long, conflicts_with = "packed")]
pub no_packed: bool,
#[arg(long, default_value_t = 0.0)]
pub temperature: f32,
#[arg(long, default_value_t = 1.0)]
pub top_p: f32,
#[arg(long)]
pub top_k: Option<u32>,
#[arg(long, default_value_t = 1.0)]
pub repetition_penalty: f32,
#[arg(long, alias = "prefer")]
pub prefer_gguf: Option<String>,
}
fn parse_format(s: &str) -> Result<WeightFormat, String> {
WeightFormat::parse(s).map_err(|e| e.to_string())
}
impl LmCliArgs {
pub fn device(&self) -> Result<Device> {
Device::from_str(&self.device).map_err(|e| anyhow::anyhow!("--device {}: {e}", self.device))
}
pub fn sample_opts(&self) -> SampleOpts {
SampleOpts {
temperature: self.temperature,
top_p: self.top_p,
top_k: self.top_k,
repetition_penalty: self.repetition_penalty,
}
}
pub fn into_builder<Cfg>(self) -> Result<LmRunnerBuilder<Cfg>> {
let device = self.device()?;
let packed = if self.packed {
Some(true)
} else if self.no_packed {
Some(false)
} else {
None
};
let config = self
.config
.clone()
.map(ConfigSource::JsonFile)
.unwrap_or(ConfigSource::Embedded);
let mut b = LmRunnerBuilder::<Cfg>::new()
.weights(self.weights.clone())
.device(device)
.max_seq(self.max_seq)
.stream(!self.no_stream)
.sample(self.sample_opts())
.config(config);
b.format = self.format;
b.packed_weights = packed;
b.max_memory_gb = self.max_memory_gb;
b.prefer_gguf = self.prefer_gguf.clone();
Ok(b)
}
}
#[cfg(test)]
mod tests {
use super::*;
use clap::CommandFactory;
#[test]
fn debug_assert_works() {
LmCliArgs::command().debug_assert();
}
#[test]
fn defaults() {
let a = LmCliArgs::try_parse_from(["x", "--weights", "/tmp/m.gguf"]).unwrap();
assert_eq!(a.device, "cpu");
assert_eq!(a.max_seq, 128);
assert_eq!(a.max_tokens, 32);
assert!(!a.no_stream);
assert_eq!(a.temperature, 0.0);
}
#[test]
fn packed_conflict() {
let r =
LmCliArgs::try_parse_from(["x", "--weights", "/tmp/m.gguf", "--packed", "--no-packed"]);
assert!(r.is_err());
}
#[test]
fn builder_propagates_packed_override() {
let a =
LmCliArgs::try_parse_from(["x", "--weights", "/tmp/m.gguf", "--no-packed"]).unwrap();
let b: LmRunnerBuilder<()> = a.into_builder().unwrap();
assert_eq!(b.packed_weights, Some(false));
}
}