1use crate::tokens::{SequenceHash, Token};
5use serde::{Deserialize, Serialize};
6use uuid::Uuid;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9#[serde(tag = "method", rename_all = "snake_case")]
10pub enum RouterRequest {
11 #[serde(rename = "new")]
13 New {
14 tokens: Vec<Token>,
15 },
16 MarkPrefill,
17 MarkFree,
18}
19
20impl Default for RouterRequest {
21 fn default() -> Self {
22 RouterRequest::New { tokens: vec![] }
23 }
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27#[serde(tag = "method", rename_all = "snake_case")]
28pub enum RouterResponse {
29 New { worker_id: i64, overlap_blocks: u32 },
30 PrefillMarked { success: bool },
31 FreeMarked { success: bool },
32}
33
34#[derive(Debug)]
35pub struct WorkerSelectionResult {
36 pub worker_id: i64,
38
39 pub required_blocks: u64,
41
42 pub overlap_blocks: u32,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
48pub struct ForwardPassMetrics {
49 pub worker_stats: WorkerStats,
50 pub kv_stats: KvStats,
51 pub spec_decode_stats: Option<SpecDecodeStats>,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
55pub struct WorkerStats {
56 pub data_parallel_rank: Option<u32>,
58 pub request_active_slots: u64,
59 pub request_total_slots: u64,
60 pub num_requests_waiting: u64,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
64pub struct KvStats {
65 pub kv_active_blocks: u64,
66 pub kv_total_blocks: u64,
67 pub gpu_cache_usage_perc: f32,
69 pub gpu_prefix_cache_hit_rate: f32,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
74pub struct PredictiveLoadMetrics {
75 pub kv_active_blocks: u64,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
79#[serde(rename_all = "snake_case")]
80pub enum LoadMetrics {
81 EngineLoadMetrics(ForwardPassMetrics),
82 PredictiveLoadMetrics(PredictiveLoadMetrics),
83}
84
85impl LoadMetrics {
86 pub fn kv_active_blocks(&self) -> u64 {
87 match self {
88 LoadMetrics::EngineLoadMetrics(metrics) => metrics.kv_stats.kv_active_blocks,
89 LoadMetrics::PredictiveLoadMetrics(metrics) => metrics.kv_active_blocks,
90 }
91 }
92}
93
94impl Default for LoadMetrics {
95 fn default() -> Self {
96 LoadMetrics::PredictiveLoadMetrics(PredictiveLoadMetrics::default())
97 }
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
101pub struct SpecDecodeStats {
102 pub num_spec_tokens: Option<u32>,
103 pub num_drafts: Option<u32>,
104 pub num_draft_tokens: Option<u32>,
105 pub num_accepted_tokens: Option<u32>,
106 pub num_accepted_tokens_per_pos: Option<Vec<u32>>,
107}
108
109#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
112pub struct LocalBlockHash(pub u64);
113
114#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
119pub struct ExternalSequenceBlockHash(pub u64);
120
121impl From<u64> for ExternalSequenceBlockHash {
123 fn from(value: u64) -> Self {
124 Self(value)
125 }
126}
127
128impl From<i64> for ExternalSequenceBlockHash {
129 fn from(value: i64) -> Self {
132 Self(value as u64)
133 }
134}
135
136#[derive(Serialize, Deserialize, Debug, Clone)]
137pub struct PrefillEvent {
138 pub request_id: String,
139 pub worker_id: i64,
140 pub data: PrefillEventData,
141 pub router_id: Uuid,
142}
143
144#[derive(Serialize, Deserialize, Debug, Clone)]
149pub enum PrefillEventData {
150 NewPrefill(usize),
151 UpdatePrefill(usize),
152 CompletePrefill,
153}
154
155#[derive(Serialize, Deserialize, Debug, Clone)]
156pub struct ActiveSequenceEvent {
157 pub request_id: String,
158 pub worker_id: i64,
159 pub data: ActiveSequenceEventData,
160 pub router_id: Uuid,
161}
162
163#[derive(Serialize, Deserialize, Debug, Clone)]
164pub enum ActiveSequenceEventData {
165 AddRequest {
166 token_sequence: Option<Vec<SequenceHash>>,
167 isl: usize,
168 overlap: u32,
169 },
170 Free,
171 MarkPrefillCompleted,
172}
173
174#[derive(Serialize, Deserialize, Debug, Clone)]
175pub struct ActiveBlockEvent {
176 pub request_id: String,
177 pub data: ActiveBlockEventData,
178}
179
180#[derive(Serialize, Deserialize, Debug, Clone)]
181pub enum ActiveBlockEventData {
182 NewBlock(Vec<SequenceHash>),
183 FreeBlock,
184}
185
186#[derive(Serialize, Deserialize, Debug, Clone)]
188pub struct KvCacheEvents {
189 pub events: Vec<KvCacheEvent>,
191 pub shutdown: bool,
193}
194
195#[derive(Serialize, Deserialize, Debug, Clone)]
197pub struct KvCacheEvent {
198 pub event_id: u64,
200 pub data: KvCacheEventData,
202}
203
204#[derive(Serialize, Deserialize, Debug, Clone)]
208#[serde(rename_all = "snake_case")]
209pub enum KvCacheEventData {
210 Stored(KvCacheStoreData),
211 Removed(KvCacheRemoveData),
212 Cleared,
213}
214
215#[derive(Serialize, Deserialize, Debug, Clone)]
217pub struct KvCacheStoreData {
218 pub parent_hash: Option<ExternalSequenceBlockHash>,
220 pub blocks: Vec<KvCacheStoredBlockData>,
222}
223
224#[derive(Serialize, Deserialize, Debug, Clone)]
226pub struct KvCacheStoredBlockData {
227 pub block_hash: ExternalSequenceBlockHash,
229 pub tokens_hash: LocalBlockHash,
231}
232
233#[derive(Serialize, Deserialize, Debug, Clone)]
235pub struct KvCacheRemoveData {
236 pub block_hashes: Vec<ExternalSequenceBlockHash>,
238}
239
240impl Serialize for LocalBlockHash {
241 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
242 where
243 S: serde::Serializer,
244 {
245 serializer.serialize_u64(self.0)
246 }
247}
248
249impl<'de> Deserialize<'de> for LocalBlockHash {
250 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
251 where
252 D: serde::Deserializer<'de>,
253 {
254 let value = u64::deserialize(deserializer)?;
255 Ok(LocalBlockHash(value))
256 }
257}
258
259impl Serialize for ExternalSequenceBlockHash {
260 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
261 where
262 S: serde::Serializer,
263 {
264 serializer.serialize_u64(self.0)
265 }
266}
267
268impl<'de> Deserialize<'de> for ExternalSequenceBlockHash {
269 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
270 where
271 D: serde::Deserializer<'de>,
272 {
273 let value = u64::deserialize(deserializer)?;
274 Ok(ExternalSequenceBlockHash(value))
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281 use serde_json;
282
283 #[test]
284 fn test_local_block_hash_serialization() {
285 let hash = LocalBlockHash(12345);
286 let serialized = serde_json::to_string(&hash).unwrap();
287 assert_eq!(serialized, "12345");
288
289 let deserialized: LocalBlockHash = serde_json::from_str(&serialized).unwrap();
290 assert_eq!(deserialized, hash);
291 }
292
293 #[test]
294 fn test_external_sequence_block_hash_serialization() {
295 let hash = ExternalSequenceBlockHash(67890);
296 let serialized = serde_json::to_string(&hash).unwrap();
297 assert_eq!(serialized, "67890");
298
299 let deserialized: ExternalSequenceBlockHash = serde_json::from_str(&serialized).unwrap();
300 assert_eq!(deserialized, hash);
301 }
302
303 #[test]
304 fn test_kv_cache_events_serialization() {
305 let event_data = KvCacheEventData::Stored(KvCacheStoreData {
306 parent_hash: Some(ExternalSequenceBlockHash(1)),
307 blocks: vec![KvCacheStoredBlockData {
308 block_hash: ExternalSequenceBlockHash(2),
309 tokens_hash: LocalBlockHash(3),
310 }],
311 });
312
313 let event = KvCacheEvent {
314 event_id: 1,
315 data: event_data,
316 };
317
318 let events = KvCacheEvents {
319 events: vec![event],
320 shutdown: false,
321 };
322
323 let serialized = serde_json::to_string(&events).unwrap();
324 let deserialized: KvCacheEvents = serde_json::from_str(&serialized).unwrap();
325
326 assert_eq!(deserialized.events.len(), 1);
327 assert_eq!(deserialized.events[0].event_id, 1);
328 if let KvCacheEventData::Stored(store_data) = &deserialized.events[0].data {
329 assert_eq!(store_data.parent_hash.unwrap().0, 1);
330 assert_eq!(store_data.blocks.len(), 1);
331 assert_eq!(store_data.blocks[0].block_hash.0, 2);
332 assert_eq!(store_data.blocks[0].tokens_hash.0, 3);
333 } else {
334 panic!("Expected KvCacheEventData::Stored variant");
335 }
336 assert!(!deserialized.shutdown);
337 }
338
339 #[test]
340 fn test_kv_cache_remove_data_serialization() {
341 let remove_data = KvCacheRemoveData {
342 block_hashes: vec![ExternalSequenceBlockHash(4), ExternalSequenceBlockHash(5)],
343 };
344
345 let serialized = serde_json::to_string(&remove_data).unwrap();
346 let deserialized: KvCacheRemoveData = serde_json::from_str(&serialized).unwrap();
347
348 assert_eq!(deserialized.block_hashes.len(), 2);
349 assert_eq!(deserialized.block_hashes[0].0, 4);
350 assert_eq!(deserialized.block_hashes[1].0, 5);
351 }
352}