use realizar::gguf::{MappedGGUFModel, OwnedQuantizedKVCache, OwnedQuantizedModel};
use std::env;
fn main() {
let args: Vec<String> = env::args().collect();
let path = args.get(1).map_or(
"/home/noah/src/aprender/tinyllama-1.1b-chat-v1.0.Q4_0.gguf",
|s| s.as_str(),
);
let mapped = MappedGGUFModel::from_path(path).expect("Failed to load model");
let model = OwnedQuantizedModel::from_mapped(&mapped).expect("test");
let vocab = mapped.model.vocabulary().expect("test");
let prompt = "Once upon a time";
let prompt_tokens = mapped.model.encode(prompt).expect("test");
println!("Prompt: '{}'", prompt);
println!("Tokens: {:?}", prompt_tokens);
let max_seq_len = 256;
let head_dim = model.config().hidden_dim / model.config().num_heads;
let kv_dim = model.config().num_kv_heads * head_dim;
let mut cache = OwnedQuantizedKVCache::new(model.config().num_layers, kv_dim, max_seq_len);
let mut logits = Vec::new();
for (pos, &tok) in prompt_tokens.iter().enumerate() {
logits = model
.forward_single_with_cache(tok, &mut cache, pos)
.expect("test");
}
let mut generated_tokens = prompt_tokens.clone();
for i in 0..20 {
let (best_idx, _best_logit) = logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.expect("test");
let tok_str = if best_idx < vocab.len() {
&vocab[best_idx]
} else {
"?"
};
print!("{}", tok_str.replace("▁", " ").replace('\u{0120}', " "));
generated_tokens.push(best_idx as u32);
let pos = prompt_tokens.len() + i;
logits = model
.forward_single_with_cache(best_idx as u32, &mut cache, pos)
.expect("test");
}
println!("\n");
println!("Full tokens: {:?}", generated_tokens);
}