1use crate::token::LlamaToken;
4use llama_cpp_bindings_sys::{
5 llama_batch, llama_batch_free, llama_batch_init, llama_pos, llama_seq_id,
6};
7use std::marker::PhantomData;
8
9fn checked_n_tokens_plus_one_as_usize(n_tokens: i32) -> Result<usize, BatchAddError> {
10 let incremented = n_tokens.checked_add(1).ok_or_else(|| {
11 BatchAddError::IntegerOverflow(format!("n_tokens + 1 overflows i32: {n_tokens}"))
12 })?;
13
14 usize::try_from(incremented).map_err(|convert_error| {
15 BatchAddError::IntegerOverflow(format!("cannot fit n_tokens into a usize: {convert_error}"))
16 })
17}
18
19fn checked_i32_as_usize(value: i32, description: &str) -> Result<usize, BatchAddError> {
20 usize::try_from(value).map_err(|convert_error| {
21 BatchAddError::IntegerOverflow(format!(
22 "cannot fit {description} into a usize: {convert_error}"
23 ))
24 })
25}
26
27fn checked_usize_as_llama_seq_id(
28 value: usize,
29 description: &str,
30) -> Result<llama_seq_id, BatchAddError> {
31 llama_seq_id::try_from(value).map_err(|convert_error| {
32 BatchAddError::IntegerOverflow(format!(
33 "cannot fit {description} into a llama_seq_id: {convert_error}"
34 ))
35 })
36}
37
38fn checked_usize_as_i32(value: usize, description: &str) -> Result<i32, BatchAddError> {
39 i32::try_from(value).map_err(|convert_error| {
40 BatchAddError::IntegerOverflow(format!(
41 "cannot fit {description} into a i32: {convert_error}"
42 ))
43 })
44}
45
46fn checked_usize_as_llama_pos(value: usize, description: &str) -> Result<llama_pos, BatchAddError> {
47 llama_pos::try_from(value).map_err(|convert_error| {
48 BatchAddError::IntegerOverflow(format!(
49 "cannot fit {description} into a llama_pos: {convert_error}"
50 ))
51 })
52}
53
54#[derive(Debug)]
59pub struct LlamaBatch<'tokens> {
60 allocated: usize,
62 pub initialized_logits: Vec<i32>,
64 pub llama_batch: llama_batch,
66 phantom: PhantomData<&'tokens [LlamaToken]>,
67}
68
69#[derive(thiserror::Error, Debug, PartialEq, Eq)]
71pub enum BatchAddError {
72 #[error("Insufficient Space of {0}")]
74 InsufficientSpace(usize),
75 #[error("Empty buffer")]
77 EmptyBuffer,
78 #[error("Integer overflow: {0}")]
80 IntegerOverflow(String),
81}
82
83impl<'tokens> LlamaBatch<'tokens> {
84 pub fn clear(&mut self) {
87 self.llama_batch.n_tokens = 0;
88 self.initialized_logits.clear();
89 }
90
91 pub fn add(
98 &mut self,
99 LlamaToken(id): LlamaToken,
100 pos: llama_pos,
101 seq_ids: &[i32],
102 logits: bool,
103 ) -> Result<(), BatchAddError> {
104 let required = checked_n_tokens_plus_one_as_usize(self.n_tokens())?;
105
106 if self.allocated < required {
107 return Err(BatchAddError::InsufficientSpace(self.allocated));
108 }
109
110 let offset = self.llama_batch.n_tokens;
111 let offset_usize = checked_i32_as_usize(offset, "n_tokens")?;
112 let n_seq_id = checked_usize_as_llama_seq_id(seq_ids.len(), "seq_ids.len()")?;
113
114 unsafe {
115 self.llama_batch.token.add(offset_usize).write(id);
117 self.llama_batch.pos.add(offset_usize).write(pos);
119 self.llama_batch.n_seq_id.add(offset_usize).write(n_seq_id);
121 for (seq_index, seq_id) in seq_ids.iter().enumerate() {
125 let tmp = *self.llama_batch.seq_id.add(offset_usize);
126 tmp.add(seq_index).write(*seq_id);
127 }
128 self.llama_batch
130 .logits
131 .add(offset_usize)
132 .write(i8::from(logits));
133 }
134
135 if logits {
136 self.initialized_logits.push(offset);
137 }
138
139 self.llama_batch.n_tokens += 1;
140
141 Ok(())
142 }
143
144 pub fn add_sequence(
153 &mut self,
154 tokens: &[LlamaToken],
155 seq_id: i32,
156 logits_all: bool,
157 ) -> Result<(), BatchAddError> {
158 let last_index = checked_usize_as_llama_pos(tokens.len().saturating_sub(1), "n_tokens")?;
159
160 for (position, token) in (0..).zip(tokens.iter()) {
161 self.add(
162 *token,
163 position,
164 &[seq_id],
165 logits_all || position == last_index,
166 )?;
167 }
168
169 Ok(())
170 }
171
172 pub fn new(n_tokens: usize, n_seq_max: i32) -> Result<Self, BatchAddError> {
183 let n_tokens_i32 = checked_usize_as_i32(n_tokens, "n_tokens")?;
184 let batch = unsafe { llama_batch_init(n_tokens_i32, 0, n_seq_max) };
185
186 Ok(LlamaBatch {
187 allocated: n_tokens,
188 initialized_logits: vec![],
189 llama_batch: batch,
190 phantom: PhantomData,
191 })
192 }
193
194 pub fn get_one(tokens: &'tokens [LlamaToken]) -> Result<Self, BatchAddError> {
203 if tokens.is_empty() {
204 return Err(BatchAddError::EmptyBuffer);
205 }
206
207 let token_count = checked_usize_as_i32(tokens.len(), "token count")?;
208
209 let batch = unsafe {
210 #[allow(clippy::as_ptr_cast_mut)]
212 let ptr = tokens.as_ptr() as *mut i32;
213 llama_cpp_bindings_sys::llama_batch_get_one(ptr, token_count)
214 };
215
216 let last_token_index = checked_usize_as_i32(tokens.len() - 1, "last token index")?;
217
218 Ok(Self {
219 allocated: 0,
220 initialized_logits: vec![last_token_index],
221 llama_batch: batch,
222 phantom: PhantomData,
223 })
224 }
225
226 #[must_use]
228 pub const fn n_tokens(&self) -> i32 {
229 self.llama_batch.n_tokens
230 }
231}
232
233impl Drop for LlamaBatch<'_> {
234 fn drop(&mut self) {
246 unsafe {
247 if self.allocated > 0 {
248 llama_batch_free(self.llama_batch);
249 }
250 }
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use crate::token::LlamaToken;
257
258 use super::{
259 BatchAddError, LlamaBatch, checked_i32_as_usize, checked_n_tokens_plus_one_as_usize,
260 checked_usize_as_i32, checked_usize_as_llama_pos, checked_usize_as_llama_seq_id,
261 };
262
263 #[test]
264 fn new_creates_empty_batch() {
265 let batch = LlamaBatch::new(16, 1).unwrap();
266
267 assert_eq!(batch.n_tokens(), 0);
268 assert!(batch.initialized_logits.is_empty());
269 }
270
271 #[test]
272 fn clear_resets_batch() {
273 let mut batch = LlamaBatch::new(16, 1).unwrap();
274 batch.add(LlamaToken::new(1), 0, &[0], true).unwrap();
275 assert_eq!(batch.n_tokens(), 1);
276
277 batch.clear();
278
279 assert_eq!(batch.n_tokens(), 0);
280 assert!(batch.initialized_logits.is_empty());
281 }
282
283 #[test]
284 fn add_increments_token_count() {
285 let mut batch = LlamaBatch::new(16, 1).unwrap();
286
287 batch.add(LlamaToken::new(1), 0, &[0], false).unwrap();
288 assert_eq!(batch.n_tokens(), 1);
289
290 batch.add(LlamaToken::new(2), 1, &[0], false).unwrap();
291 assert_eq!(batch.n_tokens(), 2);
292 }
293
294 #[test]
295 fn add_tracks_logits() {
296 let mut batch = LlamaBatch::new(16, 1).unwrap();
297
298 batch.add(LlamaToken::new(1), 0, &[0], false).unwrap();
299 assert!(batch.initialized_logits.is_empty());
300
301 batch.add(LlamaToken::new(2), 1, &[0], true).unwrap();
302 assert_eq!(batch.initialized_logits, vec![1]);
303 }
304
305 #[test]
306 fn add_returns_insufficient_space_when_full() {
307 let mut batch = LlamaBatch::new(1, 1).unwrap();
308 batch.add(LlamaToken::new(1), 0, &[0], false).unwrap();
309
310 let result = batch.add(LlamaToken::new(2), 1, &[0], false);
311
312 assert_eq!(result, Err(BatchAddError::InsufficientSpace(1)));
313 }
314
315 #[test]
316 fn add_sequence_adds_all_tokens() {
317 let mut batch = LlamaBatch::new(16, 1).unwrap();
318 let tokens = vec![
319 LlamaToken::new(10),
320 LlamaToken::new(20),
321 LlamaToken::new(30),
322 ];
323
324 batch.add_sequence(&tokens, 0, false).unwrap();
325
326 assert_eq!(batch.n_tokens(), 3);
327 }
328
329 #[test]
330 fn add_sequence_sets_logits_on_last_token() {
331 let mut batch = LlamaBatch::new(16, 1).unwrap();
332 let tokens = vec![
333 LlamaToken::new(10),
334 LlamaToken::new(20),
335 LlamaToken::new(30),
336 ];
337
338 batch.add_sequence(&tokens, 0, false).unwrap();
339
340 assert_eq!(batch.initialized_logits, vec![2]);
341 }
342
343 #[test]
344 fn add_sequence_insufficient_space() {
345 let mut batch = LlamaBatch::new(2, 1).unwrap();
346 let tokens = vec![
347 LlamaToken::new(10),
348 LlamaToken::new(20),
349 LlamaToken::new(30),
350 ];
351
352 let result = batch.add_sequence(&tokens, 0, false);
353
354 assert!(result.is_err());
355 }
356
357 #[test]
358 fn add_sequence_fails_mid_loop_when_batch_fills() {
359 let mut batch = LlamaBatch::new(2, 1).unwrap();
360 batch.add(LlamaToken::new(1), 0, &[0], false).unwrap();
361
362 let tokens = vec![LlamaToken::new(10), LlamaToken::new(20)];
363 let result = batch.add_sequence(&tokens, 0, false);
364
365 assert!(result.is_err());
366 }
367
368 #[test]
369 fn get_one_with_valid_tokens() {
370 let tokens = vec![LlamaToken::new(1), LlamaToken::new(2)];
371 let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
372
373 assert_eq!(batch.n_tokens(), 2);
374 assert_eq!(batch.initialized_logits, vec![1]);
375 }
376
377 #[test]
378 fn get_one_empty_slice_returns_error() {
379 let tokens: Vec<LlamaToken> = vec![];
380 let result = LlamaBatch::get_one(&tokens);
381
382 assert_eq!(result.unwrap_err(), BatchAddError::EmptyBuffer);
383 }
384
385 #[test]
386 fn get_one_single_token() {
387 let tokens = vec![LlamaToken::new(42)];
388 let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
389
390 assert_eq!(batch.n_tokens(), 1);
391 assert_eq!(batch.initialized_logits, vec![0]);
392 }
393
394 #[test]
395 fn add_with_logits_false_retains_only_previous_logits() {
396 let mut batch = LlamaBatch::new(16, 1).unwrap();
397
398 batch.add(LlamaToken::new(1), 0, &[0], true).unwrap();
399 assert_eq!(batch.initialized_logits, vec![0]);
400
401 batch.add(LlamaToken::new(2), 0, &[0], false).unwrap();
402 assert_eq!(batch.initialized_logits, vec![0]);
403 }
404
405 #[test]
406 fn add_sequence_with_logits_all_marks_every_token() -> Result<(), BatchAddError> {
407 let mut batch = LlamaBatch::new(16, 1)?;
408 let tokens = vec![
409 LlamaToken::new(10),
410 LlamaToken::new(20),
411 LlamaToken::new(30),
412 ];
413
414 batch.add_sequence(&tokens, 0, true)?;
415
416 assert_eq!(batch.n_tokens(), 3);
417 assert_eq!(batch.initialized_logits, vec![0, 1, 2]);
418
419 Ok(())
420 }
421
422 #[test]
423 fn add_with_multiple_seq_ids() -> Result<(), BatchAddError> {
424 let mut batch = LlamaBatch::new(16, 4)?;
425
426 batch.add(LlamaToken::new(1), 0, &[0, 1, 2], true)?;
427
428 assert_eq!(batch.n_tokens(), 1);
429 assert_eq!(batch.initialized_logits, vec![0]);
430
431 Ok(())
432 }
433
434 #[test]
435 fn drop_does_not_free_get_one_batch() {
436 let tokens = vec![LlamaToken::new(1), LlamaToken::new(2)];
437 let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
438
439 assert_eq!(batch.allocated, 0);
440 drop(batch);
441 }
442
443 #[test]
444 fn checked_n_tokens_plus_one_as_usize_succeeds_for_zero() {
445 let result = checked_n_tokens_plus_one_as_usize(0);
446
447 assert_eq!(result, Ok(1));
448 }
449
450 #[test]
451 fn checked_n_tokens_plus_one_as_usize_fails_for_negative() {
452 let result = checked_n_tokens_plus_one_as_usize(-2);
453
454 assert!(result.unwrap_err().to_string().contains("overflow"));
455 }
456
457 #[test]
458 fn checked_n_tokens_plus_one_as_usize_fails_for_i32_max() {
459 let result = checked_n_tokens_plus_one_as_usize(i32::MAX);
460
461 assert!(result.unwrap_err().to_string().contains("overflow"));
462 }
463
464 #[test]
465 fn checked_i32_as_usize_succeeds_for_zero() {
466 let result = checked_i32_as_usize(0, "test_value");
467
468 assert_eq!(result, Ok(0));
469 }
470
471 #[test]
472 fn checked_i32_as_usize_fails_for_negative() {
473 let result = checked_i32_as_usize(i32::MIN, "test_value");
474
475 assert!(result.unwrap_err().to_string().contains("overflow"));
476 }
477
478 #[test]
479 fn checked_usize_as_llama_seq_id_succeeds_for_zero() {
480 let result = checked_usize_as_llama_seq_id(0, "test_value");
481
482 assert_eq!(result, Ok(0));
483 }
484
485 #[test]
486 fn checked_usize_as_llama_seq_id_fails_for_overflow() {
487 let result = checked_usize_as_llama_seq_id(usize::MAX, "test_value");
488
489 assert!(result.unwrap_err().to_string().contains("overflow"));
490 }
491
492 #[test]
493 fn checked_usize_as_i32_succeeds_for_zero() {
494 let result = checked_usize_as_i32(0, "test_value");
495
496 assert_eq!(result, Ok(0));
497 }
498
499 #[test]
500 fn checked_usize_as_i32_fails_for_overflow() {
501 let result = checked_usize_as_i32(usize::MAX, "test_value");
502
503 assert!(result.unwrap_err().to_string().contains("overflow"));
504 }
505
506 #[test]
507 fn checked_usize_as_llama_pos_succeeds_for_zero() {
508 let result = checked_usize_as_llama_pos(0, "test_value");
509
510 assert_eq!(result, Ok(0));
511 }
512
513 #[test]
514 fn checked_usize_as_llama_pos_fails_for_overflow() {
515 let result = checked_usize_as_llama_pos(usize::MAX, "test_value");
516
517 assert!(result.unwrap_err().to_string().contains("overflow"));
518 }
519
520 #[test]
521 fn new_fails_for_oversized_n_tokens() {
522 let result = LlamaBatch::new(usize::MAX, 1);
523
524 assert!(result.unwrap_err().to_string().contains("overflow"));
525 }
526}