Skip to main content

trueno/brick/kv_cache/
mod.rs

1//! KV Cache Management and Batch Ordering
2//!
3//! LCP-10: KV Cache Slot tracking for transformer inference.
4//! LCP-14: Sequential Batch Ordering for cache-friendly processing.
5
6// ----------------------------------------------------------------------------
7// LCP-10: KV Cache Slot Info
8// ----------------------------------------------------------------------------
9
10/// Metadata for a KV cache slot in transformer inference.
11///
12/// Tracks position, token info, and usage for cache management.
13#[derive(Debug, Clone, Default)]
14pub struct KvCacheSlotInfo {
15    /// Sequence position this slot represents
16    pub position: u32,
17    /// Token ID stored in this slot
18    pub token_id: u32,
19    /// Layer index
20    pub layer: u16,
21    /// Head index
22    pub head: u16,
23    /// Whether this slot is valid/filled
24    pub valid: bool,
25    /// Last access time (in steps)
26    pub last_access: u64,
27}
28
29impl KvCacheSlotInfo {
30    /// Create a new slot info.
31    pub fn new(position: u32, token_id: u32, layer: u16, head: u16) -> Self {
32        Self { position, token_id, layer, head, valid: true, last_access: 0 }
33    }
34
35    /// Mark slot as accessed.
36    pub fn touch(&mut self, step: u64) {
37        self.last_access = step;
38    }
39
40    /// Invalidate the slot.
41    pub fn invalidate(&mut self) {
42        self.valid = false;
43    }
44
45    /// Check if slot can be evicted (LRU policy).
46    #[must_use]
47    pub fn eviction_priority(&self, current_step: u64) -> u64 {
48        if !self.valid {
49            return u64::MAX; // Invalid slots have highest eviction priority
50        }
51        current_step.saturating_sub(self.last_access)
52    }
53}
54
55/// KV cache manager with slot tracking.
56#[derive(Debug)]
57pub struct KvCacheManager {
58    /// Slot metadata
59    slots: Vec<KvCacheSlotInfo>,
60    /// Current step counter
61    current_step: u64,
62    /// Number of valid slots
63    valid_count: usize,
64}
65
66impl KvCacheManager {
67    /// Create manager with given capacity.
68    pub fn new(capacity: usize) -> Self {
69        Self { slots: vec![KvCacheSlotInfo::default(); capacity], current_step: 0, valid_count: 0 }
70    }
71
72    /// Allocate a slot.
73    pub fn allocate(
74        &mut self,
75        position: u32,
76        token_id: u32,
77        layer: u16,
78        head: u16,
79    ) -> Option<usize> {
80        // Find first invalid slot
81        for (i, slot) in self.slots.iter_mut().enumerate() {
82            if !slot.valid {
83                *slot = KvCacheSlotInfo::new(position, token_id, layer, head);
84                slot.touch(self.current_step);
85                self.valid_count += 1;
86                return Some(i);
87            }
88        }
89        None // No free slots
90    }
91
92    /// Access a slot.
93    pub fn access(&mut self, index: usize) -> Option<&KvCacheSlotInfo> {
94        if index < self.slots.len() {
95            self.slots[index].touch(self.current_step);
96            Some(&self.slots[index])
97        } else {
98            None
99        }
100    }
101
102    /// Evict LRU slot.
103    pub fn evict_lru(&mut self) -> Option<usize> {
104        let mut best_idx = None;
105        let mut best_priority = 0u64;
106
107        for (i, slot) in self.slots.iter().enumerate() {
108            if slot.valid {
109                let priority = slot.eviction_priority(self.current_step);
110                // Use >= for first found slot (best_idx.is_none()), then > for ties
111                if best_idx.is_none() || priority > best_priority {
112                    best_priority = priority;
113                    best_idx = Some(i);
114                }
115            }
116        }
117
118        if let Some(idx) = best_idx {
119            self.slots[idx].invalidate();
120            self.valid_count -= 1;
121        }
122        best_idx
123    }
124
125    /// Advance step counter.
126    pub fn step(&mut self) {
127        self.current_step += 1;
128    }
129
130    /// Get number of valid slots.
131    #[must_use]
132    pub fn valid_count(&self) -> usize {
133        self.valid_count
134    }
135
136    /// Get capacity.
137    #[must_use]
138    pub fn capacity(&self) -> usize {
139        self.slots.len()
140    }
141}
142
143// ----------------------------------------------------------------------------
144// LCP-14: Sequential Batch Ordering
145// ----------------------------------------------------------------------------
146
147/// Sequential batch orderer for cache-friendly processing.
148///
149/// Ensures batches are processed in optimal order for memory access patterns.
150#[derive(Debug, Clone)]
151pub struct SequentialBatchOrderer {
152    /// Batch indices in processing order
153    order: Vec<usize>,
154    /// Current position in order
155    position: usize,
156}
157
158impl SequentialBatchOrderer {
159    /// Create orderer for n batches.
160    pub fn new(n_batches: usize) -> Self {
161        Self { order: (0..n_batches).collect(), position: 0 }
162    }
163
164    /// Create orderer with reverse order (sometimes better for certain patterns).
165    pub fn reversed(n_batches: usize) -> Self {
166        Self { order: (0..n_batches).rev().collect(), position: 0 }
167    }
168
169    /// Create orderer with interleaved order (for better cache utilization).
170    pub fn interleaved(n_batches: usize) -> Self {
171        let mut order = Vec::with_capacity(n_batches);
172        let mid = n_batches / 2;
173
174        // Interleave: 0, mid, 1, mid+1, 2, mid+2, ...
175        for i in 0..mid {
176            order.push(i);
177            if mid + i < n_batches {
178                order.push(mid + i);
179            }
180        }
181        // Handle odd number of batches
182        if !n_batches.is_multiple_of(2) {
183            order.push(n_batches - 1);
184        }
185
186        Self { order, position: 0 }
187    }
188
189    /// Get next batch index.
190    pub fn next_batch(&mut self) -> Option<usize> {
191        if self.position < self.order.len() {
192            let idx = self.order[self.position];
193            self.position += 1;
194            Some(idx)
195        } else {
196            None
197        }
198    }
199
200    /// Reset to beginning.
201    pub fn reset(&mut self) {
202        self.position = 0;
203    }
204
205    /// Check if all batches have been processed.
206    #[must_use]
207    pub fn is_done(&self) -> bool {
208        self.position >= self.order.len()
209    }
210
211    /// Get remaining count.
212    #[must_use]
213    pub fn remaining(&self) -> usize {
214        self.order.len().saturating_sub(self.position)
215    }
216}
217
218impl Iterator for SequentialBatchOrderer {
219    type Item = usize;
220
221    fn next(&mut self) -> Option<Self::Item> {
222        self.next_batch()
223    }
224}
225
226#[cfg(test)]
227mod tests;