use std::path::Path;
use std::time::Instant;
use trueno::inference::{generate, GgufFile, LlamaModel, SampleParams};
fn main() {
let args: Vec<String> = std::env::args().collect();
if args.len() < 2 {
eprintln!("Usage: {} <model.gguf> [prompt] [--temp T] [--max-tokens N]", args[0]);
eprintln!();
eprintln!("Options:");
eprintln!(" --temp T Temperature (default: 0.7, 0=greedy)");
eprintln!(" --max-tokens N Max tokens to generate (default: 128)");
eprintln!(" --top-k K Top-K sampling (default: 40)");
eprintln!(" --top-p P Top-P nucleus sampling (default: 0.9)");
std::process::exit(1);
}
let model_path = &args[1];
let mut prompt = "The".to_string();
let mut params = SampleParams::default();
let mut max_tokens = 128usize;
let mut i = 2;
while i < args.len() {
match args[i].as_str() {
"--temp" => {
i += 1;
params.temperature = args[i].parse().expect("invalid temperature");
}
"--max-tokens" => {
i += 1;
max_tokens = args[i].parse().expect("invalid max-tokens");
}
"--top-k" => {
i += 1;
params.top_k = args[i].parse().expect("invalid top-k");
}
"--top-p" => {
i += 1;
params.top_p = args[i].parse().expect("invalid top-p");
}
s if !s.starts_with("--") => {
prompt = s.to_string();
}
_ => {
eprintln!("Unknown argument: {}", args[i]);
std::process::exit(1);
}
}
i += 1;
}
eprintln!("Loading model: {model_path}");
let load_start = Instant::now();
let gguf = GgufFile::load(Path::new(model_path)).expect("Failed to load GGUF");
eprintln!(
" GGUF: {} tensors, loaded in {:.1}s",
gguf.tensor_count,
load_start.elapsed().as_secs_f64()
);
if let Some(arch) = gguf.meta_str("general.architecture") {
eprintln!(" Architecture: {arch}");
}
if let Some(name) = gguf.meta_str("general.name") {
eprintln!(" Name: {name}");
}
let build_start = Instant::now();
let model = LlamaModel::from_gguf(&gguf).expect("Failed to build model");
eprintln!(" Model ready in {:.1}s", build_start.elapsed().as_secs_f64());
let prompt_tokens = tokenize_simple(&gguf, &prompt);
eprintln!(" Prompt: {:?} ({} tokens)", prompt, prompt_tokens.len());
let eos_token = gguf.meta_u32("tokenizer.ggml.eos_token_id").unwrap_or(2);
eprintln!("Generating ({max_tokens} tokens max, temp={:.1})...", params.temperature);
eprintln!();
let gen_start = Instant::now();
let generated = generate(&model, &prompt_tokens, max_tokens, ¶ms, eos_token)
.expect("Generation failed");
let gen_elapsed = gen_start.elapsed();
let output_text = detokenize_simple(&gguf, &generated);
let prompt_text_decoded = detokenize_simple(&gguf, &prompt_tokens);
print!("{prompt_text_decoded}{output_text}");
println!();
let total_tokens = prompt_tokens.len() + generated.len();
let tok_per_sec = generated.len() as f64 / gen_elapsed.as_secs_f64();
eprintln!();
eprintln!(
" Generated {} tokens in {:.2}s ({:.1} tok/s)",
generated.len(),
gen_elapsed.as_secs_f64(),
tok_per_sec,
);
eprintln!(" Prefill: {} tokens, Decode: {} tokens", prompt_tokens.len(), generated.len());
eprintln!(" Total: {total_tokens} tokens");
}
fn tokenize_simple(gguf: &GgufFile, text: &str) -> Vec<u32> {
if let Some(trueno::inference::gguf::MetadataValue::Array(tokens)) =
gguf.metadata.get("tokenizer.ggml.tokens")
{
let vocab: Vec<String> =
tokens.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect();
if !vocab.is_empty() {
return greedy_tokenize(text, &vocab);
}
}
text.bytes().map(|b| b as u32).collect()
}
fn greedy_tokenize(text: &str, vocab: &[String]) -> Vec<u32> {
let normalised = text.replace(' ', "\u{2581}");
let uses_sentencepiece = vocab.iter().any(|t| t.contains('\u{2581}'));
let to_tokenize = if uses_sentencepiece { normalised.as_str() } else { text };
let mut tokens = Vec::new();
let bytes = to_tokenize.as_bytes();
let mut pos = 0;
while pos < bytes.len() {
let mut best_len = 0;
let mut best_id = 0u32;
for (id, token) in vocab.iter().enumerate() {
let tok_bytes = token.as_bytes();
if tok_bytes.len() > best_len
&& pos + tok_bytes.len() <= bytes.len()
&& &bytes[pos..pos + tok_bytes.len()] == tok_bytes
{
best_len = tok_bytes.len();
best_id = id as u32;
}
}
if best_len == 0 {
pos += 1; } else {
tokens.push(best_id);
pos += best_len;
}
}
if !tokens.is_empty() {
let bos = 1u32;
tokens.insert(0, bos);
}
tokens
}
fn detokenize_simple(gguf: &GgufFile, token_ids: &[u32]) -> String {
if let Some(trueno::inference::gguf::MetadataValue::Array(tokens)) =
gguf.metadata.get("tokenizer.ggml.tokens")
{
let vocab: Vec<String> =
tokens.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect();
if !vocab.is_empty() {
return token_ids
.iter()
.filter_map(|&id| vocab.get(id as usize))
.map(|s| s.replace('\u{2581}', " "))
.collect();
}
}
String::from_utf8_lossy(&token_ids.iter().map(|&id| id as u8).collect::<Vec<u8>>()).to_string()
}