llama-cpp-v3 0.1.6

Safe and ergonomic Rust wrapper for llama.cpp with dynamic loading
Documentation
use llama_cpp_v3::backend::Backend;
use llama_cpp_v3::{LlamaBackend, LlamaBatch, LlamaContext, LlamaModel, LlamaSampler, LoadOptions};
use std::env;
use std::path::PathBuf;

fn main() {
    let args: Vec<String> = env::args().collect();
    if args.len() < 3 {
        eprintln!("Usage: simple <backend> <model.gguf> [cache_dir]");
        eprintln!("       backend: cpu | cuda | vulkan | hip | sycl | opencl");
        std::process::exit(1);
    }

    let backend_str = &args[1];
    let model_path = &args[2];

    let cache_dir = if args.len() > 3 {
        Some(PathBuf::from(&args[3]))
    } else {
        None
    };

    let backend_type = match backend_str.to_lowercase().as_str() {
        "cpu" => Backend::Cpu,
        "cuda" => Backend::Cuda,
        "vulkan" => Backend::Vulkan,
        "hip" => Backend::Hip,
        "sycl" => Backend::Sycl,
        "opencl" => Backend::OpenCl,
        _ => {
            eprintln!("Unknown backend: {}", backend_str);
            std::process::exit(1);
        }
    };

    println!("Initializing backend: {:?}", backend_type);
    let options = LoadOptions {
        backend: backend_type,
        app_name: "llama-cpp-v3-simple-test",
        version: None,
        explicit_path: None,
        cache_dir,
    };

    let backend = match LlamaBackend::load(options) {
        Ok(b) => b,
        Err(e) => {
            eprintln!("Failed to load backend: {}", e);
            std::process::exit(1);
        }
    };

    println!("Backend loaded successfully. Initializing model...");

    let mut model_params = LlamaModel::default_params(&backend);

    // Set to 0 if we assume CPU test, or conditionally set it based on backend
    if matches!(backend_type, Backend::Cpu) {
        model_params.n_gpu_layers = 0;
    } else {
        model_params.n_gpu_layers = 99;
    }

    println!(
        "Size of llama_model_params: {}",
        std::mem::size_of::<llama_cpp_sys_v3::llama_model_params>()
    );
    println!(
        "Size of llama_context_params: {}",
        std::mem::size_of::<llama_cpp_sys_v3::llama_context_params>()
    );
    println!("Loading from path: {}", model_path);

    let model = match LlamaModel::load_from_file(&backend, model_path, model_params) {
        Ok(m) => m,
        Err(e) => {
            eprintln!("Failed to load model: {}", e);
            std::process::exit(1);
        }
    };

    println!("Model initialized successfully. Creating context...");
    let ctx_params = LlamaContext::default_params(&model);

    let mut ctx = match LlamaContext::new(&model, ctx_params) {
        Ok(c) => c,
        Err(e) => {
            eprintln!("Failed to create context: {}", e);
            std::process::exit(1);
        }
    };

    println!("Context created successfully. Checking n_vocab...");
    let vocab = model.get_vocab();
    let vocab_size = unsafe { (backend.lib.symbols.llama_vocab_n_tokens)(vocab.handle) };
    println!("Model vocab size: {}", vocab_size);

    let prompt = "Why is the sky blue?";
    println!("\nPrompt: {}", prompt);

    let tokens = model
        .tokenize(prompt, true, true)
        .expect("Tokenization failed");

    let mut batch = LlamaBatch::new(backend.lib.clone(), 2048, 0, 1);

    // Add prompt tokens to batch
    for (i, token) in tokens.iter().enumerate() {
        batch.add(
            *token,
            i as llama_cpp_sys_v3::llama_pos,
            &[0],
            i == tokens.len() - 1,
        );
    }

    ctx.decode(&batch).expect("Failed to decode prompt");

    let sampler = LlamaSampler::new_greedy(backend.lib.clone());

    print!("Response: ");
    std::io::Write::flush(&mut std::io::stdout()).unwrap();

    let mut n_cur = tokens.len();
    let mut n_decode = 0;

    while n_decode < 32 {
        let token = sampler.sample(&ctx, -1);

        if vocab.is_eog(token) {
            break;
        }

        let piece = model.token_to_piece(token);
        print!("{}", piece);
        std::io::Write::flush(&mut std::io::stdout()).unwrap();

        // Prepare next batch
        batch.clear();
        batch.add(token, n_cur as llama_cpp_sys_v3::llama_pos, &[0], true);

        ctx.decode(&batch)
            .expect("Failed to decode generated token");

        n_cur += 1;
        n_decode += 1;
    }

    println!("\n\nSuccess! Generated {} tokens.", n_decode);
}