Skip to main content

rlx_runtime/
paged_kv.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Paged KV cache + continuous batching (plan #31).
17//!
18//! Borrowed from MAX's `serve/scheduler/{prefill_scheduler,
19//! decode_scheduler, text_generation_scheduler, batch_constructor/}`.
20//! The standard LLM-serving arch:
21//!
22//!   - KV cache lives in fixed-size **pages** (block of N tokens
23//!     per page per layer), not contiguous per-sequence buffers.
24//!     A sequence is a list of page IDs; reaching the end of a
25//!     page allocates the next from a pool.
26//!   - **Continuous batching**: prefill chunks of new sequences
27//!     pack into the same forward pass as decode steps of
28//!     in-flight sequences. The batch constructor decides which
29//!     work goes into the next forward.
30//!
31//! Throughput vs. naive max-padding: 5-10× higher at the same
32//! latency budget, mostly because GPU utilization stays high
33//! when sequences finish at different times.
34//!
35//! This module is the **data layer** — pool allocation, the
36//! sequence-to-page mapping, and the batch packing logic. Kernel
37//! integration (gather KV bytes from pages into the attention
38//! input) is per-attention-kernel work that lands when an
39//! autoregressive LLM model runners.
40
41use std::collections::{BTreeSet, VecDeque};
42
43/// Opaque physical-page identifier from a [`KvPagePool`].
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
45pub struct KvPageId(pub u32);
46
47/// Fixed-size KV-cache page descriptor. Owns a contiguous range
48/// of byte offsets into a backing arena (managed externally).
49#[derive(Debug, Clone, Copy)]
50pub struct KvPageDesc {
51    /// Byte offset into the KV arena for this page.
52    pub offset: usize,
53    /// Bytes per page = `tokens_per_page * num_layers * 2 (k+v) * num_heads * head_dim * dtype_bytes`.
54    pub bytes: usize,
55    /// How many tokens this page holds in its leading slots.
56    /// Used during prefill where a page may be partial.
57    pub filled: u16,
58}
59
60/// Pool of fixed-size physical pages. Allocates from a free list;
61/// `free` returns a page so the next allocation can reuse it.
62///
63/// Pool capacity is `num_pages` * `bytes_per_page`. Caller owns the
64/// underlying byte arena (typically a single large MTLBuffer or
65/// host Vec); the pool tracks which IDs are free.
66pub struct KvPagePool {
67    /// Sorted set of free page IDs.
68    free: BTreeSet<u32>,
69    /// Per-page metadata. `descs[i].offset = i * bytes_per_page`.
70    descs: Vec<KvPageDesc>,
71    /// Constants exposed for ergonomics.
72    pub bytes_per_page: usize,
73    pub tokens_per_page: u16,
74}
75
76impl KvPagePool {
77    pub fn new(num_pages: u32, bytes_per_page: usize, tokens_per_page: u16) -> Self {
78        let descs: Vec<KvPageDesc> = (0..num_pages)
79            .map(|i| KvPageDesc {
80                offset: (i as usize) * bytes_per_page,
81                bytes: bytes_per_page,
82                filled: 0,
83            })
84            .collect();
85        let free: BTreeSet<u32> = (0..num_pages).collect();
86        Self {
87            free,
88            descs,
89            bytes_per_page,
90            tokens_per_page,
91        }
92    }
93
94    pub fn capacity(&self) -> u32 {
95        self.descs.len() as u32
96    }
97    pub fn free_count(&self) -> u32 {
98        self.free.len() as u32
99    }
100    pub fn used_count(&self) -> u32 {
101        self.capacity() - self.free_count()
102    }
103
104    /// Allocate one page. Returns `None` when the pool is empty.
105    pub fn alloc(&mut self) -> Option<KvPageId> {
106        let id = *self.free.iter().next()?;
107        self.free.remove(&id);
108        // Reset filled count on alloc — caller starts fresh.
109        self.descs[id as usize].filled = 0;
110        Some(KvPageId(id))
111    }
112
113    pub fn free(&mut self, id: KvPageId) {
114        self.free.insert(id.0);
115    }
116
117    pub fn descriptor(&self, id: KvPageId) -> &KvPageDesc {
118        &self.descs[id.0 as usize]
119    }
120
121    pub fn descriptor_mut(&mut self, id: KvPageId) -> &mut KvPageDesc {
122        &mut self.descs[id.0 as usize]
123    }
124}
125
126/// Per-sequence map of logical-token-position → physical page.
127/// Token `t` lives at `pages[t / tokens_per_page]` slot
128/// `t % tokens_per_page`.
129#[derive(Debug, Clone, Default)]
130pub struct KvBlockTable {
131    pages: Vec<KvPageId>,
132    /// Number of tokens this sequence currently has cached.
133    pub seq_len: u32,
134}
135
136impl KvBlockTable {
137    pub fn new() -> Self {
138        Self::default()
139    }
140
141    /// Append a new page from the pool. Used when seq_len mod
142    /// tokens_per_page == 0 (boundary).
143    pub fn push_page(&mut self, page: KvPageId) {
144        self.pages.push(page);
145    }
146
147    /// Look up which page holds token `t`. Returns `None` if `t`
148    /// is past the cached region.
149    pub fn page_for_token(&self, t: u32, tokens_per_page: u16) -> Option<KvPageId> {
150        let idx = (t / tokens_per_page as u32) as usize;
151        self.pages.get(idx).copied()
152    }
153
154    pub fn page_count(&self) -> usize {
155        self.pages.len()
156    }
157
158    /// Free every page back to the pool. Called when a sequence
159    /// finishes / is evicted.
160    pub fn release(&mut self, pool: &mut KvPagePool) {
161        for p in self.pages.drain(..) {
162            pool.free(p);
163        }
164        self.seq_len = 0;
165    }
166
167    /// Slice of page IDs. Useful for kernel-side gather.
168    pub fn pages(&self) -> &[KvPageId] {
169        &self.pages
170    }
171}
172
173/// One slot in a continuous batch — either a decode step
174/// (single new token from a sequence with prior cache) or a
175/// prefill chunk (multiple new tokens for a fresh sequence).
176#[derive(Debug, Clone)]
177pub struct BatchEntry {
178    pub seq_id: u64,
179    pub kind: BatchKind,
180    /// Tokens to feed in this forward pass.
181    pub input_tokens: Vec<u32>,
182    /// Pre-existing KV-cache length (number of cached tokens before this batch).
183    pub cached_len: u32,
184}
185
186#[derive(Debug, Clone, Copy, PartialEq, Eq)]
187pub enum BatchKind {
188    /// Prefill (or prefill chunk): N new tokens for a sequence
189    /// with `cached_len` already prefilled.
190    Prefill,
191    /// Decode: one token sampled from the previous step.
192    Decode,
193}
194
195/// Constructs a continuous batch: pulls pending decode requests
196/// first (cheap, one-token forwards), then fills remaining
197/// **token budget** with prefill chunks. The token budget — not
198/// the sequence count — is the gating constraint because
199/// prefills can be arbitrarily long.
200pub struct BatchConstructor {
201    /// Maximum tokens per forward across all entries.
202    pub max_tokens_per_batch: usize,
203    /// Maximum entries per forward (also bounds memory/scheduler
204    /// overhead).
205    pub max_entries: usize,
206}
207
208impl BatchConstructor {
209    pub fn new(max_tokens_per_batch: usize, max_entries: usize) -> Self {
210        Self {
211            max_tokens_per_batch,
212            max_entries,
213        }
214    }
215
216    /// Build the next batch. Walks `decode_queue` first (one
217    /// token each, cheap), then fills remaining token budget by
218    /// chunking from `prefill_queue`. Sequences that didn't fit
219    /// stay in their queues for the next call.
220    pub fn build(
221        &self,
222        decode_queue: &mut VecDeque<BatchEntry>,
223        prefill_queue: &mut VecDeque<BatchEntry>,
224    ) -> Vec<BatchEntry> {
225        let mut batch: Vec<BatchEntry> = Vec::new();
226        let mut tokens_used = 0usize;
227
228        while batch.len() < self.max_entries {
229            if let Some(d) = decode_queue.front() {
230                let need = d.input_tokens.len();
231                if tokens_used + need > self.max_tokens_per_batch {
232                    break;
233                }
234                batch.push(decode_queue.pop_front().unwrap());
235                tokens_used += need;
236            } else {
237                break;
238            }
239        }
240
241        while batch.len() < self.max_entries {
242            let want = match prefill_queue.front() {
243                Some(p) => p.input_tokens.len(),
244                None => break,
245            };
246            let remaining = self.max_tokens_per_batch.saturating_sub(tokens_used);
247            if remaining == 0 {
248                break;
249            }
250
251            if want <= remaining {
252                batch.push(prefill_queue.pop_front().unwrap());
253                tokens_used += want;
254            } else {
255                // Chunk: take `remaining` tokens off the front,
256                // leave the rest for the next batch.
257                let mut p = prefill_queue.pop_front().unwrap();
258                let chunk: Vec<u32> = p.input_tokens.drain(..remaining).collect();
259                let chunk_entry = BatchEntry {
260                    seq_id: p.seq_id,
261                    kind: BatchKind::Prefill,
262                    input_tokens: chunk,
263                    cached_len: p.cached_len,
264                };
265                p.cached_len += remaining as u32;
266                batch.push(chunk_entry);
267                prefill_queue.push_front(p);
268                break;
269            }
270        }
271
272        batch
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn pool_alloc_free_round_trip() {
282        let mut pool = KvPagePool::new(4, 1024, 16);
283        assert_eq!(pool.free_count(), 4);
284        let p1 = pool.alloc().unwrap();
285        let p2 = pool.alloc().unwrap();
286        assert_eq!(pool.free_count(), 2);
287        pool.free(p1);
288        pool.free(p2);
289        assert_eq!(pool.free_count(), 4);
290    }
291
292    #[test]
293    fn pool_returns_none_when_exhausted() {
294        let mut pool = KvPagePool::new(2, 64, 4);
295        let _a = pool.alloc().unwrap();
296        let _b = pool.alloc().unwrap();
297        assert!(pool.alloc().is_none());
298    }
299
300    #[test]
301    fn pool_descriptor_offsets_are_unique_and_aligned() {
302        let pool = KvPagePool::new(4, 256, 16);
303        for i in 0..4u32 {
304            let d = pool.descriptor(KvPageId(i));
305            assert_eq!(d.offset, i as usize * 256);
306            assert_eq!(d.bytes, 256);
307        }
308    }
309
310    #[test]
311    fn block_table_page_for_token() {
312        let mut pool = KvPagePool::new(8, 64, 4);
313        let mut bt = KvBlockTable::new();
314        for _ in 0..3 {
315            bt.push_page(pool.alloc().unwrap());
316        }
317        // tokens_per_page = 4 → tokens 0..4 in page 0, 4..8 in page 1, ...
318        assert_eq!(bt.page_for_token(0, 4), Some(bt.pages()[0]));
319        assert_eq!(bt.page_for_token(7, 4), Some(bt.pages()[1]));
320        assert_eq!(bt.page_for_token(11, 4), Some(bt.pages()[2]));
321        assert_eq!(bt.page_for_token(12, 4), None);
322    }
323
324    #[test]
325    fn block_table_release_returns_pages() {
326        let mut pool = KvPagePool::new(8, 64, 4);
327        let mut bt = KvBlockTable::new();
328        for _ in 0..3 {
329            bt.push_page(pool.alloc().unwrap());
330        }
331        assert_eq!(pool.free_count(), 5);
332        bt.release(&mut pool);
333        assert_eq!(pool.free_count(), 8);
334        assert_eq!(bt.page_count(), 0);
335    }
336
337    #[test]
338    fn batch_constructor_decodes_first_then_prefill() {
339        let bc = BatchConstructor::new(8, 16);
340        let mut decodes: VecDeque<BatchEntry> = (0..3)
341            .map(|i| BatchEntry {
342                seq_id: i,
343                kind: BatchKind::Decode,
344                input_tokens: vec![100 + i as u32],
345                cached_len: 50,
346            })
347            .collect();
348        let mut prefills: VecDeque<BatchEntry> = (0..2)
349            .map(|i| BatchEntry {
350                seq_id: 100 + i,
351                kind: BatchKind::Prefill,
352                input_tokens: vec![1; 3],
353                cached_len: 0,
354            })
355            .collect();
356
357        let batch = bc.build(&mut decodes, &mut prefills);
358        // Three decodes (3 tokens) + first prefill (3 tokens) =
359        // 6 tokens, fits in budget 8. Second prefill chunks (2
360        // tokens of 3) into the remaining 2 slots; the rest
361        // stays for the next call.
362        assert_eq!(batch.len(), 5);
363        // First three are decodes.
364        for entry in batch.iter().take(3) {
365            assert_eq!(entry.kind, BatchKind::Decode);
366        }
367        for entry in batch.iter().skip(3).take(2) {
368            assert_eq!(entry.kind, BatchKind::Prefill);
369        }
370        let total_tokens: usize = batch.iter().map(|e| e.input_tokens.len()).sum();
371        assert_eq!(total_tokens, 8);
372
373        // The leftover prefill should still be queued (1 token left).
374        assert_eq!(prefills.len(), 1);
375        assert_eq!(prefills[0].input_tokens.len(), 1);
376        assert_eq!(prefills[0].cached_len, 2);
377    }
378
379    #[test]
380    fn batch_constructor_respects_max_entries() {
381        let bc = BatchConstructor::new(1024, 2); // very generous tokens, only 2 entries
382        let mut decodes: VecDeque<BatchEntry> = (0..5)
383            .map(|i| BatchEntry {
384                seq_id: i,
385                kind: BatchKind::Decode,
386                input_tokens: vec![1],
387                cached_len: 0,
388            })
389            .collect();
390        let mut prefills: VecDeque<BatchEntry> = VecDeque::new();
391        let batch = bc.build(&mut decodes, &mut prefills);
392        assert_eq!(batch.len(), 2);
393        assert_eq!(decodes.len(), 3);
394    }
395}