use std::{num::NonZeroU32, path::Path};
use anyhow::{Context, Result};
use llama_cpp_4::{
context::params::LlamaContextParams,
llama_backend::LlamaBackend,
llama_batch::LlamaBatch,
model::{params::LlamaModelParams, AddBos, LlamaModel, Special},
token::LlamaToken,
};
use tracing::info;
pub struct LlamaCppBackend {
pub llama_backend: LlamaBackend,
pub model: LlamaModel,
pub n_vocab: usize,
pub n_embd: usize,
pub ctx_size: usize,
}
impl LlamaCppBackend {
pub fn load(gguf_path: &Path, n_gpu_layers: u32, ctx_size: usize) -> Result<Self> {
info!("Loading GGUF model from {:?}", gguf_path);
let mut llama_backend = LlamaBackend::init()
.map_err(|e| anyhow::anyhow!("LlamaBackend::init failed: {e}"))?;
llama_backend.void_logs();
let model_params = LlamaModelParams::default()
.with_n_gpu_layers(n_gpu_layers);
let model = LlamaModel::load_from_file(&llama_backend, gguf_path, &model_params)
.with_context(|| format!("Failed to load model from {gguf_path:?}"))?;
let n_vocab = usize::try_from(model.n_vocab())
.context("n_vocab overflows usize")?;
let n_embd = usize::try_from(model.n_embd())
.context("n_embd overflows usize")?;
info!(
" vocab={n_vocab} embd={n_embd} ctx_train={}",
model.n_ctx_train()
);
Ok(Self { llama_backend, model, n_vocab, n_embd, ctx_size })
}
pub fn tokenize(&self, text: &str, add_bos: bool) -> Result<Vec<LlamaToken>> {
let bos = if add_bos { AddBos::Always } else { AddBos::Never };
self.model
.str_to_token(text, bos)
.map_err(|e| anyhow::anyhow!("tokenise failed: {e}"))
}
pub fn detokenize(&self, tokens: &[LlamaToken]) -> String {
self.model
.tokens_to_str(tokens, Special::Tokenize)
.unwrap_or_default()
}
pub fn eos_token(&self) -> LlamaToken { self.model.token_eos() }
pub fn bos_token(&self) -> LlamaToken { self.model.token_bos() }
pub fn answer_logits(
&self,
prompt_tokens: &[LlamaToken],
answer_tokens: &[LlamaToken],
) -> Result<Vec<Vec<f32>>> {
if answer_tokens.is_empty() {
return Ok(vec![]);
}
let p_len = prompt_tokens.len();
let a_len = answer_tokens.len();
let total = p_len + a_len;
let n_ctx = (total + 1).max(self.ctx_size);
let ctx_params = LlamaContextParams::default()
.with_n_ctx(NonZeroU32::new(n_ctx as u32));
let mut ctx = self.model
.new_context(&self.llama_backend, ctx_params)
.map_err(|e| anyhow::anyhow!("context creation failed: {e}"))?;
let mut batch = LlamaBatch::new(total, 1);
for (i, &token) in prompt_tokens.iter().enumerate() {
batch.add(token, i as i32, &[0], i == p_len - 1)?;
}
for (i, &token) in answer_tokens.iter().enumerate() {
let need = i < a_len - 1;
batch.add(token, (p_len + i) as i32, &[0], need)?;
}
ctx.decode(&mut batch)
.map_err(|e| anyhow::anyhow!("decode failed: {e}"))?;
let mut result = Vec::with_capacity(a_len);
result.push(ctx.get_logits_ith((p_len as i32) - 1).to_vec());
for i in 0..(a_len - 1) {
result.push(ctx.get_logits_ith((p_len + i) as i32).to_vec());
}
Ok(result)
}
pub fn generate(
&self,
prompt_tokens: &[LlamaToken],
max_new_tokens: usize,
logit_bias: Option<&[f32]>,
) -> Result<Vec<LlamaToken>> {
if prompt_tokens.is_empty() {
return Ok(vec![]);
}
let capacity = prompt_tokens.len() + max_new_tokens + 1;
let n_ctx = capacity.max(self.ctx_size);
let ctx_params = LlamaContextParams::default()
.with_n_ctx(NonZeroU32::new(n_ctx as u32));
let mut ctx = self.model
.new_context(&self.llama_backend, ctx_params)
.map_err(|e| anyhow::anyhow!("context creation failed: {e}"))?;
let mut batch = LlamaBatch::new(capacity, 1);
let p_last = (prompt_tokens.len() as i32) - 1;
for (i, &token) in prompt_tokens.iter().enumerate() {
batch.add(token, i as i32, &[0], i as i32 == p_last)?;
}
ctx.decode(&mut batch)
.map_err(|e| anyhow::anyhow!("decode (prompt) failed: {e}"))?;
let mut generated = Vec::with_capacity(max_new_tokens);
let mut n_cur: i32 = prompt_tokens.len() as i32;
let mut logit_idx = p_last;
for _ in 0..max_new_tokens {
let logits = ctx.get_logits_ith(logit_idx);
let next = greedy_sample(logits, logit_bias);
if self.model.is_eog_token(next) {
break;
}
generated.push(next);
batch.clear();
batch.add(next, n_cur, &[0], true)?;
ctx.decode(&mut batch)
.map_err(|e| anyhow::anyhow!("decode (step {n_cur}) failed: {e}"))?;
logit_idx = 0; n_cur += 1;
}
Ok(generated)
}
}
fn greedy_sample(logits: &[f32], bias: Option<&[f32]>) -> LlamaToken {
let idx = match bias {
None => logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0),
Some(b) => logits
.iter()
.zip(b.iter())
.map(|(l, bias)| l + bias)
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0),
};
LlamaToken::new(idx as i32)
}