llama-cpp-bindings 0.4.2

llama.cpp bindings for Rust
Documentation
#![cfg(all(feature = "tests_that_use_llms", feature = "mtmd"))]

use std::num::NonZeroU32;

use anyhow::{Context, Result};
use llama_cpp_bindings::context::params::LlamaContextParams;
use llama_cpp_bindings::llama_backend::LlamaBackend;
use llama_cpp_bindings::model::params::LlamaModelParams;
use llama_cpp_bindings::model::{LlamaChatMessage, LlamaModel};
use llama_cpp_bindings::mtmd::{MtmdBitmap, MtmdContext, MtmdContextParams, MtmdInputText};
use llama_cpp_bindings::sampling::LlamaSampler;
use llama_cpp_bindings::test_model;

#[test]
fn multimodal_vision_inference_produces_output() -> Result<()> {
    let model_path = test_model::download_model()?;
    let mmproj_path = test_model::download_mmproj()?;

    let backend = LlamaBackend::init()?;
    let model_params = LlamaModelParams::default();
    let model = LlamaModel::load_from_file(&backend, &model_path, &model_params)
        .with_context(|| "unable to load model")?;

    let n_ctx = NonZeroU32::new(4096);
    let ctx_params = LlamaContextParams::default()
        .with_n_ctx(n_ctx)
        .with_n_batch(512);
    let mut ctx = model
        .new_context(&backend, ctx_params)
        .with_context(|| "unable to create llama context")?;

    let mtmd_params = MtmdContextParams::default();
    let mmproj_path_str = mmproj_path
        .to_str()
        .with_context(|| "mmproj path is not valid UTF-8")?;
    let mtmd_ctx = MtmdContext::init_from_file(mmproj_path_str, &model, &mtmd_params)
        .with_context(|| "unable to create mtmd context")?;

    assert!(
        mtmd_ctx.support_vision(),
        "model should support vision input"
    );

    let image_path = test_model::fixtures_dir().join("llamas.jpg");
    let image_path_str = image_path
        .to_str()
        .with_context(|| "image path is not valid UTF-8")?;
    let bitmap = MtmdBitmap::from_file(&mtmd_ctx, image_path_str)
        .with_context(|| "failed to load image from file")?;

    let marker = llama_cpp_bindings::mtmd::mtmd_default_marker();
    let user_content = format!("{marker}What animals do you see in this image?");
    let chat_template = model.chat_template(None)?;
    let messages = [LlamaChatMessage::new("user".to_string(), user_content)?];
    let formatted_prompt = model.apply_chat_template(&chat_template, &messages, true)?;

    let input_text = MtmdInputText {
        text: formatted_prompt,
        add_special: false,
        parse_special: true,
    };

    let chunks = mtmd_ctx
        .tokenize(input_text, &[&bitmap])
        .with_context(|| "failed to tokenize multimodal input")?;

    assert!(
        !chunks.is_empty(),
        "tokenization should produce at least one chunk"
    );

    let total_tokens = chunks.total_tokens();
    eprintln!(
        "tokenized into {} chunks, {} total tokens",
        chunks.len(),
        total_tokens
    );

    let n_past = chunks
        .eval_chunks(&mtmd_ctx, &ctx, 0, 0, 512, true)
        .with_context(|| "failed to evaluate chunks")?;

    eprintln!("evaluated chunks, n_past = {n_past}");

    let mut sampler = LlamaSampler::greedy();
    let mut generated = String::new();
    let mut decoder = encoding_rs::UTF_8.new_decoder();
    let max_tokens = 512;

    let mut batch = llama_cpp_bindings::llama_batch::LlamaBatch::new(512, 1)?;
    let mut current_position = n_past;

    for _ in 0..max_tokens {
        let token = sampler.sample(&ctx, -1)?;

        if model.is_eog_token(token) {
            break;
        }

        let output_string = model
            .token_to_piece(token, &mut decoder, false, None)
            .with_context(|| "failed to convert token to piece")?;
        generated.push_str(&output_string);

        batch.clear();
        batch.add(token, current_position, &[0], true)?;
        current_position += 1;

        ctx.decode(&mut batch)
            .with_context(|| "failed to decode generated token")?;
    }

    eprintln!("generated text: {generated}");

    assert!(
        !generated.is_empty(),
        "model should generate at least one token from image input"
    );

    Ok(())
}