use std::path::Path;
use anyhow::{Context, Result};
use llama_cpp_2::{
context::params::LlamaContextParams,
llama_backend::LlamaBackend,
llama_batch::LlamaBatch,
model::{AddBos, LlamaModel, params::LlamaModelParams},
sampling::LlamaSampler,
token::LlamaToken,
};
use crate::tokens::STOP_TOKEN;
pub struct LlamaCppBackbone {
_backend: LlamaBackend,
model: LlamaModel,
n_ctx: u32,
pub seed: Option<u32>,
}
impl LlamaCppBackbone {
fn neutts_sampler(seed: u32) -> LlamaSampler {
LlamaSampler::chain_simple([
LlamaSampler::top_k(50),
LlamaSampler::top_p(0.9, 1),
LlamaSampler::temp(1.0),
LlamaSampler::dist(seed),
])
}
pub fn load(path: &Path, n_ctx: u32) -> Result<Self> {
let backend = LlamaBackend::init().context("Failed to initialise llama.cpp backend")?;
let model_params = LlamaModelParams::default().with_n_gpu_layers(0);
let model = LlamaModel::load_from_file(&backend, path, &model_params)
.with_context(|| format!("Cannot load GGUF model: {}", path.display()))?;
Ok(Self {
_backend: backend,
model,
n_ctx,
seed: None,
})
}
pub fn generate_greedy_ids(&self, prompt_ids: &[u32], max_new_tokens: u32) -> Result<Vec<u32>> {
let n_ctx = std::num::NonZeroU32::new(self.n_ctx).context("n_ctx must be non-zero")?;
let ctx_params = LlamaContextParams::default().with_n_ctx(Some(n_ctx));
let mut ctx = self
.model
.new_context(&self._backend, ctx_params)
.context("Failed to create llama.cpp context")?;
if prompt_ids.is_empty() {
return Ok(Vec::new());
}
let mut batch = LlamaBatch::new(prompt_ids.len().max(1), 1);
let last_idx = prompt_ids.len() as i32 - 1;
for (i, &tok) in prompt_ids.iter().enumerate() {
batch.add(LlamaToken(tok as i32), i as i32, &[0], i as i32 == last_idx)?;
}
ctx.decode(&mut batch).context("Prompt decode failed")?;
let mut sampler = LlamaSampler::chain_simple([LlamaSampler::greedy()]);
let mut n_cur = prompt_ids.len() as i32;
let mut out: Vec<u32> = Vec::with_capacity(max_new_tokens as usize);
for _ in 0..max_new_tokens {
let token = sampler.sample(&ctx, batch.n_tokens() - 1);
sampler.accept(token);
if self.model.is_eog_token(token) {
break;
}
out.push(token.0 as u32);
batch.clear();
batch.add(token, n_cur, &[0], true)?;
ctx.decode(&mut batch).context("Decode step failed")?;
n_cur += 1;
}
Ok(out)
}
pub fn generate_greedy(&self, prompt: &str, max_new_tokens: u32) -> Result<String> {
let mut output = String::new();
self.generate_streaming_greedy(prompt, max_new_tokens, |piece| {
output.push_str(piece);
Ok(())
})?;
Ok(output)
}
pub fn generate(&self, prompt: &str, max_new_tokens: u32) -> Result<String> {
let mut output = String::new();
self.generate_streaming(prompt, max_new_tokens, |piece| {
output.push_str(piece);
Ok(())
})?;
Ok(output)
}
pub fn generate_streaming<F>(
&self,
prompt: &str,
max_new_tokens: u32,
mut on_piece: F,
) -> Result<()>
where
F: FnMut(&str) -> Result<()>,
{
let n_ctx = std::num::NonZeroU32::new(self.n_ctx).context("n_ctx must be non-zero")?;
let ctx_params = LlamaContextParams::default().with_n_ctx(Some(n_ctx));
let mut ctx = self
.model
.new_context(&self._backend, ctx_params)
.context("Failed to create llama.cpp context")?;
let tokens = self
.model
.str_to_token(prompt, AddBos::Always)
.context("Tokenisation failed")?;
eprintln!(
"[backbone/llama-cpp] prompt tokens: {} / n_ctx={}",
tokens.len(),
self.n_ctx
);
if tokens.len() as u32 > self.n_ctx {
anyhow::bail!(
"Prompt too long: {} tokens exceeds n_ctx={}",
tokens.len(),
self.n_ctx
);
}
if tokens.is_empty() {
return Ok(());
}
let mut batch = LlamaBatch::new(tokens.len().max(1), 1);
let last_idx = tokens.len() as i32 - 1;
for (i, &tok) in tokens.iter().enumerate() {
batch.add(tok, i as i32, &[0], i as i32 == last_idx)?;
}
ctx.decode(&mut batch).context("Prompt decode failed")?;
let mut decoder = encoding_rs::UTF_8.new_decoder();
let seed = self.seed.unwrap_or_else(rand::random);
let mut sampler = Self::neutts_sampler(seed);
let mut n_cur = tokens.len() as i32;
let max_cur = n_cur + max_new_tokens as i32;
loop {
let token = sampler.sample(&ctx, batch.n_tokens() - 1);
sampler.accept(token);
if self.model.is_eog_token(token) {
break;
}
let piece = self
.model
.token_to_piece(token, &mut decoder, true, None)
.map_err(|e| anyhow::anyhow!("token decode error: {e}"))?;
if let Some(pos) = piece.find(STOP_TOKEN) {
let before = &piece[..pos];
if !before.is_empty() {
on_piece(before)?;
}
break;
}
on_piece(&piece)?;
if n_cur >= max_cur {
break;
}
batch.clear();
batch.add(token, n_cur, &[0], true)?;
ctx.decode(&mut batch).context("Decode step failed")?;
n_cur += 1;
}
Ok(())
}
fn generate_streaming_greedy<F>(
&self,
prompt: &str,
max_new_tokens: u32,
mut on_piece: F,
) -> Result<()>
where
F: FnMut(&str) -> Result<()>,
{
let n_ctx = std::num::NonZeroU32::new(self.n_ctx).context("n_ctx must be non-zero")?;
let ctx_params = LlamaContextParams::default().with_n_ctx(Some(n_ctx));
let mut ctx = self
.model
.new_context(&self._backend, ctx_params)
.context("Failed to create llama.cpp context")?;
let tokens = self
.model
.str_to_token(prompt, AddBos::Always)
.context("Tokenisation failed")?;
if tokens.is_empty() {
return Ok(());
}
let mut batch = LlamaBatch::new(tokens.len().max(1), 1);
let last_idx = tokens.len() as i32 - 1;
for (i, &tok) in tokens.iter().enumerate() {
batch.add(tok, i as i32, &[0], i as i32 == last_idx)?;
}
ctx.decode(&mut batch).context("Prompt decode failed")?;
let mut decoder = encoding_rs::UTF_8.new_decoder();
let mut sampler = LlamaSampler::chain_simple([LlamaSampler::greedy()]);
let mut n_cur = tokens.len() as i32;
let max_cur = n_cur + max_new_tokens as i32;
loop {
let token = sampler.sample(&ctx, batch.n_tokens() - 1);
sampler.accept(token);
if self.model.is_eog_token(token) {
break;
}
let piece = self
.model
.token_to_piece(token, &mut decoder, true, None)
.map_err(|e| anyhow::anyhow!("token decode error: {e}"))?;
if let Some(pos) = piece.find(STOP_TOKEN) {
let before = &piece[..pos];
if !before.is_empty() {
on_piece(before)?;
}
break;
}
on_piece(&piece)?;
if n_cur >= max_cur {
break;
}
batch.clear();
batch.add(token, n_cur, &[0], true)?;
ctx.decode(&mut batch).context("Decode step failed")?;
n_cur += 1;
}
Ok(())
}
}