#![allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
use std::io::Write;
use std::num::NonZeroU32;
use std::path::PathBuf;
use std::time::Duration;
use clap::Parser;
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::model::LlamaModel;
use llama_cpp_2::model::params::LlamaModelParams;
use anyhow::{bail, Context, Result};
use llama_cpp_2::ggml_time_us;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
#[derive(clap::Parser)]
struct Args {
model_path: PathBuf,
#[clap(default_value = "Hello my name is")]
prompt: String,
#[cfg(feature = "cublas")]
#[clap(long)]
disable_gpu: bool,
}
fn main() -> Result<()> {
let params = Args::parse();
let backend = LlamaBackend::init()?;
let n_len: i32 = 32;
let model_params = {
#[cfg(feature = "cublas")]
if !params.disable_gpu {
LlamaModelParams::default().with_n_gpu_layers(1000)
} else {
LlamaModelParams::default()
}
#[cfg(not(feature = "cublas"))]
LlamaModelParams::default()
};
let model = LlamaModel::load_from_file(&backend, params.model_path, &model_params)
.with_context(|| "unable to load model")?;
let ctx_params = LlamaContextParams {
seed: 1234,
n_ctx: NonZeroU32::new(2048),
..LlamaContextParams::default()
};
let mut ctx = model.new_context(&backend, &ctx_params)
.with_context(|| "unable to create the llama_context")?;
let tokens_list = model.str_to_token(¶ms.prompt, true)
.with_context(|| format!("failed to tokenize {}", params.prompt))?;
let n_cxt = ctx.n_ctx() as i32;
let n_kv_req = tokens_list.len() as i32 + (n_len - tokens_list.len() as i32);
eprintln!("n_len = {n_len}, n_ctx = {n_cxt}, k_kv_req = {n_kv_req}");
if n_kv_req > n_cxt {
bail!("n_kv_req > n_ctx, the required kv cache size is not big enough
either reduce n_len or increase n_ctx")
}
eprintln!();
for token in &tokens_list {
eprint!("{}", model.token_to_str(*token)?);
}
std::io::stderr().flush()?;
let mut batch = LlamaBatch::new(512, 1);
let last_index: i32 = (tokens_list.len() - 1) as i32;
for (i, token) in (0_i32..).zip(tokens_list.into_iter()) {
let is_last = i == last_index;
batch.add(token, i, &[0], is_last);
}
ctx.decode(&mut batch)
.with_context(|| "llama_decode() failed")?;
let mut n_cur = batch.n_tokens();
let mut n_decode = 0;
let t_main_start = ggml_time_us();
while n_cur <= n_len {
{
let candidates = ctx.candidates_ith(batch.n_tokens() - 1);
let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);
let new_token_id = ctx.sample_token_greedy(candidates_p);
if new_token_id == model.token_eos() {
eprintln!();
break;
}
print!("{}", model.token_to_str(new_token_id)?);
std::io::stdout().flush()?;
batch.clear();
batch.add(new_token_id, n_cur, &[0], true);
}
n_cur += 1;
ctx.decode(&mut batch).with_context(|| "failed to eval")?;
n_decode += 1;
}
eprintln!("\n");
let t_main_end = ggml_time_us();
let duration = Duration::from_micros((t_main_end - t_main_start) as u64);
eprintln!("decoded {} tokens in {:.2} s, speed {:.2} t/s\n", n_decode, duration.as_secs_f32(), n_decode as f32 / duration.as_secs_f32());
println!("{}", ctx.timings());
Ok(())
}