tabicl-model 2.1.1

TabICL transformer model — column embedding, row interaction, ICL learning, KV cache.
//! Inference configuration — port of `tabicl._model.inference_config`.
//!
//! Three [`MgrConfig`] blocks (`COL_CONFIG`, `ROW_CONFIG`, `ICL_CONFIG`) drive
//! the column-wise, row-wise, and in-context learning passes respectively.
//! The Python module exposes ~30 knobs; we mirror the field names exactly so
//! Python configs can be deserialized into [`InferenceConfig`] verbatim.

use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
#[non_exhaustive]
#[derive(Default)]
pub enum Offload {
    /// Keep outputs on GPU.
    Gpu,
    /// Offload to CPU memory.
    Cpu,
    /// Offload to memory-mapped files.
    Disk,
    /// Choose between gpu / cpu / disk automatically based on memory.
    #[default]
    Auto,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MgrConfig {
    // General
    pub device: Option<String>,
    pub use_amp: bool,
    pub use_fa3: bool,
    pub verbose: bool,
    // Batching
    pub min_batch_size: usize,
    pub safety_factor: f32,
    // Offloading
    pub offload: Offload,
    pub auto_offload_threshold: f32,
    // CPU offloading
    pub cpu_safety_factor: f32,
    pub max_pinned_memory_mb: f32,
    // Disk offloading
    pub disk_offload_dir: Option<String>,
    pub disk_min_free_mb: f32,
    pub disk_flush_mb: f32,
    pub disk_cleanup: bool,
    pub disk_file_prefix: String,
    pub disk_dtype: Option<String>,
    pub disk_safety_factor: f32,
    // Async transfer
    pub use_async: bool,
    pub async_depth: usize,
}

impl MgrConfig {
    /// Defaults shared by ROW_CONFIG and ICL_CONFIG. The COL_CONFIG variant
    /// differs only in `offload = Auto`.
    fn shared_defaults() -> Self {
        Self {
            device: None,
            use_amp: true,
            use_fa3: true,
            verbose: false,
            min_batch_size: 1,
            safety_factor: 0.8,
            offload: Offload::Gpu, // ROW/ICL → False in Python ≡ keep on GPU
            auto_offload_threshold: 0.5,
            cpu_safety_factor: 0.85,
            max_pinned_memory_mb: 32_768.0,
            disk_offload_dir: None,
            disk_min_free_mb: 1024.0,
            disk_flush_mb: 8192.0,
            disk_cleanup: true,
            disk_file_prefix: String::new(),
            disk_dtype: None,
            disk_safety_factor: 0.95,
            use_async: true,
            async_depth: 4,
        }
    }

    pub fn col_defaults() -> Self {
        Self {
            offload: Offload::Auto,
            ..Self::shared_defaults()
        }
    }

    pub fn row_defaults() -> Self {
        Self::shared_defaults()
    }

    pub fn icl_defaults() -> Self {
        Self::shared_defaults()
    }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferenceConfig {
    pub col_config: MgrConfig,
    pub row_config: MgrConfig,
    pub icl_config: MgrConfig,
}

impl Default for InferenceConfig {
    fn default() -> Self {
        Self {
            col_config: MgrConfig::col_defaults(),
            row_config: MgrConfig::row_defaults(),
            icl_config: MgrConfig::icl_defaults(),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn defaults_match_python_block() {
        let c = InferenceConfig::default();
        assert!(c.col_config.use_amp);
        assert_eq!(c.col_config.min_batch_size, 1);
        assert_eq!(c.col_config.async_depth, 4);
        assert_eq!(c.col_config.offload, Offload::Auto);
        assert_eq!(c.row_config.offload, Offload::Gpu);
        assert_eq!(c.icl_config.offload, Offload::Gpu);
    }
}