1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
//! Safe wrapper around `llama_batch`.
use crate::token::LlamaToken;
use llama_cpp_sys_2::{llama_batch, llama_batch_free, llama_batch_init, llama_pos, llama_seq_id};
/// A safe wrapper around `llama_batch`.
#[derive(Debug)]
pub struct LlamaBatch {
/// The number of tokens the batch was allocated with. they are safe to write to - but not necessarily read from as they are not necessarily initilized
allocated: usize,
/// The logits that are initilized. Used by [`LlamaContext`] to ensure that only initilized logits are accessed.
pub(crate) initialized_logits: Vec<i32>,
/// The llama_cpp batch. always initilize by `llama_cpp_sys_2::llama_batch_init(allocated, <unknown>, <unknown>)`
pub(crate) llama_batch: llama_batch,
}
impl LlamaBatch {
/// Clear the batch. This does not free the memory associated with the batch, but it does reset
/// the number of tokens to 0.
pub fn clear(&mut self) {
self.llama_batch.n_tokens = 0;
self.initialized_logits.clear();
}
/// add a token to the batch for sequences [`seq_ids`] at position [pos]. If [logits] is true, the
/// token will be initilized and can be read from after the next decode.
///
/// # Panics
///
/// - [`self.llama_batch.n_tokens`] does not fit into a usize
/// - [`seq_ids.len()`] does not fit into a [`llama_seq_id`]
pub fn add(
&mut self,
LlamaToken(id): LlamaToken,
pos: llama_pos,
seq_ids: &[i32],
logits: bool,
) {
assert!(self.allocated > (usize::try_from(self.n_tokens() + 1).expect("self.n_tokens does not fit into a usize")), "there are only {} tokens allocated for the batch, but {} tokens in the batch when you tried to add one", self.allocated, self.n_tokens());
let offset = self.llama_batch.n_tokens;
let offset_usize = usize::try_from(offset).expect("cannot fit n_tokens into a usize");
unsafe {
// batch.token [batch.n_tokens] = id;
self.llama_batch.token.add(offset_usize).write(id);
// batch.pos [batch.n_tokens] = pos,
self.llama_batch.pos.add(offset_usize).write(pos);
// batch.n_seq_id[batch.n_tokens] = seq_ids.size();
self.llama_batch.n_seq_id.add(offset_usize).write(llama_seq_id::try_from(seq_ids.len())
.expect("cannot fit seq_ids.len() into a llama_seq_id"));
// for (size_t i = 0; i < seq_ids.size(); ++i) {
// batch.seq_id[batch.n_tokens][i] = seq_ids[i];
// }
for (i, seq_id) in seq_ids.iter().enumerate() {
let tmp = *self.llama_batch.seq_id.add(offset_usize);
tmp.add(i).write(*seq_id);
}
// batch.logits [batch.n_tokens] = logits;
self.llama_batch.logits.add(offset_usize).write(i8::from(logits));
}
if logits {
self.initialized_logits.push(offset);
} else {
self.initialized_logits.retain(|l| l != &offset);
}
// batch.n_tokens++;
self.llama_batch.n_tokens += 1;
}
/// Create a new `LlamaBatch` that cab contain up to `n_tokens` tokens.
///
/// # Arguments
///
/// - `n_tokens`: the maximum number of tokens that can be added to the batch
/// - `n_seq_max`: the maximum number of sequences that can be added to the batch (generally 1 unless you know what you are doing)
///
/// # Panics
///
/// Panics if `n_tokens` is greater than `i32::MAX`.
#[must_use]
pub fn new(n_tokens: usize, n_seq_max: i32) -> Self {
let n_tokens_i32 = i32::try_from(n_tokens).expect("cannot fit n_tokens into a i32");
let batch = unsafe { llama_batch_init(n_tokens_i32, 0, n_seq_max) };
LlamaBatch {
allocated: n_tokens,
initialized_logits: vec![],
llama_batch: batch,
}
}
/// Returns the number of tokens in the batch.
#[must_use]
pub fn n_tokens(&self) -> i32 {
self.llama_batch.n_tokens
}
}
impl Drop for LlamaBatch {
/// Drops the `LlamaBatch`.
///
/// ```
/// # use llama_cpp_2::llama_batch::LlamaBatch;
/// # use std::error::Error;
/// # fn main() -> Result<(), Box<dyn Error>> {
/// let batch = LlamaBatch::new(512, 1);
/// // frees the memory associated with the batch. (allocated by llama.cpp)
/// drop(batch);
/// # Ok(())
/// # }
fn drop(&mut self) {
unsafe {
llama_batch_free(self.llama_batch);
}
}
}