use ndarray::{Array3, Array4, Axis};
use std::collections::BTreeMap;
#[derive(Debug, Clone, Default)]
pub struct KvCacheEntry {
pub key: Option<Array4<f32>>,
pub value: Option<Array4<f32>>,
}
impl KvCacheEntry {
pub fn new(key: Array4<f32>, value: Array4<f32>) -> Self {
assert_eq!(key.shape(), value.shape(), "K and V must share shape");
Self {
key: Some(key),
value: Some(value),
}
}
pub fn is_valid(&self) -> bool {
self.key.is_some() && self.value.is_some()
}
pub fn slice_batch(&self, start: usize, end: usize) -> KvCacheEntry {
match (&self.key, &self.value) {
(Some(k), Some(v)) => KvCacheEntry {
key: Some(k.slice(ndarray::s![start..end, .., .., ..]).to_owned()),
value: Some(v.slice(ndarray::s![start..end, .., .., ..]).to_owned()),
},
_ => KvCacheEntry::default(),
}
}
pub fn concat(entries: &[&KvCacheEntry]) -> KvCacheEntry {
let valid: Vec<&KvCacheEntry> = entries.iter().copied().filter(|e| e.is_valid()).collect();
if valid.is_empty() {
return KvCacheEntry::default();
}
let keys: Vec<_> = valid
.iter()
.map(|e| e.key.as_ref().unwrap().view())
.collect();
let values: Vec<_> = valid
.iter()
.map(|e| e.value.as_ref().unwrap().view())
.collect();
let k = ndarray::concatenate(Axis(0), &keys).unwrap();
let v = ndarray::concatenate(Axis(0), &values).unwrap();
KvCacheEntry::new(k, v)
}
pub fn byte_size(&self) -> usize {
let mut total = 0;
if let Some(k) = &self.key {
total += k.len() * std::mem::size_of::<f32>();
}
if let Some(v) = &self.value {
total += v.len() * std::mem::size_of::<f32>();
}
total
}
}
#[derive(Debug, Clone, Default)]
pub struct KvCache {
pub kv: BTreeMap<usize, KvCacheEntry>,
}
impl KvCache {
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, idx: usize, entry: KvCacheEntry) {
self.kv.insert(idx, entry);
}
pub fn is_populated(&self) -> bool {
self.kv.values().any(|e| e.is_valid())
}
pub fn slice_batch(&self, start: usize, end: usize) -> KvCache {
let kv = self
.kv
.iter()
.map(|(i, e)| (*i, e.slice_batch(start, end)))
.collect();
KvCache { kv }
}
pub fn concat(caches: &[&KvCache]) -> KvCache {
let mut all_indices: std::collections::BTreeSet<usize> = Default::default();
for c in caches {
for k in c.kv.keys() {
all_indices.insert(*k);
}
}
let mut out = BTreeMap::new();
for idx in all_indices {
let entries: Vec<&KvCacheEntry> =
caches.iter().filter_map(|c| c.kv.get(&idx)).collect();
out.insert(idx, KvCacheEntry::concat(&entries));
}
KvCache { kv: out }
}
pub fn byte_size(&self) -> usize {
self.kv.values().map(|e| e.byte_size()).sum()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum CacheType {
Empty,
Kv,
Repr,
}
#[derive(Debug, Clone, Default)]
pub struct TabICLCache {
pub col_cache: KvCache,
pub row_repr: Option<Array3<f32>>,
pub icl_cache: KvCache,
pub train_shape: (usize, usize, usize),
pub num_classes: Option<usize>,
}
impl TabICLCache {
pub fn new() -> Self {
Self::default()
}
pub fn from_row_repr(
row_repr: ndarray::Array3<f32>,
train_shape: (usize, usize, usize),
num_classes: Option<usize>,
) -> Self {
Self {
col_cache: KvCache::new(),
row_repr: Some(row_repr),
icl_cache: KvCache::new(),
train_shape,
num_classes,
}
}
pub fn cache_type(&self) -> CacheType {
if self.row_repr.is_some() {
return CacheType::Repr;
}
if !self.col_cache.kv.is_empty() || !self.icl_cache.kv.is_empty() {
return CacheType::Kv;
}
CacheType::Empty
}
pub fn is_empty(&self) -> bool {
self.cache_type() == CacheType::Empty
}
pub fn cache_size_mb(&self) -> usize {
let mut bytes = self.col_cache.byte_size() + self.icl_cache.byte_size();
if let Some(r) = &self.row_repr {
bytes += r.len() * std::mem::size_of::<f32>();
}
bytes / (1024 * 1024)
}
pub fn slice_batch(&self, start: usize, end: usize) -> TabICLCache {
TabICLCache {
col_cache: self.col_cache.slice_batch(start, end),
row_repr: self
.row_repr
.as_ref()
.map(|r| r.slice(ndarray::s![start..end, .., ..]).to_owned()),
icl_cache: self.icl_cache.slice_batch(start, end),
train_shape: (end - start, self.train_shape.1, self.train_shape.2),
num_classes: self.num_classes,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn fake_entry(b: usize, h: usize, t: usize, d: usize, fill: f32) -> KvCacheEntry {
KvCacheEntry::new(
Array4::from_elem((b, h, t, d), fill),
Array4::from_elem((b, h, t, d), fill + 0.5),
)
}
#[test]
fn validity_and_byte_size() {
let e = fake_entry(2, 4, 8, 16, 0.0);
assert!(e.is_valid());
assert_eq!(e.byte_size(), 2 * 1024 * std::mem::size_of::<f32>());
}
#[test]
fn slice_batch_carves_first_dim() {
let e = fake_entry(4, 2, 3, 2, 1.0);
let s = e.slice_batch(1, 3);
let k = s.key.unwrap();
assert_eq!(k.shape(), &[2, 2, 3, 2]);
}
#[test]
fn concat_recombines_batch_dim() {
let a = fake_entry(2, 1, 1, 1, 1.0);
let b = fake_entry(3, 1, 1, 1, 2.0);
let merged = KvCacheEntry::concat(&[&a, &b]);
let k = merged.key.unwrap();
assert_eq!(k.shape(), &[5, 1, 1, 1]);
assert_eq!(k[(0, 0, 0, 0)], 1.0);
assert_eq!(k[(4, 0, 0, 0)], 2.0);
}
#[test]
fn cache_type_state_machine() {
let mut c = TabICLCache::new();
assert_eq!(c.cache_type(), CacheType::Empty);
assert!(c.is_empty());
c.icl_cache.insert(0, fake_entry(1, 1, 1, 1, 0.0));
assert_eq!(c.cache_type(), CacheType::Kv);
assert!(!c.is_empty());
c.row_repr = Some(Array3::<f32>::zeros((1, 1, 1)));
assert_eq!(c.cache_type(), CacheType::Repr);
}
#[test]
fn slice_batch_propagates_to_subcaches_and_shape() {
let mut c = TabICLCache {
col_cache: KvCache::new(),
row_repr: Some(Array3::from_shape_fn((4, 3, 2), |(b, _, _)| b as f32)),
icl_cache: KvCache::new(),
train_shape: (4, 10, 5),
num_classes: Some(7),
};
c.col_cache.insert(0, fake_entry(4, 2, 3, 2, 0.0));
c.icl_cache.insert(0, fake_entry(4, 2, 3, 2, 0.0));
let s = c.slice_batch(1, 3);
assert_eq!(s.train_shape, (2, 10, 5));
assert_eq!(s.row_repr.as_ref().unwrap().shape(), &[2, 3, 2]);
let r = s.row_repr.as_ref().unwrap();
assert_eq!(r[(0, 0, 0)], 1.0);
assert_eq!(r[(1, 0, 0)], 2.0);
assert_eq!(
s.col_cache.kv[&0].key.as_ref().unwrap().shape(),
&[2, 2, 3, 2]
);
assert_eq!(s.num_classes, Some(7));
}
#[test]
fn cache_size_mb_sums_components() {
let mut c = TabICLCache::new();
c.icl_cache.insert(
0,
KvCacheEntry::new(
Array4::<f32>::zeros((1, 1, 1, 1024 * 256)),
Array4::<f32>::zeros((1, 1, 1, 1024 * 256)),
),
);
assert_eq!(c.cache_size_mb(), 2);
}
}