1use std::collections::{BTreeSet, VecDeque};
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
45pub struct KvPageId(pub u32);
46
47#[derive(Debug, Clone, Copy)]
50pub struct KvPageDesc {
51 pub offset: usize,
53 pub bytes: usize,
55 pub filled: u16,
58}
59
60pub struct KvPagePool {
67 free: BTreeSet<u32>,
69 descs: Vec<KvPageDesc>,
71 pub bytes_per_page: usize,
73 pub tokens_per_page: u16,
74}
75
76impl KvPagePool {
77 pub fn new(num_pages: u32, bytes_per_page: usize, tokens_per_page: u16) -> Self {
78 let descs: Vec<KvPageDesc> = (0..num_pages)
79 .map(|i| KvPageDesc {
80 offset: (i as usize) * bytes_per_page,
81 bytes: bytes_per_page,
82 filled: 0,
83 })
84 .collect();
85 let free: BTreeSet<u32> = (0..num_pages).collect();
86 Self {
87 free,
88 descs,
89 bytes_per_page,
90 tokens_per_page,
91 }
92 }
93
94 pub fn capacity(&self) -> u32 {
95 self.descs.len() as u32
96 }
97 pub fn free_count(&self) -> u32 {
98 self.free.len() as u32
99 }
100 pub fn used_count(&self) -> u32 {
101 self.capacity() - self.free_count()
102 }
103
104 pub fn alloc(&mut self) -> Option<KvPageId> {
106 let id = *self.free.iter().next()?;
107 self.free.remove(&id);
108 self.descs[id as usize].filled = 0;
110 Some(KvPageId(id))
111 }
112
113 pub fn free(&mut self, id: KvPageId) {
114 self.free.insert(id.0);
115 }
116
117 pub fn descriptor(&self, id: KvPageId) -> &KvPageDesc {
118 &self.descs[id.0 as usize]
119 }
120
121 pub fn descriptor_mut(&mut self, id: KvPageId) -> &mut KvPageDesc {
122 &mut self.descs[id.0 as usize]
123 }
124}
125
126#[derive(Debug, Clone, Default)]
130pub struct KvBlockTable {
131 pages: Vec<KvPageId>,
132 pub seq_len: u32,
134}
135
136impl KvBlockTable {
137 pub fn new() -> Self {
138 Self::default()
139 }
140
141 pub fn push_page(&mut self, page: KvPageId) {
144 self.pages.push(page);
145 }
146
147 pub fn page_for_token(&self, t: u32, tokens_per_page: u16) -> Option<KvPageId> {
150 let idx = (t / tokens_per_page as u32) as usize;
151 self.pages.get(idx).copied()
152 }
153
154 pub fn page_count(&self) -> usize {
155 self.pages.len()
156 }
157
158 pub fn release(&mut self, pool: &mut KvPagePool) {
161 for p in self.pages.drain(..) {
162 pool.free(p);
163 }
164 self.seq_len = 0;
165 }
166
167 pub fn pages(&self) -> &[KvPageId] {
169 &self.pages
170 }
171}
172
173#[derive(Debug, Clone)]
177pub struct BatchEntry {
178 pub seq_id: u64,
179 pub kind: BatchKind,
180 pub input_tokens: Vec<u32>,
182 pub cached_len: u32,
184}
185
186#[derive(Debug, Clone, Copy, PartialEq, Eq)]
187pub enum BatchKind {
188 Prefill,
191 Decode,
193}
194
195pub struct BatchConstructor {
201 pub max_tokens_per_batch: usize,
203 pub max_entries: usize,
206}
207
208impl BatchConstructor {
209 pub fn new(max_tokens_per_batch: usize, max_entries: usize) -> Self {
210 Self {
211 max_tokens_per_batch,
212 max_entries,
213 }
214 }
215
216 pub fn build(
221 &self,
222 decode_queue: &mut VecDeque<BatchEntry>,
223 prefill_queue: &mut VecDeque<BatchEntry>,
224 ) -> Vec<BatchEntry> {
225 let mut batch: Vec<BatchEntry> = Vec::new();
226 let mut tokens_used = 0usize;
227
228 while batch.len() < self.max_entries {
229 if let Some(d) = decode_queue.front() {
230 let need = d.input_tokens.len();
231 if tokens_used + need > self.max_tokens_per_batch {
232 break;
233 }
234 batch.push(decode_queue.pop_front().unwrap());
235 tokens_used += need;
236 } else {
237 break;
238 }
239 }
240
241 while batch.len() < self.max_entries {
242 let want = match prefill_queue.front() {
243 Some(p) => p.input_tokens.len(),
244 None => break,
245 };
246 let remaining = self.max_tokens_per_batch.saturating_sub(tokens_used);
247 if remaining == 0 {
248 break;
249 }
250
251 if want <= remaining {
252 batch.push(prefill_queue.pop_front().unwrap());
253 tokens_used += want;
254 } else {
255 let mut p = prefill_queue.pop_front().unwrap();
258 let chunk: Vec<u32> = p.input_tokens.drain(..remaining).collect();
259 let chunk_entry = BatchEntry {
260 seq_id: p.seq_id,
261 kind: BatchKind::Prefill,
262 input_tokens: chunk,
263 cached_len: p.cached_len,
264 };
265 p.cached_len += remaining as u32;
266 batch.push(chunk_entry);
267 prefill_queue.push_front(p);
268 break;
269 }
270 }
271
272 batch
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn pool_alloc_free_round_trip() {
282 let mut pool = KvPagePool::new(4, 1024, 16);
283 assert_eq!(pool.free_count(), 4);
284 let p1 = pool.alloc().unwrap();
285 let p2 = pool.alloc().unwrap();
286 assert_eq!(pool.free_count(), 2);
287 pool.free(p1);
288 pool.free(p2);
289 assert_eq!(pool.free_count(), 4);
290 }
291
292 #[test]
293 fn pool_returns_none_when_exhausted() {
294 let mut pool = KvPagePool::new(2, 64, 4);
295 let _a = pool.alloc().unwrap();
296 let _b = pool.alloc().unwrap();
297 assert!(pool.alloc().is_none());
298 }
299
300 #[test]
301 fn pool_descriptor_offsets_are_unique_and_aligned() {
302 let pool = KvPagePool::new(4, 256, 16);
303 for i in 0..4u32 {
304 let d = pool.descriptor(KvPageId(i));
305 assert_eq!(d.offset, i as usize * 256);
306 assert_eq!(d.bytes, 256);
307 }
308 }
309
310 #[test]
311 fn block_table_page_for_token() {
312 let mut pool = KvPagePool::new(8, 64, 4);
313 let mut bt = KvBlockTable::new();
314 for _ in 0..3 {
315 bt.push_page(pool.alloc().unwrap());
316 }
317 assert_eq!(bt.page_for_token(0, 4), Some(bt.pages()[0]));
319 assert_eq!(bt.page_for_token(7, 4), Some(bt.pages()[1]));
320 assert_eq!(bt.page_for_token(11, 4), Some(bt.pages()[2]));
321 assert_eq!(bt.page_for_token(12, 4), None);
322 }
323
324 #[test]
325 fn block_table_release_returns_pages() {
326 let mut pool = KvPagePool::new(8, 64, 4);
327 let mut bt = KvBlockTable::new();
328 for _ in 0..3 {
329 bt.push_page(pool.alloc().unwrap());
330 }
331 assert_eq!(pool.free_count(), 5);
332 bt.release(&mut pool);
333 assert_eq!(pool.free_count(), 8);
334 assert_eq!(bt.page_count(), 0);
335 }
336
337 #[test]
338 fn batch_constructor_decodes_first_then_prefill() {
339 let bc = BatchConstructor::new(8, 16);
340 let mut decodes: VecDeque<BatchEntry> = (0..3)
341 .map(|i| BatchEntry {
342 seq_id: i,
343 kind: BatchKind::Decode,
344 input_tokens: vec![100 + i as u32],
345 cached_len: 50,
346 })
347 .collect();
348 let mut prefills: VecDeque<BatchEntry> = (0..2)
349 .map(|i| BatchEntry {
350 seq_id: 100 + i,
351 kind: BatchKind::Prefill,
352 input_tokens: vec![1; 3],
353 cached_len: 0,
354 })
355 .collect();
356
357 let batch = bc.build(&mut decodes, &mut prefills);
358 assert_eq!(batch.len(), 5);
363 for entry in batch.iter().take(3) {
365 assert_eq!(entry.kind, BatchKind::Decode);
366 }
367 for entry in batch.iter().skip(3).take(2) {
368 assert_eq!(entry.kind, BatchKind::Prefill);
369 }
370 let total_tokens: usize = batch.iter().map(|e| e.input_tokens.len()).sum();
371 assert_eq!(total_tokens, 8);
372
373 assert_eq!(prefills.len(), 1);
375 assert_eq!(prefills[0].input_tokens.len(), 1);
376 assert_eq!(prefills[0].cached_len, 2);
377 }
378
379 #[test]
380 fn batch_constructor_respects_max_entries() {
381 let bc = BatchConstructor::new(1024, 2); let mut decodes: VecDeque<BatchEntry> = (0..5)
383 .map(|i| BatchEntry {
384 seq_id: i,
385 kind: BatchKind::Decode,
386 input_tokens: vec![1],
387 cached_len: 0,
388 })
389 .collect();
390 let mut prefills: VecDeque<BatchEntry> = VecDeque::new();
391 let batch = bc.build(&mut decodes, &mut prefills);
392 assert_eq!(batch.len(), 2);
393 assert_eq!(decodes.len(), 3);
394 }
395}