1use crate::batch_add_error::BatchAddError;
2use crate::sampled_token::SampledToken;
3use crate::token::LlamaToken;
4use llama_cpp_bindings_sys::{
5 llama_batch, llama_batch_free, llama_batch_init, llama_pos, llama_seq_id,
6};
7use std::marker::PhantomData;
8
9fn checked_n_tokens_plus_one_as_usize(n_tokens: i32) -> Result<usize, BatchAddError> {
10 let incremented = n_tokens.checked_add(1).ok_or_else(|| {
11 BatchAddError::IntegerOverflow(format!("n_tokens + 1 overflows i32: {n_tokens}"))
12 })?;
13
14 usize::try_from(incremented).map_err(|convert_error| {
15 BatchAddError::IntegerOverflow(format!("cannot fit n_tokens into a usize: {convert_error}"))
16 })
17}
18
19fn checked_i32_as_usize(value: i32, description: &str) -> Result<usize, BatchAddError> {
20 usize::try_from(value).map_err(|convert_error| {
21 BatchAddError::IntegerOverflow(format!(
22 "cannot fit {description} into a usize: {convert_error}"
23 ))
24 })
25}
26
27fn checked_usize_as_llama_seq_id(
28 value: usize,
29 description: &str,
30) -> Result<llama_seq_id, BatchAddError> {
31 llama_seq_id::try_from(value).map_err(|convert_error| {
32 BatchAddError::IntegerOverflow(format!(
33 "cannot fit {description} into a llama_seq_id: {convert_error}"
34 ))
35 })
36}
37
38fn checked_usize_as_i32(value: usize, description: &str) -> Result<i32, BatchAddError> {
39 i32::try_from(value).map_err(|convert_error| {
40 BatchAddError::IntegerOverflow(format!(
41 "cannot fit {description} into a i32: {convert_error}"
42 ))
43 })
44}
45
46fn checked_usize_as_llama_pos(value: usize, description: &str) -> Result<llama_pos, BatchAddError> {
47 llama_pos::try_from(value).map_err(|convert_error| {
48 BatchAddError::IntegerOverflow(format!(
49 "cannot fit {description} into a llama_pos: {convert_error}"
50 ))
51 })
52}
53
54#[derive(Debug)]
55pub struct LlamaBatch<'tokens> {
56 allocated: usize,
57 pub initialized_logits: Vec<i32>,
58 pub llama_batch: llama_batch,
59 phantom: PhantomData<&'tokens [LlamaToken]>,
60}
61
62impl<'tokens> LlamaBatch<'tokens> {
63 pub fn clear(&mut self) {
64 self.llama_batch.n_tokens = 0;
65 self.initialized_logits.clear();
66 }
67
68 pub fn add(
72 &mut self,
73 sampled_token: &SampledToken,
74 pos: llama_pos,
75 seq_ids: &[i32],
76 logits: bool,
77 ) -> Result<(), BatchAddError> {
78 let (SampledToken::Content(LlamaToken(id))
79 | SampledToken::Reasoning(LlamaToken(id))
80 | SampledToken::ToolCall(LlamaToken(id))
81 | SampledToken::Undeterminable(LlamaToken(id))) = *sampled_token;
82 let required = checked_n_tokens_plus_one_as_usize(self.n_tokens())?;
83
84 if self.allocated < required {
85 return Err(BatchAddError::InsufficientSpace(self.allocated));
86 }
87
88 let offset = self.llama_batch.n_tokens;
89 let offset_usize = checked_i32_as_usize(offset, "n_tokens")?;
90 let n_seq_id = checked_usize_as_llama_seq_id(seq_ids.len(), "seq_ids.len()")?;
91
92 unsafe {
93 self.llama_batch.token.add(offset_usize).write(id);
94 self.llama_batch.pos.add(offset_usize).write(pos);
95 self.llama_batch.n_seq_id.add(offset_usize).write(n_seq_id);
96 for (seq_index, seq_id) in seq_ids.iter().enumerate() {
97 let tmp = *self.llama_batch.seq_id.add(offset_usize);
98 tmp.add(seq_index).write(*seq_id);
99 }
100 self.llama_batch
101 .logits
102 .add(offset_usize)
103 .write(i8::from(logits));
104 }
105
106 if logits {
107 self.initialized_logits.push(offset);
108 }
109
110 self.llama_batch.n_tokens += 1;
111
112 Ok(())
113 }
114
115 pub fn add_sequence(
119 &mut self,
120 tokens: &[LlamaToken],
121 seq_id: i32,
122 logits_all: bool,
123 ) -> Result<(), BatchAddError> {
124 let last_index = checked_usize_as_llama_pos(tokens.len().saturating_sub(1), "n_tokens")?;
125
126 for (position, token) in (0..).zip(tokens.iter()) {
127 self.add(
128 &SampledToken::Content(*token),
129 position,
130 &[seq_id],
131 logits_all || position == last_index,
132 )?;
133 }
134
135 Ok(())
136 }
137
138 pub fn new(n_tokens: usize, n_seq_max: i32) -> Result<Self, BatchAddError> {
142 let n_tokens_i32 = checked_usize_as_i32(n_tokens, "n_tokens")?;
143 let batch = unsafe { llama_batch_init(n_tokens_i32, 0, n_seq_max) };
144
145 Ok(LlamaBatch {
146 allocated: n_tokens,
147 initialized_logits: vec![],
148 llama_batch: batch,
149 phantom: PhantomData,
150 })
151 }
152
153 pub fn get_one(tokens: &'tokens [LlamaToken]) -> Result<Self, BatchAddError> {
157 if tokens.is_empty() {
158 return Err(BatchAddError::EmptyBuffer);
159 }
160
161 let token_count = checked_usize_as_i32(tokens.len(), "token count")?;
162
163 let batch = unsafe {
164 #[expect(
165 clippy::as_ptr_cast_mut,
166 reason = "llama_batch_get_one signature requires *mut i32 but does not mutate the tokens"
167 )]
168 let ptr = tokens.as_ptr() as *mut i32;
169 llama_cpp_bindings_sys::llama_batch_get_one(ptr, token_count)
170 };
171
172 let last_token_index = checked_usize_as_i32(tokens.len() - 1, "last token index")?;
173
174 Ok(Self {
175 allocated: 0,
176 initialized_logits: vec![last_token_index],
177 llama_batch: batch,
178 phantom: PhantomData,
179 })
180 }
181
182 #[must_use]
183 pub const fn n_tokens(&self) -> i32 {
184 self.llama_batch.n_tokens
185 }
186}
187
188impl Drop for LlamaBatch<'_> {
189 fn drop(&mut self) {
190 unsafe {
191 if self.allocated > 0 {
192 llama_batch_free(self.llama_batch);
193 }
194 }
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use crate::sampled_token::SampledToken;
201 use crate::token::LlamaToken;
202
203 use super::{
204 BatchAddError, LlamaBatch, checked_i32_as_usize, checked_n_tokens_plus_one_as_usize,
205 checked_usize_as_i32, checked_usize_as_llama_pos, checked_usize_as_llama_seq_id,
206 };
207
208 #[test]
209 fn new_creates_empty_batch() {
210 let batch = LlamaBatch::new(16, 1).unwrap();
211
212 assert_eq!(batch.n_tokens(), 0);
213 assert!(batch.initialized_logits.is_empty());
214 }
215
216 #[test]
217 fn clear_resets_batch() {
218 let mut batch = LlamaBatch::new(16, 1).unwrap();
219 batch
220 .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], true)
221 .unwrap();
222 assert_eq!(batch.n_tokens(), 1);
223
224 batch.clear();
225
226 assert_eq!(batch.n_tokens(), 0);
227 assert!(batch.initialized_logits.is_empty());
228 }
229
230 #[test]
231 fn add_increments_token_count() {
232 let mut batch = LlamaBatch::new(16, 1).unwrap();
233
234 batch
235 .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], false)
236 .unwrap();
237 assert_eq!(batch.n_tokens(), 1);
238
239 batch
240 .add(&SampledToken::Content(LlamaToken::new(2)), 1, &[0], false)
241 .unwrap();
242 assert_eq!(batch.n_tokens(), 2);
243 }
244
245 #[test]
246 fn add_tracks_logits() {
247 let mut batch = LlamaBatch::new(16, 1).unwrap();
248
249 batch
250 .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], false)
251 .unwrap();
252 assert!(batch.initialized_logits.is_empty());
253
254 batch
255 .add(&SampledToken::Content(LlamaToken::new(2)), 1, &[0], true)
256 .unwrap();
257 assert_eq!(batch.initialized_logits, vec![1]);
258 }
259
260 #[test]
261 fn add_returns_insufficient_space_when_full() {
262 let mut batch = LlamaBatch::new(1, 1).unwrap();
263 batch
264 .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], false)
265 .unwrap();
266
267 let result = batch.add(&SampledToken::Content(LlamaToken::new(2)), 1, &[0], false);
268
269 assert_eq!(result, Err(BatchAddError::InsufficientSpace(1)));
270 }
271
272 #[test]
273 fn add_accepts_reasoning_sampled_token_variant() {
274 let mut batch = LlamaBatch::new(4, 1).unwrap();
275
276 batch
277 .add(&SampledToken::Reasoning(LlamaToken::new(11)), 0, &[0], true)
278 .unwrap();
279
280 assert_eq!(batch.n_tokens(), 1);
281 }
282
283 #[test]
284 fn add_accepts_tool_call_sampled_token_variant() {
285 let mut batch = LlamaBatch::new(4, 1).unwrap();
286
287 batch
288 .add(&SampledToken::ToolCall(LlamaToken::new(22)), 0, &[0], true)
289 .unwrap();
290
291 assert_eq!(batch.n_tokens(), 1);
292 }
293
294 #[test]
295 fn add_accepts_undeterminable_sampled_token_variant() {
296 let mut batch = LlamaBatch::new(4, 1).unwrap();
297
298 batch
299 .add(
300 &SampledToken::Undeterminable(LlamaToken::new(33)),
301 0,
302 &[0],
303 false,
304 )
305 .unwrap();
306
307 assert_eq!(batch.n_tokens(), 1);
308 }
309
310 #[test]
311 fn add_sequence_adds_all_tokens() {
312 let mut batch = LlamaBatch::new(16, 1).unwrap();
313 let tokens = vec![
314 LlamaToken::new(10),
315 LlamaToken::new(20),
316 LlamaToken::new(30),
317 ];
318
319 batch.add_sequence(&tokens, 0, false).unwrap();
320
321 assert_eq!(batch.n_tokens(), 3);
322 }
323
324 #[test]
325 fn add_sequence_sets_logits_on_last_token() {
326 let mut batch = LlamaBatch::new(16, 1).unwrap();
327 let tokens = vec![
328 LlamaToken::new(10),
329 LlamaToken::new(20),
330 LlamaToken::new(30),
331 ];
332
333 batch.add_sequence(&tokens, 0, false).unwrap();
334
335 assert_eq!(batch.initialized_logits, vec![2]);
336 }
337
338 #[test]
339 fn add_sequence_insufficient_space() {
340 let mut batch = LlamaBatch::new(2, 1).unwrap();
341 let tokens = vec![
342 LlamaToken::new(10),
343 LlamaToken::new(20),
344 LlamaToken::new(30),
345 ];
346
347 let result = batch.add_sequence(&tokens, 0, false);
348
349 assert!(result.is_err());
350 }
351
352 #[test]
353 fn add_sequence_fails_mid_loop_when_batch_fills() {
354 let mut batch = LlamaBatch::new(2, 1).unwrap();
355 batch
356 .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], false)
357 .unwrap();
358
359 let tokens = vec![LlamaToken::new(10), LlamaToken::new(20)];
360 let result = batch.add_sequence(&tokens, 0, false);
361
362 assert!(result.is_err());
363 }
364
365 #[test]
366 fn get_one_with_valid_tokens() {
367 let tokens = vec![LlamaToken::new(1), LlamaToken::new(2)];
368 let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
369
370 assert_eq!(batch.n_tokens(), 2);
371 assert_eq!(batch.initialized_logits, vec![1]);
372 }
373
374 #[test]
375 fn get_one_empty_slice_returns_error() {
376 let tokens: Vec<LlamaToken> = vec![];
377 let result = LlamaBatch::get_one(&tokens);
378
379 assert_eq!(result.unwrap_err(), BatchAddError::EmptyBuffer);
380 }
381
382 #[test]
383 fn get_one_single_token() {
384 let tokens = vec![LlamaToken::new(42)];
385 let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
386
387 assert_eq!(batch.n_tokens(), 1);
388 assert_eq!(batch.initialized_logits, vec![0]);
389 }
390
391 #[test]
392 fn add_with_logits_false_retains_only_previous_logits() {
393 let mut batch = LlamaBatch::new(16, 1).unwrap();
394
395 batch
396 .add(&SampledToken::Content(LlamaToken::new(1)), 0, &[0], true)
397 .unwrap();
398 assert_eq!(batch.initialized_logits, vec![0]);
399
400 batch
401 .add(&SampledToken::Content(LlamaToken::new(2)), 0, &[0], false)
402 .unwrap();
403 assert_eq!(batch.initialized_logits, vec![0]);
404 }
405
406 #[test]
407 fn add_sequence_with_logits_all_marks_every_token() -> Result<(), BatchAddError> {
408 let mut batch = LlamaBatch::new(16, 1)?;
409 let tokens = vec![
410 LlamaToken::new(10),
411 LlamaToken::new(20),
412 LlamaToken::new(30),
413 ];
414
415 batch.add_sequence(&tokens, 0, true)?;
416
417 assert_eq!(batch.n_tokens(), 3);
418 assert_eq!(batch.initialized_logits, vec![0, 1, 2]);
419
420 Ok(())
421 }
422
423 #[test]
424 fn add_with_multiple_seq_ids() -> Result<(), BatchAddError> {
425 let mut batch = LlamaBatch::new(16, 4)?;
426
427 batch.add(
428 &SampledToken::Content(LlamaToken::new(1)),
429 0,
430 &[0, 1, 2],
431 true,
432 )?;
433
434 assert_eq!(batch.n_tokens(), 1);
435 assert_eq!(batch.initialized_logits, vec![0]);
436
437 Ok(())
438 }
439
440 #[test]
441 fn drop_does_not_free_get_one_batch() {
442 let tokens = vec![LlamaToken::new(1), LlamaToken::new(2)];
443 let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
444
445 assert_eq!(batch.allocated, 0);
446 drop(batch);
447 }
448
449 #[test]
450 fn checked_n_tokens_plus_one_as_usize_succeeds_for_zero() {
451 let result = checked_n_tokens_plus_one_as_usize(0);
452
453 assert_eq!(result, Ok(1));
454 }
455
456 #[test]
457 fn checked_n_tokens_plus_one_as_usize_fails_for_negative() {
458 let result = checked_n_tokens_plus_one_as_usize(-2);
459
460 assert!(result.unwrap_err().to_string().contains("overflow"));
461 }
462
463 #[test]
464 fn checked_n_tokens_plus_one_as_usize_fails_for_i32_max() {
465 let result = checked_n_tokens_plus_one_as_usize(i32::MAX);
466
467 assert!(result.unwrap_err().to_string().contains("overflow"));
468 }
469
470 #[test]
471 fn checked_i32_as_usize_succeeds_for_zero() {
472 let result = checked_i32_as_usize(0, "test_value");
473
474 assert_eq!(result, Ok(0));
475 }
476
477 #[test]
478 fn checked_i32_as_usize_fails_for_negative() {
479 let result = checked_i32_as_usize(i32::MIN, "test_value");
480
481 assert!(result.unwrap_err().to_string().contains("overflow"));
482 }
483
484 #[test]
485 fn checked_usize_as_llama_seq_id_succeeds_for_zero() {
486 let result = checked_usize_as_llama_seq_id(0, "test_value");
487
488 assert_eq!(result, Ok(0));
489 }
490
491 #[test]
492 fn checked_usize_as_llama_seq_id_fails_for_overflow() {
493 let result = checked_usize_as_llama_seq_id(usize::MAX, "test_value");
494
495 assert!(result.unwrap_err().to_string().contains("overflow"));
496 }
497
498 #[test]
499 fn checked_usize_as_i32_succeeds_for_zero() {
500 let result = checked_usize_as_i32(0, "test_value");
501
502 assert_eq!(result, Ok(0));
503 }
504
505 #[test]
506 fn checked_usize_as_i32_fails_for_overflow() {
507 let result = checked_usize_as_i32(usize::MAX, "test_value");
508
509 assert!(result.unwrap_err().to_string().contains("overflow"));
510 }
511
512 #[test]
513 fn checked_usize_as_llama_pos_succeeds_for_zero() {
514 let result = checked_usize_as_llama_pos(0, "test_value");
515
516 assert_eq!(result, Ok(0));
517 }
518
519 #[test]
520 fn checked_usize_as_llama_pos_fails_for_overflow() {
521 let result = checked_usize_as_llama_pos(usize::MAX, "test_value");
522
523 assert!(result.unwrap_err().to_string().contains("overflow"));
524 }
525
526 #[test]
527 fn new_fails_for_oversized_n_tokens() {
528 let result = LlamaBatch::new(usize::MAX, 1);
529
530 assert!(result.unwrap_err().to_string().contains("overflow"));
531 }
532}