1use llama_cpp_sys::{llama_batch, llama_batch_free, llama_batch_init};
4use tracing::trace;
5
6use crate::Token;
7
8pub struct Batch {
10 inner: llama_batch,
27
28 capacity: usize,
30
31 max_sequences: usize,
33}
34
35impl Batch {
36 pub fn new(capacity: usize, embed: usize, max_sequences: usize) -> Self {
37 if capacity == 0 {
41 panic!("Cannot create a batch with no capacity");
42 }
43 if max_sequences == 0 {
44 panic!("At least one sequence must be generated");
45 }
46
47 Self {
48 inner: unsafe { llama_batch_init(capacity as i32, embed as i32, max_sequences as i32) },
49 capacity,
50 max_sequences,
51 }
52 }
53
54 pub fn clear(&mut self) {
55 self.inner.n_tokens = 0;
56 }
57
58 pub fn add(
59 &mut self,
60 token: Token,
61 position: usize,
62 sequence_ids: &[i32],
63 logits: bool,
64 ) -> usize {
65 trace!(
66 "Writing token {} of {} ({token:?})",
67 self.inner.n_tokens,
68 self.capacity
69 );
70
71 let i = self.inner.n_tokens as usize;
72
73 if i == self.capacity || self.max_sequences < sequence_ids.len() {
74 return usize::MAX;
75 }
76
77 unsafe {
78 self.inner.token.add(i).write(token.0);
81 self.inner.pos.add(i).write(position as i32);
82 if logits {
83 self.inner.logits.add(i).write(1);
84 } else {
85 self.inner.logits.add(i).write(0);
86 }
87 self.inner.n_seq_id.add(i).write(sequence_ids.len() as i32);
88
89 let seq_ptr = *self.inner.seq_id.add(i);
90
91 if !seq_ptr.is_null() {
92 for (i, id) in sequence_ids.iter().enumerate() {
93 seq_ptr.add(i).write(*id);
94 }
95 }
96 }
97
98 self.inner.n_tokens += 1;
99 self.inner.n_tokens as usize - 1
100 }
101
102 pub fn set_logits(&self, idx: usize, value: bool) {
103 assert!(idx < self.inner.n_tokens as usize, "Index out of bounds");
104
105 unsafe {
106 if value {
107 self.inner.logits.add(idx).write(1);
108 } else {
109 self.inner.logits.add(idx).write(0);
110 }
111 }
112 }
113
114 pub fn tokens(&self) -> usize {
115 self.inner.n_tokens as usize
116 }
117
118 pub fn handle(&self) -> llama_batch {
119 self.inner
120 }
121}
122
123impl Drop for Batch {
124 fn drop(&mut self) {
125 trace!("Freeing batch");
126
127 unsafe { llama_batch_free(self.inner) }
128 }
129}