#[derive(Debug, Clone, Default)]
pub struct KvCacheSlotInfo {
pub position: u32,
pub token_id: u32,
pub layer: u16,
pub head: u16,
pub valid: bool,
pub last_access: u64,
}
impl KvCacheSlotInfo {
pub fn new(position: u32, token_id: u32, layer: u16, head: u16) -> Self {
Self { position, token_id, layer, head, valid: true, last_access: 0 }
}
pub fn touch(&mut self, step: u64) {
self.last_access = step;
}
pub fn invalidate(&mut self) {
self.valid = false;
}
#[must_use]
pub fn eviction_priority(&self, current_step: u64) -> u64 {
if !self.valid {
return u64::MAX; }
current_step.saturating_sub(self.last_access)
}
}
#[derive(Debug)]
pub struct KvCacheManager {
slots: Vec<KvCacheSlotInfo>,
current_step: u64,
valid_count: usize,
}
impl KvCacheManager {
pub fn new(capacity: usize) -> Self {
Self { slots: vec![KvCacheSlotInfo::default(); capacity], current_step: 0, valid_count: 0 }
}
pub fn allocate(
&mut self,
position: u32,
token_id: u32,
layer: u16,
head: u16,
) -> Option<usize> {
for (i, slot) in self.slots.iter_mut().enumerate() {
if !slot.valid {
*slot = KvCacheSlotInfo::new(position, token_id, layer, head);
slot.touch(self.current_step);
self.valid_count += 1;
return Some(i);
}
}
None }
pub fn access(&mut self, index: usize) -> Option<&KvCacheSlotInfo> {
if index < self.slots.len() {
self.slots[index].touch(self.current_step);
Some(&self.slots[index])
} else {
None
}
}
pub fn evict_lru(&mut self) -> Option<usize> {
let mut best_idx = None;
let mut best_priority = 0u64;
for (i, slot) in self.slots.iter().enumerate() {
if slot.valid {
let priority = slot.eviction_priority(self.current_step);
if best_idx.is_none() || priority > best_priority {
best_priority = priority;
best_idx = Some(i);
}
}
}
if let Some(idx) = best_idx {
self.slots[idx].invalidate();
self.valid_count -= 1;
}
best_idx
}
pub fn step(&mut self) {
self.current_step += 1;
}
#[must_use]
pub fn valid_count(&self) -> usize {
self.valid_count
}
#[must_use]
pub fn capacity(&self) -> usize {
self.slots.len()
}
}
#[derive(Debug, Clone)]
pub struct SequentialBatchOrderer {
order: Vec<usize>,
position: usize,
}
impl SequentialBatchOrderer {
pub fn new(n_batches: usize) -> Self {
Self { order: (0..n_batches).collect(), position: 0 }
}
pub fn reversed(n_batches: usize) -> Self {
Self { order: (0..n_batches).rev().collect(), position: 0 }
}
pub fn interleaved(n_batches: usize) -> Self {
let mut order = Vec::with_capacity(n_batches);
let mid = n_batches / 2;
for i in 0..mid {
order.push(i);
if mid + i < n_batches {
order.push(mid + i);
}
}
if !n_batches.is_multiple_of(2) {
order.push(n_batches - 1);
}
Self { order, position: 0 }
}
pub fn next_batch(&mut self) -> Option<usize> {
if self.position < self.order.len() {
let idx = self.order[self.position];
self.position += 1;
Some(idx)
} else {
None
}
}
pub fn reset(&mut self) {
self.position = 0;
}
#[must_use]
pub fn is_done(&self) -> bool {
self.position >= self.order.len()
}
#[must_use]
pub fn remaining(&self) -> usize {
self.order.len().saturating_sub(self.position)
}
}
impl Iterator for SequentialBatchOrderer {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
self.next_batch()
}
}
#[cfg(test)]
mod tests;