use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
#[non_exhaustive]
#[derive(Default)]
pub enum Offload {
Gpu,
Cpu,
Disk,
#[default]
Auto,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MgrConfig {
pub device: Option<String>,
pub use_amp: bool,
pub use_fa3: bool,
pub verbose: bool,
pub min_batch_size: usize,
pub safety_factor: f32,
pub offload: Offload,
pub auto_offload_threshold: f32,
pub cpu_safety_factor: f32,
pub max_pinned_memory_mb: f32,
pub disk_offload_dir: Option<String>,
pub disk_min_free_mb: f32,
pub disk_flush_mb: f32,
pub disk_cleanup: bool,
pub disk_file_prefix: String,
pub disk_dtype: Option<String>,
pub disk_safety_factor: f32,
pub use_async: bool,
pub async_depth: usize,
}
impl MgrConfig {
fn shared_defaults() -> Self {
Self {
device: None,
use_amp: true,
use_fa3: true,
verbose: false,
min_batch_size: 1,
safety_factor: 0.8,
offload: Offload::Gpu, auto_offload_threshold: 0.5,
cpu_safety_factor: 0.85,
max_pinned_memory_mb: 32_768.0,
disk_offload_dir: None,
disk_min_free_mb: 1024.0,
disk_flush_mb: 8192.0,
disk_cleanup: true,
disk_file_prefix: String::new(),
disk_dtype: None,
disk_safety_factor: 0.95,
use_async: true,
async_depth: 4,
}
}
pub fn col_defaults() -> Self {
Self {
offload: Offload::Auto,
..Self::shared_defaults()
}
}
pub fn row_defaults() -> Self {
Self::shared_defaults()
}
pub fn icl_defaults() -> Self {
Self::shared_defaults()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferenceConfig {
pub col_config: MgrConfig,
pub row_config: MgrConfig,
pub icl_config: MgrConfig,
}
impl Default for InferenceConfig {
fn default() -> Self {
Self {
col_config: MgrConfig::col_defaults(),
row_config: MgrConfig::row_defaults(),
icl_config: MgrConfig::icl_defaults(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn defaults_match_python_block() {
let c = InferenceConfig::default();
assert!(c.col_config.use_amp);
assert_eq!(c.col_config.min_batch_size, 1);
assert_eq!(c.col_config.async_depth, 4);
assert_eq!(c.col_config.offload, Offload::Auto);
assert_eq!(c.row_config.offload, Offload::Gpu);
assert_eq!(c.icl_config.offload, Offload::Gpu);
}
}