use candle_core::{IndexOp, Tensor};
use cortex_rust::model::{BitLlamaConfig, Llama};
use std::fs;
fn main() -> anyhow::Result<()> {
let args: Vec<String> = std::env::args().collect();
let model_path = args
.iter()
.position(|s| s == "--model" || s == "--model-path")
.and_then(|i| args.get(i + 1))
.map(|s| s.as_str())
.unwrap_or("benchmark/tinyllama-1.1b-converted");
let prompt = args
.iter()
.position(|s| s == "--prompt")
.and_then(|i| args.get(i + 1))
.map(|s| s.as_str())
.unwrap_or("Hello, my name is");
let bypass_kv = args.iter().any(|s| s == "--bypass-kv");
let config_str = fs::read_to_string(format!("{}/config.json", model_path))?;
let config: BitLlamaConfig = serde_json::from_str(&config_str)?;
println!("Loading model...");
let mut llama = Llama::load_direct(
format!("{}/model.safetensors", model_path),
format!("{}/tokenizer.json", model_path),
config,
)?;
println!("Model loaded!");
if bypass_kv {
println!("⚠️ Bypass mode enabled: KV cache will use f32 (no quantization)");
llama.model.set_kv_bypass(true);
}
let tokens = llama
.tokenizer
.encode(prompt, true)
.map_err(candle_core::Error::wrap)?;
let mut token_ids: Vec<u32> = tokens.get_ids().to_vec();
println!("\nPrompt: {:?}", prompt);
println!("Token IDs: {:?}", token_ids);
llama.model.reset_kv_cache();
let device = &llama.device;
let input = Tensor::new(&token_ids[..], device)?.unsqueeze(0)?;
println!("\nPrefill input shape: {:?}", input.dims());
let logits = llama.model.forward(&input, &mut llama.w_states)?;
println!("Prefill logits shape: {:?}", logits.dims());
let last_logits = logits.i((0, logits.dim(1)? - 1))?;
let logits_vec: Vec<f32> = last_logits.to_vec1()?;
let mut indexed: Vec<(usize, f32)> = logits_vec
.iter()
.enumerate()
.map(|(i, &v)| (i, v))
.collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
println!("\nTop 5 logits:");
for (i, (idx, val)) in indexed.iter().take(5).enumerate() {
println!(" {}. {} = {:.4}", i + 1, idx, val);
}
let next_token = argmax(&last_logits)?;
println!("\nNext token (greedy): {}", next_token);
let decoded = llama
.tokenizer
.decode(&[next_token], true)
.unwrap_or_else(|_| "<error>".to_string());
println!("Decoded: {:?}", decoded);
token_ids.push(next_token);
for i in 0..2 {
println!("\n--- Generate token {} ---", i + 2);
println!("Current pos before forward: {}", llama.model.current_pos);
let input = Tensor::new(&[*token_ids.last().unwrap()], device)?.unsqueeze(0)?;
let logits = llama.model.forward(&input, &mut llama.w_states)?;
println!("Current pos after forward: {}", llama.model.current_pos);
let last_logits = logits.i((0, logits.dim(1)? - 1))?;
let logits_vec: Vec<f32> = last_logits.to_vec1()?;
let mut indexed: Vec<(usize, f32)> = logits_vec
.iter()
.enumerate()
.map(|(i, &v)| (i, v))
.collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
println!("Top 5 logits:");
for (j, (idx, val)) in indexed.iter().take(5).enumerate() {
println!(" {}. {} = {:.4}", j + 1, idx, val);
}
let next_token = argmax(&last_logits)?;
println!("Next token: {}", next_token);
let decoded = llama
.tokenizer
.decode(&[next_token], true)
.unwrap_or_else(|_| "<error>".to_string());
println!("Decoded: {:?}", decoded);
token_ids.push(next_token);
}
let full_output = llama
.tokenizer
.decode(&token_ids, true)
.unwrap_or_else(|_| "<error>".to_string());
println!("\n=== Full output ===");
println!("{}", full_output);
Ok(())
}
fn argmax(logits: &Tensor) -> candle_core::Result<u32> {
let logits_vec: Vec<f32> = logits.to_vec1()?;
let (idx, _) = logits_vec
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap();
Ok(idx as u32)
}