Skip to main content

llama_cpp_bindings/
llama_batch.rs

1use crate::batch_add_error::BatchAddError;
2use crate::sampled_token::SampledToken;
3use crate::token::LlamaToken;
4use llama_cpp_bindings_sys::{
5    llama_batch, llama_batch_free, llama_batch_init, llama_pos, llama_seq_id,
6};
7use std::marker::PhantomData;
8
9fn checked_n_tokens_plus_one_as_usize(n_tokens: i32) -> Result<usize, BatchAddError> {
10    let incremented = n_tokens.checked_add(1).ok_or_else(|| {
11        BatchAddError::IntegerOverflow(format!("n_tokens + 1 overflows i32: {n_tokens}"))
12    })?;
13
14    usize::try_from(incremented).map_err(|convert_error| {
15        BatchAddError::IntegerOverflow(format!("cannot fit n_tokens into a usize: {convert_error}"))
16    })
17}
18
19fn checked_i32_as_usize(value: i32, description: &str) -> Result<usize, BatchAddError> {
20    usize::try_from(value).map_err(|convert_error| {
21        BatchAddError::IntegerOverflow(format!(
22            "cannot fit {description} into a usize: {convert_error}"
23        ))
24    })
25}
26
27fn checked_usize_as_llama_seq_id(
28    value: usize,
29    description: &str,
30) -> Result<llama_seq_id, BatchAddError> {
31    llama_seq_id::try_from(value).map_err(|convert_error| {
32        BatchAddError::IntegerOverflow(format!(
33            "cannot fit {description} into a llama_seq_id: {convert_error}"
34        ))
35    })
36}
37
38fn checked_usize_as_i32(value: usize, description: &str) -> Result<i32, BatchAddError> {
39    i32::try_from(value).map_err(|convert_error| {
40        BatchAddError::IntegerOverflow(format!(
41            "cannot fit {description} into a i32: {convert_error}"
42        ))
43    })
44}
45
46fn checked_usize_as_llama_pos(value: usize, description: &str) -> Result<llama_pos, BatchAddError> {
47    llama_pos::try_from(value).map_err(|convert_error| {
48        BatchAddError::IntegerOverflow(format!(
49            "cannot fit {description} into a llama_pos: {convert_error}"
50        ))
51    })
52}
53
54#[derive(Debug)]
55pub struct LlamaBatch<'tokens> {
56    allocated: usize,
57    pub initialized_logits: Vec<i32>,
58    pub llama_batch: llama_batch,
59    phantom: PhantomData<&'tokens [LlamaToken]>,
60}
61
62impl<'tokens> LlamaBatch<'tokens> {
63    pub fn clear(&mut self) {
64        self.llama_batch.n_tokens = 0;
65        self.initialized_logits.clear();
66    }
67
68    /// # Errors
69    ///
70    /// Returns an error if there is insufficient space in the buffer or if integer conversions fail.
71    pub fn add(
72        &mut self,
73        sampled_token: &SampledToken,
74        pos: llama_pos,
75        seq_ids: &[i32],
76        logits: bool,
77    ) -> Result<(), BatchAddError> {
78        let (SampledToken::Content(LlamaToken(id))
79        | SampledToken::Reasoning(LlamaToken(id))
80        | SampledToken::ToolCall(LlamaToken(id))
81        | SampledToken::Undeterminable(LlamaToken(id))) = *sampled_token;
82        let required = checked_n_tokens_plus_one_as_usize(self.n_tokens())?;
83
84        if self.allocated < required {
85            return Err(BatchAddError::InsufficientSpace(self.allocated));
86        }
87
88        let offset = self.llama_batch.n_tokens;
89        let offset_usize = checked_i32_as_usize(offset, "n_tokens")?;
90        let n_seq_id = checked_usize_as_llama_seq_id(seq_ids.len(), "seq_ids.len()")?;
91
92        unsafe {
93            self.llama_batch.token.add(offset_usize).write(id);
94            self.llama_batch.pos.add(offset_usize).write(pos);
95            self.llama_batch.n_seq_id.add(offset_usize).write(n_seq_id);
96            for (seq_index, seq_id) in seq_ids.iter().enumerate() {
97                let tmp = *self.llama_batch.seq_id.add(offset_usize);
98                tmp.add(seq_index).write(*seq_id);
99            }
100            self.llama_batch
101                .logits
102                .add(offset_usize)
103                .write(i8::from(logits));
104        }
105
106        if logits {
107            self.initialized_logits.push(offset);
108        }
109
110        self.llama_batch.n_tokens += 1;
111
112        Ok(())
113    }
114
115    /// # Errors
116    ///
117    /// Returns an error if there is insufficient space in the buffer or if integer conversions fail.
118    pub fn add_sequence(
119        &mut self,
120        tokens: &[LlamaToken],
121        seq_id: i32,
122        logits_all: bool,
123    ) -> Result<(), BatchAddError> {
124        let last_index = checked_usize_as_llama_pos(tokens.len().saturating_sub(1), "n_tokens")?;
125
126        for (position, token) in (0..).zip(tokens.iter()) {
127            self.add(
128                &SampledToken::Content(*token),
129                position,
130                &[seq_id],
131                logits_all || position == last_index,
132            )?;
133        }
134
135        Ok(())
136    }
137
138    /// # Errors
139    ///
140    /// Returns an error if `n_tokens` exceeds `i32::MAX`.
141    pub fn new(n_tokens: usize, n_seq_max: i32) -> Result<Self, BatchAddError> {
142        let n_tokens_i32 = checked_usize_as_i32(n_tokens, "n_tokens")?;
143        let batch = unsafe { llama_batch_init(n_tokens_i32, 0, n_seq_max) };
144
145        Ok(LlamaBatch {
146            allocated: n_tokens,
147            initialized_logits: vec![],
148            llama_batch: batch,
149            phantom: PhantomData,
150        })
151    }
152
153    /// # Errors
154    ///
155    /// Returns an error if the provided token buffer is empty or if integer conversions fail.
156    pub fn get_one(tokens: &'tokens [LlamaToken]) -> Result<Self, BatchAddError> {
157        if tokens.is_empty() {
158            return Err(BatchAddError::EmptyBuffer);
159        }
160
161        let token_count = checked_usize_as_i32(tokens.len(), "token count")?;
162
163        let batch = unsafe {
164            #[expect(
165                clippy::as_ptr_cast_mut,
166                reason = "llama_batch_get_one signature requires *mut i32 but does not mutate the tokens"
167            )]
168            let ptr = tokens.as_ptr() as *mut i32;
169            llama_cpp_bindings_sys::llama_batch_get_one(ptr, token_count)
170        };
171
172        let last_token_index = checked_usize_as_i32(tokens.len() - 1, "last token index")?;
173
174        Ok(Self {
175            allocated: 0,
176            initialized_logits: vec![last_token_index],
177            llama_batch: batch,
178            phantom: PhantomData,
179        })
180    }
181
182    #[must_use]
183    pub const fn n_tokens(&self) -> i32 {
184        self.llama_batch.n_tokens
185    }
186}
187
188impl Drop for LlamaBatch<'_> {
189    fn drop(&mut self) {
190        unsafe {
191            if self.allocated > 0 {
192                llama_batch_free(self.llama_batch);
193            }
194        }
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use crate::sampled_token::SampledToken;
201    use crate::token::LlamaToken;
202
203    use super::{
204        BatchAddError, LlamaBatch, checked_i32_as_usize, checked_n_tokens_plus_one_as_usize,
205        checked_usize_as_i32, checked_usize_as_llama_pos, checked_usize_as_llama_seq_id,
206    };
207
208    #[test]
209    fn new_creates_empty_batch() {
210        let batch = LlamaBatch::new(16, 1).unwrap();
211
212        assert_eq!(batch.n_tokens(), 0);
213        assert!(batch.initialized_logits.is_empty());
214    }
215
216    #[test]
217    fn clear_resets_batch() {
218        let mut batch = LlamaBatch::new(16, 1).unwrap();
219        batch
220            .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], true)
221            .unwrap();
222        assert_eq!(batch.n_tokens(), 1);
223
224        batch.clear();
225
226        assert_eq!(batch.n_tokens(), 0);
227        assert!(batch.initialized_logits.is_empty());
228    }
229
230    #[test]
231    fn add_increments_token_count() {
232        let mut batch = LlamaBatch::new(16, 1).unwrap();
233
234        batch
235            .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], false)
236            .unwrap();
237        assert_eq!(batch.n_tokens(), 1);
238
239        batch
240            .add(&SampledToken::Content(LlamaToken::new(2)), 1, &[0], false)
241            .unwrap();
242        assert_eq!(batch.n_tokens(), 2);
243    }
244
245    #[test]
246    fn add_tracks_logits() {
247        let mut batch = LlamaBatch::new(16, 1).unwrap();
248
249        batch
250            .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], false)
251            .unwrap();
252        assert!(batch.initialized_logits.is_empty());
253
254        batch
255            .add(&SampledToken::Content(LlamaToken::new(2)), 1, &[0], true)
256            .unwrap();
257        assert_eq!(batch.initialized_logits, vec![1]);
258    }
259
260    #[test]
261    fn add_returns_insufficient_space_when_full() {
262        let mut batch = LlamaBatch::new(1, 1).unwrap();
263        batch
264            .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], false)
265            .unwrap();
266
267        let result = batch.add(&SampledToken::Content(LlamaToken::new(2)), 1, &[0], false);
268
269        assert_eq!(result, Err(BatchAddError::InsufficientSpace(1)));
270    }
271
272    #[test]
273    fn add_accepts_reasoning_sampled_token_variant() {
274        let mut batch = LlamaBatch::new(4, 1).unwrap();
275
276        batch
277            .add(&SampledToken::Reasoning(LlamaToken::new(11)), 0, &[0], true)
278            .unwrap();
279
280        assert_eq!(batch.n_tokens(), 1);
281    }
282
283    #[test]
284    fn add_accepts_tool_call_sampled_token_variant() {
285        let mut batch = LlamaBatch::new(4, 1).unwrap();
286
287        batch
288            .add(&SampledToken::ToolCall(LlamaToken::new(22)), 0, &[0], true)
289            .unwrap();
290
291        assert_eq!(batch.n_tokens(), 1);
292    }
293
294    #[test]
295    fn add_accepts_undeterminable_sampled_token_variant() {
296        let mut batch = LlamaBatch::new(4, 1).unwrap();
297
298        batch
299            .add(
300                &SampledToken::Undeterminable(LlamaToken::new(33)),
301                0,
302                &[0],
303                false,
304            )
305            .unwrap();
306
307        assert_eq!(batch.n_tokens(), 1);
308    }
309
310    #[test]
311    fn add_sequence_adds_all_tokens() {
312        let mut batch = LlamaBatch::new(16, 1).unwrap();
313        let tokens = vec![
314            LlamaToken::new(10),
315            LlamaToken::new(20),
316            LlamaToken::new(30),
317        ];
318
319        batch.add_sequence(&tokens, 0, false).unwrap();
320
321        assert_eq!(batch.n_tokens(), 3);
322    }
323
324    #[test]
325    fn add_sequence_sets_logits_on_last_token() {
326        let mut batch = LlamaBatch::new(16, 1).unwrap();
327        let tokens = vec![
328            LlamaToken::new(10),
329            LlamaToken::new(20),
330            LlamaToken::new(30),
331        ];
332
333        batch.add_sequence(&tokens, 0, false).unwrap();
334
335        assert_eq!(batch.initialized_logits, vec![2]);
336    }
337
338    #[test]
339    fn add_sequence_insufficient_space() {
340        let mut batch = LlamaBatch::new(2, 1).unwrap();
341        let tokens = vec![
342            LlamaToken::new(10),
343            LlamaToken::new(20),
344            LlamaToken::new(30),
345        ];
346
347        let result = batch.add_sequence(&tokens, 0, false);
348
349        assert!(result.is_err());
350    }
351
352    #[test]
353    fn add_sequence_fails_mid_loop_when_batch_fills() {
354        let mut batch = LlamaBatch::new(2, 1).unwrap();
355        batch
356            .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], false)
357            .unwrap();
358
359        let tokens = vec![LlamaToken::new(10), LlamaToken::new(20)];
360        let result = batch.add_sequence(&tokens, 0, false);
361
362        assert!(result.is_err());
363    }
364
365    #[test]
366    fn get_one_with_valid_tokens() {
367        let tokens = vec![LlamaToken::new(1), LlamaToken::new(2)];
368        let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
369
370        assert_eq!(batch.n_tokens(), 2);
371        assert_eq!(batch.initialized_logits, vec![1]);
372    }
373
374    #[test]
375    fn get_one_empty_slice_returns_error() {
376        let tokens: Vec<LlamaToken> = vec![];
377        let result = LlamaBatch::get_one(&tokens);
378
379        assert_eq!(result.unwrap_err(), BatchAddError::EmptyBuffer);
380    }
381
382    #[test]
383    fn get_one_single_token() {
384        let tokens = vec![LlamaToken::new(42)];
385        let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
386
387        assert_eq!(batch.n_tokens(), 1);
388        assert_eq!(batch.initialized_logits, vec![0]);
389    }
390
391    #[test]
392    fn add_with_logits_false_retains_only_previous_logits() {
393        let mut batch = LlamaBatch::new(16, 1).unwrap();
394
395        batch
396            .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], true)
397            .unwrap();
398        assert_eq!(batch.initialized_logits, vec![0]);
399
400        batch
401            .add(&SampledToken::Content(LlamaToken::new(2)), 0, &[0], false)
402            .unwrap();
403        assert_eq!(batch.initialized_logits, vec![0]);
404    }
405
406    #[test]
407    fn add_sequence_with_logits_all_marks_every_token() -> Result<(), BatchAddError> {
408        let mut batch = LlamaBatch::new(16, 1)?;
409        let tokens = vec![
410            LlamaToken::new(10),
411            LlamaToken::new(20),
412            LlamaToken::new(30),
413        ];
414
415        batch.add_sequence(&tokens, 0, true)?;
416
417        assert_eq!(batch.n_tokens(), 3);
418        assert_eq!(batch.initialized_logits, vec![0, 1, 2]);
419
420        Ok(())
421    }
422
423    #[test]
424    fn add_with_multiple_seq_ids() -> Result<(), BatchAddError> {
425        let mut batch = LlamaBatch::new(16, 4)?;
426
427        batch.add(
428            &SampledToken::Content(LlamaToken::new(1)),
429            0,
430            &[0, 1, 2],
431            true,
432        )?;
433
434        assert_eq!(batch.n_tokens(), 1);
435        assert_eq!(batch.initialized_logits, vec![0]);
436
437        Ok(())
438    }
439
440    #[test]
441    fn drop_does_not_free_get_one_batch() {
442        let tokens = vec![LlamaToken::new(1), LlamaToken::new(2)];
443        let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
444
445        assert_eq!(batch.allocated, 0);
446        drop(batch);
447    }
448
449    #[test]
450    fn checked_n_tokens_plus_one_as_usize_succeeds_for_zero() {
451        let result = checked_n_tokens_plus_one_as_usize(0);
452
453        assert_eq!(result, Ok(1));
454    }
455
456    #[test]
457    fn checked_n_tokens_plus_one_as_usize_fails_for_negative() {
458        let result = checked_n_tokens_plus_one_as_usize(-2);
459
460        assert!(result.unwrap_err().to_string().contains("overflow"));
461    }
462
463    #[test]
464    fn checked_n_tokens_plus_one_as_usize_fails_for_i32_max() {
465        let result = checked_n_tokens_plus_one_as_usize(i32::MAX);
466
467        assert!(result.unwrap_err().to_string().contains("overflow"));
468    }
469
470    #[test]
471    fn checked_i32_as_usize_succeeds_for_zero() {
472        let result = checked_i32_as_usize(0, "test_value");
473
474        assert_eq!(result, Ok(0));
475    }
476
477    #[test]
478    fn checked_i32_as_usize_fails_for_negative() {
479        let result = checked_i32_as_usize(i32::MIN, "test_value");
480
481        assert!(result.unwrap_err().to_string().contains("overflow"));
482    }
483
484    #[test]
485    fn checked_usize_as_llama_seq_id_succeeds_for_zero() {
486        let result = checked_usize_as_llama_seq_id(0, "test_value");
487
488        assert_eq!(result, Ok(0));
489    }
490
491    #[test]
492    fn checked_usize_as_llama_seq_id_fails_for_overflow() {
493        let result = checked_usize_as_llama_seq_id(usize::MAX, "test_value");
494
495        assert!(result.unwrap_err().to_string().contains("overflow"));
496    }
497
498    #[test]
499    fn checked_usize_as_i32_succeeds_for_zero() {
500        let result = checked_usize_as_i32(0, "test_value");
501
502        assert_eq!(result, Ok(0));
503    }
504
505    #[test]
506    fn checked_usize_as_i32_fails_for_overflow() {
507        let result = checked_usize_as_i32(usize::MAX, "test_value");
508
509        assert!(result.unwrap_err().to_string().contains("overflow"));
510    }
511
512    #[test]
513    fn checked_usize_as_llama_pos_succeeds_for_zero() {
514        let result = checked_usize_as_llama_pos(0, "test_value");
515
516        assert_eq!(result, Ok(0));
517    }
518
519    #[test]
520    fn checked_usize_as_llama_pos_fails_for_overflow() {
521        let result = checked_usize_as_llama_pos(usize::MAX, "test_value");
522
523        assert!(result.unwrap_err().to_string().contains("overflow"));
524    }
525
526    #[test]
527    fn new_fails_for_oversized_n_tokens() {
528        let result = LlamaBatch::new(usize::MAX, 1);
529
530        assert!(result.unwrap_err().to_string().contains("overflow"));
531    }
532}