use anyhow::{bail, Context, Result};
use clap::Parser;
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::context::sample::sampler::Sampler;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use bonito::parse_a;
use bonito::parse_q;
use bonito::prepare_prompt;
use bonito::TaskType;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::AddBos;
use llama_cpp_2::model::LlamaModel;
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
use llama_cpp_2::token::LlamaToken;
use std::num::NonZeroU32;
#[derive(clap::Parser, Debug, Clone)]
struct Args {
#[arg(short = 't', long = "test-chunk")]
test_chunk: String,
}
#[tokio::main]
async fn main() -> Result<()> {
let Args { test_chunk } = Args::parse();
let n_len = 1024;
let batch_size = 512;
let task_type = TaskType::ExtractiveQuestionAnswering;
let prompt = prepare_prompt(&test_chunk, &task_type);
let model_repo = "alexandreteles/bonito-v1-gguf";
let model_file = "bonito-v1_q4_k_m.gguf";
let llama_cpp_log = false;
let mut llama_cpp_backend = LlamaBackend::init()?;
if !llama_cpp_log {
llama_cpp_backend.void_logs();
}
let model_params = LlamaModelParams::default();
let hf_hub_api = hf_hub::api::tokio::ApiBuilder::new()
.with_progress(true)
.build()
.with_context(|| "unable to create huggingface api")?;
let hf_model_path = hf_hub_api
.model(model_repo.to_string())
.get(&model_file)
.await?;
let model = LlamaModel::load_from_file(&llama_cpp_backend, &hf_model_path, &model_params)
.with_context(|| "unable to load model")?;
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(model.n_ctx_train()));
let mut ctx = model
.new_context(&llama_cpp_backend, ctx_params)
.with_context(|| "unable to create the llama_context")?;
let tokens_list = model
.str_to_token(&prompt, AddBos::Always)
.with_context(|| format!("failed to tokenize {prompt}"))?;
let n_cxt = ctx.n_ctx() as i32;
let n_kv_req = tokens_list.len() as i32 + (n_len - tokens_list.len() as i32);
if n_kv_req > n_cxt {
bail!(
"n_kv_req > n_ctx, the required kv cache size is not big enough
either reduce n_len or increase n_ctx"
)
}
let tokens_list_len = tokens_list.len();
if tokens_list_len > batch_size {
bail!(format!("the prompt is too long, it has more tokens than batch_size:{batch_size}"))
}
if tokens_list.len() >= usize::try_from(n_len)? {
bail!(format!("the prompt is too long, it has more tokens than n_len:{n_len}"))
}
let mut batch = LlamaBatch::new(batch_size, 1);
let last_index: i32 = (tokens_list.len() - 1) as i32;
for (i, token) in (0_i32..).zip(tokens_list.into_iter()) {
let is_last = i == last_index;
batch.add(token, i, &[0], is_last).with_context(|| format!("failed to add token to batch, is your token list length ({tokens_list_len}) bigger than batch size ({batch_size})?"))?;
}
ctx.decode(&mut batch)
.with_context(|| "llama_decode() failed")?;
let mut n_cur = batch.n_tokens();
let mut completion = String::new();
completion.push_str(&prompt);
let finalizer = &|mut canidates: LlamaTokenDataArray, history: &mut Vec<LlamaToken>| {
canidates.sample_softmax(None);
let token = canidates.data[0];
history.push(token.id());
vec![token]
};
let mut history = vec![];
let mut sampler = Sampler::new(finalizer);
sampler.push_step(&|c, history| c.sample_repetition_penalty(None, history, 64, 1.1, 0.0, 0.0));
sampler.push_step(&|c, _| c.sample_top_k(None, 40, 1));
sampler.push_step(&|c, _| c.sample_tail_free(None, 1.0, 1));
sampler.push_step(&|c, _| c.sample_typical(None, 1.0, 1));
sampler.push_step(&|c, _| c.sample_top_p(None, 0.95, 1));
sampler.push_step(&|c, _| c.sample_min_p(None, 0.05, 1));
sampler.push_step(&|c, _| c.sample_temp(None, 1.0));
while n_cur <= n_len {
{
let candidates = ctx.candidates_ith(batch.n_tokens() - 1);
let candidates_p = LlamaTokenDataArray::from_iter(candidates, false);
let tokens = sampler.sample(&mut history, candidates_p.clone());
let new_token_id = tokens[0].id();
if new_token_id == model.token_eos() {
break;
}
let new_str = model.token_to_str(new_token_id)?;
completion.push_str(&new_str);
batch.clear();
batch.add(new_token_id, n_cur, &[0], true)?;
}
n_cur += 1;
ctx.decode(&mut batch).with_context(|| "failed to eval")?;
}
ctx.clear_kv_cache();
let q = parse_q(&completion, &test_chunk);
if q.is_some() {
println!("q: {}", q.unwrap());
println!("a: {}", parse_a(&completion).unwrap());
} else {
println!("failed to parse q/a, here is the completion:\n{}", &completion);
}
Ok(())
}