use llama_cpp_v3::backend::Backend;
use llama_cpp_v3::{LlamaBackend, LlamaBatch, LlamaContext, LlamaModel, LlamaSampler, LoadOptions};
use std::env;
use std::path::PathBuf;
fn main() {
let args: Vec<String> = env::args().collect();
if args.len() < 3 {
eprintln!("Usage: simple <backend> <model.gguf> [cache_dir]");
eprintln!(" backend: cpu | cuda | vulkan | hip | sycl | opencl");
std::process::exit(1);
}
let backend_str = &args[1];
let model_path = &args[2];
let cache_dir = if args.len() > 3 {
Some(PathBuf::from(&args[3]))
} else {
None
};
let backend_type = match backend_str.to_lowercase().as_str() {
"cpu" => Backend::Cpu,
"cuda" => Backend::Cuda,
"vulkan" => Backend::Vulkan,
"hip" => Backend::Hip,
"sycl" => Backend::Sycl,
"opencl" => Backend::OpenCl,
_ => {
eprintln!("Unknown backend: {}", backend_str);
std::process::exit(1);
}
};
println!("Initializing backend: {:?}", backend_type);
let options = LoadOptions {
backend: backend_type,
app_name: "llama-cpp-v3-simple-test",
version: None,
explicit_path: None,
cache_dir,
};
let backend = match LlamaBackend::load(options) {
Ok(b) => b,
Err(e) => {
eprintln!("Failed to load backend: {}", e);
std::process::exit(1);
}
};
println!("Backend loaded successfully. Initializing model...");
let mut model_params = LlamaModel::default_params(&backend);
if matches!(backend_type, Backend::Cpu) {
model_params.n_gpu_layers = 0;
} else {
model_params.n_gpu_layers = 99;
}
println!(
"Size of llama_model_params: {}",
std::mem::size_of::<llama_cpp_sys_v3::llama_model_params>()
);
println!(
"Size of llama_context_params: {}",
std::mem::size_of::<llama_cpp_sys_v3::llama_context_params>()
);
println!("Loading from path: {}", model_path);
let model = match LlamaModel::load_from_file(&backend, model_path, model_params) {
Ok(m) => m,
Err(e) => {
eprintln!("Failed to load model: {}", e);
std::process::exit(1);
}
};
println!("Model initialized successfully. Creating context...");
let ctx_params = LlamaContext::default_params(&model);
let mut ctx = match LlamaContext::new(&model, ctx_params) {
Ok(c) => c,
Err(e) => {
eprintln!("Failed to create context: {}", e);
std::process::exit(1);
}
};
println!("Context created successfully. Checking n_vocab...");
let vocab = model.get_vocab();
let vocab_size = unsafe { (backend.lib.symbols.llama_vocab_n_tokens)(vocab.handle) };
println!("Model vocab size: {}", vocab_size);
let prompt = "Why is the sky blue?";
println!("\nPrompt: {}", prompt);
let tokens = model
.tokenize(prompt, true, true)
.expect("Tokenization failed");
let mut batch = LlamaBatch::new(backend.lib.clone(), 2048, 0, 1);
for (i, token) in tokens.iter().enumerate() {
batch.add(
*token,
i as llama_cpp_sys_v3::llama_pos,
&[0],
i == tokens.len() - 1,
);
}
ctx.decode(&batch).expect("Failed to decode prompt");
let sampler = LlamaSampler::new_greedy(backend.lib.clone());
print!("Response: ");
std::io::Write::flush(&mut std::io::stdout()).unwrap();
let mut n_cur = tokens.len();
let mut n_decode = 0;
while n_decode < 32 {
let token = sampler.sample(&ctx, -1);
if vocab.is_eog(token) {
break;
}
let piece = model.token_to_piece(token);
print!("{}", piece);
std::io::Write::flush(&mut std::io::stdout()).unwrap();
batch.clear();
batch.add(token, n_cur as llama_cpp_sys_v3::llama_pos, &[0], true);
ctx.decode(&batch)
.expect("Failed to decode generated token");
n_cur += 1;
n_decode += 1;
}
println!("\n\nSuccess! Generated {} tokens.", n_decode);
}