1use crate::token::LlamaToken;
4use llama_cpp_bindings_sys::{
5 llama_batch, llama_batch_free, llama_batch_init, llama_pos, llama_seq_id,
6};
7use std::marker::PhantomData;
8
9#[derive(Debug)]
14pub struct LlamaBatch<'tokens> {
15 allocated: usize,
17 pub initialized_logits: Vec<i32>,
19 pub llama_batch: llama_batch,
21 phantom: PhantomData<&'tokens [LlamaToken]>,
22}
23
24#[derive(thiserror::Error, Debug, PartialEq, Eq)]
26pub enum BatchAddError {
27 #[error("Insufficient Space of {0}")]
29 InsufficientSpace(usize),
30 #[error("Empty buffer")]
32 EmptyBuffer,
33 #[error("Integer overflow: {0}")]
35 IntegerOverflow(String),
36}
37
38impl<'tokens> LlamaBatch<'tokens> {
39 pub fn clear(&mut self) {
42 self.llama_batch.n_tokens = 0;
43 self.initialized_logits.clear();
44 }
45
46 pub fn add(
53 &mut self,
54 LlamaToken(id): LlamaToken,
55 pos: llama_pos,
56 seq_ids: &[i32],
57 logits: bool,
58 ) -> Result<(), BatchAddError> {
59 let required = usize::try_from(self.n_tokens() + 1).map_err(|convert_error| {
60 BatchAddError::IntegerOverflow(format!(
61 "cannot fit n_tokens into a usize: {convert_error}"
62 ))
63 })?;
64
65 if self.allocated < required {
66 return Err(BatchAddError::InsufficientSpace(self.allocated));
67 }
68
69 let offset = self.llama_batch.n_tokens;
70 let offset_usize = usize::try_from(offset).map_err(|convert_error| {
71 BatchAddError::IntegerOverflow(format!(
72 "cannot fit n_tokens into a usize: {convert_error}"
73 ))
74 })?;
75
76 let n_seq_id = llama_seq_id::try_from(seq_ids.len()).map_err(|convert_error| {
77 BatchAddError::IntegerOverflow(format!(
78 "cannot fit seq_ids.len() into a llama_seq_id: {convert_error}"
79 ))
80 })?;
81
82 unsafe {
83 self.llama_batch.token.add(offset_usize).write(id);
85 self.llama_batch.pos.add(offset_usize).write(pos);
87 self.llama_batch.n_seq_id.add(offset_usize).write(n_seq_id);
89 for (seq_index, seq_id) in seq_ids.iter().enumerate() {
93 let tmp = *self.llama_batch.seq_id.add(offset_usize);
94 tmp.add(seq_index).write(*seq_id);
95 }
96 self.llama_batch
98 .logits
99 .add(offset_usize)
100 .write(i8::from(logits));
101 }
102
103 if logits {
104 self.initialized_logits.push(offset);
105 } else {
106 self.initialized_logits
107 .retain(|logit_offset| logit_offset != &offset);
108 }
109
110 self.llama_batch.n_tokens += 1;
111
112 Ok(())
113 }
114
115 pub fn add_sequence(
124 &mut self,
125 tokens: &[LlamaToken],
126 seq_id: i32,
127 logits_all: bool,
128 ) -> Result<(), BatchAddError> {
129 let n_tokens_0 = usize::try_from(self.llama_batch.n_tokens).map_err(|convert_error| {
130 BatchAddError::IntegerOverflow(format!(
131 "cannot fit n_tokens into a usize: {convert_error}"
132 ))
133 })?;
134 let n_tokens = tokens.len();
135
136 if self.allocated < n_tokens_0 + n_tokens {
137 return Err(BatchAddError::InsufficientSpace(self.allocated));
138 }
139
140 let last_index =
141 llama_pos::try_from(n_tokens.saturating_sub(1)).map_err(|convert_error| {
142 BatchAddError::IntegerOverflow(format!(
143 "cannot fit n_tokens into a llama_pos: {convert_error}"
144 ))
145 })?;
146
147 for (position, token) in (0..).zip(tokens.iter()) {
148 self.add(
149 *token,
150 position,
151 &[seq_id],
152 logits_all || position == last_index,
153 )?;
154 }
155
156 Ok(())
157 }
158
159 pub fn new(n_tokens: usize, n_seq_max: i32) -> Result<Self, BatchAddError> {
170 let n_tokens_i32 = i32::try_from(n_tokens).map_err(|convert_error| {
171 BatchAddError::IntegerOverflow(format!(
172 "cannot fit n_tokens into a i32: {convert_error}"
173 ))
174 })?;
175 let batch = unsafe { llama_batch_init(n_tokens_i32, 0, n_seq_max) };
176
177 Ok(LlamaBatch {
178 allocated: n_tokens,
179 initialized_logits: vec![],
180 llama_batch: batch,
181 phantom: PhantomData,
182 })
183 }
184
185 pub fn get_one(tokens: &'tokens [LlamaToken]) -> Result<Self, BatchAddError> {
194 if tokens.is_empty() {
195 return Err(BatchAddError::EmptyBuffer);
196 }
197
198 let token_count = tokens.len().try_into().map_err(|convert_error| {
199 BatchAddError::IntegerOverflow(format!(
200 "number of tokens exceeds i32::MAX: {convert_error}"
201 ))
202 })?;
203
204 let batch = unsafe {
205 let ptr = tokens.as_ptr() as *mut i32;
206 llama_cpp_bindings_sys::llama_batch_get_one(ptr, token_count)
207 };
208
209 let last_token_index = (tokens.len() - 1).try_into().map_err(|convert_error| {
210 BatchAddError::IntegerOverflow(format!(
211 "number of tokens exceeds i32::MAX: {convert_error}"
212 ))
213 })?;
214
215 Ok(Self {
216 allocated: 0,
217 initialized_logits: vec![last_token_index],
218 llama_batch: batch,
219 phantom: PhantomData,
220 })
221 }
222
223 #[must_use]
225 pub fn n_tokens(&self) -> i32 {
226 self.llama_batch.n_tokens
227 }
228}
229
230impl Drop for LlamaBatch<'_> {
231 fn drop(&mut self) {
243 unsafe {
244 if self.allocated > 0 {
245 llama_batch_free(self.llama_batch);
246 }
247 }
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use crate::token::LlamaToken;
254
255 use super::{BatchAddError, LlamaBatch};
256
257 #[test]
258 fn new_creates_empty_batch() -> Result<(), BatchAddError> {
259 let batch = LlamaBatch::new(16, 1)?;
260
261 assert_eq!(batch.n_tokens(), 0);
262 assert!(batch.initialized_logits.is_empty());
263
264 Ok(())
265 }
266
267 #[test]
268 fn clear_resets_batch() -> Result<(), BatchAddError> {
269 let mut batch = LlamaBatch::new(16, 1)?;
270 batch.add(LlamaToken::new(1), 0, &[0], true)?;
271 assert_eq!(batch.n_tokens(), 1);
272
273 batch.clear();
274
275 assert_eq!(batch.n_tokens(), 0);
276 assert!(batch.initialized_logits.is_empty());
277
278 Ok(())
279 }
280
281 #[test]
282 fn add_increments_token_count() -> Result<(), BatchAddError> {
283 let mut batch = LlamaBatch::new(16, 1)?;
284
285 batch.add(LlamaToken::new(1), 0, &[0], false)?;
286 assert_eq!(batch.n_tokens(), 1);
287
288 batch.add(LlamaToken::new(2), 1, &[0], false)?;
289 assert_eq!(batch.n_tokens(), 2);
290
291 Ok(())
292 }
293
294 #[test]
295 fn add_tracks_logits() -> Result<(), BatchAddError> {
296 let mut batch = LlamaBatch::new(16, 1)?;
297
298 batch.add(LlamaToken::new(1), 0, &[0], false)?;
299 assert!(batch.initialized_logits.is_empty());
300
301 batch.add(LlamaToken::new(2), 1, &[0], true)?;
302 assert_eq!(batch.initialized_logits, vec![1]);
303
304 Ok(())
305 }
306
307 #[test]
308 fn add_returns_insufficient_space_when_full() -> Result<(), BatchAddError> {
309 let mut batch = LlamaBatch::new(1, 1)?;
310 batch.add(LlamaToken::new(1), 0, &[0], false)?;
311
312 let result = batch.add(LlamaToken::new(2), 1, &[0], false);
313
314 assert_eq!(result, Err(BatchAddError::InsufficientSpace(1)));
315
316 Ok(())
317 }
318
319 #[test]
320 fn add_sequence_adds_all_tokens() -> Result<(), BatchAddError> {
321 let mut batch = LlamaBatch::new(16, 1)?;
322 let tokens = vec![
323 LlamaToken::new(10),
324 LlamaToken::new(20),
325 LlamaToken::new(30),
326 ];
327
328 batch.add_sequence(&tokens, 0, false)?;
329
330 assert_eq!(batch.n_tokens(), 3);
331
332 Ok(())
333 }
334
335 #[test]
336 fn add_sequence_sets_logits_on_last_token() -> Result<(), BatchAddError> {
337 let mut batch = LlamaBatch::new(16, 1)?;
338 let tokens = vec![
339 LlamaToken::new(10),
340 LlamaToken::new(20),
341 LlamaToken::new(30),
342 ];
343
344 batch.add_sequence(&tokens, 0, false)?;
345
346 assert_eq!(batch.initialized_logits, vec![2]);
347
348 Ok(())
349 }
350
351 #[test]
352 fn add_sequence_insufficient_space() -> Result<(), BatchAddError> {
353 let mut batch = LlamaBatch::new(2, 1)?;
354 let tokens = vec![
355 LlamaToken::new(10),
356 LlamaToken::new(20),
357 LlamaToken::new(30),
358 ];
359
360 let result = batch.add_sequence(&tokens, 0, false);
361
362 assert!(result.is_err());
363
364 Ok(())
365 }
366
367 #[test]
368 fn get_one_with_valid_tokens() {
369 let tokens = vec![LlamaToken::new(1), LlamaToken::new(2)];
370 let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
371
372 assert_eq!(batch.n_tokens(), 2);
373 assert_eq!(batch.initialized_logits, vec![1]);
374 }
375
376 #[test]
377 fn get_one_empty_slice_returns_error() {
378 let tokens: Vec<LlamaToken> = vec![];
379 let result = LlamaBatch::get_one(&tokens);
380
381 assert!(
382 matches!(result, Err(BatchAddError::EmptyBuffer)),
383 "expected EmptyBuffer error"
384 );
385 }
386
387 #[test]
388 fn get_one_single_token() {
389 let tokens = vec![LlamaToken::new(42)];
390 let batch = LlamaBatch::get_one(&tokens).expect("test: get_one should succeed");
391
392 assert_eq!(batch.n_tokens(), 1);
393 assert_eq!(batch.initialized_logits, vec![0]);
394 }
395}