use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use uuid::Uuid;
use validator::Validate;
use crate::common::perf_model::PerfModel;
use dynamo_kv_router::protocols::KvCacheEvent;
use dynamo_tokens::blocks::UniqueBlock;
use dynamo_tokens::{BlockHash, SequenceHash, Token};
pub trait KvCacheEventSink: Send + Sync {
fn publish(
&self,
event: KvCacheEvent,
block_token_ids: Option<&[Vec<u32>]>,
) -> anyhow::Result<()>;
}
pub type NumBlocks = usize;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MoveBlock {
Use(Vec<UniqueBlock>, Vec<BlockHash>, Option<Vec<Vec<u32>>>),
Destroy(Vec<UniqueBlock>),
Deref(Vec<UniqueBlock>),
Promote(Uuid, SequenceHash, Option<u64>, BlockHash, Option<Vec<u32>>),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MoveBlockResponse {
Store(Vec<SequenceHash>, Option<u64>),
Remove(Vec<SequenceHash>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DirectRequest {
pub tokens: Vec<Token>,
pub max_output_tokens: usize,
pub uuid: Option<Uuid>,
pub dp_rank: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrefillCost {
pub new_blocks: usize,
pub new_tokens: usize,
}
impl PrefillCost {
pub fn predict_prefill_compute(
&self,
new_tokens: Option<usize>,
perf_model: &PerfModel,
) -> f64 {
let tokens = new_tokens.unwrap_or(self.new_tokens);
perf_model.predict_prefill_time(tokens)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OutputSignal {
pub uuid: Uuid,
pub completed: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum WorkerType {
#[default]
Aggregated,
Prefill,
Decode,
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct ReasoningConfig {
pub start_thinking_token_id: u32,
pub end_thinking_token_id: u32,
#[validate(range(min = 0.0, max = 1.0))]
pub thinking_ratio: f64,
}
impl ReasoningConfig {
pub fn num_thinking_tokens(&self, max_output_tokens: usize) -> usize {
if max_output_tokens < 2 {
return 0;
}
let raw = (max_output_tokens as f64 * self.thinking_ratio).floor() as usize;
if raw == 0 {
return 0;
}
raw.max(2).min(max_output_tokens)
}
pub fn num_response_tokens(&self, max_output_tokens: usize) -> usize {
max_output_tokens.saturating_sub(self.num_thinking_tokens(max_output_tokens))
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Builder, Validate)]
#[builder(pattern = "owned", build_fn(public))]
pub struct MockEngineArgs {
#[builder(default = "16384")]
#[validate(range(min = 1))]
pub num_gpu_blocks: usize,
#[builder(default = "64")]
#[validate(range(min = 2))]
pub block_size: usize,
#[builder(default = Some(256))]
#[validate(range(min = 1))]
pub max_num_seqs: Option<usize>,
#[builder(default = Some(8192))]
#[validate(range(min = 1))]
pub max_num_batched_tokens: Option<usize>,
#[builder(default = true)]
pub enable_prefix_caching: bool,
#[builder(default = true)]
pub enable_chunked_prefill: bool,
#[builder(default = "0.01")]
#[validate(range(min = 0.0, max = 1.0))]
pub watermark: f64,
#[builder(default = "1.0")]
#[validate(range(min = 0.0))]
pub speedup_ratio: f64,
#[builder(default = "1")]
#[validate(range(min = 1))]
pub dp_size: u32,
#[builder(default = "None")]
#[validate(range(min = 0.0))]
pub startup_time: Option<f64>,
#[builder(default = "WorkerType::Aggregated")]
pub worker_type: WorkerType,
#[serde(skip)]
#[builder(default = "Arc::new(PerfModel::default())")]
pub perf_model: Arc<PerfModel>,
#[builder(default = "false")]
pub enable_local_indexer: bool,
#[builder(default = "None")]
pub bootstrap_port: Option<u16>,
#[builder(default = "None")]
pub kv_bytes_per_token: Option<usize>,
#[builder(default = "None")]
#[validate(range(min = 0.0))]
pub kv_transfer_bandwidth: Option<f64>,
#[builder(default = "None")]
pub reasoning: Option<ReasoningConfig>,
#[builder(default = "None")]
pub zmq_kv_events_port: Option<u16>,
}
impl Default for MockEngineArgs {
fn default() -> MockEngineArgs {
MockEngineArgsBuilder::default()
.build()
.expect("Failed to build default MockEngineArgs")
}
}
impl MockEngineArgs {
pub fn builder() -> MockEngineArgsBuilder {
MockEngineArgsBuilder::default()
}
pub fn is_prefill(&self) -> bool {
self.worker_type == WorkerType::Prefill
}
pub fn is_decode(&self) -> bool {
self.worker_type == WorkerType::Decode
}
pub fn needs_kv_publisher(&self) -> bool {
self.enable_prefix_caching && !self.is_decode()
}
pub fn from_json_file(path: &Path) -> anyhow::Result<Self> {
let mut builder = Self::builder();
let file_content = std::fs::read_to_string(path)?;
let extra_args: HashMap<String, serde_json::Value> = serde_json::from_str(&file_content)?;
let valid_fields: HashSet<&str> = [
"num_gpu_blocks",
"block_size",
"max_num_seqs",
"max_num_batched_tokens",
"enable_prefix_caching",
"enable_chunked_prefill",
"watermark",
"speedup_ratio",
"dp_size",
"startup_time",
"is_prefill",
"is_decode",
"planner_profile_data",
"enable_local_indexer",
"bootstrap_port",
"kv_bytes_per_token",
"kv_transfer_bandwidth",
"reasoning",
"zmq_kv_events_port",
]
.iter()
.cloned()
.collect();
let invalid_args: Vec<String> = extra_args
.keys()
.filter(|key| !valid_fields.contains(key.as_str()))
.cloned()
.collect();
if !invalid_args.is_empty() {
return Err(anyhow::anyhow!(
"Invalid arguments found in JSON file: {}. Valid arguments are: {:?}",
invalid_args.join(", "),
valid_fields
));
}
if let Some(value) = extra_args.get("num_gpu_blocks")
&& let Some(num) = value.as_u64()
{
builder = builder.num_gpu_blocks(num as usize);
}
if let Some(value) = extra_args.get("block_size")
&& let Some(num) = value.as_u64()
{
builder = builder.block_size(num as usize);
}
if let Some(value) = extra_args.get("max_num_seqs")
&& let Some(num) = value.as_u64()
{
builder = builder.max_num_seqs(Some(num as usize));
}
if let Some(value) = extra_args.get("max_num_batched_tokens")
&& let Some(num) = value.as_u64()
{
builder = builder.max_num_batched_tokens(Some(num as usize));
}
if let Some(value) = extra_args.get("enable_prefix_caching")
&& let Some(enabled) = value.as_bool()
{
builder = builder.enable_prefix_caching(enabled);
}
if let Some(value) = extra_args.get("enable_chunked_prefill")
&& let Some(enabled) = value.as_bool()
{
builder = builder.enable_chunked_prefill(enabled);
}
if let Some(value) = extra_args.get("watermark")
&& let Some(num) = value.as_f64()
{
builder = builder.watermark(num);
}
if let Some(value) = extra_args.get("speedup_ratio")
&& let Some(num) = value.as_f64()
{
builder = builder.speedup_ratio(num);
}
if let Some(value) = extra_args.get("dp_size")
&& let Some(num) = value.as_u64()
{
builder = builder.dp_size(num as u32);
}
if let Some(value) = extra_args.get("startup_time")
&& let Some(num) = value.as_f64()
{
builder = builder.startup_time(Some(num));
}
if let Some(value) = extra_args.get("enable_local_indexer")
&& let Some(enabled) = value.as_bool()
{
builder = builder.enable_local_indexer(enabled);
}
if let Some(value) = extra_args.get("bootstrap_port")
&& let Some(port) = value.as_u64()
{
builder = builder.bootstrap_port(Some(port as u16));
}
if let Some(value) = extra_args.get("kv_bytes_per_token")
&& let Some(num) = value.as_u64()
{
builder = builder.kv_bytes_per_token(Some(num as usize));
}
if let Some(value) = extra_args.get("kv_transfer_bandwidth")
&& let Some(num) = value.as_f64()
{
builder = builder.kv_transfer_bandwidth(Some(num));
}
if let Some(value) = extra_args.get("reasoning") {
let cfg: ReasoningConfig = serde_json::from_value(value.clone())
.map_err(|e| anyhow::anyhow!("Failed to parse reasoning config: {}", e))?;
builder = builder.reasoning(Some(cfg));
}
if let Some(value) = extra_args.get("zmq_kv_events_port")
&& let Some(port) = value.as_u64()
{
builder = builder.zmq_kv_events_port(Some(port as u16));
}
let is_prefill = extra_args
.get("is_prefill")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let is_decode = extra_args
.get("is_decode")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let worker_type = match (is_prefill, is_decode) {
(false, false) => WorkerType::Aggregated,
(true, false) => WorkerType::Prefill,
(false, true) => WorkerType::Decode,
(true, true) => panic!(
"Invalid worker configuration: is_prefill and is_decode cannot both be true. \
Worker must be either Aggregated (both false), Prefill (is_prefill=true), or Decode (is_decode=true)."
),
};
builder = builder.worker_type(worker_type);
let perf_model = if let Some(path_str) = extra_args.get("planner_profile_data")
&& let Some(path_str) = path_str.as_str()
{
let npz_path = PathBuf::from(path_str);
match PerfModel::from_npz(&npz_path) {
Ok(model) => {
tracing::info!("Successfully loaded performance model from: {:?}", npz_path);
Arc::new(model)
}
Err(e) => {
tracing::error!(
"Failed to load performance model from {:?}: {}. Falling back to polynomial model.",
npz_path,
e
);
Arc::new(PerfModel::default())
}
}
} else {
Arc::new(PerfModel::default())
};
builder = builder.perf_model(perf_model);
builder
.build()
.map_err(|e| anyhow::anyhow!("Failed to build MockEngineArgs: {}", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unique_block_default_uniqueness() {
let blocks: Vec<UniqueBlock> = (0..10).map(|_| UniqueBlock::default()).collect();
let mut uuids = Vec::new();
for block in blocks {
match block {
UniqueBlock::PartialBlock(uuid) => uuids.push(uuid),
_ => panic!("Expected UuidIdentifier variant"),
}
}
for i in 0..uuids.len() {
for j in i + 1..uuids.len() {
assert_ne!(
uuids[i], uuids[j],
"UUID at index {} and {} are identical",
i, j
);
}
}
}
}