1use derive_builder::Builder;
5use serde::{Deserialize, Serialize};
6use std::collections::{HashMap, HashSet};
7use std::path::Path;
8use uuid::Uuid;
9
10use crate::kv_router::protocols::{
11 ExternalSequenceBlockHash, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
12 KvCacheStoredBlockData, LocalBlockHash,
13};
14use crate::tokens::blocks::UniqueBlock;
15use crate::tokens::{BlockHash, SequenceHash, Token};
16
17pub type NumBlocks = usize;
18
19#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
22pub enum MoveBlock {
23 Use(Vec<UniqueBlock>),
24 Destroy(Vec<UniqueBlock>),
25 Deref(Vec<UniqueBlock>),
26 Promote(Uuid, SequenceHash, Option<u64>),
27}
28
29#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
30pub enum MoveBlockResponse {
31 Store(Vec<SequenceHash>, Option<u64>),
32 Remove(Vec<SequenceHash>),
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct DirectRequest {
37 pub tokens: Vec<Token>,
38 pub max_output_tokens: usize,
39 pub uuid: Option<Uuid>,
40 pub dp_rank: Option<u32>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct PrefillCost {
46 pub new_blocks: usize,
47 pub new_tokens: usize,
48}
49
50impl PrefillCost {
51 pub fn predict_prefill_compute(&self, new_tokens: Option<usize>) -> f64 {
52 let tokens = new_tokens.unwrap_or(self.new_tokens);
53 1.25e-6 * (tokens as f64).powi(2) + 7.41e-2 * (tokens as f64) + 2.62e1
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct OutputSignal {
60 pub uuid: Uuid,
61 pub completed: bool,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
66#[builder(pattern = "owned", build_fn(public))]
67pub struct MockEngineArgs {
68 #[builder(default = "16384")]
69 pub num_gpu_blocks: usize,
70
71 #[builder(default = "64")]
72 pub block_size: usize,
73
74 #[builder(default = Some(256))]
76 pub max_num_seqs: Option<usize>,
77
78 #[builder(default = Some(8192))]
80 pub max_num_batched_tokens: Option<usize>,
81
82 #[builder(default = true)]
83 pub enable_prefix_caching: bool,
84
85 #[builder(default = true)]
86 pub enable_chunked_prefill: bool,
87
88 #[builder(default = "0.01")]
89 pub watermark: f64,
90
91 #[builder(default = "1.0")]
92 pub speedup_ratio: f64,
93
94 #[builder(default = "1")]
95 pub dp_size: u32,
96
97 #[builder(default = "None")]
99 pub startup_time: Option<f64>,
100}
101
102impl Default for MockEngineArgs {
103 fn default() -> MockEngineArgs {
104 MockEngineArgsBuilder::default()
105 .build()
106 .expect("Failed to build default MockEngineArgs")
107 }
108}
109
110impl MockEngineArgs {
111 pub fn builder() -> MockEngineArgsBuilder {
112 MockEngineArgsBuilder::default()
113 }
114
115 pub fn from_json_file(path: &Path) -> anyhow::Result<Self> {
117 let mut builder = Self::builder();
118
119 let file_content = std::fs::read_to_string(path)?;
121 let extra_args: HashMap<String, serde_json::Value> = serde_json::from_str(&file_content)?;
122
123 let valid_fields: HashSet<&str> = [
125 "num_gpu_blocks",
126 "block_size",
127 "max_num_seqs",
128 "max_num_batched_tokens",
129 "enable_prefix_caching",
130 "enable_chunked_prefill",
131 "watermark",
132 "speedup_ratio",
133 "dp_size",
134 "startup_time",
135 ]
136 .iter()
137 .cloned()
138 .collect();
139
140 let invalid_args: Vec<String> = extra_args
142 .keys()
143 .filter(|key| !valid_fields.contains(key.as_str()))
144 .cloned()
145 .collect();
146
147 if !invalid_args.is_empty() {
148 return Err(anyhow::anyhow!(
149 "Invalid arguments found in JSON file: {}. Valid arguments are: {:?}",
150 invalid_args.join(", "),
151 valid_fields
152 ));
153 }
154
155 if let Some(value) = extra_args.get("num_gpu_blocks")
157 && let Some(num) = value.as_u64()
158 {
159 builder = builder.num_gpu_blocks(num as usize);
160 }
161
162 if let Some(value) = extra_args.get("block_size")
163 && let Some(num) = value.as_u64()
164 {
165 builder = builder.block_size(num as usize);
166 }
167
168 if let Some(value) = extra_args.get("max_num_seqs")
169 && let Some(num) = value.as_u64()
170 {
171 builder = builder.max_num_seqs(Some(num as usize));
172 }
173
174 if let Some(value) = extra_args.get("max_num_batched_tokens")
175 && let Some(num) = value.as_u64()
176 {
177 builder = builder.max_num_batched_tokens(Some(num as usize));
178 }
179
180 if let Some(value) = extra_args.get("enable_prefix_caching")
181 && let Some(enabled) = value.as_bool()
182 {
183 builder = builder.enable_prefix_caching(enabled);
184 }
185
186 if let Some(value) = extra_args.get("enable_chunked_prefill")
187 && let Some(enabled) = value.as_bool()
188 {
189 builder = builder.enable_chunked_prefill(enabled);
190 }
191
192 if let Some(value) = extra_args.get("watermark")
193 && let Some(num) = value.as_f64()
194 {
195 builder = builder.watermark(num);
196 }
197
198 if let Some(value) = extra_args.get("speedup_ratio")
199 && let Some(num) = value.as_f64()
200 {
201 builder = builder.speedup_ratio(num);
202 }
203
204 if let Some(value) = extra_args.get("dp_size")
205 && let Some(num) = value.as_u64()
206 {
207 builder = builder.dp_size(num as u32);
208 }
209
210 if let Some(value) = extra_args.get("startup_time")
211 && let Some(num) = value.as_f64()
212 {
213 builder = builder.startup_time(Some(num));
214 }
215
216 builder
218 .build()
219 .map_err(|e| anyhow::anyhow!("Failed to build MockEngineArgs: {}", e))
220 }
221}
222
223pub fn block_response_to_kv_event(
234 response: MoveBlockResponse,
235 local_hashes: &[BlockHash],
236) -> KvCacheEventData {
237 match response {
238 MoveBlockResponse::Store(full_blocks, parent_hash) => {
239 let num_blocks = full_blocks.len();
240 let local_hashes_slice = &local_hashes[local_hashes
241 .len()
242 .checked_sub(num_blocks)
243 .expect("local hashes fewer than block response signal")..];
244
245 KvCacheEventData::Stored(KvCacheStoreData {
246 parent_hash: parent_hash.map(ExternalSequenceBlockHash),
247 blocks: full_blocks
248 .into_iter()
249 .zip(local_hashes_slice.iter())
250 .map(|(global_hash, local_hash)| KvCacheStoredBlockData {
251 block_hash: ExternalSequenceBlockHash(global_hash),
252 tokens_hash: LocalBlockHash(*local_hash),
253 })
254 .collect(),
255 })
256 }
257 MoveBlockResponse::Remove(full_blocks) => KvCacheEventData::Removed(KvCacheRemoveData {
258 block_hashes: full_blocks
259 .into_iter()
260 .map(ExternalSequenceBlockHash)
261 .collect(),
262 }),
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn test_unique_block_default_uniqueness() {
272 let blocks: Vec<UniqueBlock> = (0..10).map(|_| UniqueBlock::default()).collect();
274
275 let mut uuids = Vec::new();
277 for block in blocks {
278 match block {
279 UniqueBlock::PartialBlock(uuid) => uuids.push(uuid),
280 _ => panic!("Expected UuidIdentifier variant"),
281 }
282 }
283
284 for i in 0..uuids.len() {
286 for j in i + 1..uuids.len() {
287 assert_ne!(
288 uuids[i], uuids[j],
289 "UUID at index {} and {} are identical",
290 i, j
291 );
292 }
293 }
294 }
295}