Skip to main content

llama_cpp_bindings/
llama_batch.rs

1//! Safe wrapper around `llama_batch`.
2
3use crate::token::LlamaToken;
4use llama_cpp_bindings_sys::{
5    llama_batch, llama_batch_free, llama_batch_init, llama_pos, llama_seq_id,
6};
7use std::marker::PhantomData;
8
9fn checked_n_tokens_plus_one_as_usize(n_tokens: i32) -> Result<usize, BatchAddError> {
10    let incremented = n_tokens.checked_add(1).ok_or_else(|| {
11        BatchAddError::IntegerOverflow(format!("n_tokens + 1 overflows i32: {n_tokens}"))
12    })?;
13
14    usize::try_from(incremented).map_err(|convert_error| {
15        BatchAddError::IntegerOverflow(format!("cannot fit n_tokens into a usize: {convert_error}"))
16    })
17}
18
19fn checked_i32_as_usize(value: i32, description: &str) -> Result<usize, BatchAddError> {
20    usize::try_from(value).map_err(|convert_error| {
21        BatchAddError::IntegerOverflow(format!(
22            "cannot fit {description} into a usize: {convert_error}"
23        ))
24    })
25}
26
27fn checked_usize_as_llama_seq_id(
28    value: usize,
29    description: &str,
30) -> Result<llama_seq_id, BatchAddError> {
31    llama_seq_id::try_from(value).map_err(|convert_error| {
32        BatchAddError::IntegerOverflow(format!(
33            "cannot fit {description} into a llama_seq_id: {convert_error}"
34        ))
35    })
36}
37
38fn checked_usize_as_i32(value: usize, description: &str) -> Result<i32, BatchAddError> {
39    i32::try_from(value).map_err(|convert_error| {
40        BatchAddError::IntegerOverflow(format!(
41            "cannot fit {description} into a i32: {convert_error}"
42        ))
43    })
44}
45
46fn checked_usize_as_llama_pos(value: usize, description: &str) -> Result<llama_pos, BatchAddError> {
47    llama_pos::try_from(value).map_err(|convert_error| {
48        BatchAddError::IntegerOverflow(format!(
49            "cannot fit {description} into a llama_pos: {convert_error}"
50        ))
51    })
52}
53
54/// A safe wrapper around `llama_batch`.
55///
56/// `PartialEq` is intentionally not implemented because the underlying `llama_batch`
57/// from the C API contains raw pointers whose address comparison would be meaningless.
58#[derive(Debug)]
59pub struct LlamaBatch<'tokens> {
60    /// The number of tokens the batch was allocated with. they are safe to write to - but not necessarily read from as they are not necessarily initialized
61    allocated: usize,
62    /// The logits that are initialized. Used by [`LlamaContext`] to ensure that only initialized logits are accessed.
63    pub initialized_logits: Vec<i32>,
64    /// The underlying `llama_batch` from the C API.
65    pub llama_batch: llama_batch,
66    phantom: PhantomData<&'tokens [LlamaToken]>,
67}
68
69/// Errors that can occur when adding a token to a batch.
70#[derive(thiserror::Error, Debug, PartialEq, Eq)]
71pub enum BatchAddError {
72    /// There was not enough space in the batch to add the token.
73    #[error("Insufficient Space of {0}")]
74    InsufficientSpace(usize),
75    /// Empty buffer is provided for [`LlamaBatch::get_one`]
76    #[error("Empty buffer")]
77    EmptyBuffer,
78    /// An integer value exceeded the allowed range.
79    #[error("Integer overflow: {0}")]
80    IntegerOverflow(String),
81}
82
83impl<'tokens> LlamaBatch<'tokens> {
84    /// Clear the batch. This does not free the memory associated with the batch, but it does reset
85    /// the number of tokens to 0.
86    pub fn clear(&mut self) {
87        self.llama_batch.n_tokens = 0;
88        self.initialized_logits.clear();
89    }
90
91    /// add a token to the batch for sequences `seq_ids` at position `pos`. If `logits` is true, the
92    /// token will be initialized and can be read from after the next decode.
93    ///
94    /// # Errors
95    ///
96    /// Returns an error if there is insufficient space in the buffer or if integer conversions fail.
97    pub fn add(
98        &mut self,
99        LlamaToken(id): LlamaToken,
100        pos: llama_pos,
101        seq_ids: &[i32],
102        logits: bool,
103    ) -> Result<(), BatchAddError> {
104        let required = checked_n_tokens_plus_one_as_usize(self.n_tokens())?;
105
106        if self.allocated < required {
107            return Err(BatchAddError::InsufficientSpace(self.allocated));
108        }
109
110        let offset = self.llama_batch.n_tokens;
111        let offset_usize = checked_i32_as_usize(offset, "n_tokens")?;
112        let n_seq_id = checked_usize_as_llama_seq_id(seq_ids.len(), "seq_ids.len()")?;
113
114        unsafe {
115            // batch.token   [batch.n_tokens] = id;
116            self.llama_batch.token.add(offset_usize).write(id);
117            // batch.pos     [batch.n_tokens] = pos,
118            self.llama_batch.pos.add(offset_usize).write(pos);
119            // batch.n_seq_id[batch.n_tokens] = seq_ids.size();
120            self.llama_batch.n_seq_id.add(offset_usize).write(n_seq_id);
121            // for (size_t i = 0; i < seq_ids.size(); ++i) {
122            //     batch.seq_id[batch.n_tokens][i] = seq_ids[i];
123            // }
124            for (seq_index, seq_id) in seq_ids.iter().enumerate() {
125                let tmp = *self.llama_batch.seq_id.add(offset_usize);
126                tmp.add(seq_index).write(*seq_id);
127            }
128            // batch.logits  [batch.n_tokens] = logits;
129            self.llama_batch
130                .logits
131                .add(offset_usize)
132                .write(i8::from(logits));
133        }
134
135        if logits {
136            self.initialized_logits.push(offset);
137        }
138
139        self.llama_batch.n_tokens += 1;
140
141        Ok(())
142    }
143
144    /// Add a sequence of tokens to the batch for the given sequence id. If `logits_all` is true, the
145    /// tokens will be initialized and can be read from after the next decode.
146    ///
147    /// Either way the last token in the sequence will have its logits set to `true`.
148    ///
149    /// # Errors
150    ///
151    /// Returns an error if there is insufficient space in the buffer or if integer conversions fail.
152    pub fn add_sequence(
153        &mut self,
154        tokens: &[LlamaToken],
155        seq_id: i32,
156        logits_all: bool,
157    ) -> Result<(), BatchAddError> {
158        let last_index = checked_usize_as_llama_pos(tokens.len().saturating_sub(1), "n_tokens")?;
159
160        for (position, token) in (0..).zip(tokens.iter()) {
161            self.add(
162                *token,
163                position,
164                &[seq_id],
165                logits_all || position == last_index,
166            )?;
167        }
168
169        Ok(())
170    }
171
172    /// Create a new `LlamaBatch` that can contain up to `n_tokens` tokens.
173    ///
174    /// # Arguments
175    ///
176    /// - `n_tokens`: the maximum number of tokens that can be added to the batch
177    /// - `n_seq_max`: the maximum number of sequences that can be added to the batch (generally 1 unless you know what you are doing)
178    ///
179    /// # Errors
180    ///
181    /// Returns an error if `n_tokens` exceeds `i32::MAX`.
182    pub fn new(n_tokens: usize, n_seq_max: i32) -> Result<Self, BatchAddError> {
183        let n_tokens_i32 = checked_usize_as_i32(n_tokens, "n_tokens")?;
184        let batch = unsafe { llama_batch_init(n_tokens_i32, 0, n_seq_max) };
185
186        Ok(LlamaBatch {
187            allocated: n_tokens,
188            initialized_logits: vec![],
189            llama_batch: batch,
190            phantom: PhantomData,
191        })
192    }
193
194    /// ``llama_batch_get_one``
195    /// Return batch for single sequence of tokens
196    ///
197    /// NOTE: this is a helper function to facilitate transition to the new batch API
198    ///
199    /// # Errors
200    ///
201    /// Returns an error if the provided token buffer is empty or if integer conversions fail.
202    pub fn get_one(tokens: &'tokens [LlamaToken]) -> Result<Self, BatchAddError> {
203        if tokens.is_empty() {
204            return Err(BatchAddError::EmptyBuffer);
205        }
206
207        let token_count = checked_usize_as_i32(tokens.len(), "token count")?;
208
209        let batch = unsafe {
210            // llama_batch_get_one takes *mut i32 even though it does not mutate the tokens.
211            #[allow(clippy::as_ptr_cast_mut)]
212            let ptr = tokens.as_ptr() as *mut i32;
213            llama_cpp_bindings_sys::llama_batch_get_one(ptr, token_count)
214        };
215
216        let last_token_index = checked_usize_as_i32(tokens.len() - 1, "last token index")?;
217
218        Ok(Self {
219            allocated: 0,
220            initialized_logits: vec![last_token_index],
221            llama_batch: batch,
222            phantom: PhantomData,
223        })
224    }
225
226    /// Returns the number of tokens in the batch.
227    #[must_use]
228    pub const fn n_tokens(&self) -> i32 {
229        self.llama_batch.n_tokens
230    }
231}
232
233impl Drop for LlamaBatch<'_> {
234    /// Drops the `LlamaBatch`.
235    ///
236    /// ```
237    /// # use llama_cpp_bindings::llama_batch::LlamaBatch;
238    /// # use std::error::Error;
239    /// # fn main() -> Result<(), Box<dyn Error>> {
240    /// let batch = LlamaBatch::new(512, 1)?;
241    /// // frees the memory associated with the batch. (allocated by llama.cpp)
242    /// drop(batch);
243    /// # Ok(())
244    /// # }
245    fn drop(&mut self) {
246        unsafe {
247            if self.allocated > 0 {
248                llama_batch_free(self.llama_batch);
249            }
250        }
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use crate::token::LlamaToken;
257
258    use super::{
259        BatchAddError, LlamaBatch, checked_i32_as_usize, checked_n_tokens_plus_one_as_usize,
260        checked_usize_as_i32, checked_usize_as_llama_pos, checked_usize_as_llama_seq_id,
261    };
262
263    #[test]
264    fn new_creates_empty_batch() {
265        let batch = LlamaBatch::new(16, 1).unwrap();
266
267        assert_eq!(batch.n_tokens(), 0);
268        assert!(batch.initialized_logits.is_empty());
269    }
270
271    #[test]
272    fn clear_resets_batch() {
273        let mut batch = LlamaBatch::new(16, 1).unwrap();
274        batch.add(LlamaToken::new(1), 0, &[0], true).unwrap();
275        assert_eq!(batch.n_tokens(), 1);
276
277        batch.clear();
278
279        assert_eq!(batch.n_tokens(), 0);
280        assert!(batch.initialized_logits.is_empty());
281    }
282
283    #[test]
284    fn add_increments_token_count() {
285        let mut batch = LlamaBatch::new(16, 1).unwrap();
286
287        batch.add(LlamaToken::new(1), 0, &[0], false).unwrap();
288        assert_eq!(batch.n_tokens(), 1);
289
290        batch.add(LlamaToken::new(2), 1, &[0], false).unwrap();
291        assert_eq!(batch.n_tokens(), 2);
292    }
293
294    #[test]
295    fn add_tracks_logits() {
296        let mut batch = LlamaBatch::new(16, 1).unwrap();
297
298        batch.add(LlamaToken::new(1), 0, &[0], false).unwrap();
299        assert!(batch.initialized_logits.is_empty());
300
301        batch.add(LlamaToken::new(2), 1, &[0], true).unwrap();
302        assert_eq!(batch.initialized_logits, vec![1]);
303    }
304
305    #[test]
306    fn add_returns_insufficient_space_when_full() {
307        let mut batch = LlamaBatch::new(1, 1).unwrap();
308        batch.add(LlamaToken::new(1), 0, &[0], false).unwrap();
309
310        let result = batch.add(LlamaToken::new(2), 1, &[0], false);
311
312        assert_eq!(result, Err(BatchAddError::InsufficientSpace(1)));
313    }
314
315    #[test]
316    fn add_sequence_adds_all_tokens() {
317        let mut batch = LlamaBatch::new(16, 1).unwrap();
318        let tokens = vec![
319            LlamaToken::new(10),
320            LlamaToken::new(20),
321            LlamaToken::new(30),
322        ];
323
324        batch.add_sequence(&tokens, 0, false).unwrap();
325
326        assert_eq!(batch.n_tokens(), 3);
327    }
328
329    #[test]
330    fn add_sequence_sets_logits_on_last_token() {
331        let mut batch = LlamaBatch::new(16, 1).unwrap();
332        let tokens = vec![
333            LlamaToken::new(10),
334            LlamaToken::new(20),
335            LlamaToken::new(30),
336        ];
337
338        batch.add_sequence(&tokens, 0, false).unwrap();
339
340        assert_eq!(batch.initialized_logits, vec![2]);
341    }
342
343    #[test]
344    fn add_sequence_insufficient_space() {
345        let mut batch = LlamaBatch::new(2, 1).unwrap();
346        let tokens = vec![
347            LlamaToken::new(10),
348            LlamaToken::new(20),
349            LlamaToken::new(30),
350        ];
351
352        let result = batch.add_sequence(&tokens, 0, false);
353
354        assert!(result.is_err());
355    }
356
357    #[test]
358    fn add_sequence_fails_mid_loop_when_batch_fills() {
359        let mut batch = LlamaBatch::new(2, 1).unwrap();
360        batch.add(LlamaToken::new(1), 0, &[0], false).unwrap();
361
362        let tokens = vec![LlamaToken::new(10), LlamaToken::new(20)];
363        let result = batch.add_sequence(&tokens, 0, false);
364
365        assert!(result.is_err());
366    }
367
368    #[test]
369    fn get_one_with_valid_tokens() {
370        let tokens = vec![LlamaToken::new(1), LlamaToken::new(2)];
371        let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
372
373        assert_eq!(batch.n_tokens(), 2);
374        assert_eq!(batch.initialized_logits, vec![1]);
375    }
376
377    #[test]
378    fn get_one_empty_slice_returns_error() {
379        let tokens: Vec<LlamaToken> = vec![];
380        let result = LlamaBatch::get_one(&tokens);
381
382        assert_eq!(result.unwrap_err(), BatchAddError::EmptyBuffer);
383    }
384
385    #[test]
386    fn get_one_single_token() {
387        let tokens = vec![LlamaToken::new(42)];
388        let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
389
390        assert_eq!(batch.n_tokens(), 1);
391        assert_eq!(batch.initialized_logits, vec![0]);
392    }
393
394    #[test]
395    fn add_with_logits_false_retains_only_previous_logits() {
396        let mut batch = LlamaBatch::new(16, 1).unwrap();
397
398        batch.add(LlamaToken::new(1), 0, &[0], true).unwrap();
399        assert_eq!(batch.initialized_logits, vec![0]);
400
401        batch.add(LlamaToken::new(2), 0, &[0], false).unwrap();
402        assert_eq!(batch.initialized_logits, vec![0]);
403    }
404
405    #[test]
406    fn add_sequence_with_logits_all_marks_every_token() -> Result<(), BatchAddError> {
407        let mut batch = LlamaBatch::new(16, 1)?;
408        let tokens = vec![
409            LlamaToken::new(10),
410            LlamaToken::new(20),
411            LlamaToken::new(30),
412        ];
413
414        batch.add_sequence(&tokens, 0, true)?;
415
416        assert_eq!(batch.n_tokens(), 3);
417        assert_eq!(batch.initialized_logits, vec![0, 1, 2]);
418
419        Ok(())
420    }
421
422    #[test]
423    fn add_with_multiple_seq_ids() -> Result<(), BatchAddError> {
424        let mut batch = LlamaBatch::new(16, 4)?;
425
426        batch.add(LlamaToken::new(1), 0, &[0, 1, 2], true)?;
427
428        assert_eq!(batch.n_tokens(), 1);
429        assert_eq!(batch.initialized_logits, vec![0]);
430
431        Ok(())
432    }
433
434    #[test]
435    fn drop_does_not_free_get_one_batch() {
436        let tokens = vec![LlamaToken::new(1), LlamaToken::new(2)];
437        let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
438
439        assert_eq!(batch.allocated, 0);
440        drop(batch);
441    }
442
443    #[test]
444    fn checked_n_tokens_plus_one_as_usize_succeeds_for_zero() {
445        let result = checked_n_tokens_plus_one_as_usize(0);
446
447        assert_eq!(result, Ok(1));
448    }
449
450    #[test]
451    fn checked_n_tokens_plus_one_as_usize_fails_for_negative() {
452        let result = checked_n_tokens_plus_one_as_usize(-2);
453
454        assert!(result.unwrap_err().to_string().contains("overflow"));
455    }
456
457    #[test]
458    fn checked_n_tokens_plus_one_as_usize_fails_for_i32_max() {
459        let result = checked_n_tokens_plus_one_as_usize(i32::MAX);
460
461        assert!(result.unwrap_err().to_string().contains("overflow"));
462    }
463
464    #[test]
465    fn checked_i32_as_usize_succeeds_for_zero() {
466        let result = checked_i32_as_usize(0, "test_value");
467
468        assert_eq!(result, Ok(0));
469    }
470
471    #[test]
472    fn checked_i32_as_usize_fails_for_negative() {
473        let result = checked_i32_as_usize(i32::MIN, "test_value");
474
475        assert!(result.unwrap_err().to_string().contains("overflow"));
476    }
477
478    #[test]
479    fn checked_usize_as_llama_seq_id_succeeds_for_zero() {
480        let result = checked_usize_as_llama_seq_id(0, "test_value");
481
482        assert_eq!(result, Ok(0));
483    }
484
485    #[test]
486    fn checked_usize_as_llama_seq_id_fails_for_overflow() {
487        let result = checked_usize_as_llama_seq_id(usize::MAX, "test_value");
488
489        assert!(result.unwrap_err().to_string().contains("overflow"));
490    }
491
492    #[test]
493    fn checked_usize_as_i32_succeeds_for_zero() {
494        let result = checked_usize_as_i32(0, "test_value");
495
496        assert_eq!(result, Ok(0));
497    }
498
499    #[test]
500    fn checked_usize_as_i32_fails_for_overflow() {
501        let result = checked_usize_as_i32(usize::MAX, "test_value");
502
503        assert!(result.unwrap_err().to_string().contains("overflow"));
504    }
505
506    #[test]
507    fn checked_usize_as_llama_pos_succeeds_for_zero() {
508        let result = checked_usize_as_llama_pos(0, "test_value");
509
510        assert_eq!(result, Ok(0));
511    }
512
513    #[test]
514    fn checked_usize_as_llama_pos_fails_for_overflow() {
515        let result = checked_usize_as_llama_pos(usize::MAX, "test_value");
516
517        assert!(result.unwrap_err().to_string().contains("overflow"));
518    }
519
520    #[test]
521    fn new_fails_for_oversized_n_tokens() {
522        let result = LlamaBatch::new(usize::MAX, 1);
523
524        assert!(result.unwrap_err().to_string().contains("overflow"));
525    }
526}