tabicl-model 2.1.1

TabICL transformer model — column embedding, row interaction, ICL learning, KV cache.
//! Inference loop with memory-aware batching — port of selected pieces
//! of `tabicl._model.inference.InferenceManager`.
//!
//! The Python class is 1,412 LOC and handles GPU↔CPU↔disk offloading.
//! We port the **batch-splitting** core (which is the user-visible
//! correctness mechanism) and skip the device-offloading layer since
//! the Rust port runs host-fp32 by default.
//!
//! Use [`InferenceManager::run_chunked`] to forward a large `(B, T, H)`
//! batch through any closure, splitting along the B axis into chunks
//! that fit memory budgets, and concatenating the outputs.

use ndarray::{Array3, ArrayView2, ArrayView3, Axis};

use crate::embedding::EmbeddingError;

/// Batch-splitting configuration. Mirrors the Python `MgrConfig` fields
/// that actually affect batching (the GPU/CPU/disk knobs are skipped).
#[derive(Debug, Clone, Copy)]
pub struct InferenceManagerConfig {
    /// Maximum allowed B (number of tables) per chunk. The actual chunk
    /// size is `min(this, batch_size)` so we never split more finely than
    /// needed.
    pub max_chunk_size: usize,
    /// Minimum allowed B per chunk. If `max_chunk_size < min_batch_size`
    /// the chunker still respects this floor.
    pub min_batch_size: usize,
    /// Conservative safety multiplier on the computed chunk size; mirror
    /// Python's `safety_factor`. `1.0` = no padding; `0.8` = use 80% of
    /// the headroom estimate.
    pub safety_factor: f32,
    /// Whether to print progress messages.
    pub verbose: bool,
}

impl Default for InferenceManagerConfig {
    fn default() -> Self {
        Self {
            max_chunk_size: 16,
            min_batch_size: 1,
            safety_factor: 0.8,
            verbose: false,
        }
    }
}

#[derive(Debug, Clone)]
pub struct InferenceManager {
    pub config: InferenceManagerConfig,
}

impl InferenceManager {
    pub fn new(config: InferenceManagerConfig) -> Self {
        Self { config }
    }

    /// Run a 3-D `(B, T, H)` batch through `forward_fn`, splitting along
    /// the B axis if it exceeds `max_chunk_size`. Outputs are
    /// concatenated along axis 0.
    ///
    /// `forward_fn` is called with each `(b_chunk, T, H)` view and
    /// expected to return `(b_chunk, ?, ?)`. The 2nd and 3rd output
    /// dims must be consistent across chunks (validated).
    pub fn run_chunked<F>(
        &self,
        x: ArrayView3<f32>,
        mut forward_fn: F,
    ) -> Result<Array3<f32>, EmbeddingError>
    where
        F: FnMut(ArrayView3<f32>) -> Result<Array3<f32>, EmbeddingError>,
    {
        let (b, _t, _h) = (x.shape()[0], x.shape()[1], x.shape()[2]);
        let chunk_size = self.effective_chunk_size(b);
        if chunk_size >= b {
            // Single chunk — no splitting needed.
            return forward_fn(x);
        }
        let mut outputs: Vec<Array3<f32>> = Vec::new();
        let mut start = 0;
        while start < b {
            let end = (start + chunk_size).min(b);
            let chunk = x.slice(ndarray::s![start..end, .., ..]);
            if self.config.verbose {
                eprintln!("[InferenceManager] processing chunk [{start}..{end}) of {b}");
            }
            let out = forward_fn(chunk)?;
            outputs.push(out);
            start = end;
        }
        // Concatenate along axis 0.
        let views: Vec<_> = outputs.iter().map(|a| a.view()).collect();
        Ok(
            ndarray::concatenate(Axis(0), &views)
                .expect("chunk outputs must agree on axes 1 and 2"),
        )
    }

    /// Run a flat 2-D `(N, H)` test-row batch through a closure expecting
    /// `(B, T_full, H)`-shaped input, splitting the test rows so each
    /// inner call sees `(1, n_train + chunk, H)`. Used by the sklearn
    /// classifier/regressor predict path to control memory.
    pub fn run_predict_chunked<F>(
        &self,
        x_train: ArrayView2<f32>,
        x_test: ArrayView2<f32>,
        mut forward_fn: F,
    ) -> Result<Array3<f32>, EmbeddingError>
    where
        F: FnMut(ArrayView2<f32>) -> Result<Array3<f32>, EmbeddingError>,
    {
        let n_test = x_test.shape()[0];
        let chunk_size = self.effective_chunk_size(n_test);
        if chunk_size >= n_test {
            return forward_fn(x_test);
        }
        let _ = x_train; // x_train is shared across calls; the caller stitches it in
        let mut outputs: Vec<Array3<f32>> = Vec::new();
        let mut start = 0;
        while start < n_test {
            let end = (start + chunk_size).min(n_test);
            let chunk = x_test.slice(ndarray::s![start..end, ..]);
            let out = forward_fn(chunk)?;
            outputs.push(out);
            start = end;
        }
        // Concatenate along the T axis (axis 1) since each chunk is
        // shaped (1, chunk_len, out_dim).
        let views: Vec<_> = outputs.iter().map(|a| a.view()).collect();
        Ok(ndarray::concatenate(Axis(1), &views)
            .expect("predict chunk outputs must agree on shapes"))
    }

    fn effective_chunk_size(&self, n: usize) -> usize {
        let s = (self.config.max_chunk_size as f32 * self.config.safety_factor) as usize;
        s.max(self.config.min_batch_size).min(n)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use ndarray::Array;

    #[test]
    fn defaults_match_python_signature() {
        let c = InferenceManagerConfig::default();
        assert_eq!(c.min_batch_size, 1);
        assert!((c.safety_factor - 0.8).abs() < 1e-6);
    }

    #[test]
    fn single_chunk_when_batch_fits() {
        let mgr = InferenceManager::new(InferenceManagerConfig {
            max_chunk_size: 100,
            ..Default::default()
        });
        let x = Array::<f32, _>::zeros((3, 4, 2));
        let calls = std::cell::Cell::new(0);
        let _ = mgr
            .run_chunked(x.view(), |chunk| {
                calls.set(calls.get() + 1);
                let mut out = Array3::<f32>::zeros((chunk.shape()[0], 1, 2));
                for bi in 0..chunk.shape()[0] {
                    out[(bi, 0, 0)] = bi as f32;
                }
                Ok(out)
            })
            .unwrap();
        assert_eq!(calls.get(), 1);
    }

    #[test]
    fn multi_chunk_when_batch_exceeds_limit() {
        let mgr = InferenceManager::new(InferenceManagerConfig {
            max_chunk_size: 3,
            safety_factor: 1.0,
            min_batch_size: 1,
            verbose: false,
        });
        let x = Array::<f32, _>::from_shape_fn((7, 4, 2), |(b, _, _)| b as f32);
        let calls = std::cell::Cell::new(0);
        let chunks_seen = std::cell::RefCell::new(Vec::new());
        let out = mgr
            .run_chunked(x.view(), |chunk| {
                calls.set(calls.get() + 1);
                chunks_seen.borrow_mut().push(chunk.shape()[0]);
                let mut o = Array3::<f32>::zeros((chunk.shape()[0], 1, 2));
                for bi in 0..chunk.shape()[0] {
                    o[(bi, 0, 0)] = chunk[(bi, 0, 0)];
                }
                Ok(o)
            })
            .unwrap();
        // chunks of 3, 3, 1.
        assert_eq!(*chunks_seen.borrow(), vec![3, 3, 1]);
        assert_eq!(calls.get(), 3);
        // Output should preserve the order of the original batch.
        assert_eq!(out.shape(), &[7, 1, 2]);
        for bi in 0..7 {
            assert_eq!(out[(bi, 0, 0)], bi as f32);
        }
    }

    #[test]
    fn safety_factor_shrinks_chunk_size() {
        let mgr = InferenceManager::new(InferenceManagerConfig {
            max_chunk_size: 10,
            safety_factor: 0.5,
            min_batch_size: 1,
            verbose: false,
        });
        assert_eq!(mgr.effective_chunk_size(20), 5);
    }

    #[test]
    fn min_batch_size_floor_honored() {
        let mgr = InferenceManager::new(InferenceManagerConfig {
            max_chunk_size: 1,
            safety_factor: 0.1,
            min_batch_size: 4,
            verbose: false,
        });
        // 1 * 0.1 = 0 → floor to min_batch_size = 4.
        assert_eq!(mgr.effective_chunk_size(100), 4);
    }
}