Skip to main content

dynamo_mocker/common/
protocols.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use derive_builder::Builder;
5use serde::{Deserialize, Serialize};
6use std::collections::{HashMap, HashSet};
7use std::path::{Path, PathBuf};
8use std::sync::Arc;
9use uuid::Uuid;
10use validator::Validate;
11
12use crate::common::perf_model::PerfModel;
13use dynamo_kv_router::protocols::KvCacheEvent;
14use dynamo_tokens::blocks::UniqueBlock;
15use dynamo_tokens::{BlockHash, SequenceHash, Token};
16
17/// Trait for publishing KV cache events.
18/// This abstracts the runtime dependency so mocker components can remain generic.
19pub trait KvCacheEventSink: Send + Sync {
20    fn publish(
21        &self,
22        event: KvCacheEvent,
23        block_token_ids: Option<&[Vec<u32>]>,
24    ) -> anyhow::Result<()>;
25}
26
27pub type NumBlocks = usize;
28
29/// Represents different block movement operations in the cache
30/// For Use and Promote variants, block hashes are included for KV event publishing
31#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
32pub enum MoveBlock {
33    Use(Vec<UniqueBlock>, Vec<BlockHash>, Option<Vec<Vec<u32>>>),
34    Destroy(Vec<UniqueBlock>),
35    Deref(Vec<UniqueBlock>),
36    Promote(Uuid, SequenceHash, Option<u64>, BlockHash, Option<Vec<u32>>),
37}
38
39#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
40pub enum MoveBlockResponse {
41    Store(Vec<SequenceHash>, Option<u64>),
42    Remove(Vec<SequenceHash>),
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct DirectRequest {
47    pub tokens: Vec<Token>,
48    pub max_output_tokens: usize,
49    pub uuid: Option<Uuid>,
50    pub dp_rank: u32,
51}
52
53/// Represents the cost of prefilling content in the cache
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct PrefillCost {
56    pub new_blocks: usize,
57    pub new_tokens: usize,
58}
59
60impl PrefillCost {
61    pub fn predict_prefill_compute(
62        &self,
63        new_tokens: Option<usize>,
64        perf_model: &PerfModel,
65    ) -> f64 {
66        let tokens = new_tokens.unwrap_or(self.new_tokens);
67        perf_model.predict_prefill_time(tokens)
68    }
69}
70
71/// Signal for output token generation with completion status
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct OutputSignal {
74    pub uuid: Uuid,
75    pub completed: bool,
76}
77
78/// Worker type for disaggregated serving configurations
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
80pub enum WorkerType {
81    /// Standard aggregated worker handling both prefill and decode
82    #[default]
83    Aggregated,
84    /// Dedicated prefill worker in disaggregated mode
85    Prefill,
86    /// Dedicated decode worker in disaggregated mode
87    Decode,
88}
89
90/// Configuration for reasoning/thinking token output in the mocker.
91///
92/// When set, the mocker wraps the first portion of each response in thinking
93/// boundary tokens: `[start_token, random..., end_token, random...]`.
94#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
95pub struct ReasoningConfig {
96    pub start_thinking_token_id: u32,
97    pub end_thinking_token_id: u32,
98    #[validate(range(min = 0.0, max = 1.0))]
99    pub thinking_ratio: f64,
100}
101
102impl ReasoningConfig {
103    /// Number of thinking tokens (including start/end boundaries) for a given osl.
104    /// Returns 0 if osl < 2 (thinking disabled). Otherwise clamps to [2, osl].
105    pub fn num_thinking_tokens(&self, max_output_tokens: usize) -> usize {
106        if max_output_tokens < 2 {
107            return 0;
108        }
109        let raw = (max_output_tokens as f64 * self.thinking_ratio).floor() as usize;
110        if raw == 0 {
111            return 0;
112        }
113        raw.max(2).min(max_output_tokens)
114    }
115
116    /// Number of response tokens after the thinking block.
117    pub fn num_response_tokens(&self, max_output_tokens: usize) -> usize {
118        max_output_tokens.saturating_sub(self.num_thinking_tokens(max_output_tokens))
119    }
120}
121
122/// Configuration arguments for MockVllmEngine
123#[derive(Debug, Clone, Serialize, Deserialize, Builder, Validate)]
124#[builder(pattern = "owned", build_fn(public))]
125pub struct MockEngineArgs {
126    #[builder(default = "16384")]
127    #[validate(range(min = 1))]
128    pub num_gpu_blocks: usize,
129
130    #[builder(default = "64")]
131    #[validate(range(min = 2))]
132    pub block_size: usize,
133
134    // This was 1024 in the past but reverted back to 256
135    #[builder(default = Some(256))]
136    #[validate(range(min = 1))]
137    pub max_num_seqs: Option<usize>,
138
139    // default for open api server, for llm class it's 16384
140    #[builder(default = Some(8192))]
141    #[validate(range(min = 1))]
142    pub max_num_batched_tokens: Option<usize>,
143
144    #[builder(default = true)]
145    pub enable_prefix_caching: bool,
146
147    #[builder(default = true)]
148    pub enable_chunked_prefill: bool,
149
150    #[builder(default = "0.01")]
151    #[validate(range(min = 0.0, max = 1.0))]
152    pub watermark: f64,
153
154    #[builder(default = "1.0")]
155    #[validate(range(min = 0.0))]
156    pub speedup_ratio: f64,
157
158    #[builder(default = "1")]
159    #[validate(range(min = 1))]
160    pub dp_size: u32,
161
162    /// Optional startup time in seconds to simulate engine initialization delay
163    #[builder(default = "None")]
164    #[validate(range(min = 0.0))]
165    pub startup_time: Option<f64>,
166
167    /// Worker type for disaggregated serving (Aggregated, Prefill, or Decode)
168    #[builder(default = "WorkerType::Aggregated")]
169    pub worker_type: WorkerType,
170
171    /// Performance model for timing predictions (not serialized, loaded from planner_profile_data)
172    #[serde(skip)]
173    #[builder(default = "Arc::new(PerfModel::default())")]
174    pub perf_model: Arc<PerfModel>,
175
176    /// Enable worker-local KV indexer for tracking this worker's own KV cache state
177    #[builder(default = "false")]
178    pub enable_local_indexer: bool,
179
180    /// Bootstrap port for disaggregated serving rendezvous.
181    /// Prefill workers listen on this port; decode workers connect to it.
182    /// If None, bootstrap rendezvous is disabled.
183    #[builder(default = "None")]
184    pub bootstrap_port: Option<u16>,
185
186    /// KV cache bytes per token, auto-computed from model config by Python CLI.
187    /// Formula: num_layers * 2 * num_kv_heads * head_dim * dtype_bytes
188    #[builder(default = "None")]
189    pub kv_bytes_per_token: Option<usize>,
190
191    /// KV cache transfer bandwidth in GB/s for disaggregated serving latency simulation.
192    /// Default: 64.0 (inter-node InfiniBand). Set to 0 to disable KV transfer delay.
193    /// For intra-node NVLink, typical value is ~450.
194    #[builder(default = "None")]
195    #[validate(range(min = 0.0))]
196    pub kv_transfer_bandwidth: Option<f64>,
197
198    /// Reasoning/thinking token configuration.
199    /// When set, the mocker wraps output in thinking boundary tokens.
200    #[builder(default = "None")]
201    pub reasoning: Option<ReasoningConfig>,
202
203    /// ZMQ port for publishing KV events in vLLM's native wire format.
204    /// When set, the scheduler publishes to a ZMQ PUB socket instead of directly to NATS.
205    /// A KvEventPublisher relay subscribes to this socket and forwards events to NATS.
206    #[builder(default = "None")]
207    pub zmq_kv_events_port: Option<u16>,
208}
209
210impl Default for MockEngineArgs {
211    fn default() -> MockEngineArgs {
212        MockEngineArgsBuilder::default()
213            .build()
214            .expect("Failed to build default MockEngineArgs")
215    }
216}
217
218impl MockEngineArgs {
219    pub fn builder() -> MockEngineArgsBuilder {
220        MockEngineArgsBuilder::default()
221    }
222
223    pub fn is_prefill(&self) -> bool {
224        self.worker_type == WorkerType::Prefill
225    }
226
227    pub fn is_decode(&self) -> bool {
228        self.worker_type == WorkerType::Decode
229    }
230
231    pub fn needs_kv_publisher(&self) -> bool {
232        self.enable_prefix_caching && !self.is_decode()
233    }
234
235    /// Create MockEngineArgs from a JSON file containing extra engine arguments
236    pub fn from_json_file(path: &Path) -> anyhow::Result<Self> {
237        let mut builder = Self::builder();
238
239        // Load and parse the JSON file
240        let file_content = std::fs::read_to_string(path)?;
241        let extra_args: HashMap<String, serde_json::Value> = serde_json::from_str(&file_content)?;
242
243        // Define valid field names
244        let valid_fields: HashSet<&str> = [
245            "num_gpu_blocks",
246            "block_size",
247            "max_num_seqs",
248            "max_num_batched_tokens",
249            "enable_prefix_caching",
250            "enable_chunked_prefill",
251            "watermark",
252            "speedup_ratio",
253            "dp_size",
254            "startup_time",
255            "is_prefill",
256            "is_decode",
257            "planner_profile_data",
258            "enable_local_indexer",
259            "bootstrap_port",
260            "kv_bytes_per_token",
261            "kv_transfer_bandwidth",
262            "reasoning",
263            "zmq_kv_events_port",
264        ]
265        .iter()
266        .cloned()
267        .collect();
268
269        // Check for invalid arguments
270        let invalid_args: Vec<String> = extra_args
271            .keys()
272            .filter(|key| !valid_fields.contains(key.as_str()))
273            .cloned()
274            .collect();
275
276        if !invalid_args.is_empty() {
277            return Err(anyhow::anyhow!(
278                "Invalid arguments found in JSON file: {}. Valid arguments are: {:?}",
279                invalid_args.join(", "),
280                valid_fields
281            ));
282        }
283
284        // Apply each extra argument to the builder
285        if let Some(value) = extra_args.get("num_gpu_blocks")
286            && let Some(num) = value.as_u64()
287        {
288            builder = builder.num_gpu_blocks(num as usize);
289        }
290
291        if let Some(value) = extra_args.get("block_size")
292            && let Some(num) = value.as_u64()
293        {
294            builder = builder.block_size(num as usize);
295        }
296
297        if let Some(value) = extra_args.get("max_num_seqs")
298            && let Some(num) = value.as_u64()
299        {
300            builder = builder.max_num_seqs(Some(num as usize));
301        }
302
303        if let Some(value) = extra_args.get("max_num_batched_tokens")
304            && let Some(num) = value.as_u64()
305        {
306            builder = builder.max_num_batched_tokens(Some(num as usize));
307        }
308
309        if let Some(value) = extra_args.get("enable_prefix_caching")
310            && let Some(enabled) = value.as_bool()
311        {
312            builder = builder.enable_prefix_caching(enabled);
313        }
314
315        if let Some(value) = extra_args.get("enable_chunked_prefill")
316            && let Some(enabled) = value.as_bool()
317        {
318            builder = builder.enable_chunked_prefill(enabled);
319        }
320
321        if let Some(value) = extra_args.get("watermark")
322            && let Some(num) = value.as_f64()
323        {
324            builder = builder.watermark(num);
325        }
326
327        if let Some(value) = extra_args.get("speedup_ratio")
328            && let Some(num) = value.as_f64()
329        {
330            builder = builder.speedup_ratio(num);
331        }
332
333        if let Some(value) = extra_args.get("dp_size")
334            && let Some(num) = value.as_u64()
335        {
336            builder = builder.dp_size(num as u32);
337        }
338
339        if let Some(value) = extra_args.get("startup_time")
340            && let Some(num) = value.as_f64()
341        {
342            builder = builder.startup_time(Some(num));
343        }
344
345        if let Some(value) = extra_args.get("enable_local_indexer")
346            && let Some(enabled) = value.as_bool()
347        {
348            builder = builder.enable_local_indexer(enabled);
349        }
350
351        if let Some(value) = extra_args.get("bootstrap_port")
352            && let Some(port) = value.as_u64()
353        {
354            builder = builder.bootstrap_port(Some(port as u16));
355        }
356
357        if let Some(value) = extra_args.get("kv_bytes_per_token")
358            && let Some(num) = value.as_u64()
359        {
360            builder = builder.kv_bytes_per_token(Some(num as usize));
361        }
362
363        if let Some(value) = extra_args.get("kv_transfer_bandwidth")
364            && let Some(num) = value.as_f64()
365        {
366            builder = builder.kv_transfer_bandwidth(Some(num));
367        }
368
369        if let Some(value) = extra_args.get("reasoning") {
370            let cfg: ReasoningConfig = serde_json::from_value(value.clone())
371                .map_err(|e| anyhow::anyhow!("Failed to parse reasoning config: {}", e))?;
372            builder = builder.reasoning(Some(cfg));
373        }
374
375        if let Some(value) = extra_args.get("zmq_kv_events_port")
376            && let Some(port) = value.as_u64()
377        {
378            builder = builder.zmq_kv_events_port(Some(port as u16));
379        }
380
381        // Parse worker type from is_prefill and is_decode flags
382        let is_prefill = extra_args
383            .get("is_prefill")
384            .and_then(|v| v.as_bool())
385            .unwrap_or(false);
386        let is_decode = extra_args
387            .get("is_decode")
388            .and_then(|v| v.as_bool())
389            .unwrap_or(false);
390
391        // Determine worker type based on flags
392        let worker_type = match (is_prefill, is_decode) {
393            (false, false) => WorkerType::Aggregated,
394            (true, false) => WorkerType::Prefill,
395            (false, true) => WorkerType::Decode,
396            (true, true) => panic!(
397                "Invalid worker configuration: is_prefill and is_decode cannot both be true. \
398                 Worker must be either Aggregated (both false), Prefill (is_prefill=true), or Decode (is_decode=true)."
399            ),
400        };
401        builder = builder.worker_type(worker_type);
402
403        // Load performance model from NPZ file if provided
404        let perf_model = if let Some(path_str) = extra_args.get("planner_profile_data")
405            && let Some(path_str) = path_str.as_str()
406        {
407            let npz_path = PathBuf::from(path_str);
408            match PerfModel::from_npz(&npz_path) {
409                Ok(model) => {
410                    tracing::info!("Successfully loaded performance model from: {:?}", npz_path);
411                    Arc::new(model)
412                }
413                Err(e) => {
414                    tracing::error!(
415                        "Failed to load performance model from {:?}: {}. Falling back to polynomial model.",
416                        npz_path,
417                        e
418                    );
419                    Arc::new(PerfModel::default())
420                }
421            }
422        } else {
423            Arc::new(PerfModel::default())
424        };
425        builder = builder.perf_model(perf_model);
426
427        // Build the MockEngineArgs with either defaults or overridden values
428        builder
429            .build()
430            .map_err(|e| anyhow::anyhow!("Failed to build MockEngineArgs: {}", e))
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn test_unique_block_default_uniqueness() {
440        // Create 10 default UniqueBlock instances
441        let blocks: Vec<UniqueBlock> = (0..10).map(|_| UniqueBlock::default()).collect();
442
443        // Extract UUIDs from each block
444        let mut uuids = Vec::new();
445        for block in blocks {
446            match block {
447                UniqueBlock::PartialBlock(uuid) => uuids.push(uuid),
448                _ => panic!("Expected UuidIdentifier variant"),
449            }
450        }
451
452        // Check that all UUIDs are unique by comparing each with every other
453        for i in 0..uuids.len() {
454            for j in i + 1..uuids.len() {
455                assert_ne!(
456                    uuids[i], uuids[j],
457                    "UUID at index {} and {} are identical",
458                    i, j
459                );
460            }
461        }
462    }
463}