llama_cpp_2/
llama_batch.rs

1//! Safe wrapper around `llama_batch`.
2
3use crate::token::LlamaToken;
4use llama_cpp_sys_2::{llama_batch, llama_batch_free, llama_batch_init, llama_pos, llama_seq_id};
5
6/// A safe wrapper around `llama_batch`.
7#[derive(Debug)]
8pub struct LlamaBatch {
9    /// The number of tokens the batch was allocated with. they are safe to write to - but not necessarily read from as they are not necessarily initialized
10    allocated: usize,
11    /// The logits that are initialized. Used by [`LlamaContext`] to ensure that only initialized logits are accessed.
12    pub(crate) initialized_logits: Vec<i32>,
13    #[allow(clippy::doc_markdown)]
14    /// The llama_cpp batch. always initialize by `llama_cpp_sys_2::llama_batch_init(allocated, <unknown>, <unknown>)`
15    pub(crate) llama_batch: llama_batch,
16}
17
18/// Errors that can occur when adding a token to a batch.
19#[derive(thiserror::Error, Debug, PartialEq, Eq)]
20pub enum BatchAddError {
21    /// There was not enough space in the batch to add the token.
22    #[error("Insufficient Space of {0}")]
23    InsufficientSpace(usize),
24    /// Empty buffer is provided for [`LlamaBatch::get_one`]
25    #[error("Empty buffer")]
26    EmptyBuffer,
27}
28
29impl LlamaBatch {
30    /// Clear the batch. This does not free the memory associated with the batch, but it does reset
31    /// the number of tokens to 0.
32    pub fn clear(&mut self) {
33        self.llama_batch.n_tokens = 0;
34        self.initialized_logits.clear();
35    }
36
37    /// add a token to the batch for sequences `seq_ids` at position `pos`. If `logits` is true, the
38    /// token will be initialized and can be read from after the next decode.
39    ///
40    /// # Panics
41    ///
42    /// - [`self.llama_batch.n_tokens`] does not fit into a usize
43    /// - [`seq_ids.len()`] does not fit into a [`llama_seq_id`]
44    ///
45    /// # Errors
46    ///
47    /// returns a error if there is insufficient space in the buffer
48    pub fn add(
49        &mut self,
50        LlamaToken(id): LlamaToken,
51        pos: llama_pos,
52        seq_ids: &[i32],
53        logits: bool,
54    ) -> Result<(), BatchAddError> {
55        if self.allocated
56            < usize::try_from(self.n_tokens() + 1).expect("cannot fit n_tokens into a usize")
57        {
58            return Err(BatchAddError::InsufficientSpace(self.allocated));
59        }
60        let offset = self.llama_batch.n_tokens;
61        let offset_usize = usize::try_from(offset).expect("cannot fit n_tokens into a usize");
62        unsafe {
63            // batch.token   [batch.n_tokens] = id;
64            self.llama_batch.token.add(offset_usize).write(id);
65            // batch.pos     [batch.n_tokens] = pos,
66            self.llama_batch.pos.add(offset_usize).write(pos);
67            // batch.n_seq_id[batch.n_tokens] = seq_ids.size();
68            self.llama_batch.n_seq_id.add(offset_usize).write(
69                llama_seq_id::try_from(seq_ids.len())
70                    .expect("cannot fit seq_ids.len() into a llama_seq_id"),
71            );
72            // for (size_t i = 0; i < seq_ids.size(); ++i) {
73            //     batch.seq_id[batch.n_tokens][i] = seq_ids[i];
74            // }
75            for (i, seq_id) in seq_ids.iter().enumerate() {
76                let tmp = *self.llama_batch.seq_id.add(offset_usize);
77                tmp.add(i).write(*seq_id);
78            }
79            // batch.logits  [batch.n_tokens] = logits;
80            self.llama_batch
81                .logits
82                .add(offset_usize)
83                .write(i8::from(logits));
84        }
85
86        if logits {
87            self.initialized_logits.push(offset);
88        } else {
89            self.initialized_logits.retain(|l| l != &offset);
90        }
91
92        // batch.n_tokens++;
93        self.llama_batch.n_tokens += 1;
94
95        Ok(())
96    }
97
98    /// Add a sequence of tokens to the batch for the given sequence id. If `logits_all` is true, the
99    /// tokens will be initialized and can be read from after the next decode.
100    ///
101    /// Either way the last token in the sequence will have its logits set to `true`.
102    ///
103    /// # Errors
104    ///
105    /// Returns an error if there is insufficient space in the buffer
106    ///
107    /// # Panics
108    ///
109    /// - [`self.llama_batch.n_tokens`] does not fit into a [`usize`]
110    /// - [`n_tokens - 1`] does not fit into a [`llama_pos`]
111    pub fn add_sequence(
112        &mut self,
113        tokens: &[LlamaToken],
114        seq_id: i32,
115        logits_all: bool,
116    ) -> Result<(), BatchAddError> {
117        let n_tokens_0 =
118            usize::try_from(self.llama_batch.n_tokens).expect("cannot fit n_tokens into a usize");
119        let n_tokens = tokens.len();
120
121        if self.allocated < n_tokens_0 + n_tokens {
122            return Err(BatchAddError::InsufficientSpace(self.allocated));
123        }
124
125        let last_index = llama_pos::try_from(n_tokens.saturating_sub(1))
126            .expect("cannot fit n_tokens into a llama_pos");
127        for (i, token) in (0..).zip(tokens.iter()) {
128            self.add(*token, i, &[seq_id], logits_all || i == last_index)?;
129        }
130
131        Ok(())
132    }
133
134    /// Create a new `LlamaBatch` that can contain up to `n_tokens` tokens.
135    ///
136    /// # Arguments
137    ///
138    /// - `n_tokens`: the maximum number of tokens that can be added to the batch
139    /// - `n_seq_max`: the maximum number of sequences that can be added to the batch (generally 1 unless you know what you are doing)
140    ///
141    /// # Panics
142    ///
143    /// Panics if `n_tokens` is greater than `i32::MAX`.
144    #[must_use]
145    pub fn new(n_tokens: usize, n_seq_max: i32) -> Self {
146        let n_tokens_i32 = i32::try_from(n_tokens).expect("cannot fit n_tokens into a i32");
147        let batch = unsafe { llama_batch_init(n_tokens_i32, 0, n_seq_max) };
148
149        LlamaBatch {
150            allocated: n_tokens,
151            initialized_logits: vec![],
152            llama_batch: batch,
153        }
154    }
155
156    /// ``llama_batch_get_one``
157    /// Return batch for single sequence of tokens
158    ///
159    /// NOTE: this is a helper function to facilitate transition to the new batch API
160    ///
161    /// # Errors
162    /// If the provided token buffer is empty.
163    ///
164    /// # Panics
165    /// If the number of tokens in ``tokens`` exceeds [`i32::MAX`].
166    pub fn get_one(tokens: &[LlamaToken]) -> Result<Self, BatchAddError> {
167        if tokens.is_empty() {
168            return Err(BatchAddError::EmptyBuffer);
169        }
170        let batch = unsafe {
171            let ptr = tokens.as_ptr() as *mut i32;
172            llama_cpp_sys_2::llama_batch_get_one(
173                ptr,
174                tokens
175                    .len()
176                    .try_into()
177                    .expect("number of tokens exceeds i32::MAX"),
178            )
179        };
180        let batch = Self {
181            allocated: 0,
182            initialized_logits: vec![(tokens.len() - 1)
183                .try_into()
184                .expect("number of tokens exceeds i32::MAX + 1")],
185            llama_batch: batch,
186        };
187        Ok(batch)
188    }
189
190    /// Returns the number of tokens in the batch.
191    #[must_use]
192    pub fn n_tokens(&self) -> i32 {
193        self.llama_batch.n_tokens
194    }
195}
196
197impl Drop for LlamaBatch {
198    /// Drops the `LlamaBatch`.
199    ///
200    /// ```
201    /// # use llama_cpp_2::llama_batch::LlamaBatch;
202    /// # use std::error::Error;
203    /// # fn main() -> Result<(), Box<dyn Error>> {
204    /// let batch = LlamaBatch::new(512, 1);
205    /// // frees the memory associated with the batch. (allocated by llama.cpp)
206    /// drop(batch);
207    /// # Ok(())
208    /// # }
209    fn drop(&mut self) {
210        unsafe {
211            if self.allocated > 0 {
212                llama_batch_free(self.llama_batch);
213            }
214        }
215    }
216}