trueno/brick/kv_cache/
mod.rs1#[derive(Debug, Clone, Default)]
14pub struct KvCacheSlotInfo {
15 pub position: u32,
17 pub token_id: u32,
19 pub layer: u16,
21 pub head: u16,
23 pub valid: bool,
25 pub last_access: u64,
27}
28
29impl KvCacheSlotInfo {
30 pub fn new(position: u32, token_id: u32, layer: u16, head: u16) -> Self {
32 Self { position, token_id, layer, head, valid: true, last_access: 0 }
33 }
34
35 pub fn touch(&mut self, step: u64) {
37 self.last_access = step;
38 }
39
40 pub fn invalidate(&mut self) {
42 self.valid = false;
43 }
44
45 #[must_use]
47 pub fn eviction_priority(&self, current_step: u64) -> u64 {
48 if !self.valid {
49 return u64::MAX; }
51 current_step.saturating_sub(self.last_access)
52 }
53}
54
55#[derive(Debug)]
57pub struct KvCacheManager {
58 slots: Vec<KvCacheSlotInfo>,
60 current_step: u64,
62 valid_count: usize,
64}
65
66impl KvCacheManager {
67 pub fn new(capacity: usize) -> Self {
69 Self { slots: vec![KvCacheSlotInfo::default(); capacity], current_step: 0, valid_count: 0 }
70 }
71
72 pub fn allocate(
74 &mut self,
75 position: u32,
76 token_id: u32,
77 layer: u16,
78 head: u16,
79 ) -> Option<usize> {
80 for (i, slot) in self.slots.iter_mut().enumerate() {
82 if !slot.valid {
83 *slot = KvCacheSlotInfo::new(position, token_id, layer, head);
84 slot.touch(self.current_step);
85 self.valid_count += 1;
86 return Some(i);
87 }
88 }
89 None }
91
92 pub fn access(&mut self, index: usize) -> Option<&KvCacheSlotInfo> {
94 if index < self.slots.len() {
95 self.slots[index].touch(self.current_step);
96 Some(&self.slots[index])
97 } else {
98 None
99 }
100 }
101
102 pub fn evict_lru(&mut self) -> Option<usize> {
104 let mut best_idx = None;
105 let mut best_priority = 0u64;
106
107 for (i, slot) in self.slots.iter().enumerate() {
108 if slot.valid {
109 let priority = slot.eviction_priority(self.current_step);
110 if best_idx.is_none() || priority > best_priority {
112 best_priority = priority;
113 best_idx = Some(i);
114 }
115 }
116 }
117
118 if let Some(idx) = best_idx {
119 self.slots[idx].invalidate();
120 self.valid_count -= 1;
121 }
122 best_idx
123 }
124
125 pub fn step(&mut self) {
127 self.current_step += 1;
128 }
129
130 #[must_use]
132 pub fn valid_count(&self) -> usize {
133 self.valid_count
134 }
135
136 #[must_use]
138 pub fn capacity(&self) -> usize {
139 self.slots.len()
140 }
141}
142
143#[derive(Debug, Clone)]
151pub struct SequentialBatchOrderer {
152 order: Vec<usize>,
154 position: usize,
156}
157
158impl SequentialBatchOrderer {
159 pub fn new(n_batches: usize) -> Self {
161 Self { order: (0..n_batches).collect(), position: 0 }
162 }
163
164 pub fn reversed(n_batches: usize) -> Self {
166 Self { order: (0..n_batches).rev().collect(), position: 0 }
167 }
168
169 pub fn interleaved(n_batches: usize) -> Self {
171 let mut order = Vec::with_capacity(n_batches);
172 let mid = n_batches / 2;
173
174 for i in 0..mid {
176 order.push(i);
177 if mid + i < n_batches {
178 order.push(mid + i);
179 }
180 }
181 if !n_batches.is_multiple_of(2) {
183 order.push(n_batches - 1);
184 }
185
186 Self { order, position: 0 }
187 }
188
189 pub fn next_batch(&mut self) -> Option<usize> {
191 if self.position < self.order.len() {
192 let idx = self.order[self.position];
193 self.position += 1;
194 Some(idx)
195 } else {
196 None
197 }
198 }
199
200 pub fn reset(&mut self) {
202 self.position = 0;
203 }
204
205 #[must_use]
207 pub fn is_done(&self) -> bool {
208 self.position >= self.order.len()
209 }
210
211 #[must_use]
213 pub fn remaining(&self) -> usize {
214 self.order.len().saturating_sub(self.position)
215 }
216}
217
218impl Iterator for SequentialBatchOrderer {
219 type Item = usize;
220
221 fn next(&mut self) -> Option<Self::Item> {
222 self.next_batch()
223 }
224}
225
226#[cfg(test)]
227mod tests;