Skip to main content

llama_crab/
batch.rs

1//! Reusable batching primitive.
2
3use llama_crab_sys as sys;
4
5use crate::error::LlamaError;
6use crate::token::LlamaToken;
7
8/// A `llama_batch` wrapper. Owns the underlying C struct and its memory.
9#[derive(Debug)]
10pub struct LlamaBatch {
11    raw: sys::llama_batch,
12    // The C struct borrows from these vectors; we keep them alive.
13    tokens: Vec<sys::llama_token>,
14    positions: Vec<sys::llama_pos>,
15    n_seq_id: Vec<i32>,
16    seq_ids: Vec<Vec<sys::llama_seq_id>>,
17    seq_ids_ptrs: Vec<*mut sys::llama_seq_id>,
18    logits: Vec<i8>,
19    allocated: bool,
20}
21
22/// Reason a token could not be added to the batch.
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub enum BatchAddError {
25    /// The batch is full.
26    InsufficientSpace(usize),
27    /// Attempted to add nothing.
28    Empty,
29}
30
31impl std::fmt::Display for BatchAddError {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        match self {
34            Self::InsufficientSpace(n) => write!(f, "batch only has space for {n} tokens"),
35            Self::Empty => write!(f, "no token to add"),
36        }
37    }
38}
39
40impl std::error::Error for BatchAddError {}
41
42impl LlamaBatch {
43    /// Allocate a batch that can hold up to `n_tokens`.
44    ///
45    /// `n_seq_max` is the maximum number of sequences any single token can
46    /// belong to (typically 1 for single-stream inference).
47    #[must_use]
48    pub fn new(n_tokens: usize, n_seq_max: i32) -> Self {
49        let tokens = vec![0_i32; n_tokens];
50        let positions = vec![0_i32; n_tokens];
51        let n_seq_id = vec![n_seq_max; n_tokens];
52        let mut seq_ids = Vec::with_capacity(n_tokens);
53        let mut seq_ids_ptrs: Vec<*mut sys::llama_seq_id> = Vec::with_capacity(n_tokens);
54        for _ in 0..n_tokens {
55            let mut v: Vec<i32> = vec![0; n_seq_max as usize];
56            seq_ids_ptrs.push(v.as_mut_ptr());
57            seq_ids.push(v);
58        }
59        let logits = vec![0_i8; n_tokens];
60        let raw = sys::llama_batch {
61            n_tokens: 0,
62            token: tokens.as_ptr().cast_mut(),
63            embd: std::ptr::null_mut(),
64            pos: positions.as_ptr().cast_mut(),
65            n_seq_id: n_seq_id.as_ptr().cast_mut(),
66            seq_id: seq_ids_ptrs.as_ptr().cast_mut(),
67            logits: logits.as_ptr().cast_mut(),
68        };
69        Self {
70            raw,
71            tokens,
72            positions,
73            n_seq_id,
74            seq_ids,
75            seq_ids_ptrs,
76            logits,
77            allocated: true,
78        }
79    }
80
81    /// Construct a single-sequence batch of one token. Convenience for the
82    /// most common decode step.
83    #[must_use]
84    pub fn one(token: LlamaToken, pos: i32, seq_id: i32, logits: bool) -> Self {
85        let mut b = Self::new(1, 1);
86        b.add(token, pos, &[seq_id], logits).expect("capacity 1");
87        b
88    }
89
90    /// Number of tokens currently in the batch.
91    #[must_use]
92    pub fn n_tokens(&self) -> i32 {
93        self.raw.n_tokens
94    }
95
96    /// Reset the batch so it can be reused without reallocating.
97    pub fn clear(&mut self) {
98        self.raw.n_tokens = 0;
99    }
100
101    /// Append a single token to the batch.
102    ///
103    /// # Errors
104    /// Returns [`BatchAddError::InsufficientSpace`] if the batch is full.
105    pub fn add(
106        &mut self,
107        token: LlamaToken,
108        pos: i32,
109        seq_ids: &[i32],
110        logits: bool,
111    ) -> std::result::Result<(), BatchAddError> {
112        let idx = self.raw.n_tokens as usize;
113        if idx >= self.tokens.len() {
114            return Err(BatchAddError::InsufficientSpace(self.tokens.len()));
115        }
116        if seq_ids.is_empty() {
117            return Err(BatchAddError::Empty);
118        }
119        // Storage vectors are immutable for borrow-checker; we go through raw
120        // pointers for the mutation because the C batch only reads from them.
121        // Safety: idx < capacity and the vectors outlive the batch.
122        unsafe {
123            let mut_ptr = self.tokens.as_ptr().cast_mut();
124            std::ptr::write(mut_ptr.add(idx), token.0);
125            let pos_ptr = self.positions.as_ptr().cast_mut();
126            std::ptr::write(pos_ptr.add(idx), pos);
127            let logits_ptr = self.logits.as_ptr().cast_mut();
128            std::ptr::write(logits_ptr.add(idx), i8::from(logits));
129        }
130        for (i, &sid) in seq_ids.iter().enumerate() {
131            if i < self.seq_ids[idx].len() {
132                self.seq_ids[idx][i] = sid;
133            }
134        }
135        self.raw.n_tokens += 1;
136        Ok(())
137    }
138
139    /// Borrow the underlying C struct (read-only).
140    pub(crate) fn raw(&self) -> &sys::llama_batch {
141        &self.raw
142    }
143}
144
145/// Convert a `BatchAddError` into the crate-wide [`LlamaError`].
146impl From<BatchAddError> for LlamaError {
147    fn from(e: BatchAddError) -> Self {
148        Self::Batch(e.to_string())
149    }
150}