llama_cpp/
batch.rs

1//! Implements the [`Batch`] struct
2
3use llama_cpp_sys::{llama_batch, llama_batch_free, llama_batch_init};
4use tracing::trace;
5
6use crate::Token;
7
8/// A safe wrapper around a [`llama_batch`].
9pub struct Batch {
10    // TODO
11    /// ## Members
12    /// * `n_tokens`: [`i32`] - The number of tokens
13    /// * `tokens`: `*mut` [`llama_token`][llama_token] - The number of tokens
14    /// * `embd`: `*mut` [`f32`] - The number of tokens
15    /// * `pos`: `*mut` [`llama_pos`][llama_pos] - The number of tokens
16    /// * `n_seq_id`: `*mut` [`i32`] - The number of tokens
17    /// * `seq_id`: `*mut *mut` [`llama_seq_id`][llama_seq_id] - The number of tokens
18    /// * `logits`: `*mut` [`i8`] - The number of tokens
19    /// * `all_pos_0`: [`llama_pos`][llama_pos] - The number of tokens
20    /// * `all_pos_1`: [`llama_pos`][llama_pos] - The number of tokens
21    /// * `all_seq_id`: [`llama_seq_id`][llama_seq_id] - The number of tokens
22    ///
23    /// [llama_token]: llama_cpp_sys::llama_token
24    /// [llama_seq_id]: llama_cpp_sys::llama_seq_id
25    /// [llama_pos]: llama_cpp_sys::llama_pos
26    inner: llama_batch,
27
28    /// The maximum number of tokens this batch can have.
29    capacity: usize,
30
31    /// The maximum number of sequences that can be generated for this batch.
32    max_sequences: usize,
33}
34
35impl Batch {
36    pub fn new(capacity: usize, embed: usize, max_sequences: usize) -> Self {
37        // Ideally panic shouldn't be used, but this struct is only used inside this crate, so it
38        // should be fine.
39
40        if capacity == 0 {
41            panic!("Cannot create a batch with no capacity");
42        }
43        if max_sequences == 0 {
44            panic!("At least one sequence must be generated");
45        }
46
47        Self {
48            inner: unsafe { llama_batch_init(capacity as i32, embed as i32, max_sequences as i32) },
49            capacity,
50            max_sequences,
51        }
52    }
53
54    pub fn clear(&mut self) {
55        self.inner.n_tokens = 0;
56    }
57
58    pub fn add(
59        &mut self,
60        token: Token,
61        position: usize,
62        sequence_ids: &[i32],
63        logits: bool,
64    ) -> usize {
65        trace!(
66            "Writing token {} of {} ({token:?})",
67            self.inner.n_tokens,
68            self.capacity
69        );
70
71        let i = self.inner.n_tokens as usize;
72
73        if i == self.capacity || self.max_sequences < sequence_ids.len() {
74            return usize::MAX;
75        }
76
77        unsafe {
78            // SAFETY: For all 0 < i < n_tokens, `llama_batch_init` created each of these
79            // offsets; although each offset may be currently uninitialized.
80            self.inner.token.add(i).write(token.0);
81            self.inner.pos.add(i).write(position as i32);
82            if logits {
83                self.inner.logits.add(i).write(1);
84            } else {
85                self.inner.logits.add(i).write(0);
86            }
87            self.inner.n_seq_id.add(i).write(sequence_ids.len() as i32);
88
89            let seq_ptr = *self.inner.seq_id.add(i);
90
91            if !seq_ptr.is_null() {
92                for (i, id) in sequence_ids.iter().enumerate() {
93                    seq_ptr.add(i).write(*id);
94                }
95            }
96        }
97
98        self.inner.n_tokens += 1;
99        self.inner.n_tokens as usize - 1
100    }
101
102    pub fn set_logits(&self, idx: usize, value: bool) {
103        assert!(idx < self.inner.n_tokens as usize, "Index out of bounds");
104
105        unsafe {
106            if value {
107                self.inner.logits.add(idx).write(1);
108            } else {
109                self.inner.logits.add(idx).write(0);
110            }
111        }
112    }
113
114    pub fn tokens(&self) -> usize {
115        self.inner.n_tokens as usize
116    }
117
118    pub fn handle(&self) -> llama_batch {
119        self.inner
120    }
121}
122
123impl Drop for Batch {
124    fn drop(&mut self) {
125        trace!("Freeing batch");
126
127        unsafe { llama_batch_free(self.inner) }
128    }
129}