Skip to main content

llama_cpp_bindings/
llama_batch.rs

1//! Safe wrapper around `llama_batch`.
2
3use crate::token::LlamaToken;
4use llama_cpp_bindings_sys::{
5    llama_batch, llama_batch_free, llama_batch_init, llama_pos, llama_seq_id,
6};
7use std::marker::PhantomData;
8
9/// A safe wrapper around `llama_batch`.
10///
11/// `PartialEq` is intentionally not implemented because the underlying `llama_batch`
12/// from the C API contains raw pointers whose address comparison would be meaningless.
13#[derive(Debug)]
14pub struct LlamaBatch<'tokens> {
15    /// The number of tokens the batch was allocated with. they are safe to write to - but not necessarily read from as they are not necessarily initialized
16    allocated: usize,
17    /// The logits that are initialized. Used by [`LlamaContext`] to ensure that only initialized logits are accessed.
18    pub initialized_logits: Vec<i32>,
19    /// The underlying `llama_batch` from the C API.
20    pub llama_batch: llama_batch,
21    phantom: PhantomData<&'tokens [LlamaToken]>,
22}
23
24/// Errors that can occur when adding a token to a batch.
25#[derive(thiserror::Error, Debug, PartialEq, Eq)]
26pub enum BatchAddError {
27    /// There was not enough space in the batch to add the token.
28    #[error("Insufficient Space of {0}")]
29    InsufficientSpace(usize),
30    /// Empty buffer is provided for [`LlamaBatch::get_one`]
31    #[error("Empty buffer")]
32    EmptyBuffer,
33    /// An integer value exceeded the allowed range.
34    #[error("Integer overflow: {0}")]
35    IntegerOverflow(String),
36}
37
38impl<'tokens> LlamaBatch<'tokens> {
39    /// Clear the batch. This does not free the memory associated with the batch, but it does reset
40    /// the number of tokens to 0.
41    pub fn clear(&mut self) {
42        self.llama_batch.n_tokens = 0;
43        self.initialized_logits.clear();
44    }
45
46    /// add a token to the batch for sequences `seq_ids` at position `pos`. If `logits` is true, the
47    /// token will be initialized and can be read from after the next decode.
48    ///
49    /// # Errors
50    ///
51    /// Returns an error if there is insufficient space in the buffer or if integer conversions fail.
52    pub fn add(
53        &mut self,
54        LlamaToken(id): LlamaToken,
55        pos: llama_pos,
56        seq_ids: &[i32],
57        logits: bool,
58    ) -> Result<(), BatchAddError> {
59        let required = usize::try_from(self.n_tokens() + 1).map_err(|convert_error| {
60            BatchAddError::IntegerOverflow(format!(
61                "cannot fit n_tokens into a usize: {convert_error}"
62            ))
63        })?;
64
65        if self.allocated < required {
66            return Err(BatchAddError::InsufficientSpace(self.allocated));
67        }
68
69        let offset = self.llama_batch.n_tokens;
70        let offset_usize = usize::try_from(offset).map_err(|convert_error| {
71            BatchAddError::IntegerOverflow(format!(
72                "cannot fit n_tokens into a usize: {convert_error}"
73            ))
74        })?;
75
76        let n_seq_id = llama_seq_id::try_from(seq_ids.len()).map_err(|convert_error| {
77            BatchAddError::IntegerOverflow(format!(
78                "cannot fit seq_ids.len() into a llama_seq_id: {convert_error}"
79            ))
80        })?;
81
82        unsafe {
83            // batch.token   [batch.n_tokens] = id;
84            self.llama_batch.token.add(offset_usize).write(id);
85            // batch.pos     [batch.n_tokens] = pos,
86            self.llama_batch.pos.add(offset_usize).write(pos);
87            // batch.n_seq_id[batch.n_tokens] = seq_ids.size();
88            self.llama_batch.n_seq_id.add(offset_usize).write(n_seq_id);
89            // for (size_t i = 0; i < seq_ids.size(); ++i) {
90            //     batch.seq_id[batch.n_tokens][i] = seq_ids[i];
91            // }
92            for (i, seq_id) in seq_ids.iter().enumerate() {
93                let tmp = *self.llama_batch.seq_id.add(offset_usize);
94                tmp.add(i).write(*seq_id);
95            }
96            // batch.logits  [batch.n_tokens] = logits;
97            self.llama_batch
98                .logits
99                .add(offset_usize)
100                .write(i8::from(logits));
101        }
102
103        if logits {
104            self.initialized_logits.push(offset);
105        } else {
106            self.initialized_logits.retain(|l| l != &offset);
107        }
108
109        // batch.n_tokens++;
110        self.llama_batch.n_tokens += 1;
111
112        Ok(())
113    }
114
115    /// Add a sequence of tokens to the batch for the given sequence id. If `logits_all` is true, the
116    /// tokens will be initialized and can be read from after the next decode.
117    ///
118    /// Either way the last token in the sequence will have its logits set to `true`.
119    ///
120    /// # Errors
121    ///
122    /// Returns an error if there is insufficient space in the buffer or if integer conversions fail.
123    pub fn add_sequence(
124        &mut self,
125        tokens: &[LlamaToken],
126        seq_id: i32,
127        logits_all: bool,
128    ) -> Result<(), BatchAddError> {
129        let n_tokens_0 = usize::try_from(self.llama_batch.n_tokens).map_err(|convert_error| {
130            BatchAddError::IntegerOverflow(format!(
131                "cannot fit n_tokens into a usize: {convert_error}"
132            ))
133        })?;
134        let n_tokens = tokens.len();
135
136        if self.allocated < n_tokens_0 + n_tokens {
137            return Err(BatchAddError::InsufficientSpace(self.allocated));
138        }
139
140        let last_index =
141            llama_pos::try_from(n_tokens.saturating_sub(1)).map_err(|convert_error| {
142                BatchAddError::IntegerOverflow(format!(
143                    "cannot fit n_tokens into a llama_pos: {convert_error}"
144                ))
145            })?;
146
147        for (i, token) in (0..).zip(tokens.iter()) {
148            self.add(*token, i, &[seq_id], logits_all || i == last_index)?;
149        }
150
151        Ok(())
152    }
153
154    /// Create a new `LlamaBatch` that can contain up to `n_tokens` tokens.
155    ///
156    /// # Arguments
157    ///
158    /// - `n_tokens`: the maximum number of tokens that can be added to the batch
159    /// - `n_seq_max`: the maximum number of sequences that can be added to the batch (generally 1 unless you know what you are doing)
160    ///
161    /// # Errors
162    ///
163    /// Returns an error if `n_tokens` exceeds `i32::MAX`.
164    pub fn new(n_tokens: usize, n_seq_max: i32) -> Result<Self, BatchAddError> {
165        let n_tokens_i32 = i32::try_from(n_tokens).map_err(|convert_error| {
166            BatchAddError::IntegerOverflow(format!(
167                "cannot fit n_tokens into a i32: {convert_error}"
168            ))
169        })?;
170        let batch = unsafe { llama_batch_init(n_tokens_i32, 0, n_seq_max) };
171
172        Ok(LlamaBatch {
173            allocated: n_tokens,
174            initialized_logits: vec![],
175            llama_batch: batch,
176            phantom: PhantomData,
177        })
178    }
179
180    /// ``llama_batch_get_one``
181    /// Return batch for single sequence of tokens
182    ///
183    /// NOTE: this is a helper function to facilitate transition to the new batch API
184    ///
185    /// # Errors
186    ///
187    /// Returns an error if the provided token buffer is empty or if integer conversions fail.
188    pub fn get_one(tokens: &'tokens [LlamaToken]) -> Result<Self, BatchAddError> {
189        if tokens.is_empty() {
190            return Err(BatchAddError::EmptyBuffer);
191        }
192
193        let token_count = tokens.len().try_into().map_err(|convert_error| {
194            BatchAddError::IntegerOverflow(format!(
195                "number of tokens exceeds i32::MAX: {convert_error}"
196            ))
197        })?;
198
199        let batch = unsafe {
200            let ptr = tokens.as_ptr() as *mut i32;
201            llama_cpp_bindings_sys::llama_batch_get_one(ptr, token_count)
202        };
203
204        let last_token_index = (tokens.len() - 1).try_into().map_err(|convert_error| {
205            BatchAddError::IntegerOverflow(format!(
206                "number of tokens exceeds i32::MAX: {convert_error}"
207            ))
208        })?;
209
210        Ok(Self {
211            allocated: 0,
212            initialized_logits: vec![last_token_index],
213            llama_batch: batch,
214            phantom: PhantomData,
215        })
216    }
217
218    /// Returns the number of tokens in the batch.
219    #[must_use]
220    pub fn n_tokens(&self) -> i32 {
221        self.llama_batch.n_tokens
222    }
223}
224
225impl Drop for LlamaBatch<'_> {
226    /// Drops the `LlamaBatch`.
227    ///
228    /// ```
229    /// # use llama_cpp_bindings::llama_batch::LlamaBatch;
230    /// # use std::error::Error;
231    /// # fn main() -> Result<(), Box<dyn Error>> {
232    /// let batch = LlamaBatch::new(512, 1)?;
233    /// // frees the memory associated with the batch. (allocated by llama.cpp)
234    /// drop(batch);
235    /// # Ok(())
236    /// # }
237    fn drop(&mut self) {
238        unsafe {
239            if self.allocated > 0 {
240                llama_batch_free(self.llama_batch);
241            }
242        }
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use crate::token::LlamaToken;
249
250    use super::{BatchAddError, LlamaBatch};
251
252    #[test]
253    fn new_creates_empty_batch() -> Result<(), BatchAddError> {
254        let batch = LlamaBatch::new(16, 1)?;
255
256        assert_eq!(batch.n_tokens(), 0);
257        assert!(batch.initialized_logits.is_empty());
258
259        Ok(())
260    }
261
262    #[test]
263    fn clear_resets_batch() -> Result<(), BatchAddError> {
264        let mut batch = LlamaBatch::new(16, 1)?;
265        batch.add(LlamaToken::new(1), 0, &[0], true)?;
266        assert_eq!(batch.n_tokens(), 1);
267
268        batch.clear();
269
270        assert_eq!(batch.n_tokens(), 0);
271        assert!(batch.initialized_logits.is_empty());
272
273        Ok(())
274    }
275
276    #[test]
277    fn add_increments_token_count() -> Result<(), BatchAddError> {
278        let mut batch = LlamaBatch::new(16, 1)?;
279
280        batch.add(LlamaToken::new(1), 0, &[0], false)?;
281        assert_eq!(batch.n_tokens(), 1);
282
283        batch.add(LlamaToken::new(2), 1, &[0], false)?;
284        assert_eq!(batch.n_tokens(), 2);
285
286        Ok(())
287    }
288
289    #[test]
290    fn add_tracks_logits() -> Result<(), BatchAddError> {
291        let mut batch = LlamaBatch::new(16, 1)?;
292
293        batch.add(LlamaToken::new(1), 0, &[0], false)?;
294        assert!(batch.initialized_logits.is_empty());
295
296        batch.add(LlamaToken::new(2), 1, &[0], true)?;
297        assert_eq!(batch.initialized_logits, vec![1]);
298
299        Ok(())
300    }
301
302    #[test]
303    fn add_returns_insufficient_space_when_full() -> Result<(), BatchAddError> {
304        let mut batch = LlamaBatch::new(1, 1)?;
305        batch.add(LlamaToken::new(1), 0, &[0], false)?;
306
307        let result = batch.add(LlamaToken::new(2), 1, &[0], false);
308
309        assert_eq!(result, Err(BatchAddError::InsufficientSpace(1)));
310
311        Ok(())
312    }
313
314    #[test]
315    fn add_sequence_adds_all_tokens() -> Result<(), BatchAddError> {
316        let mut batch = LlamaBatch::new(16, 1)?;
317        let tokens = vec![
318            LlamaToken::new(10),
319            LlamaToken::new(20),
320            LlamaToken::new(30),
321        ];
322
323        batch.add_sequence(&tokens, 0, false)?;
324
325        assert_eq!(batch.n_tokens(), 3);
326
327        Ok(())
328    }
329
330    #[test]
331    fn add_sequence_sets_logits_on_last_token() -> Result<(), BatchAddError> {
332        let mut batch = LlamaBatch::new(16, 1)?;
333        let tokens = vec![
334            LlamaToken::new(10),
335            LlamaToken::new(20),
336            LlamaToken::new(30),
337        ];
338
339        batch.add_sequence(&tokens, 0, false)?;
340
341        assert_eq!(batch.initialized_logits, vec![2]);
342
343        Ok(())
344    }
345
346    #[test]
347    fn add_sequence_insufficient_space() -> Result<(), BatchAddError> {
348        let mut batch = LlamaBatch::new(2, 1)?;
349        let tokens = vec![
350            LlamaToken::new(10),
351            LlamaToken::new(20),
352            LlamaToken::new(30),
353        ];
354
355        let result = batch.add_sequence(&tokens, 0, false);
356
357        assert!(result.is_err());
358
359        Ok(())
360    }
361
362    #[test]
363    fn get_one_with_valid_tokens() {
364        let tokens = vec![LlamaToken::new(1), LlamaToken::new(2)];
365        let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
366
367        assert_eq!(batch.n_tokens(), 2);
368        assert_eq!(batch.initialized_logits, vec![1]);
369    }
370
371    #[test]
372    fn get_one_empty_slice_returns_error() {
373        let tokens: Vec<LlamaToken> = vec![];
374        let result = LlamaBatch::get_one(&tokens);
375
376        assert!(
377            matches!(result, Err(BatchAddError::EmptyBuffer)),
378            "expected EmptyBuffer error"
379        );
380    }
381
382    #[test]
383    fn get_one_single_token() {
384        let tokens = vec![LlamaToken::new(42)];
385        let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
386
387        assert_eq!(batch.n_tokens(), 1);
388        assert_eq!(batch.initialized_logits, vec![0]);
389    }
390}