dynamo_llm/mocker/
protocols.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use derive_builder::Builder;
5use serde::{Deserialize, Serialize};
6use std::collections::{HashMap, HashSet};
7use std::path::Path;
8use uuid::Uuid;
9
10use crate::kv_router::protocols::{
11    ExternalSequenceBlockHash, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
12    KvCacheStoredBlockData, LocalBlockHash,
13};
14use crate::tokens::blocks::UniqueBlock;
15use crate::tokens::{BlockHash, SequenceHash, Token};
16
17pub type NumBlocks = usize;
18
19/// Represents different block movement operations in the cache
20/// For Use and Promote variants, parent hash is the second field
21#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
22pub enum MoveBlock {
23    Use(Vec<UniqueBlock>),
24    Destroy(Vec<UniqueBlock>),
25    Deref(Vec<UniqueBlock>),
26    Promote(Uuid, SequenceHash, Option<u64>),
27}
28
29#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
30pub enum MoveBlockResponse {
31    Store(Vec<SequenceHash>, Option<u64>),
32    Remove(Vec<SequenceHash>),
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct DirectRequest {
37    pub tokens: Vec<Token>,
38    pub max_output_tokens: usize,
39    pub uuid: Option<Uuid>,
40    pub dp_rank: Option<u32>,
41}
42
43/// Represents the cost of prefilling content in the cache
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct PrefillCost {
46    pub new_blocks: usize,
47    pub new_tokens: usize,
48}
49
50impl PrefillCost {
51    pub fn predict_prefill_compute(&self, new_tokens: Option<usize>) -> f64 {
52        let tokens = new_tokens.unwrap_or(self.new_tokens);
53        1.25e-6 * (tokens as f64).powi(2) + 7.41e-2 * (tokens as f64) + 2.62e1
54    }
55}
56
57/// Signal for output token generation with completion status
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct OutputSignal {
60    pub uuid: Uuid,
61    pub completed: bool,
62}
63
64/// Configuration arguments for MockVllmEngine
65#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
66#[builder(pattern = "owned", build_fn(public))]
67pub struct MockEngineArgs {
68    #[builder(default = "16384")]
69    pub num_gpu_blocks: usize,
70
71    #[builder(default = "64")]
72    pub block_size: usize,
73
74    // This was 1024 in the past but reverted back to 256
75    #[builder(default = Some(256))]
76    pub max_num_seqs: Option<usize>,
77
78    // default for open api server, for llm class it's 16384
79    #[builder(default = Some(8192))]
80    pub max_num_batched_tokens: Option<usize>,
81
82    #[builder(default = true)]
83    pub enable_prefix_caching: bool,
84
85    #[builder(default = true)]
86    pub enable_chunked_prefill: bool,
87
88    #[builder(default = "0.01")]
89    pub watermark: f64,
90
91    #[builder(default = "1.0")]
92    pub speedup_ratio: f64,
93
94    #[builder(default = "1")]
95    pub dp_size: u32,
96
97    /// Optional startup time in seconds to simulate engine initialization delay
98    #[builder(default = "None")]
99    pub startup_time: Option<f64>,
100}
101
102impl Default for MockEngineArgs {
103    fn default() -> MockEngineArgs {
104        MockEngineArgsBuilder::default()
105            .build()
106            .expect("Failed to build default MockEngineArgs")
107    }
108}
109
110impl MockEngineArgs {
111    pub fn builder() -> MockEngineArgsBuilder {
112        MockEngineArgsBuilder::default()
113    }
114
115    /// Create MockEngineArgs from a JSON file containing extra engine arguments
116    pub fn from_json_file(path: &Path) -> anyhow::Result<Self> {
117        let mut builder = Self::builder();
118
119        // Load and parse the JSON file
120        let file_content = std::fs::read_to_string(path)?;
121        let extra_args: HashMap<String, serde_json::Value> = serde_json::from_str(&file_content)?;
122
123        // Define valid field names
124        let valid_fields: HashSet<&str> = [
125            "num_gpu_blocks",
126            "block_size",
127            "max_num_seqs",
128            "max_num_batched_tokens",
129            "enable_prefix_caching",
130            "enable_chunked_prefill",
131            "watermark",
132            "speedup_ratio",
133            "dp_size",
134            "startup_time",
135        ]
136        .iter()
137        .cloned()
138        .collect();
139
140        // Check for invalid arguments
141        let invalid_args: Vec<String> = extra_args
142            .keys()
143            .filter(|key| !valid_fields.contains(key.as_str()))
144            .cloned()
145            .collect();
146
147        if !invalid_args.is_empty() {
148            return Err(anyhow::anyhow!(
149                "Invalid arguments found in JSON file: {}. Valid arguments are: {:?}",
150                invalid_args.join(", "),
151                valid_fields
152            ));
153        }
154
155        // Apply each extra argument to the builder
156        if let Some(value) = extra_args.get("num_gpu_blocks")
157            && let Some(num) = value.as_u64()
158        {
159            builder = builder.num_gpu_blocks(num as usize);
160        }
161
162        if let Some(value) = extra_args.get("block_size")
163            && let Some(num) = value.as_u64()
164        {
165            builder = builder.block_size(num as usize);
166        }
167
168        if let Some(value) = extra_args.get("max_num_seqs")
169            && let Some(num) = value.as_u64()
170        {
171            builder = builder.max_num_seqs(Some(num as usize));
172        }
173
174        if let Some(value) = extra_args.get("max_num_batched_tokens")
175            && let Some(num) = value.as_u64()
176        {
177            builder = builder.max_num_batched_tokens(Some(num as usize));
178        }
179
180        if let Some(value) = extra_args.get("enable_prefix_caching")
181            && let Some(enabled) = value.as_bool()
182        {
183            builder = builder.enable_prefix_caching(enabled);
184        }
185
186        if let Some(value) = extra_args.get("enable_chunked_prefill")
187            && let Some(enabled) = value.as_bool()
188        {
189            builder = builder.enable_chunked_prefill(enabled);
190        }
191
192        if let Some(value) = extra_args.get("watermark")
193            && let Some(num) = value.as_f64()
194        {
195            builder = builder.watermark(num);
196        }
197
198        if let Some(value) = extra_args.get("speedup_ratio")
199            && let Some(num) = value.as_f64()
200        {
201            builder = builder.speedup_ratio(num);
202        }
203
204        if let Some(value) = extra_args.get("dp_size")
205            && let Some(num) = value.as_u64()
206        {
207            builder = builder.dp_size(num as u32);
208        }
209
210        if let Some(value) = extra_args.get("startup_time")
211            && let Some(num) = value.as_f64()
212        {
213            builder = builder.startup_time(Some(num));
214        }
215
216        // Build the MockEngineArgs with either defaults or overridden values
217        builder
218            .build()
219            .map_err(|e| anyhow::anyhow!("Failed to build MockEngineArgs: {}", e))
220    }
221}
222
223/// Converts a MoveBlockResponse from the mocker backend into a KvCacheEventData.
224///
225/// This function assumes that the stored sequence hashes in the response always
226/// correspond to the tail part of the local hashes array. This is the expected
227/// behavior of KV block storage, where blocks are stored sequentially and the
228/// response contains the most recent blocks that were stored.
229///
230/// # Panics
231/// Panics if the number of blocks in the Store response exceeds the length
232/// of local_hashes.
233pub fn block_response_to_kv_event(
234    response: MoveBlockResponse,
235    local_hashes: &[BlockHash],
236) -> KvCacheEventData {
237    match response {
238        MoveBlockResponse::Store(full_blocks, parent_hash) => {
239            let num_blocks = full_blocks.len();
240            let local_hashes_slice = &local_hashes[local_hashes
241                .len()
242                .checked_sub(num_blocks)
243                .expect("local hashes fewer than block response signal")..];
244
245            KvCacheEventData::Stored(KvCacheStoreData {
246                parent_hash: parent_hash.map(ExternalSequenceBlockHash),
247                blocks: full_blocks
248                    .into_iter()
249                    .zip(local_hashes_slice.iter())
250                    .map(|(global_hash, local_hash)| KvCacheStoredBlockData {
251                        block_hash: ExternalSequenceBlockHash(global_hash),
252                        tokens_hash: LocalBlockHash(*local_hash),
253                    })
254                    .collect(),
255            })
256        }
257        MoveBlockResponse::Remove(full_blocks) => KvCacheEventData::Removed(KvCacheRemoveData {
258            block_hashes: full_blocks
259                .into_iter()
260                .map(ExternalSequenceBlockHash)
261                .collect(),
262        }),
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn test_unique_block_default_uniqueness() {
272        // Create 10 default UniqueBlock instances
273        let blocks: Vec<UniqueBlock> = (0..10).map(|_| UniqueBlock::default()).collect();
274
275        // Extract UUIDs from each block
276        let mut uuids = Vec::new();
277        for block in blocks {
278            match block {
279                UniqueBlock::PartialBlock(uuid) => uuids.push(uuid),
280                _ => panic!("Expected UuidIdentifier variant"),
281            }
282        }
283
284        // Check that all UUIDs are unique by comparing each with every other
285        for i in 0..uuids.len() {
286            for j in i + 1..uuids.len() {
287                assert_ne!(
288                    uuids[i], uuids[j],
289                    "UUID at index {} and {} are identical",
290                    i, j
291                );
292            }
293        }
294    }
295}