Skip to main content

llama_cpp_bindings/
llama_batch.rs

1//! Safe wrapper around `llama_batch`.
2
3use crate::token::LlamaToken;
4use llama_cpp_bindings_sys::{
5    llama_batch, llama_batch_free, llama_batch_init, llama_pos, llama_seq_id,
6};
7use std::marker::PhantomData;
8
9/// A safe wrapper around `llama_batch`.
10///
11/// `PartialEq` is intentionally not implemented because the underlying `llama_batch`
12/// from the C API contains raw pointers whose address comparison would be meaningless.
13#[derive(Debug)]
14pub struct LlamaBatch<'tokens> {
15    /// The number of tokens the batch was allocated with. they are safe to write to - but not necessarily read from as they are not necessarily initialized
16    allocated: usize,
17    /// The logits that are initialized. Used by [`LlamaContext`] to ensure that only initialized logits are accessed.
18    pub initialized_logits: Vec<i32>,
19    /// The underlying `llama_batch` from the C API.
20    pub llama_batch: llama_batch,
21    phantom: PhantomData<&'tokens [LlamaToken]>,
22}
23
24/// Errors that can occur when adding a token to a batch.
25#[derive(thiserror::Error, Debug, PartialEq, Eq)]
26pub enum BatchAddError {
27    /// There was not enough space in the batch to add the token.
28    #[error("Insufficient Space of {0}")]
29    InsufficientSpace(usize),
30    /// Empty buffer is provided for [`LlamaBatch::get_one`]
31    #[error("Empty buffer")]
32    EmptyBuffer,
33    /// An integer value exceeded the allowed range.
34    #[error("Integer overflow: {0}")]
35    IntegerOverflow(String),
36}
37
38impl<'tokens> LlamaBatch<'tokens> {
39    /// Clear the batch. This does not free the memory associated with the batch, but it does reset
40    /// the number of tokens to 0.
41    pub fn clear(&mut self) {
42        self.llama_batch.n_tokens = 0;
43        self.initialized_logits.clear();
44    }
45
46    /// add a token to the batch for sequences `seq_ids` at position `pos`. If `logits` is true, the
47    /// token will be initialized and can be read from after the next decode.
48    ///
49    /// # Errors
50    ///
51    /// Returns an error if there is insufficient space in the buffer or if integer conversions fail.
52    pub fn add(
53        &mut self,
54        LlamaToken(id): LlamaToken,
55        pos: llama_pos,
56        seq_ids: &[i32],
57        logits: bool,
58    ) -> Result<(), BatchAddError> {
59        let required = usize::try_from(self.n_tokens() + 1).map_err(|convert_error| {
60            BatchAddError::IntegerOverflow(format!(
61                "cannot fit n_tokens into a usize: {convert_error}"
62            ))
63        })?;
64
65        if self.allocated < required {
66            return Err(BatchAddError::InsufficientSpace(self.allocated));
67        }
68
69        let offset = self.llama_batch.n_tokens;
70        let offset_usize = usize::try_from(offset).map_err(|convert_error| {
71            BatchAddError::IntegerOverflow(format!(
72                "cannot fit n_tokens into a usize: {convert_error}"
73            ))
74        })?;
75
76        let n_seq_id = llama_seq_id::try_from(seq_ids.len()).map_err(|convert_error| {
77            BatchAddError::IntegerOverflow(format!(
78                "cannot fit seq_ids.len() into a llama_seq_id: {convert_error}"
79            ))
80        })?;
81
82        unsafe {
83            // batch.token   [batch.n_tokens] = id;
84            self.llama_batch.token.add(offset_usize).write(id);
85            // batch.pos     [batch.n_tokens] = pos,
86            self.llama_batch.pos.add(offset_usize).write(pos);
87            // batch.n_seq_id[batch.n_tokens] = seq_ids.size();
88            self.llama_batch.n_seq_id.add(offset_usize).write(n_seq_id);
89            // for (size_t i = 0; i < seq_ids.size(); ++i) {
90            //     batch.seq_id[batch.n_tokens][i] = seq_ids[i];
91            // }
92            for (seq_index, seq_id) in seq_ids.iter().enumerate() {
93                let tmp = *self.llama_batch.seq_id.add(offset_usize);
94                tmp.add(seq_index).write(*seq_id);
95            }
96            // batch.logits  [batch.n_tokens] = logits;
97            self.llama_batch
98                .logits
99                .add(offset_usize)
100                .write(i8::from(logits));
101        }
102
103        if logits {
104            self.initialized_logits.push(offset);
105        } else {
106            self.initialized_logits
107                .retain(|logit_offset| logit_offset != &offset);
108        }
109
110        self.llama_batch.n_tokens += 1;
111
112        Ok(())
113    }
114
115    /// Add a sequence of tokens to the batch for the given sequence id. If `logits_all` is true, the
116    /// tokens will be initialized and can be read from after the next decode.
117    ///
118    /// Either way the last token in the sequence will have its logits set to `true`.
119    ///
120    /// # Errors
121    ///
122    /// Returns an error if there is insufficient space in the buffer or if integer conversions fail.
123    pub fn add_sequence(
124        &mut self,
125        tokens: &[LlamaToken],
126        seq_id: i32,
127        logits_all: bool,
128    ) -> Result<(), BatchAddError> {
129        let n_tokens_0 = usize::try_from(self.llama_batch.n_tokens).map_err(|convert_error| {
130            BatchAddError::IntegerOverflow(format!(
131                "cannot fit n_tokens into a usize: {convert_error}"
132            ))
133        })?;
134        let n_tokens = tokens.len();
135
136        if self.allocated < n_tokens_0 + n_tokens {
137            return Err(BatchAddError::InsufficientSpace(self.allocated));
138        }
139
140        let last_index =
141            llama_pos::try_from(n_tokens.saturating_sub(1)).map_err(|convert_error| {
142                BatchAddError::IntegerOverflow(format!(
143                    "cannot fit n_tokens into a llama_pos: {convert_error}"
144                ))
145            })?;
146
147        for (position, token) in (0..).zip(tokens.iter()) {
148            self.add(
149                *token,
150                position,
151                &[seq_id],
152                logits_all || position == last_index,
153            )?;
154        }
155
156        Ok(())
157    }
158
159    /// Create a new `LlamaBatch` that can contain up to `n_tokens` tokens.
160    ///
161    /// # Arguments
162    ///
163    /// - `n_tokens`: the maximum number of tokens that can be added to the batch
164    /// - `n_seq_max`: the maximum number of sequences that can be added to the batch (generally 1 unless you know what you are doing)
165    ///
166    /// # Errors
167    ///
168    /// Returns an error if `n_tokens` exceeds `i32::MAX`.
169    pub fn new(n_tokens: usize, n_seq_max: i32) -> Result<Self, BatchAddError> {
170        let n_tokens_i32 = i32::try_from(n_tokens).map_err(|convert_error| {
171            BatchAddError::IntegerOverflow(format!(
172                "cannot fit n_tokens into a i32: {convert_error}"
173            ))
174        })?;
175        let batch = unsafe { llama_batch_init(n_tokens_i32, 0, n_seq_max) };
176
177        Ok(LlamaBatch {
178            allocated: n_tokens,
179            initialized_logits: vec![],
180            llama_batch: batch,
181            phantom: PhantomData,
182        })
183    }
184
185    /// ``llama_batch_get_one``
186    /// Return batch for single sequence of tokens
187    ///
188    /// NOTE: this is a helper function to facilitate transition to the new batch API
189    ///
190    /// # Errors
191    ///
192    /// Returns an error if the provided token buffer is empty or if integer conversions fail.
193    pub fn get_one(tokens: &'tokens [LlamaToken]) -> Result<Self, BatchAddError> {
194        if tokens.is_empty() {
195            return Err(BatchAddError::EmptyBuffer);
196        }
197
198        let token_count = tokens.len().try_into().map_err(|convert_error| {
199            BatchAddError::IntegerOverflow(format!(
200                "number of tokens exceeds i32::MAX: {convert_error}"
201            ))
202        })?;
203
204        let batch = unsafe {
205            let ptr = tokens.as_ptr() as *mut i32;
206            llama_cpp_bindings_sys::llama_batch_get_one(ptr, token_count)
207        };
208
209        let last_token_index = (tokens.len() - 1).try_into().map_err(|convert_error| {
210            BatchAddError::IntegerOverflow(format!(
211                "number of tokens exceeds i32::MAX: {convert_error}"
212            ))
213        })?;
214
215        Ok(Self {
216            allocated: 0,
217            initialized_logits: vec![last_token_index],
218            llama_batch: batch,
219            phantom: PhantomData,
220        })
221    }
222
223    /// Returns the number of tokens in the batch.
224    #[must_use]
225    pub fn n_tokens(&self) -> i32 {
226        self.llama_batch.n_tokens
227    }
228}
229
230impl Drop for LlamaBatch<'_> {
231    /// Drops the `LlamaBatch`.
232    ///
233    /// ```
234    /// # use llama_cpp_bindings::llama_batch::LlamaBatch;
235    /// # use std::error::Error;
236    /// # fn main() -> Result<(), Box<dyn Error>> {
237    /// let batch = LlamaBatch::new(512, 1)?;
238    /// // frees the memory associated with the batch. (allocated by llama.cpp)
239    /// drop(batch);
240    /// # Ok(())
241    /// # }
242    fn drop(&mut self) {
243        unsafe {
244            if self.allocated > 0 {
245                llama_batch_free(self.llama_batch);
246            }
247        }
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use crate::token::LlamaToken;
254
255    use super::{BatchAddError, LlamaBatch};
256
257    #[test]
258    fn new_creates_empty_batch() -> Result<(), BatchAddError> {
259        let batch = LlamaBatch::new(16, 1)?;
260
261        assert_eq!(batch.n_tokens(), 0);
262        assert!(batch.initialized_logits.is_empty());
263
264        Ok(())
265    }
266
267    #[test]
268    fn clear_resets_batch() -> Result<(), BatchAddError> {
269        let mut batch = LlamaBatch::new(16, 1)?;
270        batch.add(LlamaToken::new(1), 0, &[0], true)?;
271        assert_eq!(batch.n_tokens(), 1);
272
273        batch.clear();
274
275        assert_eq!(batch.n_tokens(), 0);
276        assert!(batch.initialized_logits.is_empty());
277
278        Ok(())
279    }
280
281    #[test]
282    fn add_increments_token_count() -> Result<(), BatchAddError> {
283        let mut batch = LlamaBatch::new(16, 1)?;
284
285        batch.add(LlamaToken::new(1), 0, &[0], false)?;
286        assert_eq!(batch.n_tokens(), 1);
287
288        batch.add(LlamaToken::new(2), 1, &[0], false)?;
289        assert_eq!(batch.n_tokens(), 2);
290
291        Ok(())
292    }
293
294    #[test]
295    fn add_tracks_logits() -> Result<(), BatchAddError> {
296        let mut batch = LlamaBatch::new(16, 1)?;
297
298        batch.add(LlamaToken::new(1), 0, &[0], false)?;
299        assert!(batch.initialized_logits.is_empty());
300
301        batch.add(LlamaToken::new(2), 1, &[0], true)?;
302        assert_eq!(batch.initialized_logits, vec![1]);
303
304        Ok(())
305    }
306
307    #[test]
308    fn add_returns_insufficient_space_when_full() -> Result<(), BatchAddError> {
309        let mut batch = LlamaBatch::new(1, 1)?;
310        batch.add(LlamaToken::new(1), 0, &[0], false)?;
311
312        let result = batch.add(LlamaToken::new(2), 1, &[0], false);
313
314        assert_eq!(result, Err(BatchAddError::InsufficientSpace(1)));
315
316        Ok(())
317    }
318
319    #[test]
320    fn add_sequence_adds_all_tokens() -> Result<(), BatchAddError> {
321        let mut batch = LlamaBatch::new(16, 1)?;
322        let tokens = vec![
323            LlamaToken::new(10),
324            LlamaToken::new(20),
325            LlamaToken::new(30),
326        ];
327
328        batch.add_sequence(&tokens, 0, false)?;
329
330        assert_eq!(batch.n_tokens(), 3);
331
332        Ok(())
333    }
334
335    #[test]
336    fn add_sequence_sets_logits_on_last_token() -> Result<(), BatchAddError> {
337        let mut batch = LlamaBatch::new(16, 1)?;
338        let tokens = vec![
339            LlamaToken::new(10),
340            LlamaToken::new(20),
341            LlamaToken::new(30),
342        ];
343
344        batch.add_sequence(&tokens, 0, false)?;
345
346        assert_eq!(batch.initialized_logits, vec![2]);
347
348        Ok(())
349    }
350
351    #[test]
352    fn add_sequence_insufficient_space() -> Result<(), BatchAddError> {
353        let mut batch = LlamaBatch::new(2, 1)?;
354        let tokens = vec![
355            LlamaToken::new(10),
356            LlamaToken::new(20),
357            LlamaToken::new(30),
358        ];
359
360        let result = batch.add_sequence(&tokens, 0, false);
361
362        assert!(result.is_err());
363
364        Ok(())
365    }
366
367    #[test]
368    fn get_one_with_valid_tokens() {
369        let tokens = vec![LlamaToken::new(1), LlamaToken::new(2)];
370        let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
371
372        assert_eq!(batch.n_tokens(), 2);
373        assert_eq!(batch.initialized_logits, vec![1]);
374    }
375
376    #[test]
377    fn get_one_empty_slice_returns_error() {
378        let tokens: Vec<LlamaToken> = vec![];
379        let result = LlamaBatch::get_one(&tokens);
380
381        assert!(
382            matches!(result, Err(BatchAddError::EmptyBuffer)),
383            "expected EmptyBuffer error"
384        );
385    }
386
387    #[test]
388    fn get_one_single_token() {
389        let tokens = vec![LlamaToken::new(42)];
390        let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
391
392        assert_eq!(batch.n_tokens(), 1);
393        assert_eq!(batch.initialized_logits, vec![0]);
394    }
395}