1use crate::token::LlamaToken;
4use llama_cpp_bindings_sys::{
5 llama_batch, llama_batch_free, llama_batch_init, llama_pos, llama_seq_id,
6};
7use std::marker::PhantomData;
8
9#[derive(Debug)]
14pub struct LlamaBatch<'tokens> {
15 allocated: usize,
17 pub initialized_logits: Vec<i32>,
19 pub llama_batch: llama_batch,
21 phantom: PhantomData<&'tokens [LlamaToken]>,
22}
23
24#[derive(thiserror::Error, Debug, PartialEq, Eq)]
26pub enum BatchAddError {
27 #[error("Insufficient Space of {0}")]
29 InsufficientSpace(usize),
30 #[error("Empty buffer")]
32 EmptyBuffer,
33 #[error("Integer overflow: {0}")]
35 IntegerOverflow(String),
36}
37
38impl<'tokens> LlamaBatch<'tokens> {
39 pub fn clear(&mut self) {
42 self.llama_batch.n_tokens = 0;
43 self.initialized_logits.clear();
44 }
45
46 pub fn add(
53 &mut self,
54 LlamaToken(id): LlamaToken,
55 pos: llama_pos,
56 seq_ids: &[i32],
57 logits: bool,
58 ) -> Result<(), BatchAddError> {
59 let required = usize::try_from(self.n_tokens() + 1).map_err(|convert_error| {
60 BatchAddError::IntegerOverflow(format!(
61 "cannot fit n_tokens into a usize: {convert_error}"
62 ))
63 })?;
64
65 if self.allocated < required {
66 return Err(BatchAddError::InsufficientSpace(self.allocated));
67 }
68
69 let offset = self.llama_batch.n_tokens;
70 let offset_usize = usize::try_from(offset).map_err(|convert_error| {
71 BatchAddError::IntegerOverflow(format!(
72 "cannot fit n_tokens into a usize: {convert_error}"
73 ))
74 })?;
75
76 let n_seq_id = llama_seq_id::try_from(seq_ids.len()).map_err(|convert_error| {
77 BatchAddError::IntegerOverflow(format!(
78 "cannot fit seq_ids.len() into a llama_seq_id: {convert_error}"
79 ))
80 })?;
81
82 unsafe {
83 self.llama_batch.token.add(offset_usize).write(id);
85 self.llama_batch.pos.add(offset_usize).write(pos);
87 self.llama_batch.n_seq_id.add(offset_usize).write(n_seq_id);
89 for (i, seq_id) in seq_ids.iter().enumerate() {
93 let tmp = *self.llama_batch.seq_id.add(offset_usize);
94 tmp.add(i).write(*seq_id);
95 }
96 self.llama_batch
98 .logits
99 .add(offset_usize)
100 .write(i8::from(logits));
101 }
102
103 if logits {
104 self.initialized_logits.push(offset);
105 } else {
106 self.initialized_logits.retain(|l| l != &offset);
107 }
108
109 self.llama_batch.n_tokens += 1;
111
112 Ok(())
113 }
114
115 pub fn add_sequence(
124 &mut self,
125 tokens: &[LlamaToken],
126 seq_id: i32,
127 logits_all: bool,
128 ) -> Result<(), BatchAddError> {
129 let n_tokens_0 = usize::try_from(self.llama_batch.n_tokens).map_err(|convert_error| {
130 BatchAddError::IntegerOverflow(format!(
131 "cannot fit n_tokens into a usize: {convert_error}"
132 ))
133 })?;
134 let n_tokens = tokens.len();
135
136 if self.allocated < n_tokens_0 + n_tokens {
137 return Err(BatchAddError::InsufficientSpace(self.allocated));
138 }
139
140 let last_index =
141 llama_pos::try_from(n_tokens.saturating_sub(1)).map_err(|convert_error| {
142 BatchAddError::IntegerOverflow(format!(
143 "cannot fit n_tokens into a llama_pos: {convert_error}"
144 ))
145 })?;
146
147 for (i, token) in (0..).zip(tokens.iter()) {
148 self.add(*token, i, &[seq_id], logits_all || i == last_index)?;
149 }
150
151 Ok(())
152 }
153
154 pub fn new(n_tokens: usize, n_seq_max: i32) -> Result<Self, BatchAddError> {
165 let n_tokens_i32 = i32::try_from(n_tokens).map_err(|convert_error| {
166 BatchAddError::IntegerOverflow(format!(
167 "cannot fit n_tokens into a i32: {convert_error}"
168 ))
169 })?;
170 let batch = unsafe { llama_batch_init(n_tokens_i32, 0, n_seq_max) };
171
172 Ok(LlamaBatch {
173 allocated: n_tokens,
174 initialized_logits: vec![],
175 llama_batch: batch,
176 phantom: PhantomData,
177 })
178 }
179
180 pub fn get_one(tokens: &'tokens [LlamaToken]) -> Result<Self, BatchAddError> {
189 if tokens.is_empty() {
190 return Err(BatchAddError::EmptyBuffer);
191 }
192
193 let token_count = tokens.len().try_into().map_err(|convert_error| {
194 BatchAddError::IntegerOverflow(format!(
195 "number of tokens exceeds i32::MAX: {convert_error}"
196 ))
197 })?;
198
199 let batch = unsafe {
200 let ptr = tokens.as_ptr() as *mut i32;
201 llama_cpp_bindings_sys::llama_batch_get_one(ptr, token_count)
202 };
203
204 let last_token_index = (tokens.len() - 1).try_into().map_err(|convert_error| {
205 BatchAddError::IntegerOverflow(format!(
206 "number of tokens exceeds i32::MAX: {convert_error}"
207 ))
208 })?;
209
210 Ok(Self {
211 allocated: 0,
212 initialized_logits: vec![last_token_index],
213 llama_batch: batch,
214 phantom: PhantomData,
215 })
216 }
217
218 #[must_use]
220 pub fn n_tokens(&self) -> i32 {
221 self.llama_batch.n_tokens
222 }
223}
224
225impl Drop for LlamaBatch<'_> {
226 fn drop(&mut self) {
238 unsafe {
239 if self.allocated > 0 {
240 llama_batch_free(self.llama_batch);
241 }
242 }
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use crate::token::LlamaToken;
249
250 use super::{BatchAddError, LlamaBatch};
251
252 #[test]
253 fn new_creates_empty_batch() -> Result<(), BatchAddError> {
254 let batch = LlamaBatch::new(16, 1)?;
255
256 assert_eq!(batch.n_tokens(), 0);
257 assert!(batch.initialized_logits.is_empty());
258
259 Ok(())
260 }
261
262 #[test]
263 fn clear_resets_batch() -> Result<(), BatchAddError> {
264 let mut batch = LlamaBatch::new(16, 1)?;
265 batch.add(LlamaToken::new(1), 0, &[0], true)?;
266 assert_eq!(batch.n_tokens(), 1);
267
268 batch.clear();
269
270 assert_eq!(batch.n_tokens(), 0);
271 assert!(batch.initialized_logits.is_empty());
272
273 Ok(())
274 }
275
276 #[test]
277 fn add_increments_token_count() -> Result<(), BatchAddError> {
278 let mut batch = LlamaBatch::new(16, 1)?;
279
280 batch.add(LlamaToken::new(1), 0, &[0], false)?;
281 assert_eq!(batch.n_tokens(), 1);
282
283 batch.add(LlamaToken::new(2), 1, &[0], false)?;
284 assert_eq!(batch.n_tokens(), 2);
285
286 Ok(())
287 }
288
289 #[test]
290 fn add_tracks_logits() -> Result<(), BatchAddError> {
291 let mut batch = LlamaBatch::new(16, 1)?;
292
293 batch.add(LlamaToken::new(1), 0, &[0], false)?;
294 assert!(batch.initialized_logits.is_empty());
295
296 batch.add(LlamaToken::new(2), 1, &[0], true)?;
297 assert_eq!(batch.initialized_logits, vec![1]);
298
299 Ok(())
300 }
301
302 #[test]
303 fn add_returns_insufficient_space_when_full() -> Result<(), BatchAddError> {
304 let mut batch = LlamaBatch::new(1, 1)?;
305 batch.add(LlamaToken::new(1), 0, &[0], false)?;
306
307 let result = batch.add(LlamaToken::new(2), 1, &[0], false);
308
309 assert_eq!(result, Err(BatchAddError::InsufficientSpace(1)));
310
311 Ok(())
312 }
313
314 #[test]
315 fn add_sequence_adds_all_tokens() -> Result<(), BatchAddError> {
316 let mut batch = LlamaBatch::new(16, 1)?;
317 let tokens = vec![
318 LlamaToken::new(10),
319 LlamaToken::new(20),
320 LlamaToken::new(30),
321 ];
322
323 batch.add_sequence(&tokens, 0, false)?;
324
325 assert_eq!(batch.n_tokens(), 3);
326
327 Ok(())
328 }
329
330 #[test]
331 fn add_sequence_sets_logits_on_last_token() -> Result<(), BatchAddError> {
332 let mut batch = LlamaBatch::new(16, 1)?;
333 let tokens = vec![
334 LlamaToken::new(10),
335 LlamaToken::new(20),
336 LlamaToken::new(30),
337 ];
338
339 batch.add_sequence(&tokens, 0, false)?;
340
341 assert_eq!(batch.initialized_logits, vec![2]);
342
343 Ok(())
344 }
345
346 #[test]
347 fn add_sequence_insufficient_space() -> Result<(), BatchAddError> {
348 let mut batch = LlamaBatch::new(2, 1)?;
349 let tokens = vec![
350 LlamaToken::new(10),
351 LlamaToken::new(20),
352 LlamaToken::new(30),
353 ];
354
355 let result = batch.add_sequence(&tokens, 0, false);
356
357 assert!(result.is_err());
358
359 Ok(())
360 }
361
362 #[test]
363 fn get_one_with_valid_tokens() {
364 let tokens = vec![LlamaToken::new(1), LlamaToken::new(2)];
365 let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
366
367 assert_eq!(batch.n_tokens(), 2);
368 assert_eq!(batch.initialized_logits, vec![1]);
369 }
370
371 #[test]
372 fn get_one_empty_slice_returns_error() {
373 let tokens: Vec<LlamaToken> = vec![];
374 let result = LlamaBatch::get_one(&tokens);
375
376 assert!(
377 matches!(result, Err(BatchAddError::EmptyBuffer)),
378 "expected EmptyBuffer error"
379 );
380 }
381
382 #[test]
383 fn get_one_single_token() {
384 let tokens = vec![LlamaToken::new(42)];
385 let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
386
387 assert_eq!(batch.n_tokens(), 1);
388 assert_eq!(batch.initialized_logits, vec![0]);
389 }
390}