1use llama_crab_sys as sys;
4
5use crate::error::LlamaError;
6use crate::token::LlamaToken;
7
8#[derive(Debug)]
10pub struct LlamaBatch {
11 raw: sys::llama_batch,
12 tokens: Vec<sys::llama_token>,
14 positions: Vec<sys::llama_pos>,
15 n_seq_id: Vec<i32>,
16 seq_ids: Vec<Vec<sys::llama_seq_id>>,
17 seq_ids_ptrs: Vec<*mut sys::llama_seq_id>,
18 logits: Vec<i8>,
19 allocated: bool,
20}
21
22#[derive(Debug, Clone, PartialEq, Eq)]
24pub enum BatchAddError {
25 InsufficientSpace(usize),
27 Empty,
29}
30
31impl std::fmt::Display for BatchAddError {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 match self {
34 Self::InsufficientSpace(n) => write!(f, "batch only has space for {n} tokens"),
35 Self::Empty => write!(f, "no token to add"),
36 }
37 }
38}
39
40impl std::error::Error for BatchAddError {}
41
42impl LlamaBatch {
43 #[must_use]
48 pub fn new(n_tokens: usize, n_seq_max: i32) -> Self {
49 let tokens = vec![0_i32; n_tokens];
50 let positions = vec![0_i32; n_tokens];
51 let n_seq_id = vec![n_seq_max; n_tokens];
52 let mut seq_ids = Vec::with_capacity(n_tokens);
53 let mut seq_ids_ptrs: Vec<*mut sys::llama_seq_id> = Vec::with_capacity(n_tokens);
54 for _ in 0..n_tokens {
55 let mut v: Vec<i32> = vec![0; n_seq_max as usize];
56 seq_ids_ptrs.push(v.as_mut_ptr());
57 seq_ids.push(v);
58 }
59 let logits = vec![0_i8; n_tokens];
60 let raw = sys::llama_batch {
61 n_tokens: 0,
62 token: tokens.as_ptr().cast_mut(),
63 embd: std::ptr::null_mut(),
64 pos: positions.as_ptr().cast_mut(),
65 n_seq_id: n_seq_id.as_ptr().cast_mut(),
66 seq_id: seq_ids_ptrs.as_ptr().cast_mut(),
67 logits: logits.as_ptr().cast_mut(),
68 };
69 Self {
70 raw,
71 tokens,
72 positions,
73 n_seq_id,
74 seq_ids,
75 seq_ids_ptrs,
76 logits,
77 allocated: true,
78 }
79 }
80
81 #[must_use]
84 pub fn one(token: LlamaToken, pos: i32, seq_id: i32, logits: bool) -> Self {
85 let mut b = Self::new(1, 1);
86 b.add(token, pos, &[seq_id], logits).expect("capacity 1");
87 b
88 }
89
90 #[must_use]
92 pub fn n_tokens(&self) -> i32 {
93 self.raw.n_tokens
94 }
95
96 pub fn clear(&mut self) {
98 self.raw.n_tokens = 0;
99 }
100
101 pub fn add(
106 &mut self,
107 token: LlamaToken,
108 pos: i32,
109 seq_ids: &[i32],
110 logits: bool,
111 ) -> std::result::Result<(), BatchAddError> {
112 let idx = self.raw.n_tokens as usize;
113 if idx >= self.tokens.len() {
114 return Err(BatchAddError::InsufficientSpace(self.tokens.len()));
115 }
116 if seq_ids.is_empty() {
117 return Err(BatchAddError::Empty);
118 }
119 unsafe {
123 let mut_ptr = self.tokens.as_ptr().cast_mut();
124 std::ptr::write(mut_ptr.add(idx), token.0);
125 let pos_ptr = self.positions.as_ptr().cast_mut();
126 std::ptr::write(pos_ptr.add(idx), pos);
127 let logits_ptr = self.logits.as_ptr().cast_mut();
128 std::ptr::write(logits_ptr.add(idx), i8::from(logits));
129 }
130 for (i, &sid) in seq_ids.iter().enumerate() {
131 if i < self.seq_ids[idx].len() {
132 self.seq_ids[idx][i] = sid;
133 }
134 }
135 self.raw.n_tokens += 1;
136 Ok(())
137 }
138
139 pub(crate) fn raw(&self) -> &sys::llama_batch {
141 &self.raw
142 }
143}
144
145impl From<BatchAddError> for LlamaError {
147 fn from(e: BatchAddError) -> Self {
148 Self::Batch(e.to_string())
149 }
150}