tabicl-model 2.1.1

TabICL transformer model — column embedding, row interaction, ICL learning, KV cache.
//! KV cache for repeated inference — port of `tabicl._model.kv_cache`.
//!
//! Two layered cache types match the Python:
//!
//!   - [`KvCacheEntry`] — `(key, value)` pair for a single attention layer.
//!     K/V tensors always have shape `(B, H, T, D)` (batch / heads / seq /
//!     head_dim).
//!   - [`KvCache`]      — `BTreeMap<layer_idx, KvCacheEntry>` covering a
//!     whole transformer stack (ColEmbedding ISAB blocks, or ICLearning
//!     transformer blocks).
//!   - [`TabICLCache`]  — the top-level cache that aggregates a column
//!     cache, an ICL cache, an optional pre-computed row-repr, and the
//!     `(B, train_size, n_features)` shape it was built for.
//!
//! Status: types, validity helpers, batch slicing, concatenation, and
//! memory-size reporting are implemented and tested. The graph-level
//! "write back into pre-allocated buffers" path will land alongside the
//! ColEmbedding / ICLearning ports that produce the K/V tensors.

use ndarray::{Array3, Array4, Axis};
use std::collections::BTreeMap;

/// One layer's cached `(key, value)` projections.
///
/// Both tensors share shape `(B, H, T, D)` and dtype f32 in this port. The
/// Python `to(device, dtype)` API is unnecessary here since the host-side
/// representation is always fp32; on-device transitions happen at
/// runtime-dispatch time.
#[derive(Debug, Clone, Default)]
pub struct KvCacheEntry {
    pub key: Option<Array4<f32>>,
    pub value: Option<Array4<f32>>,
}

impl KvCacheEntry {
    pub fn new(key: Array4<f32>, value: Array4<f32>) -> Self {
        assert_eq!(key.shape(), value.shape(), "K and V must share shape");
        Self {
            key: Some(key),
            value: Some(value),
        }
    }

    pub fn is_valid(&self) -> bool {
        self.key.is_some() && self.value.is_some()
    }

    /// Slice along the batch dim. Mirrors Python `entry[indices]` for the
    /// half-open batch range `[start, end)`.
    pub fn slice_batch(&self, start: usize, end: usize) -> KvCacheEntry {
        match (&self.key, &self.value) {
            (Some(k), Some(v)) => KvCacheEntry {
                key: Some(k.slice(ndarray::s![start..end, .., .., ..]).to_owned()),
                value: Some(v.slice(ndarray::s![start..end, .., .., ..]).to_owned()),
            },
            _ => KvCacheEntry::default(),
        }
    }

    /// Concatenate entries along the batch dim.
    pub fn concat(entries: &[&KvCacheEntry]) -> KvCacheEntry {
        let valid: Vec<&KvCacheEntry> = entries.iter().copied().filter(|e| e.is_valid()).collect();
        if valid.is_empty() {
            return KvCacheEntry::default();
        }
        let keys: Vec<_> = valid
            .iter()
            .map(|e| e.key.as_ref().unwrap().view())
            .collect();
        let values: Vec<_> = valid
            .iter()
            .map(|e| e.value.as_ref().unwrap().view())
            .collect();
        let k = ndarray::concatenate(Axis(0), &keys).unwrap();
        let v = ndarray::concatenate(Axis(0), &values).unwrap();
        KvCacheEntry::new(k, v)
    }

    /// Element count + per-element size (bytes).
    pub fn byte_size(&self) -> usize {
        let mut total = 0;
        if let Some(k) = &self.key {
            total += k.len() * std::mem::size_of::<f32>();
        }
        if let Some(v) = &self.value {
            total += v.len() * std::mem::size_of::<f32>();
        }
        total
    }
}

/// A whole transformer stack's worth of cached K/V.
///
/// Layer indices are stored sorted (`BTreeMap`) so `concat` / iteration
/// are deterministic.
#[derive(Debug, Clone, Default)]
pub struct KvCache {
    pub kv: BTreeMap<usize, KvCacheEntry>,
}

impl KvCache {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn insert(&mut self, idx: usize, entry: KvCacheEntry) {
        self.kv.insert(idx, entry);
    }

    /// Any valid entry in the cache?
    pub fn is_populated(&self) -> bool {
        self.kv.values().any(|e| e.is_valid())
    }

    pub fn slice_batch(&self, start: usize, end: usize) -> KvCache {
        let kv = self
            .kv
            .iter()
            .map(|(i, e)| (*i, e.slice_batch(start, end)))
            .collect();
        KvCache { kv }
    }

    /// Concat caches that share the same layer indices.
    pub fn concat(caches: &[&KvCache]) -> KvCache {
        let mut all_indices: std::collections::BTreeSet<usize> = Default::default();
        for c in caches {
            for k in c.kv.keys() {
                all_indices.insert(*k);
            }
        }
        let mut out = BTreeMap::new();
        for idx in all_indices {
            let entries: Vec<&KvCacheEntry> =
                caches.iter().filter_map(|c| c.kv.get(&idx)).collect();
            out.insert(idx, KvCacheEntry::concat(&entries));
        }
        KvCache { kv: out }
    }

    pub fn byte_size(&self) -> usize {
        self.kv.values().map(|e| e.byte_size()).sum()
    }
}

/// What kind of payload the cache currently holds. Matches Python
/// `TabICLCache.cache_type` exactly.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum CacheType {
    Empty,
    Kv,
    Repr,
}

#[derive(Debug, Clone, Default)]
pub struct TabICLCache {
    pub col_cache: KvCache,
    /// Pre-computed row representations, shape `(B, train_size, embed_dim *
    /// row_num_cls)`. When present, takes precedence over `icl_cache` —
    /// see `cache_type`.
    pub row_repr: Option<Array3<f32>>,
    pub icl_cache: KvCache,
    /// `(batch_size, train_size, num_features)` of the training data the
    /// cache was built with.
    pub train_shape: (usize, usize, usize),
    pub num_classes: Option<usize>,
}

impl TabICLCache {
    pub fn new() -> Self {
        Self::default()
    }

    /// Build a `row_repr`-only cache from a stored representation matrix.
    /// This is the "repr cache" variant — fastest path when the training
    /// data fits comfortably in memory.
    pub fn from_row_repr(
        row_repr: ndarray::Array3<f32>,
        train_shape: (usize, usize, usize),
        num_classes: Option<usize>,
    ) -> Self {
        Self {
            col_cache: KvCache::new(),
            row_repr: Some(row_repr),
            icl_cache: KvCache::new(),
            train_shape,
            num_classes,
        }
    }

    pub fn cache_type(&self) -> CacheType {
        if self.row_repr.is_some() {
            return CacheType::Repr;
        }
        if !self.col_cache.kv.is_empty() || !self.icl_cache.kv.is_empty() {
            return CacheType::Kv;
        }
        CacheType::Empty
    }

    pub fn is_empty(&self) -> bool {
        self.cache_type() == CacheType::Empty
    }

    /// Memory footprint in MB. Matches the Python `cache_size_mb` integer-MB
    /// rounding (floor division).
    pub fn cache_size_mb(&self) -> usize {
        let mut bytes = self.col_cache.byte_size() + self.icl_cache.byte_size();
        if let Some(r) = &self.row_repr {
            bytes += r.len() * std::mem::size_of::<f32>();
        }
        bytes / (1024 * 1024)
    }

    pub fn slice_batch(&self, start: usize, end: usize) -> TabICLCache {
        TabICLCache {
            col_cache: self.col_cache.slice_batch(start, end),
            row_repr: self
                .row_repr
                .as_ref()
                .map(|r| r.slice(ndarray::s![start..end, .., ..]).to_owned()),
            icl_cache: self.icl_cache.slice_batch(start, end),
            train_shape: (end - start, self.train_shape.1, self.train_shape.2),
            num_classes: self.num_classes,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn fake_entry(b: usize, h: usize, t: usize, d: usize, fill: f32) -> KvCacheEntry {
        KvCacheEntry::new(
            Array4::from_elem((b, h, t, d), fill),
            Array4::from_elem((b, h, t, d), fill + 0.5),
        )
    }

    #[test]
    fn validity_and_byte_size() {
        let e = fake_entry(2, 4, 8, 16, 0.0);
        assert!(e.is_valid());
        // Each of key / value has 2*4*8*16 = 1024 f32 = 4096 bytes.
        assert_eq!(e.byte_size(), 2 * 1024 * std::mem::size_of::<f32>());
    }

    #[test]
    fn slice_batch_carves_first_dim() {
        let e = fake_entry(4, 2, 3, 2, 1.0);
        let s = e.slice_batch(1, 3);
        let k = s.key.unwrap();
        assert_eq!(k.shape(), &[2, 2, 3, 2]);
    }

    #[test]
    fn concat_recombines_batch_dim() {
        let a = fake_entry(2, 1, 1, 1, 1.0);
        let b = fake_entry(3, 1, 1, 1, 2.0);
        let merged = KvCacheEntry::concat(&[&a, &b]);
        let k = merged.key.unwrap();
        assert_eq!(k.shape(), &[5, 1, 1, 1]);
        assert_eq!(k[(0, 0, 0, 0)], 1.0);
        assert_eq!(k[(4, 0, 0, 0)], 2.0);
    }

    #[test]
    fn cache_type_state_machine() {
        let mut c = TabICLCache::new();
        assert_eq!(c.cache_type(), CacheType::Empty);
        assert!(c.is_empty());

        c.icl_cache.insert(0, fake_entry(1, 1, 1, 1, 0.0));
        assert_eq!(c.cache_type(), CacheType::Kv);
        assert!(!c.is_empty());

        c.row_repr = Some(Array3::<f32>::zeros((1, 1, 1)));
        // row_repr takes precedence over kv.
        assert_eq!(c.cache_type(), CacheType::Repr);
    }

    #[test]
    fn slice_batch_propagates_to_subcaches_and_shape() {
        let mut c = TabICLCache {
            col_cache: KvCache::new(),
            row_repr: Some(Array3::from_shape_fn((4, 3, 2), |(b, _, _)| b as f32)),
            icl_cache: KvCache::new(),
            train_shape: (4, 10, 5),
            num_classes: Some(7),
        };
        c.col_cache.insert(0, fake_entry(4, 2, 3, 2, 0.0));
        c.icl_cache.insert(0, fake_entry(4, 2, 3, 2, 0.0));
        let s = c.slice_batch(1, 3);
        assert_eq!(s.train_shape, (2, 10, 5));
        assert_eq!(s.row_repr.as_ref().unwrap().shape(), &[2, 3, 2]);
        // Verify slicing actually picked the right batch slice.
        let r = s.row_repr.as_ref().unwrap();
        assert_eq!(r[(0, 0, 0)], 1.0);
        assert_eq!(r[(1, 0, 0)], 2.0);
        assert_eq!(
            s.col_cache.kv[&0].key.as_ref().unwrap().shape(),
            &[2, 2, 3, 2]
        );
        assert_eq!(s.num_classes, Some(7));
    }

    #[test]
    fn cache_size_mb_sums_components() {
        let mut c = TabICLCache::new();
        // 4 MB worth: 1M f32 elements.
        c.icl_cache.insert(
            0,
            KvCacheEntry::new(
                Array4::<f32>::zeros((1, 1, 1, 1024 * 256)),
                Array4::<f32>::zeros((1, 1, 1, 1024 * 256)),
            ),
        );
        // 1024 * 256 * 4 bytes * 2 tensors = 2 MB.
        assert_eq!(c.cache_size_mb(), 2);
    }
}