use derive_builder::Builder;
use dynamo_kv_router::config::RouterQueuePolicy;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use uuid::Uuid;
use validator::{Validate, ValidationError};
use crate::common::perf_model::PerfModel;
use dynamo_kv_router::protocols::{KvCacheEvent, StorageTier};
use dynamo_tokens::blocks::UniqueBlock;
use dynamo_tokens::{BlockHash, PositionalLineageHash, SequenceHash, Token};
#[derive(Clone, Debug)]
pub struct G1;
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
pub enum MockerEvictionBackend {
Lru,
MultiLru,
#[default]
Lineage,
}
pub trait KvCacheEventSink: Send + Sync {
fn publish(&self, event: KvCacheEvent) -> anyhow::Result<()>;
fn publish_with_storage_tier(
&self,
event: KvCacheEvent,
_storage_tier: StorageTier,
) -> anyhow::Result<()> {
self.publish(event)
}
}
#[derive(Debug, Clone)]
pub struct RawKvEvent {
pub event: KvCacheEvent,
pub block_token_ids: Option<Vec<Vec<u32>>>,
pub storage_tier: StorageTier,
}
pub trait RawKvEventSink: Send + Sync {
fn publish(&self, event: RawKvEvent) -> anyhow::Result<()>;
}
#[derive(Clone, Default)]
pub struct KvEventPublishers {
event_sink: Option<Arc<dyn KvCacheEventSink>>,
raw_sink: Option<Arc<dyn RawKvEventSink>>,
}
impl KvEventPublishers {
pub fn new(
event_sink: Option<Arc<dyn KvCacheEventSink>>,
raw_sink: Option<Arc<dyn RawKvEventSink>>,
) -> Self {
Self {
event_sink,
raw_sink,
}
}
pub fn raw_enabled(&self) -> bool {
self.raw_sink.is_some()
}
pub fn is_empty(&self) -> bool {
self.event_sink.is_none() && self.raw_sink.is_none()
}
pub fn publish(
&self,
event: KvCacheEvent,
block_token_ids: Option<&[Vec<u32>]>,
) -> anyhow::Result<()> {
self.publish_with_storage_tier(event, block_token_ids, StorageTier::Device)
}
pub fn publish_with_storage_tier(
&self,
event: KvCacheEvent,
block_token_ids: Option<&[Vec<u32>]>,
storage_tier: StorageTier,
) -> anyhow::Result<()> {
if let Some(sink) = self.event_sink.as_ref() {
sink.publish_with_storage_tier(event.clone(), storage_tier)?;
}
if let Some(sink) = self.raw_sink.as_ref() {
sink.publish(RawKvEvent {
event,
block_token_ids: block_token_ids.map(|token_ids| token_ids.to_vec()),
storage_tier,
})?;
}
Ok(())
}
}
#[derive(Debug, Clone, Default)]
pub struct ForwardPassSnapshot {
pub version: u32,
pub worker_id: String,
pub dp_rank: u32,
pub counter_id: u64,
pub num_prefill_requests: u32,
pub sum_prefill_tokens: u64,
pub var_prefill_length: f64,
pub sum_prefill_kv_tokens: u64,
pub num_decode_requests: u32,
pub sum_decode_kv_tokens: u64,
pub var_decode_kv_tokens: f64,
pub num_queued_prefill: u32,
pub sum_queued_prefill_tokens: u64,
pub var_queued_prefill_length: f64,
pub num_queued_decode: u32,
pub sum_queued_decode_kv_tokens: u64,
pub var_queued_decode_kv_tokens: f64,
pub wall_time_secs: f64,
}
pub trait FpmSink: Send + Sync {
fn publish(&self, snapshot: ForwardPassSnapshot) -> anyhow::Result<()>;
}
#[derive(Clone, Default)]
pub struct FpmPublisher {
sink: Option<Arc<dyn FpmSink>>,
}
impl FpmPublisher {
pub fn new(sink: Option<Arc<dyn FpmSink>>) -> Self {
Self { sink }
}
pub fn publish(&self, snapshot: ForwardPassSnapshot) -> anyhow::Result<()> {
if let Some(sink) = &self.sink {
sink.publish(snapshot)?;
}
Ok(())
}
}
pub type NumBlocks = usize;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MoveBlock {
Use(
Vec<UniqueBlock>,
Vec<BlockHash>,
Vec<PositionalLineageHash>,
Option<Vec<Vec<u32>>>,
Option<UniqueBlock>,
),
Deref(Vec<UniqueBlock>),
Promote(
Uuid,
SequenceHash,
Option<u64>,
BlockHash,
PositionalLineageHash,
Option<Vec<u32>>,
),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MoveBlockResponse {
Store(Vec<SequenceHash>, Option<u64>),
Remove(Vec<SequenceHash>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DirectRequest {
pub tokens: Vec<Token>,
pub max_output_tokens: usize,
pub uuid: Option<Uuid>,
pub dp_rank: u32,
pub arrival_timestamp_ms: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrefillCost {
pub new_blocks: usize,
pub new_tokens: usize,
pub cached_tokens: usize,
pub active_cached_tokens: usize,
}
impl PrefillCost {
pub fn predict_prefill_compute(
&self,
new_tokens: Option<usize>,
perf_model: &PerfModel,
) -> f64 {
let tokens = new_tokens.unwrap_or(self.new_tokens);
let isl = self.cached_tokens + tokens;
perf_model.predict_prefill_time(1, isl, self.cached_tokens)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OutputSignal {
pub uuid: Uuid,
pub completed: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub handoff_delay_ms: Option<f64>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum PreemptionMode {
#[default]
Lifo,
Fifo,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum EngineType {
#[default]
Vllm,
Sglang,
Trtllm,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SchedulingPolicy {
#[default]
Vllm,
TrtllmGuaranteedNoEvict,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum WorkerType {
#[default]
Aggregated,
Prefill,
Decode,
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct ReasoningConfig {
pub start_thinking_token_id: u32,
pub end_thinking_token_id: u32,
#[validate(range(min = 0.0, max = 1.0))]
pub thinking_ratio: f64,
}
impl ReasoningConfig {
pub fn num_thinking_tokens(&self, max_output_tokens: usize) -> usize {
if max_output_tokens < 2 {
return 0;
}
let raw = (max_output_tokens as f64 * self.thinking_ratio).floor() as usize;
if raw == 0 {
return 0;
}
raw.max(2).min(max_output_tokens)
}
pub fn num_response_tokens(&self, max_output_tokens: usize) -> usize {
max_output_tokens.saturating_sub(self.num_thinking_tokens(max_output_tokens))
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate, Default)]
pub struct SglangArgs {
pub schedule_policy: Option<String>,
#[validate(range(min = 1))]
pub page_size: Option<usize>,
#[validate(range(min = 1))]
pub max_prefill_tokens: Option<usize>,
#[validate(range(min = 1))]
pub chunked_prefill_size: Option<usize>,
#[validate(range(min = 1))]
pub clip_max_new_tokens: Option<usize>,
#[validate(range(min = 0.0, max = 1.0))]
pub schedule_conservativeness: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate, Default)]
pub struct TrtllmArgs {
pub capacity_scheduler_policy: Option<String>,
}
#[derive(Debug, Clone, Default)]
enum OptionalConfigValue<T> {
#[default]
Missing,
Present(Option<T>),
}
impl<'de, T> Deserialize<'de> for OptionalConfigValue<T>
where
T: Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
Option::<T>::deserialize(deserializer).map(Self::Present)
}
}
impl<T> OptionalConfigValue<T> {
fn into_nullable(self) -> Option<Option<T>> {
match self {
Self::Missing => None,
Self::Present(value) => Some(value),
}
}
fn into_non_null(self, field: &str) -> Result<Option<T>, String> {
match self {
Self::Missing => Ok(None),
Self::Present(Some(value)) => Ok(Some(value)),
Self::Present(None) => Err(format!("{field} must not be null")),
}
}
}
#[derive(Debug, Clone, Default, Deserialize)]
#[serde(default, deny_unknown_fields)]
struct MockEngineArgsSerde {
engine_type: OptionalConfigValue<String>,
num_gpu_blocks: OptionalConfigValue<usize>,
block_size: OptionalConfigValue<usize>,
max_num_seqs: OptionalConfigValue<usize>,
max_num_batched_tokens: OptionalConfigValue<usize>,
enable_prefix_caching: OptionalConfigValue<bool>,
enable_chunked_prefill: OptionalConfigValue<bool>,
speedup_ratio: OptionalConfigValue<f64>,
decode_speedup_ratio: OptionalConfigValue<f64>,
dp_size: OptionalConfigValue<u32>,
startup_time: OptionalConfigValue<f64>,
worker_type: OptionalConfigValue<String>,
is_prefill: OptionalConfigValue<bool>,
is_decode: OptionalConfigValue<bool>,
planner_profile_data: OptionalConfigValue<PathBuf>,
aic_backend: OptionalConfigValue<String>,
aic_system: OptionalConfigValue<String>,
aic_backend_version: OptionalConfigValue<String>,
aic_tp_size: OptionalConfigValue<usize>,
aic_model_path: OptionalConfigValue<String>,
aic_moe_tp_size: OptionalConfigValue<usize>,
aic_moe_ep_size: OptionalConfigValue<usize>,
aic_attention_dp_size: OptionalConfigValue<usize>,
aic_nextn: OptionalConfigValue<usize>,
aic_nextn_accept_rates: OptionalConfigValue<String>,
gpu_memory_utilization: OptionalConfigValue<f64>,
mem_fraction_static: OptionalConfigValue<f64>,
free_gpu_memory_fraction: OptionalConfigValue<f64>,
enable_local_indexer: OptionalConfigValue<bool>,
bootstrap_port: OptionalConfigValue<u16>,
kv_bytes_per_token: OptionalConfigValue<usize>,
kv_transfer_bandwidth: OptionalConfigValue<f64>,
num_g2_blocks: OptionalConfigValue<usize>,
num_g3_blocks: OptionalConfigValue<usize>,
enable_g4_storage: OptionalConfigValue<bool>,
offload_batch_size: OptionalConfigValue<usize>,
bandwidth_g1_to_g2_gbps: OptionalConfigValue<f64>,
bandwidth_g2_to_g1_gbps: OptionalConfigValue<f64>,
bandwidth_g2_to_g3_gbps: OptionalConfigValue<f64>,
bandwidth_g3_to_g2_gbps: OptionalConfigValue<f64>,
bandwidth_g2_to_g4_gbps: OptionalConfigValue<f64>,
bandwidth_g4_to_g2_gbps: OptionalConfigValue<f64>,
reasoning: OptionalConfigValue<ReasoningConfig>,
zmq_kv_events_port: OptionalConfigValue<u16>,
zmq_replay_port: OptionalConfigValue<u16>,
preemption_mode: OptionalConfigValue<String>,
router_queue_policy: OptionalConfigValue<String>,
sglang: OptionalConfigValue<SglangArgs>,
trtllm: OptionalConfigValue<TrtllmArgs>,
#[serde(rename = "has_perf_model")]
_has_perf_model: OptionalConfigValue<serde_json::Value>,
}
fn parse_engine_type(value: &str) -> Result<EngineType, String> {
match value {
"vllm" => Ok(EngineType::Vllm),
"sglang" => Ok(EngineType::Sglang),
"trtllm" => Ok(EngineType::Trtllm),
other => Err(format!(
"Invalid engine_type '{other}'. Must be 'vllm', 'sglang', or 'trtllm'."
)),
}
}
fn parse_worker_type(value: &str) -> Result<WorkerType, String> {
match value {
"aggregated" => Ok(WorkerType::Aggregated),
"prefill" => Ok(WorkerType::Prefill),
"decode" => Ok(WorkerType::Decode),
other => Err(format!(
"Invalid worker_type '{other}'. Must be 'aggregated', 'prefill', or 'decode'."
)),
}
}
fn parse_preemption_mode(value: &str) -> Result<PreemptionMode, String> {
match value {
"lifo" => Ok(PreemptionMode::Lifo),
"fifo" => Ok(PreemptionMode::Fifo),
other => Err(format!(
"Invalid preemption_mode: '{other}'. Must be 'lifo' or 'fifo'."
)),
}
}
fn load_perf_model(path: &Path) -> Arc<PerfModel> {
match PerfModel::from_npz(path) {
Ok(model) => {
tracing::info!("Successfully loaded performance model from: {:?}", path);
Arc::new(model)
}
Err(e) => {
tracing::error!(
"Failed to load performance model from {:?}: {}. Falling back to polynomial model.",
path,
e
);
Arc::new(PerfModel::default())
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Builder, Validate)]
#[serde(try_from = "MockEngineArgsSerde")]
#[validate(schema(function = "validate_mock_engine_args"))]
#[builder(pattern = "owned", build_fn(public))]
pub struct MockEngineArgs {
#[builder(default = "EngineType::Vllm")]
pub engine_type: EngineType,
#[builder(default = "16384")]
#[validate(range(min = 1))]
pub num_gpu_blocks: usize,
#[builder(default = "0")]
pub block_size: usize,
#[builder(default = Some(256))]
#[validate(range(min = 1))]
pub max_num_seqs: Option<usize>,
#[builder(default = Some(8192))]
#[validate(range(min = 1))]
pub max_num_batched_tokens: Option<usize>,
#[builder(default = true)]
pub enable_prefix_caching: bool,
#[builder(default = true)]
pub enable_chunked_prefill: bool,
#[builder(default = "1.0")]
#[validate(range(min = 0.0))]
pub speedup_ratio: f64,
#[builder(default = "1.0")]
#[validate(range(min = 0.0))]
pub decode_speedup_ratio: f64,
#[builder(default = "1")]
#[validate(range(min = 1))]
pub dp_size: u32,
#[builder(default = "None")]
#[validate(range(min = 0.0))]
pub startup_time: Option<f64>,
#[builder(default = "WorkerType::Aggregated")]
pub worker_type: WorkerType,
#[builder(default = "None")]
pub planner_profile_data: Option<PathBuf>,
#[serde(skip)]
#[builder(default = "Arc::new(PerfModel::default())")]
pub perf_model: Arc<PerfModel>,
#[serde(skip)]
#[builder(default = "None")]
pub aic_backend: Option<String>,
#[serde(skip)]
#[builder(default = "None")]
pub aic_system: Option<String>,
#[serde(skip)]
#[builder(default = "None")]
pub aic_backend_version: Option<String>,
#[serde(skip)]
#[builder(default = "None")]
pub aic_tp_size: Option<usize>,
#[serde(skip)]
#[builder(default = "None")]
pub aic_model_path: Option<String>,
#[serde(skip)]
#[builder(default = "None")]
pub aic_moe_tp_size: Option<usize>,
#[serde(skip)]
#[builder(default = "None")]
pub aic_moe_ep_size: Option<usize>,
#[serde(skip)]
#[builder(default = "None")]
pub aic_attention_dp_size: Option<usize>,
#[serde(skip)]
#[builder(default = "None")]
#[validate(range(min = 1, max = 5))]
pub aic_nextn: Option<usize>,
#[serde(skip)]
#[builder(default = "None")]
pub aic_nextn_accept_rates: Option<String>,
#[builder(default = "None")]
#[validate(range(min = 0.0, max = 1.0))]
pub gpu_memory_utilization: Option<f64>,
#[builder(default = "None")]
#[validate(range(min = 0.0, max = 1.0))]
pub mem_fraction_static: Option<f64>,
#[builder(default = "None")]
#[validate(range(min = 0.0, max = 1.0))]
pub free_gpu_memory_fraction: Option<f64>,
#[builder(default = "false")]
pub enable_local_indexer: bool,
#[builder(default = "None")]
pub bootstrap_port: Option<u16>,
#[builder(default = "None")]
pub kv_bytes_per_token: Option<usize>,
#[builder(default = "None")]
#[validate(range(min = 0.0))]
pub kv_transfer_bandwidth: Option<f64>,
#[builder(default = "None")]
#[validate(range(min = 1))]
pub num_g2_blocks: Option<usize>,
#[builder(default = "None")]
#[validate(range(min = 1))]
pub num_g3_blocks: Option<usize>,
#[builder(default = "false")]
pub enable_g4_storage: bool,
#[builder(default = "None")]
#[validate(range(min = 1))]
pub offload_batch_size: Option<usize>,
#[builder(default = "None")]
#[validate(range(min = 0.0))]
pub bandwidth_g1_to_g2_gbps: Option<f64>,
#[builder(default = "None")]
#[validate(range(min = 0.0))]
pub bandwidth_g2_to_g1_gbps: Option<f64>,
#[builder(default = "None")]
#[validate(range(min = 0.0))]
pub bandwidth_g2_to_g3_gbps: Option<f64>,
#[builder(default = "None")]
#[validate(range(min = 0.0))]
pub bandwidth_g3_to_g2_gbps: Option<f64>,
#[builder(default = "None")]
#[validate(range(min = 0.0))]
pub bandwidth_g2_to_g4_gbps: Option<f64>,
#[builder(default = "None")]
#[validate(range(min = 0.0))]
pub bandwidth_g4_to_g2_gbps: Option<f64>,
#[builder(default = "None")]
pub reasoning: Option<ReasoningConfig>,
#[builder(default = "None")]
pub zmq_kv_events_port: Option<u16>,
#[builder(default = "None")]
pub zmq_replay_port: Option<u16>,
#[builder(default)]
pub preemption_mode: PreemptionMode,
#[builder(default = "None")]
pub router_queue_policy: Option<RouterQueuePolicy>,
#[builder(default = "None")]
pub sglang: Option<SglangArgs>,
#[builder(default = "None")]
pub trtllm: Option<TrtllmArgs>,
}
fn mock_engine_args_validation_error(code: &'static str, message: String) -> ValidationError {
let mut error = ValidationError::new(code);
error.message = Some(message.into());
error
}
fn validate_mock_engine_args(args: &MockEngineArgs) -> Result<(), ValidationError> {
if args.block_size == 0 {
return Err(mock_engine_args_validation_error(
"block_size_zero",
"block_size must be greater than 0".to_string(),
));
}
if args.num_g3_blocks.is_some() && args.num_g2_blocks.is_none() {
return Err(mock_engine_args_validation_error(
"g3_requires_g2",
"num_g3_blocks requires num_g2_blocks because mocker stages G3 through G2".to_string(),
));
}
if args.enable_g4_storage && args.num_g2_blocks.is_none() {
return Err(mock_engine_args_validation_error(
"g4_requires_g2",
"enable_g4_storage requires num_g2_blocks because mocker stages G4 through G2"
.to_string(),
));
}
if let Some(policy) = args
.trtllm
.as_ref()
.and_then(|trtllm| trtllm.capacity_scheduler_policy.as_deref())
&& policy != "guaranteed_no_evict"
{
return Err(mock_engine_args_validation_error(
"trtllm_unsupported_capacity_scheduler_policy",
format!(
"engine_type=trtllm v1 supports only capacity_scheduler_policy='guaranteed_no_evict', got '{policy}'",
),
));
}
if args.engine_type != EngineType::Sglang {
return Ok(());
}
if let Some(page_size) = args.sglang.as_ref().and_then(|sglang| sglang.page_size)
&& args.block_size != page_size
{
return Err(mock_engine_args_validation_error(
"sglang_block_size_page_size_mismatch",
format!(
"engine_type=sglang requires block_size and sglang.page_size to match when both are set, got block_size={} and sglang.page_size={page_size}",
args.block_size,
),
));
}
if let Some(chunked_prefill_size) = args
.sglang
.as_ref()
.and_then(|sglang| sglang.chunked_prefill_size)
&& chunked_prefill_size % args.block_size != 0
{
return Err(mock_engine_args_validation_error(
"sglang_chunked_prefill_size_not_divisible_by_block_size",
format!(
"engine_type=sglang requires sglang.chunked_prefill_size to be divisible by block_size, got chunked_prefill_size={} and block_size={}",
chunked_prefill_size, args.block_size,
),
));
}
Ok(())
}
impl TryFrom<MockEngineArgsSerde> for MockEngineArgs {
type Error = String;
fn try_from(compat: MockEngineArgsSerde) -> Result<Self, Self::Error> {
let mut builder = Self::builder();
if let Some(engine_type) = compat.engine_type.into_non_null("engine_type")? {
builder = builder.engine_type(parse_engine_type(&engine_type)?);
}
if let Some(Some(num_gpu_blocks)) = compat.num_gpu_blocks.into_nullable() {
builder = builder.num_gpu_blocks(num_gpu_blocks);
}
if let Some(block_size) = compat.block_size.into_non_null("block_size")? {
builder = builder.block_size(block_size);
}
if let Some(max_num_seqs) = compat.max_num_seqs.into_nullable() {
builder = builder.max_num_seqs(max_num_seqs);
}
if let Some(max_num_batched_tokens) = compat.max_num_batched_tokens.into_nullable() {
builder = builder.max_num_batched_tokens(max_num_batched_tokens);
}
if let Some(enable_prefix_caching) = compat
.enable_prefix_caching
.into_non_null("enable_prefix_caching")?
{
builder = builder.enable_prefix_caching(enable_prefix_caching);
}
if let Some(enable_chunked_prefill) = compat
.enable_chunked_prefill
.into_non_null("enable_chunked_prefill")?
{
builder = builder.enable_chunked_prefill(enable_chunked_prefill);
}
if let Some(speedup_ratio) = compat.speedup_ratio.into_non_null("speedup_ratio")? {
builder = builder.speedup_ratio(speedup_ratio);
}
if let Some(decode_speedup_ratio) = compat
.decode_speedup_ratio
.into_non_null("decode_speedup_ratio")?
{
builder = builder.decode_speedup_ratio(decode_speedup_ratio);
}
if let Some(dp_size) = compat.dp_size.into_non_null("dp_size")? {
builder = builder.dp_size(dp_size);
}
if let Some(startup_time) = compat.startup_time.into_nullable() {
builder = builder.startup_time(startup_time);
}
let worker_type = if let Some(worker_type) =
compat.worker_type.into_non_null("worker_type")?
{
parse_worker_type(&worker_type)?
} else {
let is_prefill = compat
.is_prefill
.into_non_null("is_prefill")?
.unwrap_or(false);
let is_decode = compat
.is_decode
.into_non_null("is_decode")?
.unwrap_or(false);
match (is_prefill, is_decode) {
(false, false) => WorkerType::Aggregated,
(true, false) => WorkerType::Prefill,
(false, true) => WorkerType::Decode,
(true, true) => {
return Err(
"Invalid worker configuration: is_prefill and is_decode cannot both be true."
.to_string(),
);
}
}
};
builder = builder.worker_type(worker_type);
if let Some(planner_profile_data) = compat.planner_profile_data.into_nullable() {
builder = builder.planner_profile_data(planner_profile_data.clone());
if let Some(path) = planner_profile_data {
builder = builder.perf_model(load_perf_model(&path));
}
}
if let Some(aic_backend) = compat.aic_backend.into_nullable() {
builder = builder.aic_backend(aic_backend);
}
if let Some(aic_system) = compat.aic_system.into_nullable() {
builder = builder.aic_system(aic_system);
}
if let Some(aic_backend_version) = compat.aic_backend_version.into_nullable() {
builder = builder.aic_backend_version(aic_backend_version);
}
if let Some(aic_tp_size) = compat.aic_tp_size.into_nullable() {
builder = builder.aic_tp_size(aic_tp_size);
}
if let Some(aic_model_path) = compat.aic_model_path.into_nullable() {
builder = builder.aic_model_path(aic_model_path);
}
if let Some(aic_moe_tp_size) = compat.aic_moe_tp_size.into_nullable() {
builder = builder.aic_moe_tp_size(aic_moe_tp_size);
}
if let Some(aic_moe_ep_size) = compat.aic_moe_ep_size.into_nullable() {
builder = builder.aic_moe_ep_size(aic_moe_ep_size);
}
if let Some(aic_attention_dp_size) = compat.aic_attention_dp_size.into_nullable() {
builder = builder.aic_attention_dp_size(aic_attention_dp_size);
}
if let Some(aic_nextn) = compat.aic_nextn.into_nullable() {
builder = builder.aic_nextn(aic_nextn);
}
if let Some(aic_nextn_accept_rates) = compat.aic_nextn_accept_rates.into_nullable() {
builder = builder.aic_nextn_accept_rates(aic_nextn_accept_rates);
}
if let Some(gpu_memory_utilization) = compat.gpu_memory_utilization.into_nullable() {
builder = builder.gpu_memory_utilization(gpu_memory_utilization);
}
if let Some(mem_fraction_static) = compat.mem_fraction_static.into_nullable() {
builder = builder.mem_fraction_static(mem_fraction_static);
}
if let Some(free_gpu_memory_fraction) = compat.free_gpu_memory_fraction.into_nullable() {
builder = builder.free_gpu_memory_fraction(free_gpu_memory_fraction);
}
if let Some(enable_local_indexer) = compat
.enable_local_indexer
.into_non_null("enable_local_indexer")?
{
builder = builder.enable_local_indexer(enable_local_indexer);
}
if let Some(bootstrap_port) = compat.bootstrap_port.into_nullable() {
builder = builder.bootstrap_port(bootstrap_port);
}
if let Some(kv_bytes_per_token) = compat.kv_bytes_per_token.into_nullable() {
builder = builder.kv_bytes_per_token(kv_bytes_per_token);
}
if let Some(kv_transfer_bandwidth) = compat.kv_transfer_bandwidth.into_nullable() {
builder = builder.kv_transfer_bandwidth(kv_transfer_bandwidth);
}
if let Some(num_g2_blocks) = compat.num_g2_blocks.into_nullable() {
builder = builder.num_g2_blocks(num_g2_blocks);
}
if let Some(num_g3_blocks) = compat.num_g3_blocks.into_nullable() {
builder = builder.num_g3_blocks(num_g3_blocks);
}
if let Some(enable_g4_storage) = compat
.enable_g4_storage
.into_non_null("enable_g4_storage")?
{
builder = builder.enable_g4_storage(enable_g4_storage);
}
if let Some(offload_batch_size) = compat.offload_batch_size.into_nullable() {
builder = builder.offload_batch_size(offload_batch_size);
}
if let Some(bandwidth_g1_to_g2_gbps) = compat.bandwidth_g1_to_g2_gbps.into_nullable() {
builder = builder.bandwidth_g1_to_g2_gbps(bandwidth_g1_to_g2_gbps);
}
if let Some(bandwidth_g2_to_g1_gbps) = compat.bandwidth_g2_to_g1_gbps.into_nullable() {
builder = builder.bandwidth_g2_to_g1_gbps(bandwidth_g2_to_g1_gbps);
}
if let Some(bandwidth_g2_to_g3_gbps) = compat.bandwidth_g2_to_g3_gbps.into_nullable() {
builder = builder.bandwidth_g2_to_g3_gbps(bandwidth_g2_to_g3_gbps);
}
if let Some(bandwidth_g3_to_g2_gbps) = compat.bandwidth_g3_to_g2_gbps.into_nullable() {
builder = builder.bandwidth_g3_to_g2_gbps(bandwidth_g3_to_g2_gbps);
}
if let Some(bandwidth_g2_to_g4_gbps) = compat.bandwidth_g2_to_g4_gbps.into_nullable() {
builder = builder.bandwidth_g2_to_g4_gbps(bandwidth_g2_to_g4_gbps);
}
if let Some(bandwidth_g4_to_g2_gbps) = compat.bandwidth_g4_to_g2_gbps.into_nullable() {
builder = builder.bandwidth_g4_to_g2_gbps(bandwidth_g4_to_g2_gbps);
}
if let Some(reasoning) = compat.reasoning.into_nullable() {
builder = builder.reasoning(reasoning);
}
if let Some(zmq_kv_events_port) = compat.zmq_kv_events_port.into_nullable() {
builder = builder.zmq_kv_events_port(zmq_kv_events_port);
}
if let Some(zmq_replay_port) = compat.zmq_replay_port.into_nullable() {
builder = builder.zmq_replay_port(zmq_replay_port);
}
if let Some(preemption_mode) = compat.preemption_mode.into_non_null("preemption_mode")? {
builder = builder.preemption_mode(parse_preemption_mode(&preemption_mode)?);
}
if let Some(router_queue_policy) = compat.router_queue_policy.into_nullable() {
let router_queue_policy = router_queue_policy
.map(|policy| policy.parse().map_err(|e: String| e))
.transpose()?;
builder = builder.router_queue_policy(router_queue_policy);
}
if let Some(sglang) = compat.sglang.into_nullable() {
builder = builder.sglang(sglang);
}
if let Some(trtllm) = compat.trtllm.into_nullable() {
builder = builder.trtllm(trtllm);
}
builder
.build()
.map_err(|e| format!("Failed to build MockEngineArgs: {e}"))?
.normalized()
.map_err(|e| e.to_string())
}
}
impl Default for MockEngineArgs {
fn default() -> MockEngineArgs {
MockEngineArgsBuilder::default()
.build()
.expect("Failed to build default MockEngineArgs")
.normalized()
.expect("Failed to normalize default MockEngineArgs")
}
}
impl MockEngineArgs {
const DEFAULT_VLLM_BLOCK_SIZE: usize = 64;
const DEFAULT_SGLANG_BLOCK_SIZE: usize = 1;
const DEFAULT_TRTLLM_BLOCK_SIZE: usize = 32;
pub fn builder() -> MockEngineArgsBuilder {
MockEngineArgsBuilder::default()
}
pub fn normalized(mut self) -> anyhow::Result<Self> {
self.materialize_defaults();
self.validate_config()?;
Ok(self)
}
fn materialize_defaults(&mut self) {
match self.engine_type {
EngineType::Vllm => {
if self.block_size == 0 {
self.block_size = Self::DEFAULT_VLLM_BLOCK_SIZE;
}
}
EngineType::Sglang => {
let page_size = self.sglang.as_ref().and_then(|sglang| sglang.page_size);
match (self.block_size, page_size) {
(0, None) => {
self.block_size = Self::DEFAULT_SGLANG_BLOCK_SIZE;
}
(0, Some(page_size)) => {
self.block_size = page_size;
}
(_, Some(_)) => {}
(_, None) => {}
}
}
EngineType::Trtllm => {
if self.block_size == 0 {
self.block_size = Self::DEFAULT_TRTLLM_BLOCK_SIZE;
}
}
}
if self.num_g2_blocks == Some(0) {
self.num_g2_blocks = None;
}
if self.num_g3_blocks == Some(0) {
self.num_g3_blocks = None;
}
if self.offload_batch_size == Some(0) {
self.offload_batch_size = None;
}
}
fn validate_config(&self) -> anyhow::Result<()> {
self.validate()
.map_err(|error| anyhow::anyhow!("Failed to validate MockEngineArgs: {error}"))?;
Ok(())
}
pub fn scheduling_policy(&self) -> SchedulingPolicy {
match self.engine_type {
EngineType::Trtllm => SchedulingPolicy::TrtllmGuaranteedNoEvict,
EngineType::Vllm | EngineType::Sglang => SchedulingPolicy::Vllm,
}
}
pub fn is_prefill(&self) -> bool {
self.worker_type == WorkerType::Prefill
}
pub fn is_decode(&self) -> bool {
self.worker_type == WorkerType::Decode
}
pub fn needs_kv_publisher(&self) -> bool {
self.enable_prefix_caching && !self.is_decode()
}
pub fn from_json_file(path: &Path) -> anyhow::Result<Self> {
let file_content = std::fs::read_to_string(path)?;
Self::from_json_str(&file_content)
}
pub fn from_json_str(content: &str) -> anyhow::Result<Self> {
let mut deserializer = serde_json::Deserializer::from_str(content);
let args = serde_path_to_error::deserialize(&mut deserializer)
.map_err(|error| anyhow::anyhow!("{error}"))?;
deserializer
.end()
.map_err(|error| anyhow::anyhow!("{error}"))?;
Ok(args)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_mock_engine_args_json_round_trip_preserves_worker_type_and_nulls() {
let args = MockEngineArgs::builder()
.worker_type(WorkerType::Decode)
.max_num_seqs(None)
.max_num_batched_tokens(None)
.reasoning(None)
.sglang(None)
.build()
.unwrap()
.normalized()
.unwrap();
let payload = serde_json::json!({
"engine_type": "vllm",
"num_gpu_blocks": args.num_gpu_blocks,
"block_size": args.block_size,
"max_num_seqs": args.max_num_seqs,
"max_num_batched_tokens": args.max_num_batched_tokens,
"enable_prefix_caching": args.enable_prefix_caching,
"enable_chunked_prefill": args.enable_chunked_prefill,
"speedup_ratio": args.speedup_ratio,
"decode_speedup_ratio": args.decode_speedup_ratio,
"dp_size": args.dp_size,
"startup_time": args.startup_time,
"worker_type": "decode",
"planner_profile_data": args.planner_profile_data,
"aic_backend": args.aic_backend,
"aic_system": args.aic_system,
"aic_backend_version": args.aic_backend_version,
"aic_tp_size": args.aic_tp_size,
"aic_model_path": args.aic_model_path,
"enable_local_indexer": args.enable_local_indexer,
"bootstrap_port": args.bootstrap_port,
"kv_bytes_per_token": args.kv_bytes_per_token,
"kv_transfer_bandwidth": args.kv_transfer_bandwidth,
"num_g2_blocks": args.num_g2_blocks,
"num_g3_blocks": args.num_g3_blocks,
"enable_g4_storage": args.enable_g4_storage,
"offload_batch_size": args.offload_batch_size,
"bandwidth_g1_to_g2_gbps": args.bandwidth_g1_to_g2_gbps,
"bandwidth_g2_to_g1_gbps": args.bandwidth_g2_to_g1_gbps,
"bandwidth_g2_to_g3_gbps": args.bandwidth_g2_to_g3_gbps,
"bandwidth_g3_to_g2_gbps": args.bandwidth_g3_to_g2_gbps,
"bandwidth_g2_to_g4_gbps": args.bandwidth_g2_to_g4_gbps,
"bandwidth_g4_to_g2_gbps": args.bandwidth_g4_to_g2_gbps,
"reasoning": args.reasoning,
"zmq_kv_events_port": args.zmq_kv_events_port,
"zmq_replay_port": args.zmq_replay_port,
"preemption_mode": "lifo",
"router_queue_policy": args.router_queue_policy.map(|policy| policy.to_string()),
"sglang": args.sglang,
"has_perf_model": true,
});
let restored = MockEngineArgs::from_json_str(&payload.to_string()).unwrap();
assert_eq!(restored.worker_type, WorkerType::Decode);
assert_eq!(restored.max_num_seqs, None);
assert_eq!(restored.max_num_batched_tokens, None);
}
#[test]
fn test_mock_engine_args_json_rejects_unknown_and_invalid_types() {
let unknown = MockEngineArgs::from_json_str(&json!({"unknown": true}).to_string())
.expect_err("unknown fields should be rejected");
assert!(
unknown.to_string().contains("unknown field"),
"unexpected error: {unknown}",
);
let invalid =
MockEngineArgs::from_json_str(&json!({"gpu_memory_utilization": "bad"}).to_string())
.expect_err("wrongly typed fields should be rejected");
assert!(
invalid.to_string().contains("gpu_memory_utilization"),
"unexpected error: {invalid}",
);
let trailing = MockEngineArgs::from_json_str(r#"{"block_size": 16} true"#)
.expect_err("trailing JSON should be rejected");
assert!(
trailing.to_string().contains("trailing characters"),
"unexpected error: {trailing}",
);
}
#[test]
fn test_unique_block_default_uniqueness() {
let blocks: Vec<UniqueBlock> = (0..10).map(|_| UniqueBlock::default()).collect();
let mut uuids = Vec::new();
for block in blocks {
match block {
UniqueBlock::PartialBlock(uuid) => uuids.push(uuid),
_ => panic!("Expected UuidIdentifier variant"),
}
}
for i in 0..uuids.len() {
for j in i + 1..uuids.len() {
assert_ne!(
uuids[i], uuids[j],
"UUID at index {} and {} are identical",
i, j
);
}
}
}
#[test]
fn test_normalized_sglang_uses_page_size_alias_for_block_size() {
let args = MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.sglang(Some(SglangArgs {
page_size: Some(16),
..Default::default()
}))
.build()
.unwrap()
.normalized()
.unwrap();
assert_eq!(args.block_size, 16);
}
#[test]
fn test_normalized_sglang_accepts_equal_block_size_and_page_size() {
let args = MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(8)
.sglang(Some(SglangArgs {
page_size: Some(8),
..Default::default()
}))
.build()
.unwrap()
.normalized()
.unwrap();
assert_eq!(args.block_size, 8);
}
#[test]
fn test_normalized_sglang_rejects_mismatched_block_size_and_page_size() {
let error = MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(8)
.sglang(Some(SglangArgs {
page_size: Some(4),
..Default::default()
}))
.build()
.unwrap()
.normalized()
.unwrap_err();
assert!(
error
.to_string()
.contains("block_size and sglang.page_size to match"),
"unexpected error: {error}",
);
}
#[test]
fn test_normalized_g3_requires_g2() {
let missing_g2 = MockEngineArgs::builder()
.num_g3_blocks(Some(10))
.kv_bytes_per_token(Some(1024))
.build()
.unwrap()
.normalized()
.unwrap_err();
assert!(
missing_g2.to_string().contains("requires num_g2_blocks"),
"unexpected error: {missing_g2}",
);
}
#[test]
fn test_normalized_g4_requires_g2() {
let missing_g2 = MockEngineArgs::builder()
.enable_g4_storage(true)
.kv_bytes_per_token(Some(1024))
.build()
.unwrap()
.normalized()
.unwrap_err();
assert!(
missing_g2.to_string().contains("requires num_g2_blocks"),
"unexpected error: {missing_g2}",
);
}
#[test]
fn test_normalized_rejects_out_of_range_aic_nextn() {
for bad in [0_usize, 6] {
let err = MockEngineArgs::builder()
.aic_nextn(Some(bad))
.build()
.unwrap()
.normalized()
.unwrap_err();
assert!(
err.to_string().contains("aic_nextn"),
"unexpected error for nextn={bad}: {err}",
);
}
MockEngineArgs::builder()
.aic_nextn(Some(3))
.build()
.unwrap()
.normalized()
.expect("in-range aic_nextn should validate");
}
#[test]
fn test_normalized_zero_disables_optional_offload_knobs() {
let args = MockEngineArgs::builder()
.num_g2_blocks(Some(0))
.num_g3_blocks(Some(0))
.offload_batch_size(Some(0))
.build()
.unwrap()
.normalized()
.unwrap();
assert_eq!(args.num_g2_blocks, None);
assert_eq!(args.num_g3_blocks, None);
assert!(!args.enable_g4_storage);
assert_eq!(args.offload_batch_size, None);
}
#[test]
fn test_normalized_zero_g3_does_not_require_g2_or_kv_bytes() {
let args = MockEngineArgs::builder()
.num_g3_blocks(Some(0))
.build()
.unwrap()
.normalized()
.unwrap();
assert_eq!(args.num_g3_blocks, None);
}
#[test]
fn test_normalized_g3_allows_missing_kv_bytes_for_cli_auto_compute() {
let args = MockEngineArgs::builder()
.num_g2_blocks(Some(10))
.num_g3_blocks(Some(10))
.build()
.unwrap()
.normalized()
.unwrap();
assert_eq!(args.num_g2_blocks, Some(10));
assert_eq!(args.num_g3_blocks, Some(10));
assert_eq!(args.kv_bytes_per_token, None);
}
#[test]
fn test_normalized_g4_allows_missing_kv_bytes_for_cli_auto_compute() {
let args = MockEngineArgs::builder()
.num_g2_blocks(Some(10))
.enable_g4_storage(true)
.build()
.unwrap()
.normalized()
.unwrap();
assert_eq!(args.num_g2_blocks, Some(10));
assert!(args.enable_g4_storage);
assert_eq!(args.kv_bytes_per_token, None);
}
#[test]
fn test_normalized_sglang_defaults_block_size_to_one() {
let args = MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.build()
.unwrap()
.normalized()
.unwrap();
assert_eq!(args.block_size, 1);
}
#[test]
fn test_from_json_file_normalizes_sglang_page_size() {
let tempdir = tempfile::tempdir().unwrap();
let path = tempdir.path().join("args.json");
std::fs::write(
&path,
serde_json::to_string(&json!({
"engine_type": "sglang",
"sglang": {
"page_size": 32
}
}))
.unwrap(),
)
.unwrap();
let args = MockEngineArgs::from_json_file(&path).unwrap();
assert_eq!(args.block_size, 32);
}
#[test]
fn test_normalized_sglang_rejects_chunked_prefill_not_divisible_by_block_size() {
let error = MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(4)
.sglang(Some(SglangArgs {
page_size: Some(4),
chunked_prefill_size: Some(6),
..Default::default()
}))
.build()
.unwrap()
.normalized()
.unwrap_err();
assert!(
error
.to_string()
.contains("chunked_prefill_size to be divisible by block_size"),
"unexpected error: {error}",
);
}
#[test]
fn test_normalized_sglang_accepts_chunked_prefill_divisible_by_block_size() {
let args = MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(4)
.sglang(Some(SglangArgs {
page_size: Some(4),
chunked_prefill_size: Some(8),
..Default::default()
}))
.build()
.unwrap()
.normalized()
.unwrap();
assert_eq!(args.block_size, 4);
}
}