use llama_crab_sys as sys;
use crate::error::LlamaError;
use crate::token::LlamaToken;
#[derive(Debug)]
pub struct LlamaBatch {
raw: sys::llama_batch,
tokens: Vec<sys::llama_token>,
positions: Vec<sys::llama_pos>,
n_seq_id: Vec<i32>,
seq_ids: Vec<Vec<sys::llama_seq_id>>,
seq_ids_ptrs: Vec<*mut sys::llama_seq_id>,
logits: Vec<i8>,
allocated: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BatchAddError {
InsufficientSpace(usize),
Empty,
}
impl std::fmt::Display for BatchAddError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InsufficientSpace(n) => write!(f, "batch only has space for {n} tokens"),
Self::Empty => write!(f, "no token to add"),
}
}
}
impl std::error::Error for BatchAddError {}
impl LlamaBatch {
#[must_use]
pub fn new(n_tokens: usize, n_seq_max: i32) -> Self {
let tokens = vec![0_i32; n_tokens];
let positions = vec![0_i32; n_tokens];
let n_seq_id = vec![n_seq_max; n_tokens];
let mut seq_ids = Vec::with_capacity(n_tokens);
let mut seq_ids_ptrs: Vec<*mut sys::llama_seq_id> = Vec::with_capacity(n_tokens);
for _ in 0..n_tokens {
let mut v: Vec<i32> = vec![0; n_seq_max as usize];
seq_ids_ptrs.push(v.as_mut_ptr());
seq_ids.push(v);
}
let logits = vec![0_i8; n_tokens];
let raw = sys::llama_batch {
n_tokens: 0,
token: tokens.as_ptr().cast_mut(),
embd: std::ptr::null_mut(),
pos: positions.as_ptr().cast_mut(),
n_seq_id: n_seq_id.as_ptr().cast_mut(),
seq_id: seq_ids_ptrs.as_ptr().cast_mut(),
logits: logits.as_ptr().cast_mut(),
};
Self {
raw,
tokens,
positions,
n_seq_id,
seq_ids,
seq_ids_ptrs,
logits,
allocated: true,
}
}
#[must_use]
pub fn one(token: LlamaToken, pos: i32, seq_id: i32, logits: bool) -> Self {
let mut b = Self::new(1, 1);
b.add(token, pos, &[seq_id], logits).expect("capacity 1");
b
}
#[must_use]
pub fn n_tokens(&self) -> i32 {
self.raw.n_tokens
}
pub fn clear(&mut self) {
self.raw.n_tokens = 0;
}
pub fn add(
&mut self,
token: LlamaToken,
pos: i32,
seq_ids: &[i32],
logits: bool,
) -> std::result::Result<(), BatchAddError> {
let idx = self.raw.n_tokens as usize;
if idx >= self.tokens.len() {
return Err(BatchAddError::InsufficientSpace(self.tokens.len()));
}
if seq_ids.is_empty() {
return Err(BatchAddError::Empty);
}
unsafe {
let mut_ptr = self.tokens.as_ptr().cast_mut();
std::ptr::write(mut_ptr.add(idx), token.0);
let pos_ptr = self.positions.as_ptr().cast_mut();
std::ptr::write(pos_ptr.add(idx), pos);
let logits_ptr = self.logits.as_ptr().cast_mut();
std::ptr::write(logits_ptr.add(idx), i8::from(logits));
}
for (i, &sid) in seq_ids.iter().enumerate() {
if i < self.seq_ids[idx].len() {
self.seq_ids[idx][i] = sid;
}
}
self.raw.n_tokens += 1;
Ok(())
}
pub(crate) fn raw(&self) -> &sys::llama_batch {
&self.raw
}
}
impl From<BatchAddError> for LlamaError {
fn from(e: BatchAddError) -> Self {
Self::Batch(e.to_string())
}
}