use futures::{channel::mpsc, SinkExt, StreamExt};
use llama_cpp_2::{
context::params::LlamaContextParams, llama_backend::LlamaBackend, llama_batch::LlamaBatch,
model::LlamaModel, sampling::LlamaSampler,
};
use std::{future::Future, num::NonZeroU32};
use crate::{prelude::*, template};
#[derive(Debug)]
pub struct LlamaCpp
{
server_request_tx: mpsc::Sender<(Prompt, mpsc::Sender<Result<String>>)>,
}
pub struct LlamaCppParameters
{
#[cfg_attr(not(feature = "gpu"), allow(dead_code))]
disable_gpu: bool,
#[cfg_attr(not(feature = "gpu"), allow(dead_code))]
n_gpu_layers: u32,
context_size: u32,
seed: u32,
model_uri: String,
template: String,
threads: Option<i32>,
threads_batch: Option<i32>,
}
impl LlamaCppParameters
{
pub fn n_gpu_layers(mut self, n_gpu_layers: u32) -> LlamaCppParameters
{
self.n_gpu_layers = n_gpu_layers;
self
}
}
impl Default for LlamaCppParameters
{
fn default() -> Self
{
Self {
disable_gpu: false,
n_gpu_layers: 20,
context_size: 2048,
seed: 1234,
model_uri:
"hf://ggml-org/Meta-Llama-3.1-8B-Instruct-Q4_0-GGUF/meta-llama-3.1-8b-instruct-q4_0.gguf"
.into(),
template: include_str!("../data/templates/llama").into(),
threads: None,
threads_batch: None,
}
}
}
impl LlamaCpp
{
pub fn from_model(model_uri: impl Into<String>) -> Result<Self>
{
let model_uri = model_uri.into();
Self::new(LlamaCppParameters {
model_uri,
..Default::default()
})
}
pub fn new(params: LlamaCppParameters) -> Result<Self>
{
let backend = LlamaBackend::init()?;
let model_params = {
#[cfg(feature = "gpu")]
if !params.disable_gpu
{
llama_cpp_2::model::params::LlamaModelParams::default()
.with_n_gpu_layers(params.n_gpu_layers)
}
else
{
llama_cpp_2::model::params::LlamaModelParams::default()
}
#[cfg(not(feature = "gpu"))]
llama_cpp_2::model::params::LlamaModelParams::default()
};
let mut ctx_params =
LlamaContextParams::default().with_n_ctx(NonZeroU32::new(params.context_size));
if let Some(threads) = params.threads
{
ctx_params = ctx_params.with_n_threads(threads);
}
if let Some(threads_batch) = params.threads_batch.or(params.threads)
{
ctx_params = ctx_params.with_n_threads_batch(threads_batch);
}
let model_path = if params.model_uri.starts_with("hf://")
{
let last_slash = params.model_uri[5..]
.rfind("/")
.ok_or(Error::HfInvalidUri)?
+ 5;
let model_repo = ¶ms.model_uri[5..last_slash];
let model_name = ¶ms.model_uri[(last_slash + 1)..];
hf_hub::api::sync::ApiBuilder::new()
.with_progress(true)
.build()?
.model(model_repo.to_string())
.get(&model_name)?
}
else
{
params.model_uri.to_owned().into()
};
let model = LlamaModel::load_from_file(&backend, model_path, &model_params)?;
let template = template::Template::new(params.template)?;
let (server_request_tx, mut server_request_rx) =
mpsc::channel::<(Prompt, mpsc::Sender<Result<String>>)>(10);
std::thread::spawn(move || {
futures::executor::block_on(async {
while let Some((prompt, mut answer_tx)) = server_request_rx.next().await
{
let r = async {
let mut ctx = model.new_context(&backend, ctx_params.clone())?;
let prompt_str = template.render(
Some(&prompt.prompt),
prompt.assistant.as_ref().map(|x| x.as_str()),
prompt.system.as_ref().map(|x| x.as_str()),
)?;
let tokens_list =
model.str_to_token(&prompt_str, llama_cpp_2::model::AddBos::Always)?;
let mut batch = LlamaBatch::new(ctx_params.n_batch() as usize, 1);
let last_index: i32 = (tokens_list.len() - 1) as i32;
for (i, token) in (0_i32..).zip(tokens_list.into_iter())
{
let is_last = i == last_index;
batch.add(token, i, &[0], is_last)?;
}
ctx.decode(&mut batch)?;
let mut n_cur = batch.n_tokens();
let mut decoder = encoding_rs::UTF_8.new_decoder();
let mut sampler = match prompt.format
{
Format::Text => LlamaSampler::chain_simple([
LlamaSampler::dist(params.seed),
LlamaSampler::greedy(),
]),
Format::Json => LlamaSampler::chain_simple([
LlamaSampler::grammar(&model, include_str!("../data/grammar/json.gbnf"), "root"),
LlamaSampler::min_p(0.05, 1),
LlamaSampler::temp(0.8),
LlamaSampler::dist(params.seed),
]),
};
loop
{
let token = sampler.sample(&ctx, batch.n_tokens() - 1);
sampler.accept(token);
if model.is_eog_token(token)
{
return Ok(());
}
let output_bytes =
model.token_to_bytes(token, llama_cpp_2::model::Special::Tokenize)?;
let mut output_string = String::with_capacity(32);
let _decode_result =
decoder.decode_to_string(&output_bytes, &mut output_string, false);
answer_tx.send(Ok(output_string)).await?;
batch.clear();
batch.add(token, n_cur, &[0], true)?;
n_cur += 1;
ctx.decode(&mut batch)?;
}
}
.await;
if let Err(e) = r
{
let tx_r = answer_tx.send(Err(e)).await;
if let Err(e) = tx_r
{
log::error!("Failed to send error, with error {:?}", e);
}
}
}
});
});
Ok(Self { server_request_tx })
}
}
impl Default for LlamaCpp
{
fn default() -> Self
{
Self::new(Default::default()).unwrap()
}
}
impl LargeLanguageModel for LlamaCpp
{
fn infer_stream(
&self,
prompt: Prompt,
) -> Result<impl Future<Output = Result<StringStream>> + Send>
{
Ok(async {
let (tx, rx) = mpsc::channel(20);
self.server_request_tx.clone().send((prompt, tx)).await?;
Ok(pin_stream(rx))
})
}
}