use crate::discover::Model;
use crate::engine::{EngineBackend, EngineConfig};
use anyhow::Result;
use lazy_static::lazy_static;
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::{AddBos, LlamaModel, Special};
use llama_cpp_2::sampling::LlamaSampler;
use std::env;
use std::num::NonZeroU32;
lazy_static! {
pub static ref LLAMA_BACKEND: LlamaBackend = LlamaBackend::init().unwrap();
}
unsafe impl Send for LlamaEngine {}
unsafe impl Sync for LlamaEngine {}
pub struct LlamaEngine {
model_info: Model,
model: LlamaModel,
args: EngineConfig,
}
impl EngineBackend for LlamaEngine {
fn new(args: &EngineConfig, model_info: &Model) -> Result<Self> {
let model_params = LlamaModelParams::default();
let model = LlamaModel::load_from_file(&LLAMA_BACKEND, &model_info.path, &model_params)?;
Ok(LlamaEngine {
model,
args: (*args).clone(),
model_info: model_info.clone(),
})
}
fn get_model_info(&self) -> Model {
self.model_info.clone()
}
fn infer(
&self,
prompt: &str,
args: Option<&EngineConfig>,
mut callback: Option<Box<dyn FnMut(String) + Send>>,
) -> Result<String> {
let args = args.unwrap_or(&self.args);
let mut decoder = encoding_rs::UTF_8.new_decoder();
let ctx_params = LlamaContextParams::default()
.with_n_ctx(Some(NonZeroU32::new(args.n_ctx as u32).unwrap()))
.with_n_batch(2048)
.with_n_ubatch(512)
.with_n_threads(
env::var("TLLAMA_THREADS")
.ok()
.and_then(|s| s.parse::<i32>().ok())
.unwrap_or_else(|| {
std::thread::available_parallelism()
.map(|n| n.get() as i32)
.unwrap_or(4)
}),
)
.with_n_threads_batch(
env::var("TLLAMA_THREADS")
.ok()
.and_then(|s| s.parse::<i32>().ok())
.unwrap_or_else(|| {
std::thread::available_parallelism()
.map(|n| n.get() as i32)
.unwrap_or(4)
}),
);
let mut ctx = self.model.new_context(&LLAMA_BACKEND, ctx_params)?;
let tokens_list = self.model.str_to_token(&prompt, AddBos::Always)?;
let mut batch = LlamaBatch::new(tokens_list.len(), 1);
for (i, &token) in tokens_list.iter().enumerate() {
let logits = i == tokens_list.len() - 1;
batch.add(token, i as i32, &[0], logits)?;
}
ctx.decode(&mut batch)?;
let mut sampler = LlamaSampler::chain_simple([
LlamaSampler::temp(args.temperature),
LlamaSampler::top_p(args.top_p, 1),
LlamaSampler::top_k(args.top_k),
LlamaSampler::penalties(64, args.repeat_penalty, 0.0, 0.0),
LlamaSampler::greedy(),
])
.with_tokens(tokens_list.iter().copied());
let mut n_cur = batch.n_tokens();
let mut n_decode = 0;
let mut output = String::new();
let max_tokens = args.n_len.map(|n| n as i32);
while max_tokens.map_or(true, |max| n_decode < max) {
let token = sampler.sample(&ctx, -1);
if self.model.is_eog_token(token) {
break;
}
let output_bytes = self.model.token_to_bytes(token, Special::Tokenize)?;
let mut token_str = String::with_capacity(32);
let _decode_result = decoder.decode_to_string(&output_bytes, &mut token_str, false);
if callback.is_some() {
callback.as_mut().unwrap()(token_str.clone());
}
sampler.accept(token);
batch.clear();
batch.add(token, n_cur as i32, &[0], true)?;
n_cur += 1;
n_decode += 1;
output += &token_str;
ctx.decode(&mut batch)?;
}
Ok(output)
}
}