Skip to main content

llama_cpp_bindings/
llama_batch.rs

1//! Safe wrapper around `llama_batch`.
2
3use crate::token::LlamaToken;
4use llama_cpp_bindings_sys::{
5    llama_batch, llama_batch_free, llama_batch_init, llama_pos, llama_seq_id,
6};
7use std::marker::PhantomData;
8
9/// A safe wrapper around `llama_batch`.
10#[derive(Debug)]
11pub struct LlamaBatch<'tokens> {
12    /// The number of tokens the batch was allocated with. they are safe to write to - but not necessarily read from as they are not necessarily initialized
13    allocated: usize,
14    /// The logits that are initialized. Used by [`LlamaContext`] to ensure that only initialized logits are accessed.
15    pub initialized_logits: Vec<i32>,
16    /// The underlying `llama_batch` from the C API.
17    pub llama_batch: llama_batch,
18    phantom: PhantomData<&'tokens [LlamaToken]>,
19}
20
21/// Errors that can occur when adding a token to a batch.
22#[derive(thiserror::Error, Debug, PartialEq, Eq)]
23pub enum BatchAddError {
24    /// There was not enough space in the batch to add the token.
25    #[error("Insufficient Space of {0}")]
26    InsufficientSpace(usize),
27    /// Empty buffer is provided for [`LlamaBatch::get_one`]
28    #[error("Empty buffer")]
29    EmptyBuffer,
30}
31
32impl<'tokens> LlamaBatch<'tokens> {
33    /// Clear the batch. This does not free the memory associated with the batch, but it does reset
34    /// the number of tokens to 0.
35    pub fn clear(&mut self) {
36        self.llama_batch.n_tokens = 0;
37        self.initialized_logits.clear();
38    }
39
40    /// add a token to the batch for sequences `seq_ids` at position `pos`. If `logits` is true, the
41    /// token will be initialized and can be read from after the next decode.
42    ///
43    /// # Panics
44    ///
45    /// - [`self.llama_batch.n_tokens`] does not fit into a usize
46    /// - [`seq_ids.len()`] does not fit into a [`llama_seq_id`]
47    ///
48    /// # Errors
49    ///
50    /// returns a error if there is insufficient space in the buffer
51    pub fn add(
52        &mut self,
53        LlamaToken(id): LlamaToken,
54        pos: llama_pos,
55        seq_ids: &[i32],
56        logits: bool,
57    ) -> Result<(), BatchAddError> {
58        if self.allocated
59            < usize::try_from(self.n_tokens() + 1).expect("cannot fit n_tokens into a usize")
60        {
61            return Err(BatchAddError::InsufficientSpace(self.allocated));
62        }
63        let offset = self.llama_batch.n_tokens;
64        let offset_usize = usize::try_from(offset).expect("cannot fit n_tokens into a usize");
65        unsafe {
66            // batch.token   [batch.n_tokens] = id;
67            self.llama_batch.token.add(offset_usize).write(id);
68            // batch.pos     [batch.n_tokens] = pos,
69            self.llama_batch.pos.add(offset_usize).write(pos);
70            // batch.n_seq_id[batch.n_tokens] = seq_ids.size();
71            self.llama_batch.n_seq_id.add(offset_usize).write(
72                llama_seq_id::try_from(seq_ids.len())
73                    .expect("cannot fit seq_ids.len() into a llama_seq_id"),
74            );
75            // for (size_t i = 0; i < seq_ids.size(); ++i) {
76            //     batch.seq_id[batch.n_tokens][i] = seq_ids[i];
77            // }
78            for (i, seq_id) in seq_ids.iter().enumerate() {
79                let tmp = *self.llama_batch.seq_id.add(offset_usize);
80                tmp.add(i).write(*seq_id);
81            }
82            // batch.logits  [batch.n_tokens] = logits;
83            self.llama_batch
84                .logits
85                .add(offset_usize)
86                .write(i8::from(logits));
87        }
88
89        if logits {
90            self.initialized_logits.push(offset);
91        } else {
92            self.initialized_logits.retain(|l| l != &offset);
93        }
94
95        // batch.n_tokens++;
96        self.llama_batch.n_tokens += 1;
97
98        Ok(())
99    }
100
101    /// Add a sequence of tokens to the batch for the given sequence id. If `logits_all` is true, the
102    /// tokens will be initialized and can be read from after the next decode.
103    ///
104    /// Either way the last token in the sequence will have its logits set to `true`.
105    ///
106    /// # Errors
107    ///
108    /// Returns an error if there is insufficient space in the buffer
109    ///
110    /// # Panics
111    ///
112    /// - [`self.llama_batch.n_tokens`] does not fit into a [`usize`]
113    /// - [`n_tokens - 1`] does not fit into a [`llama_pos`]
114    pub fn add_sequence(
115        &mut self,
116        tokens: &[LlamaToken],
117        seq_id: i32,
118        logits_all: bool,
119    ) -> Result<(), BatchAddError> {
120        let n_tokens_0 =
121            usize::try_from(self.llama_batch.n_tokens).expect("cannot fit n_tokens into a usize");
122        let n_tokens = tokens.len();
123
124        if self.allocated < n_tokens_0 + n_tokens {
125            return Err(BatchAddError::InsufficientSpace(self.allocated));
126        }
127
128        let last_index = llama_pos::try_from(n_tokens.saturating_sub(1))
129            .expect("cannot fit n_tokens into a llama_pos");
130        for (i, token) in (0..).zip(tokens.iter()) {
131            self.add(*token, i, &[seq_id], logits_all || i == last_index)?;
132        }
133
134        Ok(())
135    }
136
137    /// Create a new `LlamaBatch` that can contain up to `n_tokens` tokens.
138    ///
139    /// # Arguments
140    ///
141    /// - `n_tokens`: the maximum number of tokens that can be added to the batch
142    /// - `n_seq_max`: the maximum number of sequences that can be added to the batch (generally 1 unless you know what you are doing)
143    ///
144    /// # Panics
145    ///
146    /// Panics if `n_tokens` is greater than `i32::MAX`.
147    #[must_use]
148    pub fn new(n_tokens: usize, n_seq_max: i32) -> Self {
149        let n_tokens_i32 = i32::try_from(n_tokens).expect("cannot fit n_tokens into a i32");
150        let batch = unsafe { llama_batch_init(n_tokens_i32, 0, n_seq_max) };
151
152        LlamaBatch {
153            allocated: n_tokens,
154            initialized_logits: vec![],
155            llama_batch: batch,
156            phantom: PhantomData,
157        }
158    }
159
160    /// ``llama_batch_get_one``
161    /// Return batch for single sequence of tokens
162    ///
163    /// NOTE: this is a helper function to facilitate transition to the new batch API
164    ///
165    /// # Errors
166    /// If the provided token buffer is empty.
167    ///
168    /// # Panics
169    /// If the number of tokens in ``tokens`` exceeds [`i32::MAX`].
170    pub fn get_one(tokens: &'tokens [LlamaToken]) -> Result<Self, BatchAddError> {
171        if tokens.is_empty() {
172            return Err(BatchAddError::EmptyBuffer);
173        }
174        let batch = unsafe {
175            let ptr = tokens.as_ptr() as *mut i32;
176            llama_cpp_bindings_sys::llama_batch_get_one(
177                ptr,
178                tokens
179                    .len()
180                    .try_into()
181                    .expect("number of tokens exceeds i32::MAX"),
182            )
183        };
184        let batch = Self {
185            allocated: 0,
186            initialized_logits: vec![
187                (tokens.len() - 1)
188                    .try_into()
189                    .expect("number of tokens exceeds i32::MAX + 1"),
190            ],
191            llama_batch: batch,
192            phantom: PhantomData,
193        };
194        Ok(batch)
195    }
196
197    /// Returns the number of tokens in the batch.
198    #[must_use]
199    pub fn n_tokens(&self) -> i32 {
200        self.llama_batch.n_tokens
201    }
202}
203
204impl Drop for LlamaBatch<'_> {
205    /// Drops the `LlamaBatch`.
206    ///
207    /// ```
208    /// # use llama_cpp_bindings::llama_batch::LlamaBatch;
209    /// # use std::error::Error;
210    /// # fn main() -> Result<(), Box<dyn Error>> {
211    /// let batch = LlamaBatch::new(512, 1);
212    /// // frees the memory associated with the batch. (allocated by llama.cpp)
213    /// drop(batch);
214    /// # Ok(())
215    /// # }
216    fn drop(&mut self) {
217        unsafe {
218            if self.allocated > 0 {
219                llama_batch_free(self.llama_batch);
220            }
221        }
222    }
223}