use crate::common::protocols::MockEngineArgs;
const DEFAULT_G1_G2_BANDWIDTH_GBPS: f64 = 14.0;
const DEFAULT_G2_G3_BANDWIDTH_GBPS: f64 = 7.0;
const DEFAULT_G2_G4_BANDWIDTH_GBPS: f64 = 4.0;
#[derive(Debug, Clone)]
pub struct KvbmOffloadConfig {
pub num_g2_blocks: usize,
pub block_size_tokens: usize,
pub offload_batch_size: usize,
pub num_g3_blocks: Option<usize>,
pub enable_g4_storage: bool,
pub block_size_bytes: Option<usize>,
pub bandwidth_g1_to_g2_gbps: f64,
pub bandwidth_g2_to_g1_gbps: f64,
pub bandwidth_g2_to_g3_gbps: f64,
pub bandwidth_g3_to_g2_gbps: f64,
pub bandwidth_g2_to_g4_gbps: f64,
pub bandwidth_g4_to_g2_gbps: f64,
}
impl Default for KvbmOffloadConfig {
fn default() -> Self {
Self {
num_g2_blocks: 100_000,
block_size_tokens: 64,
offload_batch_size: 32,
num_g3_blocks: None,
enable_g4_storage: false,
block_size_bytes: None,
bandwidth_g1_to_g2_gbps: DEFAULT_G1_G2_BANDWIDTH_GBPS,
bandwidth_g2_to_g1_gbps: DEFAULT_G1_G2_BANDWIDTH_GBPS,
bandwidth_g2_to_g3_gbps: DEFAULT_G2_G3_BANDWIDTH_GBPS,
bandwidth_g3_to_g2_gbps: DEFAULT_G2_G3_BANDWIDTH_GBPS,
bandwidth_g2_to_g4_gbps: DEFAULT_G2_G4_BANDWIDTH_GBPS,
bandwidth_g4_to_g2_gbps: DEFAULT_G2_G4_BANDWIDTH_GBPS,
}
}
}
impl KvbmOffloadConfig {
pub fn from_args(args: &MockEngineArgs) -> anyhow::Result<Option<Self>> {
let num_g3_blocks = args
.num_g3_blocks
.and_then(|block_count| (block_count > 0).then_some(block_count));
let enable_g4_storage = args.enable_g4_storage;
let Some(num_g2_blocks) = args.num_g2_blocks else {
if num_g3_blocks.is_some() || enable_g4_storage {
anyhow::bail!(
"G3/G4 offload requires num_g2_blocks because mocker stages lower tiers through G2"
);
}
return Ok(None);
};
if num_g2_blocks == 0 {
return Ok(None);
}
let Some(bpt) = args.kv_bytes_per_token else {
if num_g3_blocks.is_some() || enable_g4_storage {
anyhow::bail!(
"G3/G4 offload requires kv_bytes_per_token so mocker can size lower-tier transfers"
);
}
return Ok(None);
};
let defaults = Self::default();
let offload_batch_size = args
.offload_batch_size
.filter(|batch_size| *batch_size > 0)
.unwrap_or(defaults.offload_batch_size);
Ok(Some(Self {
num_g2_blocks,
block_size_tokens: args.block_size,
offload_batch_size,
num_g3_blocks,
enable_g4_storage,
block_size_bytes: Some(args.block_size * bpt),
bandwidth_g1_to_g2_gbps: args
.bandwidth_g1_to_g2_gbps
.unwrap_or(defaults.bandwidth_g1_to_g2_gbps),
bandwidth_g2_to_g1_gbps: args
.bandwidth_g2_to_g1_gbps
.unwrap_or(defaults.bandwidth_g2_to_g1_gbps),
bandwidth_g2_to_g3_gbps: args
.bandwidth_g2_to_g3_gbps
.unwrap_or(defaults.bandwidth_g2_to_g3_gbps),
bandwidth_g3_to_g2_gbps: args
.bandwidth_g3_to_g2_gbps
.unwrap_or(defaults.bandwidth_g3_to_g2_gbps),
bandwidth_g2_to_g4_gbps: args
.bandwidth_g2_to_g4_gbps
.unwrap_or(defaults.bandwidth_g2_to_g4_gbps),
bandwidth_g4_to_g2_gbps: args
.bandwidth_g4_to_g2_gbps
.unwrap_or(defaults.bandwidth_g4_to_g2_gbps),
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn from_args_none_when_kv_bytes_per_token_missing() {
let args = MockEngineArgs::builder()
.num_g2_blocks(Some(10_000))
.build()
.unwrap()
.normalized()
.unwrap();
assert!(args.kv_bytes_per_token.is_none());
assert!(KvbmOffloadConfig::from_args(&args).unwrap().is_none());
}
#[test]
fn from_args_errors_when_g3_kv_bytes_per_token_missing() {
let args = MockEngineArgs::builder()
.num_g2_blocks(Some(10_000))
.num_g3_blocks(Some(20_000))
.build()
.unwrap()
.normalized()
.unwrap();
let error = KvbmOffloadConfig::from_args(&args).unwrap_err();
assert!(
error.to_string().contains("requires kv_bytes_per_token"),
"unexpected error: {error}",
);
}
#[test]
fn from_args_none_when_num_g2_blocks_missing() {
let args = MockEngineArgs::builder()
.block_size(64)
.kv_bytes_per_token(Some(131_072))
.build()
.unwrap()
.normalized()
.unwrap();
assert!(args.num_g2_blocks.is_none());
assert!(KvbmOffloadConfig::from_args(&args).unwrap().is_none());
}
#[test]
fn from_args_none_when_num_g2_blocks_zero() {
let args = MockEngineArgs::builder()
.block_size(64)
.kv_bytes_per_token(Some(131_072))
.num_g2_blocks(Some(0))
.build()
.unwrap();
assert!(KvbmOffloadConfig::from_args(&args).unwrap().is_none());
}
#[test]
fn from_args_computes_block_size_bytes() {
let args = MockEngineArgs::builder()
.block_size(64)
.kv_bytes_per_token(Some(131_072))
.num_g2_blocks(Some(10_000))
.build()
.unwrap()
.normalized()
.unwrap();
let cfg = KvbmOffloadConfig::from_args(&args)
.unwrap()
.expect("bpt set");
assert_eq!(cfg.block_size_bytes, Some(64 * 131_072));
assert_eq!(cfg.num_g2_blocks, 10_000);
assert_eq!(cfg.block_size_tokens, 64);
assert_eq!(cfg.offload_batch_size, 32);
assert_eq!(cfg.num_g3_blocks, None);
assert!(!cfg.enable_g4_storage);
assert_eq!(cfg.bandwidth_g1_to_g2_gbps, DEFAULT_G1_G2_BANDWIDTH_GBPS);
assert_eq!(cfg.bandwidth_g2_to_g1_gbps, DEFAULT_G1_G2_BANDWIDTH_GBPS);
assert_eq!(cfg.bandwidth_g2_to_g3_gbps, DEFAULT_G2_G3_BANDWIDTH_GBPS);
assert_eq!(cfg.bandwidth_g3_to_g2_gbps, DEFAULT_G2_G3_BANDWIDTH_GBPS);
assert_eq!(cfg.bandwidth_g2_to_g4_gbps, DEFAULT_G2_G4_BANDWIDTH_GBPS);
assert_eq!(cfg.bandwidth_g4_to_g2_gbps, DEFAULT_G2_G4_BANDWIDTH_GBPS);
}
#[test]
fn from_args_threads_optional_knobs_when_set() {
let args = MockEngineArgs::builder()
.block_size(64)
.kv_bytes_per_token(Some(131_072))
.num_g2_blocks(Some(10_000))
.offload_batch_size(Some(16))
.num_g3_blocks(Some(20_000))
.enable_g4_storage(true)
.bandwidth_g1_to_g2_gbps(Some(8.0))
.bandwidth_g2_to_g1_gbps(Some(12.0))
.bandwidth_g2_to_g3_gbps(Some(3.0))
.bandwidth_g3_to_g2_gbps(Some(4.0))
.bandwidth_g2_to_g4_gbps(Some(5.0))
.bandwidth_g4_to_g2_gbps(Some(6.0))
.build()
.unwrap()
.normalized()
.unwrap();
let cfg = KvbmOffloadConfig::from_args(&args)
.unwrap()
.expect("bpt set");
assert_eq!(cfg.num_g2_blocks, 10_000);
assert_eq!(cfg.block_size_tokens, 64);
assert_eq!(cfg.offload_batch_size, 16);
assert_eq!(cfg.num_g3_blocks, Some(20_000));
assert!(cfg.enable_g4_storage);
assert_eq!(cfg.bandwidth_g1_to_g2_gbps, 8.0);
assert_eq!(cfg.bandwidth_g2_to_g1_gbps, 12.0);
assert_eq!(cfg.bandwidth_g2_to_g3_gbps, 3.0);
assert_eq!(cfg.bandwidth_g3_to_g2_gbps, 4.0);
assert_eq!(cfg.bandwidth_g2_to_g4_gbps, 5.0);
assert_eq!(cfg.bandwidth_g4_to_g2_gbps, 6.0);
}
#[test]
fn from_args_errors_when_g4_num_g2_blocks_missing() {
let args = MockEngineArgs::builder()
.block_size(64)
.kv_bytes_per_token(Some(131_072))
.enable_g4_storage(true)
.build()
.unwrap();
let error = KvbmOffloadConfig::from_args(&args).unwrap_err();
assert!(
error.to_string().contains("requires num_g2_blocks"),
"unexpected error: {error}",
);
}
#[test]
fn from_args_errors_when_g4_kv_bytes_per_token_missing() {
let args = MockEngineArgs::builder()
.block_size(64)
.num_g2_blocks(Some(10_000))
.enable_g4_storage(true)
.build()
.unwrap();
let error = KvbmOffloadConfig::from_args(&args).unwrap_err();
assert!(
error.to_string().contains("requires kv_bytes_per_token"),
"unexpected error: {error}",
);
}
#[test]
fn from_args_treats_zero_optional_g3_and_batch_size_as_disabled() {
let args = MockEngineArgs::builder()
.block_size(64)
.kv_bytes_per_token(Some(131_072))
.num_g2_blocks(Some(10_000))
.num_g3_blocks(Some(0))
.offload_batch_size(Some(0))
.build()
.unwrap();
let cfg = KvbmOffloadConfig::from_args(&args)
.unwrap()
.expect("G2 remains enabled");
assert_eq!(cfg.num_g2_blocks, 10_000);
assert_eq!(cfg.num_g3_blocks, None);
assert_eq!(
cfg.offload_batch_size,
KvbmOffloadConfig::default().offload_batch_size
);
}
}