1use crate::batch_add_error::BatchAddError;
4use crate::sampled_token::SampledToken;
5use crate::token::LlamaToken;
6use llama_cpp_bindings_sys::{
7 llama_batch, llama_batch_free, llama_batch_init, llama_pos, llama_seq_id,
8};
9use std::marker::PhantomData;
10
11fn checked_n_tokens_plus_one_as_usize(n_tokens: i32) -> Result<usize, BatchAddError> {
12 let incremented = n_tokens.checked_add(1).ok_or_else(|| {
13 BatchAddError::IntegerOverflow(format!("n_tokens + 1 overflows i32: {n_tokens}"))
14 })?;
15
16 usize::try_from(incremented).map_err(|convert_error| {
17 BatchAddError::IntegerOverflow(format!("cannot fit n_tokens into a usize: {convert_error}"))
18 })
19}
20
21fn checked_i32_as_usize(value: i32, description: &str) -> Result<usize, BatchAddError> {
22 usize::try_from(value).map_err(|convert_error| {
23 BatchAddError::IntegerOverflow(format!(
24 "cannot fit {description} into a usize: {convert_error}"
25 ))
26 })
27}
28
29fn checked_usize_as_llama_seq_id(
30 value: usize,
31 description: &str,
32) -> Result<llama_seq_id, BatchAddError> {
33 llama_seq_id::try_from(value).map_err(|convert_error| {
34 BatchAddError::IntegerOverflow(format!(
35 "cannot fit {description} into a llama_seq_id: {convert_error}"
36 ))
37 })
38}
39
40fn checked_usize_as_i32(value: usize, description: &str) -> Result<i32, BatchAddError> {
41 i32::try_from(value).map_err(|convert_error| {
42 BatchAddError::IntegerOverflow(format!(
43 "cannot fit {description} into a i32: {convert_error}"
44 ))
45 })
46}
47
48fn checked_usize_as_llama_pos(value: usize, description: &str) -> Result<llama_pos, BatchAddError> {
49 llama_pos::try_from(value).map_err(|convert_error| {
50 BatchAddError::IntegerOverflow(format!(
51 "cannot fit {description} into a llama_pos: {convert_error}"
52 ))
53 })
54}
55
56#[derive(Debug)]
61pub struct LlamaBatch<'tokens> {
62 allocated: usize,
64 pub initialized_logits: Vec<i32>,
66 pub llama_batch: llama_batch,
68 phantom: PhantomData<&'tokens [LlamaToken]>,
69}
70
71impl<'tokens> LlamaBatch<'tokens> {
72 pub fn clear(&mut self) {
75 self.llama_batch.n_tokens = 0;
76 self.initialized_logits.clear();
77 }
78
79 pub fn add(
86 &mut self,
87 sampled_token: &SampledToken,
88 pos: llama_pos,
89 seq_ids: &[i32],
90 logits: bool,
91 ) -> Result<(), BatchAddError> {
92 let (SampledToken::Content(LlamaToken(id))
93 | SampledToken::Reasoning(LlamaToken(id))
94 | SampledToken::ToolCall(LlamaToken(id))
95 | SampledToken::Undeterminable(LlamaToken(id))) = *sampled_token;
96 let required = checked_n_tokens_plus_one_as_usize(self.n_tokens())?;
97
98 if self.allocated < required {
99 return Err(BatchAddError::InsufficientSpace(self.allocated));
100 }
101
102 let offset = self.llama_batch.n_tokens;
103 let offset_usize = checked_i32_as_usize(offset, "n_tokens")?;
104 let n_seq_id = checked_usize_as_llama_seq_id(seq_ids.len(), "seq_ids.len()")?;
105
106 unsafe {
107 self.llama_batch.token.add(offset_usize).write(id);
108 self.llama_batch.pos.add(offset_usize).write(pos);
109 self.llama_batch.n_seq_id.add(offset_usize).write(n_seq_id);
110 for (seq_index, seq_id) in seq_ids.iter().enumerate() {
111 let tmp = *self.llama_batch.seq_id.add(offset_usize);
112 tmp.add(seq_index).write(*seq_id);
113 }
114 self.llama_batch
115 .logits
116 .add(offset_usize)
117 .write(i8::from(logits));
118 }
119
120 if logits {
121 self.initialized_logits.push(offset);
122 }
123
124 self.llama_batch.n_tokens += 1;
125
126 Ok(())
127 }
128
129 pub fn add_sequence(
138 &mut self,
139 tokens: &[LlamaToken],
140 seq_id: i32,
141 logits_all: bool,
142 ) -> Result<(), BatchAddError> {
143 let last_index = checked_usize_as_llama_pos(tokens.len().saturating_sub(1), "n_tokens")?;
144
145 for (position, token) in (0..).zip(tokens.iter()) {
146 self.add(
147 &SampledToken::Content(*token),
148 position,
149 &[seq_id],
150 logits_all || position == last_index,
151 )?;
152 }
153
154 Ok(())
155 }
156
157 pub fn new(n_tokens: usize, n_seq_max: i32) -> Result<Self, BatchAddError> {
168 let n_tokens_i32 = checked_usize_as_i32(n_tokens, "n_tokens")?;
169 let batch = unsafe { llama_batch_init(n_tokens_i32, 0, n_seq_max) };
170
171 Ok(LlamaBatch {
172 allocated: n_tokens,
173 initialized_logits: vec![],
174 llama_batch: batch,
175 phantom: PhantomData,
176 })
177 }
178
179 pub fn get_one(tokens: &'tokens [LlamaToken]) -> Result<Self, BatchAddError> {
188 if tokens.is_empty() {
189 return Err(BatchAddError::EmptyBuffer);
190 }
191
192 let token_count = checked_usize_as_i32(tokens.len(), "token count")?;
193
194 let batch = unsafe {
195 #[expect(
196 clippy::as_ptr_cast_mut,
197 reason = "llama_batch_get_one signature requires *mut i32 but does not mutate the tokens"
198 )]
199 let ptr = tokens.as_ptr() as *mut i32;
200 llama_cpp_bindings_sys::llama_batch_get_one(ptr, token_count)
201 };
202
203 let last_token_index = checked_usize_as_i32(tokens.len() - 1, "last token index")?;
204
205 Ok(Self {
206 allocated: 0,
207 initialized_logits: vec![last_token_index],
208 llama_batch: batch,
209 phantom: PhantomData,
210 })
211 }
212
213 #[must_use]
215 pub const fn n_tokens(&self) -> i32 {
216 self.llama_batch.n_tokens
217 }
218}
219
220impl Drop for LlamaBatch<'_> {
221 fn drop(&mut self) {
233 unsafe {
234 if self.allocated > 0 {
235 llama_batch_free(self.llama_batch);
236 }
237 }
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use crate::sampled_token::SampledToken;
244 use crate::token::LlamaToken;
245
246 use super::{
247 BatchAddError, LlamaBatch, checked_i32_as_usize, checked_n_tokens_plus_one_as_usize,
248 checked_usize_as_i32, checked_usize_as_llama_pos, checked_usize_as_llama_seq_id,
249 };
250
251 #[test]
252 fn new_creates_empty_batch() {
253 let batch = LlamaBatch::new(16, 1).unwrap();
254
255 assert_eq!(batch.n_tokens(), 0);
256 assert!(batch.initialized_logits.is_empty());
257 }
258
259 #[test]
260 fn clear_resets_batch() {
261 let mut batch = LlamaBatch::new(16, 1).unwrap();
262 batch
263 .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], true)
264 .unwrap();
265 assert_eq!(batch.n_tokens(), 1);
266
267 batch.clear();
268
269 assert_eq!(batch.n_tokens(), 0);
270 assert!(batch.initialized_logits.is_empty());
271 }
272
273 #[test]
274 fn add_increments_token_count() {
275 let mut batch = LlamaBatch::new(16, 1).unwrap();
276
277 batch
278 .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], false)
279 .unwrap();
280 assert_eq!(batch.n_tokens(), 1);
281
282 batch
283 .add(&SampledToken::Content(LlamaToken::new(2)), 1, &[0], false)
284 .unwrap();
285 assert_eq!(batch.n_tokens(), 2);
286 }
287
288 #[test]
289 fn add_tracks_logits() {
290 let mut batch = LlamaBatch::new(16, 1).unwrap();
291
292 batch
293 .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], false)
294 .unwrap();
295 assert!(batch.initialized_logits.is_empty());
296
297 batch
298 .add(&SampledToken::Content(LlamaToken::new(2)), 1, &[0], true)
299 .unwrap();
300 assert_eq!(batch.initialized_logits, vec![1]);
301 }
302
303 #[test]
304 fn add_returns_insufficient_space_when_full() {
305 let mut batch = LlamaBatch::new(1, 1).unwrap();
306 batch
307 .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], false)
308 .unwrap();
309
310 let result = batch.add(&SampledToken::Content(LlamaToken::new(2)), 1, &[0], false);
311
312 assert_eq!(result, Err(BatchAddError::InsufficientSpace(1)));
313 }
314
315 #[test]
316 fn add_accepts_reasoning_sampled_token_variant() {
317 let mut batch = LlamaBatch::new(4, 1).unwrap();
318
319 batch
320 .add(&SampledToken::Reasoning(LlamaToken::new(11)), 0, &[0], true)
321 .unwrap();
322
323 assert_eq!(batch.n_tokens(), 1);
324 }
325
326 #[test]
327 fn add_accepts_tool_call_sampled_token_variant() {
328 let mut batch = LlamaBatch::new(4, 1).unwrap();
329
330 batch
331 .add(&SampledToken::ToolCall(LlamaToken::new(22)), 0, &[0], true)
332 .unwrap();
333
334 assert_eq!(batch.n_tokens(), 1);
335 }
336
337 #[test]
338 fn add_accepts_undeterminable_sampled_token_variant() {
339 let mut batch = LlamaBatch::new(4, 1).unwrap();
340
341 batch
342 .add(
343 &SampledToken::Undeterminable(LlamaToken::new(33)),
344 0,
345 &[0],
346 false,
347 )
348 .unwrap();
349
350 assert_eq!(batch.n_tokens(), 1);
351 }
352
353 #[test]
354 fn add_sequence_adds_all_tokens() {
355 let mut batch = LlamaBatch::new(16, 1).unwrap();
356 let tokens = vec![
357 LlamaToken::new(10),
358 LlamaToken::new(20),
359 LlamaToken::new(30),
360 ];
361
362 batch.add_sequence(&tokens, 0, false).unwrap();
363
364 assert_eq!(batch.n_tokens(), 3);
365 }
366
367 #[test]
368 fn add_sequence_sets_logits_on_last_token() {
369 let mut batch = LlamaBatch::new(16, 1).unwrap();
370 let tokens = vec![
371 LlamaToken::new(10),
372 LlamaToken::new(20),
373 LlamaToken::new(30),
374 ];
375
376 batch.add_sequence(&tokens, 0, false).unwrap();
377
378 assert_eq!(batch.initialized_logits, vec![2]);
379 }
380
381 #[test]
382 fn add_sequence_insufficient_space() {
383 let mut batch = LlamaBatch::new(2, 1).unwrap();
384 let tokens = vec![
385 LlamaToken::new(10),
386 LlamaToken::new(20),
387 LlamaToken::new(30),
388 ];
389
390 let result = batch.add_sequence(&tokens, 0, false);
391
392 assert!(result.is_err());
393 }
394
395 #[test]
396 fn add_sequence_fails_mid_loop_when_batch_fills() {
397 let mut batch = LlamaBatch::new(2, 1).unwrap();
398 batch
399 .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], false)
400 .unwrap();
401
402 let tokens = vec![LlamaToken::new(10), LlamaToken::new(20)];
403 let result = batch.add_sequence(&tokens, 0, false);
404
405 assert!(result.is_err());
406 }
407
408 #[test]
409 fn get_one_with_valid_tokens() {
410 let tokens = vec![LlamaToken::new(1), LlamaToken::new(2)];
411 let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
412
413 assert_eq!(batch.n_tokens(), 2);
414 assert_eq!(batch.initialized_logits, vec![1]);
415 }
416
417 #[test]
418 fn get_one_empty_slice_returns_error() {
419 let tokens: Vec<LlamaToken> = vec![];
420 let result = LlamaBatch::get_one(&tokens);
421
422 assert_eq!(result.unwrap_err(), BatchAddError::EmptyBuffer);
423 }
424
425 #[test]
426 fn get_one_single_token() {
427 let tokens = vec![LlamaToken::new(42)];
428 let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
429
430 assert_eq!(batch.n_tokens(), 1);
431 assert_eq!(batch.initialized_logits, vec![0]);
432 }
433
434 #[test]
435 fn add_with_logits_false_retains_only_previous_logits() {
436 let mut batch = LlamaBatch::new(16, 1).unwrap();
437
438 batch
439 .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], true)
440 .unwrap();
441 assert_eq!(batch.initialized_logits, vec![0]);
442
443 batch
444 .add(&SampledToken::Content(LlamaToken::new(2)), 0, &[0], false)
445 .unwrap();
446 assert_eq!(batch.initialized_logits, vec![0]);
447 }
448
449 #[test]
450 fn add_sequence_with_logits_all_marks_every_token() -> Result<(), BatchAddError> {
451 let mut batch = LlamaBatch::new(16, 1)?;
452 let tokens = vec![
453 LlamaToken::new(10),
454 LlamaToken::new(20),
455 LlamaToken::new(30),
456 ];
457
458 batch.add_sequence(&tokens, 0, true)?;
459
460 assert_eq!(batch.n_tokens(), 3);
461 assert_eq!(batch.initialized_logits, vec![0, 1, 2]);
462
463 Ok(())
464 }
465
466 #[test]
467 fn add_with_multiple_seq_ids() -> Result<(), BatchAddError> {
468 let mut batch = LlamaBatch::new(16, 4)?;
469
470 batch.add(
471 &SampledToken::Content(LlamaToken::new(1)),
472 0,
473 &[0, 1, 2],
474 true,
475 )?;
476
477 assert_eq!(batch.n_tokens(), 1);
478 assert_eq!(batch.initialized_logits, vec![0]);
479
480 Ok(())
481 }
482
483 #[test]
484 fn drop_does_not_free_get_one_batch() {
485 let tokens = vec![LlamaToken::new(1), LlamaToken::new(2)];
486 let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
487
488 assert_eq!(batch.allocated, 0);
489 drop(batch);
490 }
491
492 #[test]
493 fn checked_n_tokens_plus_one_as_usize_succeeds_for_zero() {
494 let result = checked_n_tokens_plus_one_as_usize(0);
495
496 assert_eq!(result, Ok(1));
497 }
498
499 #[test]
500 fn checked_n_tokens_plus_one_as_usize_fails_for_negative() {
501 let result = checked_n_tokens_plus_one_as_usize(-2);
502
503 assert!(result.unwrap_err().to_string().contains("overflow"));
504 }
505
506 #[test]
507 fn checked_n_tokens_plus_one_as_usize_fails_for_i32_max() {
508 let result = checked_n_tokens_plus_one_as_usize(i32::MAX);
509
510 assert!(result.unwrap_err().to_string().contains("overflow"));
511 }
512
513 #[test]
514 fn checked_i32_as_usize_succeeds_for_zero() {
515 let result = checked_i32_as_usize(0, "test_value");
516
517 assert_eq!(result, Ok(0));
518 }
519
520 #[test]
521 fn checked_i32_as_usize_fails_for_negative() {
522 let result = checked_i32_as_usize(i32::MIN, "test_value");
523
524 assert!(result.unwrap_err().to_string().contains("overflow"));
525 }
526
527 #[test]
528 fn checked_usize_as_llama_seq_id_succeeds_for_zero() {
529 let result = checked_usize_as_llama_seq_id(0, "test_value");
530
531 assert_eq!(result, Ok(0));
532 }
533
534 #[test]
535 fn checked_usize_as_llama_seq_id_fails_for_overflow() {
536 let result = checked_usize_as_llama_seq_id(usize::MAX, "test_value");
537
538 assert!(result.unwrap_err().to_string().contains("overflow"));
539 }
540
541 #[test]
542 fn checked_usize_as_i32_succeeds_for_zero() {
543 let result = checked_usize_as_i32(0, "test_value");
544
545 assert_eq!(result, Ok(0));
546 }
547
548 #[test]
549 fn checked_usize_as_i32_fails_for_overflow() {
550 let result = checked_usize_as_i32(usize::MAX, "test_value");
551
552 assert!(result.unwrap_err().to_string().contains("overflow"));
553 }
554
555 #[test]
556 fn checked_usize_as_llama_pos_succeeds_for_zero() {
557 let result = checked_usize_as_llama_pos(0, "test_value");
558
559 assert_eq!(result, Ok(0));
560 }
561
562 #[test]
563 fn checked_usize_as_llama_pos_fails_for_overflow() {
564 let result = checked_usize_as_llama_pos(usize::MAX, "test_value");
565
566 assert!(result.unwrap_err().to_string().contains("overflow"));
567 }
568
569 #[test]
570 fn new_fails_for_oversized_n_tokens() {
571 let result = LlamaBatch::new(usize::MAX, 1);
572
573 assert!(result.unwrap_err().to_string().contains("overflow"));
574 }
575}