1use derive_builder::Builder;
5use serde::{Deserialize, Serialize};
6use std::collections::{HashMap, HashSet};
7use std::path::{Path, PathBuf};
8use std::sync::Arc;
9use uuid::Uuid;
10use validator::Validate;
11
12use crate::common::perf_model::PerfModel;
13use dynamo_kv_router::protocols::KvCacheEvent;
14use dynamo_tokens::blocks::UniqueBlock;
15use dynamo_tokens::{BlockHash, SequenceHash, Token};
16
17pub trait KvCacheEventSink: Send + Sync {
20 fn publish(
21 &self,
22 event: KvCacheEvent,
23 block_token_ids: Option<&[Vec<u32>]>,
24 ) -> anyhow::Result<()>;
25}
26
27pub type NumBlocks = usize;
28
29#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
32pub enum MoveBlock {
33 Use(Vec<UniqueBlock>, Vec<BlockHash>, Option<Vec<Vec<u32>>>),
34 Destroy(Vec<UniqueBlock>),
35 Deref(Vec<UniqueBlock>),
36 Promote(Uuid, SequenceHash, Option<u64>, BlockHash, Option<Vec<u32>>),
37}
38
39#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
40pub enum MoveBlockResponse {
41 Store(Vec<SequenceHash>, Option<u64>),
42 Remove(Vec<SequenceHash>),
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct DirectRequest {
47 pub tokens: Vec<Token>,
48 pub max_output_tokens: usize,
49 pub uuid: Option<Uuid>,
50 pub dp_rank: u32,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct PrefillCost {
56 pub new_blocks: usize,
57 pub new_tokens: usize,
58}
59
60impl PrefillCost {
61 pub fn predict_prefill_compute(
62 &self,
63 new_tokens: Option<usize>,
64 perf_model: &PerfModel,
65 ) -> f64 {
66 let tokens = new_tokens.unwrap_or(self.new_tokens);
67 perf_model.predict_prefill_time(tokens)
68 }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct OutputSignal {
74 pub uuid: Uuid,
75 pub completed: bool,
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
80pub enum WorkerType {
81 #[default]
83 Aggregated,
84 Prefill,
86 Decode,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
95pub struct ReasoningConfig {
96 pub start_thinking_token_id: u32,
97 pub end_thinking_token_id: u32,
98 #[validate(range(min = 0.0, max = 1.0))]
99 pub thinking_ratio: f64,
100}
101
102impl ReasoningConfig {
103 pub fn num_thinking_tokens(&self, max_output_tokens: usize) -> usize {
106 if max_output_tokens < 2 {
107 return 0;
108 }
109 let raw = (max_output_tokens as f64 * self.thinking_ratio).floor() as usize;
110 if raw == 0 {
111 return 0;
112 }
113 raw.max(2).min(max_output_tokens)
114 }
115
116 pub fn num_response_tokens(&self, max_output_tokens: usize) -> usize {
118 max_output_tokens.saturating_sub(self.num_thinking_tokens(max_output_tokens))
119 }
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize, Builder, Validate)]
124#[builder(pattern = "owned", build_fn(public))]
125pub struct MockEngineArgs {
126 #[builder(default = "16384")]
127 #[validate(range(min = 1))]
128 pub num_gpu_blocks: usize,
129
130 #[builder(default = "64")]
131 #[validate(range(min = 2))]
132 pub block_size: usize,
133
134 #[builder(default = Some(256))]
136 #[validate(range(min = 1))]
137 pub max_num_seqs: Option<usize>,
138
139 #[builder(default = Some(8192))]
141 #[validate(range(min = 1))]
142 pub max_num_batched_tokens: Option<usize>,
143
144 #[builder(default = true)]
145 pub enable_prefix_caching: bool,
146
147 #[builder(default = true)]
148 pub enable_chunked_prefill: bool,
149
150 #[builder(default = "0.01")]
151 #[validate(range(min = 0.0, max = 1.0))]
152 pub watermark: f64,
153
154 #[builder(default = "1.0")]
155 #[validate(range(min = 0.0))]
156 pub speedup_ratio: f64,
157
158 #[builder(default = "1")]
159 #[validate(range(min = 1))]
160 pub dp_size: u32,
161
162 #[builder(default = "None")]
164 #[validate(range(min = 0.0))]
165 pub startup_time: Option<f64>,
166
167 #[builder(default = "WorkerType::Aggregated")]
169 pub worker_type: WorkerType,
170
171 #[serde(skip)]
173 #[builder(default = "Arc::new(PerfModel::default())")]
174 pub perf_model: Arc<PerfModel>,
175
176 #[builder(default = "false")]
178 pub enable_local_indexer: bool,
179
180 #[builder(default = "None")]
184 pub bootstrap_port: Option<u16>,
185
186 #[builder(default = "None")]
189 pub kv_bytes_per_token: Option<usize>,
190
191 #[builder(default = "None")]
195 #[validate(range(min = 0.0))]
196 pub kv_transfer_bandwidth: Option<f64>,
197
198 #[builder(default = "None")]
201 pub reasoning: Option<ReasoningConfig>,
202
203 #[builder(default = "None")]
207 pub zmq_kv_events_port: Option<u16>,
208}
209
210impl Default for MockEngineArgs {
211 fn default() -> MockEngineArgs {
212 MockEngineArgsBuilder::default()
213 .build()
214 .expect("Failed to build default MockEngineArgs")
215 }
216}
217
218impl MockEngineArgs {
219 pub fn builder() -> MockEngineArgsBuilder {
220 MockEngineArgsBuilder::default()
221 }
222
223 pub fn is_prefill(&self) -> bool {
224 self.worker_type == WorkerType::Prefill
225 }
226
227 pub fn is_decode(&self) -> bool {
228 self.worker_type == WorkerType::Decode
229 }
230
231 pub fn needs_kv_publisher(&self) -> bool {
232 self.enable_prefix_caching && !self.is_decode()
233 }
234
235 pub fn from_json_file(path: &Path) -> anyhow::Result<Self> {
237 let mut builder = Self::builder();
238
239 let file_content = std::fs::read_to_string(path)?;
241 let extra_args: HashMap<String, serde_json::Value> = serde_json::from_str(&file_content)?;
242
243 let valid_fields: HashSet<&str> = [
245 "num_gpu_blocks",
246 "block_size",
247 "max_num_seqs",
248 "max_num_batched_tokens",
249 "enable_prefix_caching",
250 "enable_chunked_prefill",
251 "watermark",
252 "speedup_ratio",
253 "dp_size",
254 "startup_time",
255 "is_prefill",
256 "is_decode",
257 "planner_profile_data",
258 "enable_local_indexer",
259 "bootstrap_port",
260 "kv_bytes_per_token",
261 "kv_transfer_bandwidth",
262 "reasoning",
263 "zmq_kv_events_port",
264 ]
265 .iter()
266 .cloned()
267 .collect();
268
269 let invalid_args: Vec<String> = extra_args
271 .keys()
272 .filter(|key| !valid_fields.contains(key.as_str()))
273 .cloned()
274 .collect();
275
276 if !invalid_args.is_empty() {
277 return Err(anyhow::anyhow!(
278 "Invalid arguments found in JSON file: {}. Valid arguments are: {:?}",
279 invalid_args.join(", "),
280 valid_fields
281 ));
282 }
283
284 if let Some(value) = extra_args.get("num_gpu_blocks")
286 && let Some(num) = value.as_u64()
287 {
288 builder = builder.num_gpu_blocks(num as usize);
289 }
290
291 if let Some(value) = extra_args.get("block_size")
292 && let Some(num) = value.as_u64()
293 {
294 builder = builder.block_size(num as usize);
295 }
296
297 if let Some(value) = extra_args.get("max_num_seqs")
298 && let Some(num) = value.as_u64()
299 {
300 builder = builder.max_num_seqs(Some(num as usize));
301 }
302
303 if let Some(value) = extra_args.get("max_num_batched_tokens")
304 && let Some(num) = value.as_u64()
305 {
306 builder = builder.max_num_batched_tokens(Some(num as usize));
307 }
308
309 if let Some(value) = extra_args.get("enable_prefix_caching")
310 && let Some(enabled) = value.as_bool()
311 {
312 builder = builder.enable_prefix_caching(enabled);
313 }
314
315 if let Some(value) = extra_args.get("enable_chunked_prefill")
316 && let Some(enabled) = value.as_bool()
317 {
318 builder = builder.enable_chunked_prefill(enabled);
319 }
320
321 if let Some(value) = extra_args.get("watermark")
322 && let Some(num) = value.as_f64()
323 {
324 builder = builder.watermark(num);
325 }
326
327 if let Some(value) = extra_args.get("speedup_ratio")
328 && let Some(num) = value.as_f64()
329 {
330 builder = builder.speedup_ratio(num);
331 }
332
333 if let Some(value) = extra_args.get("dp_size")
334 && let Some(num) = value.as_u64()
335 {
336 builder = builder.dp_size(num as u32);
337 }
338
339 if let Some(value) = extra_args.get("startup_time")
340 && let Some(num) = value.as_f64()
341 {
342 builder = builder.startup_time(Some(num));
343 }
344
345 if let Some(value) = extra_args.get("enable_local_indexer")
346 && let Some(enabled) = value.as_bool()
347 {
348 builder = builder.enable_local_indexer(enabled);
349 }
350
351 if let Some(value) = extra_args.get("bootstrap_port")
352 && let Some(port) = value.as_u64()
353 {
354 builder = builder.bootstrap_port(Some(port as u16));
355 }
356
357 if let Some(value) = extra_args.get("kv_bytes_per_token")
358 && let Some(num) = value.as_u64()
359 {
360 builder = builder.kv_bytes_per_token(Some(num as usize));
361 }
362
363 if let Some(value) = extra_args.get("kv_transfer_bandwidth")
364 && let Some(num) = value.as_f64()
365 {
366 builder = builder.kv_transfer_bandwidth(Some(num));
367 }
368
369 if let Some(value) = extra_args.get("reasoning") {
370 let cfg: ReasoningConfig = serde_json::from_value(value.clone())
371 .map_err(|e| anyhow::anyhow!("Failed to parse reasoning config: {}", e))?;
372 builder = builder.reasoning(Some(cfg));
373 }
374
375 if let Some(value) = extra_args.get("zmq_kv_events_port")
376 && let Some(port) = value.as_u64()
377 {
378 builder = builder.zmq_kv_events_port(Some(port as u16));
379 }
380
381 let is_prefill = extra_args
383 .get("is_prefill")
384 .and_then(|v| v.as_bool())
385 .unwrap_or(false);
386 let is_decode = extra_args
387 .get("is_decode")
388 .and_then(|v| v.as_bool())
389 .unwrap_or(false);
390
391 let worker_type = match (is_prefill, is_decode) {
393 (false, false) => WorkerType::Aggregated,
394 (true, false) => WorkerType::Prefill,
395 (false, true) => WorkerType::Decode,
396 (true, true) => panic!(
397 "Invalid worker configuration: is_prefill and is_decode cannot both be true. \
398 Worker must be either Aggregated (both false), Prefill (is_prefill=true), or Decode (is_decode=true)."
399 ),
400 };
401 builder = builder.worker_type(worker_type);
402
403 let perf_model = if let Some(path_str) = extra_args.get("planner_profile_data")
405 && let Some(path_str) = path_str.as_str()
406 {
407 let npz_path = PathBuf::from(path_str);
408 match PerfModel::from_npz(&npz_path) {
409 Ok(model) => {
410 tracing::info!("Successfully loaded performance model from: {:?}", npz_path);
411 Arc::new(model)
412 }
413 Err(e) => {
414 tracing::error!(
415 "Failed to load performance model from {:?}: {}. Falling back to polynomial model.",
416 npz_path,
417 e
418 );
419 Arc::new(PerfModel::default())
420 }
421 }
422 } else {
423 Arc::new(PerfModel::default())
424 };
425 builder = builder.perf_model(perf_model);
426
427 builder
429 .build()
430 .map_err(|e| anyhow::anyhow!("Failed to build MockEngineArgs: {}", e))
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[test]
439 fn test_unique_block_default_uniqueness() {
440 let blocks: Vec<UniqueBlock> = (0..10).map(|_| UniqueBlock::default()).collect();
442
443 let mut uuids = Vec::new();
445 for block in blocks {
446 match block {
447 UniqueBlock::PartialBlock(uuid) => uuids.push(uuid),
448 _ => panic!("Expected UuidIdentifier variant"),
449 }
450 }
451
452 for i in 0..uuids.len() {
454 for j in i + 1..uuids.len() {
455 assert_ne!(
456 uuids[i], uuids[j],
457 "UUID at index {} and {} are identical",
458 i, j
459 );
460 }
461 }
462 }
463}