use infrastructure_llama_cpp::context::params::LlamaModelContextParams;
use infrastructure_llama_cpp::llama_backend::LlamaBackend;
use infrastructure_llama_cpp::llama_batch::LlamaBatch;
use infrastructure_llama_cpp::model::params::LlamaModelParams;
use infrastructure_llama_cpp::model::LlamaModel;
use infrastructure_llama_cpp::model::{AddBos, Special};
use infrastructure_llama_cpp::sampling::LlamaSampler;
use std::io::Write;
#[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
fn main() {
let model_path = std::env::args().nth(1).expect("Please specify model path");
let backend = LlamaBackend::init_or_get().unwrap();
let params = LlamaModelParams::default();
let prompt =
"<|im_start|>user\nHello! how are you?<|im_end|>\n<|im_start|>assistant\n".to_string();
LlamaModelContextParams::default();
let model =
LlamaModel::load_from_file(&backend, model_path, ¶ms).expect("unable to load model");
let ctx_params = LlamaModelContextParams::default();
let mut ctx = model
.new_context(&backend, ctx_params)
.expect("unable to create the llama_context");
let tokens_list = model
.str_to_token(&prompt, AddBos::Always)
.unwrap_or_else(|_| panic!("failed to tokenize {prompt}"));
let n_len = 64;
let mut batch = LlamaBatch::new(512, 1);
let last_index = tokens_list.len() as i32 - 1;
for (i, token) in (0_i32..).zip(tokens_list.into_iter()) {
let is_last = i == last_index;
batch.add(token, i, &[0], is_last).unwrap();
}
ctx.decode(&mut batch).expect("llama_decode() failed");
let mut n_cur = batch.n_tokens();
let mut decoder = encoding_rs::UTF_8.new_decoder();
let mut sampler = LlamaSampler::greedy();
while n_cur <= n_len {
{
let token = sampler.sample(&ctx, batch.n_tokens() - 1);
sampler.accept(token);
if token == model.token_eos() {
eprintln!();
break;
}
let output_bytes = model.token_to_bytes(token, Special::Tokenize).unwrap();
let mut output_string = String::with_capacity(32);
let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
print!("{output_string}");
std::io::stdout().flush().unwrap();
batch.clear();
batch.add(token, n_cur, &[0], true).unwrap();
}
n_cur += 1;
ctx.decode(&mut batch).expect("failed to eval");
}
}