llama-cpp-bindings 0.4.2

llama.cpp bindings for Rust
Documentation
#![cfg(feature = "tests_that_use_llms")]

use std::io::Write;
use std::time::Duration;

use anyhow::{Context, Result};
use llama_cpp_bindings::context::params::LlamaContextParams;
use llama_cpp_bindings::ggml_time_us;
use llama_cpp_bindings::llama_backend::LlamaBackend;
use llama_cpp_bindings::llama_batch::LlamaBatch;
use llama_cpp_bindings::model::params::LlamaModelParams;
use llama_cpp_bindings::model::{AddBos, LlamaChatMessage, LlamaModel};
use llama_cpp_bindings::sampling::LlamaSampler;
use llama_cpp_bindings::test_model;

#[test]
fn raw_prompt_completion_with_timing() -> Result<()> {
    let backend = LlamaBackend::init()?;
    let model_params = LlamaModelParams::default();
    let model_path = test_model::download_model()?;
    let model = LlamaModel::load_from_file(&backend, &model_path, &model_params)
        .with_context(|| "unable to load model")?;

    let ctx_params = LlamaContextParams::default();
    let mut ctx = model
        .new_context(&backend, ctx_params)
        .with_context(|| "unable to create context")?;

    let prompt = "Hello my name is";
    let n_len: i32 = 64;

    let tokens_list = model
        .str_to_token(prompt, AddBos::Always)
        .with_context(|| format!("failed to tokenize {prompt}"))?;

    let mut decoder = encoding_rs::UTF_8.new_decoder();

    for token in &tokens_list {
        eprint!(
            "{}",
            model.token_to_piece(*token, &mut decoder, true, None)?
        );
    }
    std::io::stderr().flush()?;

    let mut batch = LlamaBatch::new(512, 1)?;
    let last_index = i32::try_from(tokens_list.len() - 1)?;

    for (index, token) in (0_i32..).zip(tokens_list.into_iter()) {
        let is_last = index == last_index;
        batch.add(token, index, &[0], is_last)?;
    }

    ctx.decode(&mut batch)
        .with_context(|| "llama_decode() failed")?;

    let mut n_cur = batch.n_tokens();
    let mut n_decode: i32 = 0;
    let t_main_start = ggml_time_us();

    let mut sampler =
        LlamaSampler::chain_simple([LlamaSampler::dist(1234), LlamaSampler::greedy()]);

    let mut generated = String::new();

    while n_cur <= n_len {
        let token = sampler.sample(&ctx, batch.n_tokens() - 1)?;

        if model.is_eog_token(token) {
            break;
        }

        let output_string = model.token_to_piece(token, &mut decoder, true, None)?;
        generated.push_str(&output_string);
        print!("{output_string}");
        std::io::stdout().flush()?;

        batch.clear();
        batch.add(token, n_cur, &[0], true)?;
        n_cur += 1;

        ctx.decode(&mut batch).with_context(|| "failed to eval")?;
        n_decode += 1;
    }

    let t_main_end = ggml_time_us();
    let duration = Duration::from_micros(u64::try_from(t_main_end - t_main_start)?);

    #[allow(clippy::cast_precision_loss)]
    let tokens_per_second = n_decode as f32 / duration.as_secs_f32();

    eprintln!(
        "\ndecoded {n_decode} tokens in {:.2} s, speed {tokens_per_second:.2} t/s",
        duration.as_secs_f32(),
    );

    assert!(
        !generated.is_empty(),
        "model should generate at least one token"
    );

    Ok(())
}

#[test]
fn chat_inference_produces_coherent_output() -> Result<()> {
    let model_path = test_model::download_model()?;

    let backend = LlamaBackend::init()?;
    let model_params = LlamaModelParams::default();
    let model = LlamaModel::load_from_file(&backend, &model_path, &model_params)?;
    let context_params = LlamaContextParams::default();
    let mut context = model.new_context(&backend, context_params)?;

    let chat_template = model.chat_template(None)?;
    let messages = vec![LlamaChatMessage::new(
        "user".to_string(),
        "Hello! How are you?".to_string(),
    )?];
    let prompt = model.apply_chat_template(&chat_template, &messages, true)?;

    let tokens = model.str_to_token(&prompt, AddBos::Always)?;
    let mut batch = LlamaBatch::new(512, 1)?;

    let last_index = i32::try_from(tokens.len())? - 1;
    for (position, token) in (0_i32..).zip(tokens.into_iter()) {
        let output_logits = position == last_index;
        batch.add(token, position, &[0], output_logits)?;
    }
    context.decode(&mut batch)?;

    let mut decoder = encoding_rs::UTF_8.new_decoder();
    let mut sampler = LlamaSampler::greedy();
    let mut position = batch.n_tokens();
    let max_tokens = 1024;
    let mut generated = String::new();

    while position <= max_tokens {
        let token = sampler.sample(&context, batch.n_tokens() - 1)?;

        if model.is_eog_token(token) {
            break;
        }

        let piece = model.token_to_piece(token, &mut decoder, true, None)?;
        generated.push_str(&piece);
        print!("{piece}");
        std::io::stdout().flush()?;

        batch.clear();
        batch.add(token, position, &[0], true)?;
        position += 1;

        context.decode(&mut batch)?;
    }

    println!();

    assert!(
        !generated.is_empty(),
        "model should generate at least one token"
    );

    Ok(())
}