use anyhow::{Context, Result, bail};
use rlx_models::run::{ChatMessage, auto_chat_template, auto_runner, auto_tokenize};
use std::path::PathBuf;
#[derive(Debug, Default)]
struct Args {
weights: Option<PathBuf>,
prompt: Option<String>,
system: Option<String>,
n_new: usize,
}
fn parse_args() -> Result<Args> {
let mut args = Args {
n_new: 32,
..Default::default()
};
let mut it = std::env::args().skip(1);
while let Some(a) = it.next() {
match a.as_str() {
"--weights" => args.weights = it.next().map(PathBuf::from),
"--prompt" => args.prompt = it.next(),
"--system" => args.system = it.next(),
"--n-new" => args.n_new = it.next().and_then(|s| s.parse().ok()).unwrap_or(32),
other => bail!("unknown arg: {other}"),
}
}
if args.weights.is_none() {
bail!("--weights <path/to.gguf> required");
}
if args.prompt.is_none() {
bail!("--prompt <text> required");
}
Ok(args)
}
fn main() -> Result<()> {
let args = parse_args()?;
let weights = args.weights.as_ref().unwrap();
let prompt = args.prompt.as_ref().unwrap();
eprintln!("# 1) sniff arch + load chat template from {weights:?}");
let template =
auto_chat_template(weights).with_context(|| format!("auto_chat_template({weights:?})"))?;
eprintln!(
"# bos = {:?}, eos = {:?}",
template.bos_token(),
template.eos_token()
);
eprintln!("# 2) render chat template");
let mut messages: Vec<ChatMessage> = Vec::new();
if let Some(sys) = &args.system {
messages.push(ChatMessage::system(sys.clone()));
}
messages.push(ChatMessage::user(prompt.clone()));
let rendered = template.render(&messages, true)?;
eprintln!("# rendered prompt ({} chars):", rendered.len());
for line in rendered.lines().take(8) {
eprintln!("# {line}");
}
if rendered.lines().count() > 8 {
eprintln!("# …");
}
eprintln!("# 3) tokenize");
let prompt_ids = auto_tokenize(weights, &rendered, None)
.with_context(|| format!("auto_tokenize({weights:?})"))?;
eprintln!("# {} prompt tokens", prompt_ids.len());
eprintln!("# 4) build runner");
let mut runner = auto_runner(weights).with_context(|| format!("auto_runner({weights:?})"))?;
eprintln!("# family = {}", runner.family());
eprintln!("# 5) generate {} tokens", args.n_new);
let mut produced: Vec<u32> = Vec::new();
let generated = runner.generate(&prompt_ids, args.n_new, &mut |tok: u32| -> bool {
produced.push(tok);
print!("{tok} ");
use std::io::Write;
let _ = std::io::stdout().flush();
true
})?;
println!();
eprintln!("# generated {} ids: {:?}", generated.len(), generated);
Ok(())
}