use std::num::NonZeroU32;
use std::path::Path;
use anyhow::{Context, Result};
use llama_cpp_4::{
context::params::LlamaContextParams,
llama_backend::LlamaBackend,
llama_batch::LlamaBatch,
model::{params::LlamaModelParams, AddBos, LlamaModel, Special},
sampling::{LlamaSampler, LlamaSamplerParams},
};
use crate::tokens::STOP_TOKEN;
pub const DEFAULT_N_CTX: u32 = 2048;
pub struct BackboneModel {
_backend: LlamaBackend,
model: LlamaModel,
n_ctx: u32,
pub seed: Option<u32>,
}
impl BackboneModel {
pub fn load(path: &Path, n_ctx: u32) -> Result<Self> {
let mut backend = LlamaBackend::init()
.context("Failed to initialise llama.cpp backend")?;
#[cfg(not(feature = "verbose"))]
backend.void_logs();
let model_params = LlamaModelParams::default();
let model = LlamaModel::load_from_file(&backend, path, &model_params)
.with_context(|| format!("Cannot load GGUF model: {}", path.display()))?;
Ok(Self { _backend: backend, model, n_ctx, seed: None })
}
pub fn generate(&self, prompt: &str, max_new_tokens: u32) -> Result<String> {
let ctx_params = LlamaContextParams::default()
.with_n_ctx(NonZeroU32::new(self.n_ctx));
let mut ctx = self.model
.new_context(&self._backend, ctx_params)
.context("Failed to create llama.cpp context")?;
let tokens = self.model
.str_to_token(prompt, AddBos::Always)
.context("Tokenisation failed")?;
eprintln!("[backbone] prompt token count: {} / n_ctx={}", tokens.len(), self.n_ctx);
if tokens.len() as u32 > self.n_ctx {
anyhow::bail!(
"Prompt too long: {} tokens exceeds n_ctx={}. \
Reduce reference code count.",
tokens.len(), self.n_ctx
);
}
if tokens.is_empty() {
return Ok(String::new());
}
let mut batch = LlamaBatch::new(tokens.len().max(1), 1);
let last_idx = tokens.len() - 1;
for (i, &tok) in tokens.iter().enumerate() {
batch
.add(tok, i as i32, &[0], i == last_idx)
.context("Failed to add token to batch")?;
}
ctx.decode(&mut batch).context("Prompt decode failed")?;
let seed = self.seed
.unwrap_or_else(|| LlamaSamplerParams::default().with_seed(rand::random()).seed());
let mut sampler = LlamaSampler::chain_simple([
LlamaSampler::top_k(50),
LlamaSampler::top_p(0.9, 1),
LlamaSampler::temp(1.0), LlamaSampler::dist(seed),
]);
let mut n_cur = tokens.len() as i32;
let max_tokens = n_cur + max_new_tokens as i32;
let mut output = String::new();
loop {
let token = sampler.sample(&ctx, batch.n_tokens() - 1);
sampler.accept(token);
if self.model.is_eog_token(token) {
break;
}
let piece = token_to_piece(&self.model, token)?;
output.push_str(&piece);
if let Some(pos) = output.find(STOP_TOKEN) {
output.truncate(pos);
break;
}
if n_cur >= max_tokens {
break;
}
batch.clear();
batch
.add(token, n_cur, &[0], true)
.context("Failed to add generated token to batch")?;
ctx.decode(&mut batch).context("Decode step failed")?;
n_cur += 1;
}
Ok(output)
}
pub fn generate_streaming<F>(
&self,
prompt: &str,
max_new_tokens: u32,
mut on_piece: F,
) -> Result<()>
where
F: FnMut(&str) -> Result<()>,
{
let ctx_params = LlamaContextParams::default()
.with_n_ctx(NonZeroU32::new(self.n_ctx));
let mut ctx = self.model
.new_context(&self._backend, ctx_params)
.context("Failed to create llama.cpp context")?;
let tokens = self.model
.str_to_token(prompt, AddBos::Always)
.context("Tokenisation failed")?;
eprintln!("[backbone] prompt token count: {} / n_ctx={}", tokens.len(), self.n_ctx);
if tokens.len() as u32 > self.n_ctx {
anyhow::bail!(
"Prompt too long: {} tokens exceeds n_ctx={}. \
Reduce reference code count.",
tokens.len(), self.n_ctx
);
}
if tokens.is_empty() {
return Ok(());
}
let mut batch = LlamaBatch::new(tokens.len().max(1), 1);
let last_idx = tokens.len() - 1;
for (i, &tok) in tokens.iter().enumerate() {
batch
.add(tok, i as i32, &[0], i == last_idx)
.context("Failed to add token to batch")?;
}
ctx.decode(&mut batch).context("Prompt decode failed")?;
let seed = self.seed
.unwrap_or_else(|| LlamaSamplerParams::default().with_seed(rand::random()).seed());
let mut sampler = LlamaSampler::chain_simple([
LlamaSampler::top_k(50),
LlamaSampler::top_p(0.9, 1),
LlamaSampler::temp(1.0),
LlamaSampler::dist(seed),
]);
let mut n_cur = tokens.len() as i32;
let max_cur = n_cur + max_new_tokens as i32;
loop {
let token = sampler.sample(&ctx, batch.n_tokens() - 1);
sampler.accept(token);
if self.model.is_eog_token(token) {
break;
}
let piece = token_to_piece(&self.model, token)?;
if let Some(pos) = piece.find(STOP_TOKEN) {
let before = &piece[..pos];
if !before.is_empty() {
on_piece(before)?;
}
break;
}
on_piece(&piece)?;
if n_cur >= max_cur {
break;
}
batch.clear();
batch
.add(token, n_cur, &[0], true)
.context("Failed to add generated token to batch")?;
ctx.decode(&mut batch).context("Decode step failed")?;
n_cur += 1;
}
Ok(())
}
}
fn token_to_piece(model: &LlamaModel, token: llama_cpp_4::token::LlamaToken) -> Result<String> {
use llama_cpp_4::TokenToStringError;
match model.token_to_str_with_size(token, 64, Special::Tokenize) {
Ok(s) => Ok(s),
Err(TokenToStringError::InsufficientBufferSpace(needed)) => {
let size = needed.unsigned_abs() as usize + 1;
model
.token_to_str_with_size(token, size, Special::Tokenize)
.map_err(|e| anyhow::anyhow!("token decode retry failed: {e}"))
}
Err(e) => Err(anyhow::anyhow!("token decode error: {e}")),
}
}