llama_cpp_2/
llama_batch.rs

1//! Safe wrapper around `llama_batch`.
2
3use crate::token::LlamaToken;
4use llama_cpp_sys_2::{llama_batch, llama_batch_free, llama_batch_init, llama_pos, llama_seq_id};
5use std::marker::PhantomData;
6
7/// A safe wrapper around `llama_batch`.
8#[derive(Debug)]
9pub struct LlamaBatch<'a> {
10    /// The number of tokens the batch was allocated with. they are safe to write to - but not necessarily read from as they are not necessarily initialized
11    allocated: usize,
12    /// The logits that are initialized. Used by [`LlamaContext`] to ensure that only initialized logits are accessed.
13    pub(crate) initialized_logits: Vec<i32>,
14    #[allow(clippy::doc_markdown)]
15    /// The llama_cpp batch. always initialize by `llama_cpp_sys_2::llama_batch_init(allocated, <unknown>, <unknown>)`
16    pub(crate) llama_batch: llama_batch,
17    phantom: PhantomData<&'a [LlamaToken]>,
18}
19
20/// Errors that can occur when adding a token to a batch.
21#[derive(thiserror::Error, Debug, PartialEq, Eq)]
22pub enum BatchAddError {
23    /// There was not enough space in the batch to add the token.
24    #[error("Insufficient Space of {0}")]
25    InsufficientSpace(usize),
26    /// Empty buffer is provided for [`LlamaBatch::get_one`]
27    #[error("Empty buffer")]
28    EmptyBuffer,
29}
30
31impl<'a> LlamaBatch<'a> {
32    /// Clear the batch. This does not free the memory associated with the batch, but it does reset
33    /// the number of tokens to 0.
34    pub fn clear(&mut self) {
35        self.llama_batch.n_tokens = 0;
36        self.initialized_logits.clear();
37    }
38
39    /// add a token to the batch for sequences `seq_ids` at position `pos`. If `logits` is true, the
40    /// token will be initialized and can be read from after the next decode.
41    ///
42    /// # Panics
43    ///
44    /// - [`self.llama_batch.n_tokens`] does not fit into a usize
45    /// - [`seq_ids.len()`] does not fit into a [`llama_seq_id`]
46    ///
47    /// # Errors
48    ///
49    /// returns a error if there is insufficient space in the buffer
50    pub fn add(
51        &mut self,
52        LlamaToken(id): LlamaToken,
53        pos: llama_pos,
54        seq_ids: &[i32],
55        logits: bool,
56    ) -> Result<(), BatchAddError> {
57        if self.allocated
58            < usize::try_from(self.n_tokens() + 1).expect("cannot fit n_tokens into a usize")
59        {
60            return Err(BatchAddError::InsufficientSpace(self.allocated));
61        }
62        let offset = self.llama_batch.n_tokens;
63        let offset_usize = usize::try_from(offset).expect("cannot fit n_tokens into a usize");
64        unsafe {
65            // batch.token   [batch.n_tokens] = id;
66            self.llama_batch.token.add(offset_usize).write(id);
67            // batch.pos     [batch.n_tokens] = pos,
68            self.llama_batch.pos.add(offset_usize).write(pos);
69            // batch.n_seq_id[batch.n_tokens] = seq_ids.size();
70            self.llama_batch.n_seq_id.add(offset_usize).write(
71                llama_seq_id::try_from(seq_ids.len())
72                    .expect("cannot fit seq_ids.len() into a llama_seq_id"),
73            );
74            // for (size_t i = 0; i < seq_ids.size(); ++i) {
75            //     batch.seq_id[batch.n_tokens][i] = seq_ids[i];
76            // }
77            for (i, seq_id) in seq_ids.iter().enumerate() {
78                let tmp = *self.llama_batch.seq_id.add(offset_usize);
79                tmp.add(i).write(*seq_id);
80            }
81            // batch.logits  [batch.n_tokens] = logits;
82            self.llama_batch
83                .logits
84                .add(offset_usize)
85                .write(i8::from(logits));
86        }
87
88        if logits {
89            self.initialized_logits.push(offset);
90        } else {
91            self.initialized_logits.retain(|l| l != &offset);
92        }
93
94        // batch.n_tokens++;
95        self.llama_batch.n_tokens += 1;
96
97        Ok(())
98    }
99
100    /// Add a sequence of tokens to the batch for the given sequence id. If `logits_all` is true, the
101    /// tokens will be initialized and can be read from after the next decode.
102    ///
103    /// Either way the last token in the sequence will have its logits set to `true`.
104    ///
105    /// # Errors
106    ///
107    /// Returns an error if there is insufficient space in the buffer
108    ///
109    /// # Panics
110    ///
111    /// - [`self.llama_batch.n_tokens`] does not fit into a [`usize`]
112    /// - [`n_tokens - 1`] does not fit into a [`llama_pos`]
113    pub fn add_sequence(
114        &mut self,
115        tokens: &[LlamaToken],
116        seq_id: i32,
117        logits_all: bool,
118    ) -> Result<(), BatchAddError> {
119        let n_tokens_0 =
120            usize::try_from(self.llama_batch.n_tokens).expect("cannot fit n_tokens into a usize");
121        let n_tokens = tokens.len();
122
123        if self.allocated < n_tokens_0 + n_tokens {
124            return Err(BatchAddError::InsufficientSpace(self.allocated));
125        }
126
127        let last_index = llama_pos::try_from(n_tokens.saturating_sub(1))
128            .expect("cannot fit n_tokens into a llama_pos");
129        for (i, token) in (0..).zip(tokens.iter()) {
130            self.add(*token, i, &[seq_id], logits_all || i == last_index)?;
131        }
132
133        Ok(())
134    }
135
136    /// Create a new `LlamaBatch` that can contain up to `n_tokens` tokens.
137    ///
138    /// # Arguments
139    ///
140    /// - `n_tokens`: the maximum number of tokens that can be added to the batch
141    /// - `n_seq_max`: the maximum number of sequences that can be added to the batch (generally 1 unless you know what you are doing)
142    ///
143    /// # Panics
144    ///
145    /// Panics if `n_tokens` is greater than `i32::MAX`.
146    #[must_use]
147    pub fn new(n_tokens: usize, n_seq_max: i32) -> Self {
148        let n_tokens_i32 = i32::try_from(n_tokens).expect("cannot fit n_tokens into a i32");
149        let batch = unsafe { llama_batch_init(n_tokens_i32, 0, n_seq_max) };
150
151        LlamaBatch {
152            allocated: n_tokens,
153            initialized_logits: vec![],
154            llama_batch: batch,
155            phantom: PhantomData,
156        }
157    }
158
159    /// ``llama_batch_get_one``
160    /// Return batch for single sequence of tokens
161    ///
162    /// NOTE: this is a helper function to facilitate transition to the new batch API
163    ///
164    /// # Errors
165    /// If the provided token buffer is empty.
166    ///
167    /// # Panics
168    /// If the number of tokens in ``tokens`` exceeds [`i32::MAX`].
169    pub fn get_one(tokens: &'a [LlamaToken]) -> Result<Self, BatchAddError> {
170        if tokens.is_empty() {
171            return Err(BatchAddError::EmptyBuffer);
172        }
173        let batch = unsafe {
174            let ptr = tokens.as_ptr() as *mut i32;
175            llama_cpp_sys_2::llama_batch_get_one(
176                ptr,
177                tokens
178                    .len()
179                    .try_into()
180                    .expect("number of tokens exceeds i32::MAX"),
181            )
182        };
183        let batch = Self {
184            allocated: 0,
185            initialized_logits: vec![(tokens.len() - 1)
186                .try_into()
187                .expect("number of tokens exceeds i32::MAX + 1")],
188            llama_batch: batch,
189            phantom: PhantomData,
190        };
191        Ok(batch)
192    }
193
194    /// Returns the number of tokens in the batch.
195    #[must_use]
196    pub fn n_tokens(&self) -> i32 {
197        self.llama_batch.n_tokens
198    }
199}
200
201impl<'a> Drop for LlamaBatch<'a> {
202    /// Drops the `LlamaBatch`.
203    ///
204    /// ```
205    /// # use llama_cpp_2::llama_batch::LlamaBatch;
206    /// # use std::error::Error;
207    /// # fn main() -> Result<(), Box<dyn Error>> {
208    /// let batch = LlamaBatch::new(512, 1);
209    /// // frees the memory associated with the batch. (allocated by llama.cpp)
210    /// drop(batch);
211    /// # Ok(())
212    /// # }
213    fn drop(&mut self) {
214        unsafe {
215            if self.allocated > 0 {
216                llama_batch_free(self.llama_batch);
217            }
218        }
219    }
220}