Skip to main content

llama_cpp_bindings/
llama_batch.rs

1//! Safe wrapper around `llama_batch`.
2
3use crate::batch_add_error::BatchAddError;
4use crate::sampled_token::SampledToken;
5use crate::token::LlamaToken;
6use llama_cpp_bindings_sys::{
7    llama_batch, llama_batch_free, llama_batch_init, llama_pos, llama_seq_id,
8};
9use std::marker::PhantomData;
10
11fn checked_n_tokens_plus_one_as_usize(n_tokens: i32) -> Result<usize, BatchAddError> {
12    let incremented = n_tokens.checked_add(1).ok_or_else(|| {
13        BatchAddError::IntegerOverflow(format!("n_tokens + 1 overflows i32: {n_tokens}"))
14    })?;
15
16    usize::try_from(incremented).map_err(|convert_error| {
17        BatchAddError::IntegerOverflow(format!("cannot fit n_tokens into a usize: {convert_error}"))
18    })
19}
20
21fn checked_i32_as_usize(value: i32, description: &str) -> Result<usize, BatchAddError> {
22    usize::try_from(value).map_err(|convert_error| {
23        BatchAddError::IntegerOverflow(format!(
24            "cannot fit {description} into a usize: {convert_error}"
25        ))
26    })
27}
28
29fn checked_usize_as_llama_seq_id(
30    value: usize,
31    description: &str,
32) -> Result<llama_seq_id, BatchAddError> {
33    llama_seq_id::try_from(value).map_err(|convert_error| {
34        BatchAddError::IntegerOverflow(format!(
35            "cannot fit {description} into a llama_seq_id: {convert_error}"
36        ))
37    })
38}
39
40fn checked_usize_as_i32(value: usize, description: &str) -> Result<i32, BatchAddError> {
41    i32::try_from(value).map_err(|convert_error| {
42        BatchAddError::IntegerOverflow(format!(
43            "cannot fit {description} into a i32: {convert_error}"
44        ))
45    })
46}
47
48fn checked_usize_as_llama_pos(value: usize, description: &str) -> Result<llama_pos, BatchAddError> {
49    llama_pos::try_from(value).map_err(|convert_error| {
50        BatchAddError::IntegerOverflow(format!(
51            "cannot fit {description} into a llama_pos: {convert_error}"
52        ))
53    })
54}
55
56/// A safe wrapper around `llama_batch`.
57///
58/// `PartialEq` is intentionally not implemented because the underlying `llama_batch`
59/// from the C API contains raw pointers whose address comparison would be meaningless.
60#[derive(Debug)]
61pub struct LlamaBatch<'tokens> {
62    /// The number of tokens the batch was allocated with. they are safe to write to - but not necessarily read from as they are not necessarily initialized
63    allocated: usize,
64    /// The logits that are initialized. Used by [`LlamaContext`] to ensure that only initialized logits are accessed.
65    pub initialized_logits: Vec<i32>,
66    /// The underlying `llama_batch` from the C API.
67    pub llama_batch: llama_batch,
68    phantom: PhantomData<&'tokens [LlamaToken]>,
69}
70
71impl<'tokens> LlamaBatch<'tokens> {
72    /// Clear the batch. This does not free the memory associated with the batch, but it does reset
73    /// the number of tokens to 0.
74    pub fn clear(&mut self) {
75        self.llama_batch.n_tokens = 0;
76        self.initialized_logits.clear();
77    }
78
79    /// add a token to the batch for sequences `seq_ids` at position `pos`. If `logits` is true, the
80    /// token will be initialized and can be read from after the next decode.
81    ///
82    /// # Errors
83    ///
84    /// Returns an error if there is insufficient space in the buffer or if integer conversions fail.
85    pub fn add(
86        &mut self,
87        sampled_token: &SampledToken,
88        pos: llama_pos,
89        seq_ids: &[i32],
90        logits: bool,
91    ) -> Result<(), BatchAddError> {
92        let (SampledToken::Content(LlamaToken(id))
93        | SampledToken::Reasoning(LlamaToken(id))
94        | SampledToken::ToolCall(LlamaToken(id))
95        | SampledToken::Undeterminable(LlamaToken(id))) = *sampled_token;
96        let required = checked_n_tokens_plus_one_as_usize(self.n_tokens())?;
97
98        if self.allocated < required {
99            return Err(BatchAddError::InsufficientSpace(self.allocated));
100        }
101
102        let offset = self.llama_batch.n_tokens;
103        let offset_usize = checked_i32_as_usize(offset, "n_tokens")?;
104        let n_seq_id = checked_usize_as_llama_seq_id(seq_ids.len(), "seq_ids.len()")?;
105
106        unsafe {
107            self.llama_batch.token.add(offset_usize).write(id);
108            self.llama_batch.pos.add(offset_usize).write(pos);
109            self.llama_batch.n_seq_id.add(offset_usize).write(n_seq_id);
110            for (seq_index, seq_id) in seq_ids.iter().enumerate() {
111                let tmp = *self.llama_batch.seq_id.add(offset_usize);
112                tmp.add(seq_index).write(*seq_id);
113            }
114            self.llama_batch
115                .logits
116                .add(offset_usize)
117                .write(i8::from(logits));
118        }
119
120        if logits {
121            self.initialized_logits.push(offset);
122        }
123
124        self.llama_batch.n_tokens += 1;
125
126        Ok(())
127    }
128
129    /// Add a sequence of tokens to the batch for the given sequence id. If `logits_all` is true, the
130    /// tokens will be initialized and can be read from after the next decode.
131    ///
132    /// Either way the last token in the sequence will have its logits set to `true`.
133    ///
134    /// # Errors
135    ///
136    /// Returns an error if there is insufficient space in the buffer or if integer conversions fail.
137    pub fn add_sequence(
138        &mut self,
139        tokens: &[LlamaToken],
140        seq_id: i32,
141        logits_all: bool,
142    ) -> Result<(), BatchAddError> {
143        let last_index = checked_usize_as_llama_pos(tokens.len().saturating_sub(1), "n_tokens")?;
144
145        for (position, token) in (0..).zip(tokens.iter()) {
146            self.add(
147                &SampledToken::Content(*token),
148                position,
149                &[seq_id],
150                logits_all || position == last_index,
151            )?;
152        }
153
154        Ok(())
155    }
156
157    /// Create a new `LlamaBatch` that can contain up to `n_tokens` tokens.
158    ///
159    /// # Arguments
160    ///
161    /// - `n_tokens`: the maximum number of tokens that can be added to the batch
162    /// - `n_seq_max`: the maximum number of sequences that can be added to the batch (generally 1 unless you know what you are doing)
163    ///
164    /// # Errors
165    ///
166    /// Returns an error if `n_tokens` exceeds `i32::MAX`.
167    pub fn new(n_tokens: usize, n_seq_max: i32) -> Result<Self, BatchAddError> {
168        let n_tokens_i32 = checked_usize_as_i32(n_tokens, "n_tokens")?;
169        let batch = unsafe { llama_batch_init(n_tokens_i32, 0, n_seq_max) };
170
171        Ok(LlamaBatch {
172            allocated: n_tokens,
173            initialized_logits: vec![],
174            llama_batch: batch,
175            phantom: PhantomData,
176        })
177    }
178
179    /// ``llama_batch_get_one``
180    /// Return batch for single sequence of tokens
181    ///
182    /// NOTE: this is a helper function to facilitate transition to the new batch API
183    ///
184    /// # Errors
185    ///
186    /// Returns an error if the provided token buffer is empty or if integer conversions fail.
187    pub fn get_one(tokens: &'tokens [LlamaToken]) -> Result<Self, BatchAddError> {
188        if tokens.is_empty() {
189            return Err(BatchAddError::EmptyBuffer);
190        }
191
192        let token_count = checked_usize_as_i32(tokens.len(), "token count")?;
193
194        let batch = unsafe {
195            #[expect(
196                clippy::as_ptr_cast_mut,
197                reason = "llama_batch_get_one signature requires *mut i32 but does not mutate the tokens"
198            )]
199            let ptr = tokens.as_ptr() as *mut i32;
200            llama_cpp_bindings_sys::llama_batch_get_one(ptr, token_count)
201        };
202
203        let last_token_index = checked_usize_as_i32(tokens.len() - 1, "last token index")?;
204
205        Ok(Self {
206            allocated: 0,
207            initialized_logits: vec![last_token_index],
208            llama_batch: batch,
209            phantom: PhantomData,
210        })
211    }
212
213    /// Returns the number of tokens in the batch.
214    #[must_use]
215    pub const fn n_tokens(&self) -> i32 {
216        self.llama_batch.n_tokens
217    }
218}
219
220impl Drop for LlamaBatch<'_> {
221    /// Drops the `LlamaBatch`.
222    ///
223    /// ```
224    /// # use llama_cpp_bindings::llama_batch::LlamaBatch;
225    /// # use std::error::Error;
226    /// # fn main() -> Result<(), Box<dyn Error>> {
227    /// let batch = LlamaBatch::new(512, 1)?;
228    /// // frees the memory associated with the batch. (allocated by llama.cpp)
229    /// drop(batch);
230    /// # Ok(())
231    /// # }
232    fn drop(&mut self) {
233        unsafe {
234            if self.allocated > 0 {
235                llama_batch_free(self.llama_batch);
236            }
237        }
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use crate::sampled_token::SampledToken;
244    use crate::token::LlamaToken;
245
246    use super::{
247        BatchAddError, LlamaBatch, checked_i32_as_usize, checked_n_tokens_plus_one_as_usize,
248        checked_usize_as_i32, checked_usize_as_llama_pos, checked_usize_as_llama_seq_id,
249    };
250
251    #[test]
252    fn new_creates_empty_batch() {
253        let batch = LlamaBatch::new(16, 1).unwrap();
254
255        assert_eq!(batch.n_tokens(), 0);
256        assert!(batch.initialized_logits.is_empty());
257    }
258
259    #[test]
260    fn clear_resets_batch() {
261        let mut batch = LlamaBatch::new(16, 1).unwrap();
262        batch
263            .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], true)
264            .unwrap();
265        assert_eq!(batch.n_tokens(), 1);
266
267        batch.clear();
268
269        assert_eq!(batch.n_tokens(), 0);
270        assert!(batch.initialized_logits.is_empty());
271    }
272
273    #[test]
274    fn add_increments_token_count() {
275        let mut batch = LlamaBatch::new(16, 1).unwrap();
276
277        batch
278            .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], false)
279            .unwrap();
280        assert_eq!(batch.n_tokens(), 1);
281
282        batch
283            .add(&SampledToken::Content(LlamaToken::new(2)), 1, &[0], false)
284            .unwrap();
285        assert_eq!(batch.n_tokens(), 2);
286    }
287
288    #[test]
289    fn add_tracks_logits() {
290        let mut batch = LlamaBatch::new(16, 1).unwrap();
291
292        batch
293            .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], false)
294            .unwrap();
295        assert!(batch.initialized_logits.is_empty());
296
297        batch
298            .add(&SampledToken::Content(LlamaToken::new(2)), 1, &[0], true)
299            .unwrap();
300        assert_eq!(batch.initialized_logits, vec![1]);
301    }
302
303    #[test]
304    fn add_returns_insufficient_space_when_full() {
305        let mut batch = LlamaBatch::new(1, 1).unwrap();
306        batch
307            .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], false)
308            .unwrap();
309
310        let result = batch.add(&SampledToken::Content(LlamaToken::new(2)), 1, &[0], false);
311
312        assert_eq!(result, Err(BatchAddError::InsufficientSpace(1)));
313    }
314
315    #[test]
316    fn add_accepts_reasoning_sampled_token_variant() {
317        let mut batch = LlamaBatch::new(4, 1).unwrap();
318
319        batch
320            .add(&SampledToken::Reasoning(LlamaToken::new(11)), 0, &[0], true)
321            .unwrap();
322
323        assert_eq!(batch.n_tokens(), 1);
324    }
325
326    #[test]
327    fn add_accepts_tool_call_sampled_token_variant() {
328        let mut batch = LlamaBatch::new(4, 1).unwrap();
329
330        batch
331            .add(&SampledToken::ToolCall(LlamaToken::new(22)), 0, &[0], true)
332            .unwrap();
333
334        assert_eq!(batch.n_tokens(), 1);
335    }
336
337    #[test]
338    fn add_accepts_undeterminable_sampled_token_variant() {
339        let mut batch = LlamaBatch::new(4, 1).unwrap();
340
341        batch
342            .add(
343                &SampledToken::Undeterminable(LlamaToken::new(33)),
344                0,
345                &[0],
346                false,
347            )
348            .unwrap();
349
350        assert_eq!(batch.n_tokens(), 1);
351    }
352
353    #[test]
354    fn add_sequence_adds_all_tokens() {
355        let mut batch = LlamaBatch::new(16, 1).unwrap();
356        let tokens = vec![
357            LlamaToken::new(10),
358            LlamaToken::new(20),
359            LlamaToken::new(30),
360        ];
361
362        batch.add_sequence(&tokens, 0, false).unwrap();
363
364        assert_eq!(batch.n_tokens(), 3);
365    }
366
367    #[test]
368    fn add_sequence_sets_logits_on_last_token() {
369        let mut batch = LlamaBatch::new(16, 1).unwrap();
370        let tokens = vec![
371            LlamaToken::new(10),
372            LlamaToken::new(20),
373            LlamaToken::new(30),
374        ];
375
376        batch.add_sequence(&tokens, 0, false).unwrap();
377
378        assert_eq!(batch.initialized_logits, vec![2]);
379    }
380
381    #[test]
382    fn add_sequence_insufficient_space() {
383        let mut batch = LlamaBatch::new(2, 1).unwrap();
384        let tokens = vec![
385            LlamaToken::new(10),
386            LlamaToken::new(20),
387            LlamaToken::new(30),
388        ];
389
390        let result = batch.add_sequence(&tokens, 0, false);
391
392        assert!(result.is_err());
393    }
394
395    #[test]
396    fn add_sequence_fails_mid_loop_when_batch_fills() {
397        let mut batch = LlamaBatch::new(2, 1).unwrap();
398        batch
399            .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], false)
400            .unwrap();
401
402        let tokens = vec![LlamaToken::new(10), LlamaToken::new(20)];
403        let result = batch.add_sequence(&tokens, 0, false);
404
405        assert!(result.is_err());
406    }
407
408    #[test]
409    fn get_one_with_valid_tokens() {
410        let tokens = vec![LlamaToken::new(1), LlamaToken::new(2)];
411        let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
412
413        assert_eq!(batch.n_tokens(), 2);
414        assert_eq!(batch.initialized_logits, vec![1]);
415    }
416
417    #[test]
418    fn get_one_empty_slice_returns_error() {
419        let tokens: Vec<LlamaToken> = vec![];
420        let result = LlamaBatch::get_one(&tokens);
421
422        assert_eq!(result.unwrap_err(), BatchAddError::EmptyBuffer);
423    }
424
425    #[test]
426    fn get_one_single_token() {
427        let tokens = vec![LlamaToken::new(42)];
428        let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
429
430        assert_eq!(batch.n_tokens(), 1);
431        assert_eq!(batch.initialized_logits, vec![0]);
432    }
433
434    #[test]
435    fn add_with_logits_false_retains_only_previous_logits() {
436        let mut batch = LlamaBatch::new(16, 1).unwrap();
437
438        batch
439            .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], true)
440            .unwrap();
441        assert_eq!(batch.initialized_logits, vec![0]);
442
443        batch
444            .add(&SampledToken::Content(LlamaToken::new(2)), 0, &[0], false)
445            .unwrap();
446        assert_eq!(batch.initialized_logits, vec![0]);
447    }
448
449    #[test]
450    fn add_sequence_with_logits_all_marks_every_token() -> Result<(), BatchAddError> {
451        let mut batch = LlamaBatch::new(16, 1)?;
452        let tokens = vec![
453            LlamaToken::new(10),
454            LlamaToken::new(20),
455            LlamaToken::new(30),
456        ];
457
458        batch.add_sequence(&tokens, 0, true)?;
459
460        assert_eq!(batch.n_tokens(), 3);
461        assert_eq!(batch.initialized_logits, vec![0, 1, 2]);
462
463        Ok(())
464    }
465
466    #[test]
467    fn add_with_multiple_seq_ids() -> Result<(), BatchAddError> {
468        let mut batch = LlamaBatch::new(16, 4)?;
469
470        batch.add(
471            &SampledToken::Content(LlamaToken::new(1)),
472            0,
473            &[0, 1, 2],
474            true,
475        )?;
476
477        assert_eq!(batch.n_tokens(), 1);
478        assert_eq!(batch.initialized_logits, vec![0]);
479
480        Ok(())
481    }
482
483    #[test]
484    fn drop_does_not_free_get_one_batch() {
485        let tokens = vec![LlamaToken::new(1), LlamaToken::new(2)];
486        let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
487
488        assert_eq!(batch.allocated, 0);
489        drop(batch);
490    }
491
492    #[test]
493    fn checked_n_tokens_plus_one_as_usize_succeeds_for_zero() {
494        let result = checked_n_tokens_plus_one_as_usize(0);
495
496        assert_eq!(result, Ok(1));
497    }
498
499    #[test]
500    fn checked_n_tokens_plus_one_as_usize_fails_for_negative() {
501        let result = checked_n_tokens_plus_one_as_usize(-2);
502
503        assert!(result.unwrap_err().to_string().contains("overflow"));
504    }
505
506    #[test]
507    fn checked_n_tokens_plus_one_as_usize_fails_for_i32_max() {
508        let result = checked_n_tokens_plus_one_as_usize(i32::MAX);
509
510        assert!(result.unwrap_err().to_string().contains("overflow"));
511    }
512
513    #[test]
514    fn checked_i32_as_usize_succeeds_for_zero() {
515        let result = checked_i32_as_usize(0, "test_value");
516
517        assert_eq!(result, Ok(0));
518    }
519
520    #[test]
521    fn checked_i32_as_usize_fails_for_negative() {
522        let result = checked_i32_as_usize(i32::MIN, "test_value");
523
524        assert!(result.unwrap_err().to_string().contains("overflow"));
525    }
526
527    #[test]
528    fn checked_usize_as_llama_seq_id_succeeds_for_zero() {
529        let result = checked_usize_as_llama_seq_id(0, "test_value");
530
531        assert_eq!(result, Ok(0));
532    }
533
534    #[test]
535    fn checked_usize_as_llama_seq_id_fails_for_overflow() {
536        let result = checked_usize_as_llama_seq_id(usize::MAX, "test_value");
537
538        assert!(result.unwrap_err().to_string().contains("overflow"));
539    }
540
541    #[test]
542    fn checked_usize_as_i32_succeeds_for_zero() {
543        let result = checked_usize_as_i32(0, "test_value");
544
545        assert_eq!(result, Ok(0));
546    }
547
548    #[test]
549    fn checked_usize_as_i32_fails_for_overflow() {
550        let result = checked_usize_as_i32(usize::MAX, "test_value");
551
552        assert!(result.unwrap_err().to_string().contains("overflow"));
553    }
554
555    #[test]
556    fn checked_usize_as_llama_pos_succeeds_for_zero() {
557        let result = checked_usize_as_llama_pos(0, "test_value");
558
559        assert_eq!(result, Ok(0));
560    }
561
562    #[test]
563    fn checked_usize_as_llama_pos_fails_for_overflow() {
564        let result = checked_usize_as_llama_pos(usize::MAX, "test_value");
565
566        assert!(result.unwrap_err().to_string().contains("overflow"));
567    }
568
569    #[test]
570    fn new_fails_for_oversized_n_tokens() {
571        let result = LlamaBatch::new(usize::MAX, 1);
572
573        assert!(result.unwrap_err().to_string().contains("overflow"));
574    }
575}