Skip to main content

llama_cpp_4/
llama_batch.rs

1//! Safe wrapper around `llama_batch`.
2
3use crate::token::LlamaToken;
4use llama_cpp_sys_4::{llama_batch, llama_batch_free, llama_batch_init, llama_pos, llama_seq_id};
5
6/// A safe wrapper around `llama_batch`.
7#[derive(Debug)]
8#[allow(clippy::struct_field_names)]
9pub struct LlamaBatch {
10    /// The number of tokens the batch was allocated with. they are safe to write to - but not necessarily read from as they are not necessarily initialized
11    allocated: usize,
12    /// The logits that are initialized. Used by [`LlamaContext`] to ensure that only initialized logits are accessed.
13    pub(crate) initialized_logits: Vec<i32>,
14    /// The `llama_cpp` batch. always initialize by `llama_cpp_sys_4::llama_batch_init(allocated, <unknown>, <unknown>)`
15    pub(crate) llama_batch: llama_batch,
16}
17
18/// Errors that can occur when adding a token to a batch.
19#[derive(thiserror::Error, Debug, PartialEq, Eq)]
20pub enum BatchAddError {
21    /// There was not enough space in the batch to add the token.
22    #[error("Insufficient Space of {0}")]
23    InsufficientSpace(usize),
24}
25
26impl LlamaBatch {
27    /// Clear the batch. This does not free the memory associated with the batch, but it does reset
28    /// the number of tokens to 0.
29    pub fn clear(&mut self) {
30        self.llama_batch.n_tokens = 0;
31        self.initialized_logits.clear();
32    }
33
34    /// add a token to the batch for sequences `seq_ids` at position `pos`. If `logits` is true, the
35    /// token will be initialized and can be read from after the next decode.
36    ///
37    /// # Panics
38    ///
39    /// - `self.llama_batch.n_tokens` does not fit into a usize
40    /// - `seq_ids.len()` does not fit into a [`llama_seq_id`]
41    ///
42    /// # Errors
43    ///
44    /// returns a error if there is insufficient space in the buffer
45    pub fn add(
46        &mut self,
47        LlamaToken(id): LlamaToken,
48        pos: llama_pos,
49        seq_ids: &[i32],
50        logits: bool,
51    ) -> Result<(), BatchAddError> {
52        if self.allocated
53            < usize::try_from(self.n_tokens() + 1).expect("cannot fit n_tokens into a usize")
54        {
55            return Err(BatchAddError::InsufficientSpace(self.allocated));
56        }
57        let offset = self.llama_batch.n_tokens;
58        let offset_usize = usize::try_from(offset).expect("cannot fit n_tokens into a usize");
59        unsafe {
60            // batch.token   [batch.n_tokens] = id;
61            self.llama_batch.token.add(offset_usize).write(id);
62            // batch.pos     [batch.n_tokens] = pos,
63            self.llama_batch.pos.add(offset_usize).write(pos);
64            // batch.n_seq_id[batch.n_tokens] = seq_ids.size();
65            self.llama_batch.n_seq_id.add(offset_usize).write(
66                llama_seq_id::try_from(seq_ids.len())
67                    .expect("cannot fit seq_ids.len() into a llama_seq_id"),
68            );
69            // for (size_t i = 0; i < seq_ids.size(); ++i) {
70            //     batch.seq_id[batch.n_tokens][i] = seq_ids[i];
71            // }
72            for (i, seq_id) in seq_ids.iter().enumerate() {
73                let tmp = *self.llama_batch.seq_id.add(offset_usize);
74                tmp.add(i).write(*seq_id);
75            }
76            // batch.logits  [batch.n_tokens] = logits;
77            self.llama_batch
78                .logits
79                .add(offset_usize)
80                .write(i8::from(logits));
81        }
82
83        if logits {
84            self.initialized_logits.push(offset);
85        } else {
86            self.initialized_logits.retain(|l| l != &offset);
87        }
88
89        // batch.n_tokens++;
90        self.llama_batch.n_tokens += 1;
91
92        Ok(())
93    }
94
95    /// Add a sequence of tokens to the batch for the given sequence id. If `logits_all` is true, the
96    /// tokens will be initialized and can be read from after the next decode.
97    ///
98    /// Either way the last token in the sequence will have its logits set to `true`.
99    ///
100    /// # Errors
101    ///
102    /// Returns an error if there is insufficient space in the buffer
103    ///
104    /// # Panics
105    ///
106    /// - `self.llama_batch.n_tokens` does not fit into a [`usize`]
107    /// - [`n_tokens - 1`] does not fit into a [`llama_pos`]
108    pub fn add_sequence(
109        &mut self,
110        tokens: &[LlamaToken],
111        seq_id: i32,
112        logits_all: bool,
113    ) -> Result<(), BatchAddError> {
114        let n_tokens_0 =
115            usize::try_from(self.llama_batch.n_tokens).expect("cannot fit n_tokens into a usize");
116        let n_tokens = tokens.len();
117
118        if self.allocated < n_tokens_0 + n_tokens {
119            return Err(BatchAddError::InsufficientSpace(self.allocated));
120        }
121
122        let last_index = llama_pos::try_from(n_tokens.saturating_sub(1))
123            .expect("cannot fit n_tokens into a llama_pos");
124        for (i, token) in (0..).zip(tokens.iter()) {
125            self.add(*token, i, &[seq_id], logits_all || i == last_index)?;
126        }
127
128        Ok(())
129    }
130
131    /// Create a new `LlamaBatch` that can contain up to `n_tokens` tokens.
132    ///
133    /// # Arguments
134    ///
135    /// - `n_tokens`: the maximum number of tokens that can be added to the batch
136    /// - `n_seq_max`: the maximum number of sequences that can be added to the batch (generally 1 unless you know what you are doing)
137    ///
138    /// # Panics
139    ///
140    /// Panics if `n_tokens` is greater than `i32::MAX`.
141    #[must_use]
142    pub fn new(n_tokens: usize, n_seq_max: i32) -> Self {
143        let n_tokens_i32 = i32::try_from(n_tokens).expect("cannot fit n_tokens into a i32");
144        let batch = unsafe { llama_batch_init(n_tokens_i32, 0, n_seq_max) };
145
146        LlamaBatch {
147            allocated: n_tokens,
148            initialized_logits: vec![],
149            llama_batch: batch,
150        }
151    }
152
153    /// Returns the number of tokens in the batch.
154    #[must_use]
155    pub fn n_tokens(&self) -> i32 {
156        self.llama_batch.n_tokens
157    }
158
159    /// Create a batch from a slice of tokens for simple one-shot decoding.
160    ///
161    /// The returned batch uses the provided token buffer directly and does not own the memory.
162    /// All tokens are assigned to sequence 0 and logits are enabled for the last token only.
163    ///
164    /// **Note:** The returned batch does NOT free memory on drop — it borrows from the input
165    /// slice. The caller must ensure `tokens` outlives the returned batch.
166    ///
167    /// # Panics
168    ///
169    /// Panics if `tokens.len()` does not fit into an `i32`.
170    #[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
171    #[must_use]
172    pub fn get_one(tokens: &mut [LlamaToken]) -> llama_batch {
173        unsafe {
174            llama_cpp_sys_4::llama_batch_get_one(
175                tokens.as_mut_ptr().cast(),
176                tokens.len() as i32,
177            )
178        }
179    }
180}
181
182impl Drop for LlamaBatch {
183    /// Drops the `LlamaBatch`.
184    ///
185    /// ```
186    /// # use llama_cpp_4::llama_batch::LlamaBatch;
187    /// # use std::error::Error;
188    /// # fn main() -> Result<(), Box<dyn Error>> {
189    /// let batch = LlamaBatch::new(512, 1);
190    /// // frees the memory associated with the batch. (allocated by llama.cpp)
191    /// drop(batch);
192    /// # Ok(())
193    /// # }
194    fn drop(&mut self) {
195        unsafe {
196            llama_batch_free(self.llama_batch);
197        }
198    }
199}