dynamo_llm/kv_router/
protocols.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::tokens::{SequenceHash, Token};
5use serde::{Deserialize, Serialize};
6use uuid::Uuid;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9#[serde(tag = "method", rename_all = "snake_case")]
10pub enum RouterRequest {
11    // ini
12    #[serde(rename = "new")]
13    New {
14        tokens: Vec<Token>,
15    },
16    MarkPrefill,
17    MarkFree,
18}
19
20impl Default for RouterRequest {
21    fn default() -> Self {
22        RouterRequest::New { tokens: vec![] }
23    }
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27#[serde(tag = "method", rename_all = "snake_case")]
28pub enum RouterResponse {
29    New { worker_id: i64, overlap_blocks: u32 },
30    PrefillMarked { success: bool },
31    FreeMarked { success: bool },
32}
33
34#[derive(Debug)]
35pub struct WorkerSelectionResult {
36    /// The worker id of the selected worker
37    pub worker_id: i64,
38
39    /// The total number of blocks required to prefill the request
40    pub required_blocks: u64,
41
42    /// The number of blocks that the selected worker may already have cached.
43    /// This is not a guarantee, but an estimate.
44    pub overlap_blocks: u32,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
48pub struct ForwardPassMetrics {
49    pub worker_stats: WorkerStats,
50    pub kv_stats: KvStats,
51    pub spec_decode_stats: Option<SpecDecodeStats>,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
55pub struct WorkerStats {
56    // https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models
57    pub data_parallel_rank: Option<u32>,
58    pub request_active_slots: u64,
59    pub request_total_slots: u64,
60    pub num_requests_waiting: u64,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
64pub struct KvStats {
65    pub kv_active_blocks: u64,
66    pub kv_total_blocks: u64,
67    // percentage represented as a float from 0 to 1
68    pub gpu_cache_usage_perc: f32,
69    // percentage represented as a float from 0 to 1
70    pub gpu_prefix_cache_hit_rate: f32,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
74pub struct PredictiveLoadMetrics {
75    pub kv_active_blocks: u64,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
79#[serde(rename_all = "snake_case")]
80pub enum LoadMetrics {
81    EngineLoadMetrics(ForwardPassMetrics),
82    PredictiveLoadMetrics(PredictiveLoadMetrics),
83}
84
85impl LoadMetrics {
86    pub fn kv_active_blocks(&self) -> u64 {
87        match self {
88            LoadMetrics::EngineLoadMetrics(metrics) => metrics.kv_stats.kv_active_blocks,
89            LoadMetrics::PredictiveLoadMetrics(metrics) => metrics.kv_active_blocks,
90        }
91    }
92}
93
94impl Default for LoadMetrics {
95    fn default() -> Self {
96        LoadMetrics::PredictiveLoadMetrics(PredictiveLoadMetrics::default())
97    }
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
101pub struct SpecDecodeStats {
102    pub num_spec_tokens: Option<u32>,
103    pub num_drafts: Option<u32>,
104    pub num_draft_tokens: Option<u32>,
105    pub num_accepted_tokens: Option<u32>,
106    pub num_accepted_tokens_per_pos: Option<Vec<u32>>,
107}
108
109/// A [`LocalBlockHash`] is a hash computed from the tokens_ids, extra_token_ids and the optional
110/// lora_id of a block.
111#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
112pub struct LocalBlockHash(pub u64);
113
114/// A sequence aware hash of a block where the hash is computed from the tokens_ids, extra_token_ids
115/// and the optional lora_id of a block, PLUS the hash of the parent block.
116///
117/// In this case, the hashing function is external and unknown.
118#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
119pub struct ExternalSequenceBlockHash(pub u64);
120
121// Implement From trait for convenient conversion
122impl From<u64> for ExternalSequenceBlockHash {
123    fn from(value: u64) -> Self {
124        Self(value)
125    }
126}
127
128impl From<i64> for ExternalSequenceBlockHash {
129    /// Bitwise reinterpretation: preserves all bits, including negatives.
130    /// This is lossless, but negative i64 values will appear as large u64 values.
131    fn from(value: i64) -> Self {
132        Self(value as u64)
133    }
134}
135
136#[derive(Serialize, Deserialize, Debug, Clone)]
137pub struct PrefillEvent {
138    pub request_id: String,
139    pub worker_id: i64,
140    pub data: PrefillEventData,
141    pub router_id: Uuid,
142}
143
144/// Represents the different stages of prefilling tokens for a request.
145///
146/// Each variant contains a `usize` representing the number of tokens
147/// that are pending prefill in the request.
148#[derive(Serialize, Deserialize, Debug, Clone)]
149pub enum PrefillEventData {
150    NewPrefill(usize),
151    UpdatePrefill(usize),
152    CompletePrefill,
153}
154
155#[derive(Serialize, Deserialize, Debug, Clone)]
156pub struct ActiveSequenceEvent {
157    pub request_id: String,
158    pub worker_id: i64,
159    pub data: ActiveSequenceEventData,
160    pub router_id: Uuid,
161}
162
163#[derive(Serialize, Deserialize, Debug, Clone)]
164pub enum ActiveSequenceEventData {
165    AddRequest {
166        token_sequence: Option<Vec<SequenceHash>>,
167        isl: usize,
168        overlap: u32,
169    },
170    Free,
171    MarkPrefillCompleted,
172}
173
174#[derive(Serialize, Deserialize, Debug, Clone)]
175pub struct ActiveBlockEvent {
176    pub request_id: String,
177    pub data: ActiveBlockEventData,
178}
179
180#[derive(Serialize, Deserialize, Debug, Clone)]
181pub enum ActiveBlockEventData {
182    NewBlock(Vec<SequenceHash>),
183    FreeBlock,
184}
185
186/// Represents a collection of cache events and a shutdown flag.
187#[derive(Serialize, Deserialize, Debug, Clone)]
188pub struct KvCacheEvents {
189    /// A list of cache events.
190    pub events: Vec<KvCacheEvent>,
191    /// A flag indicating whether the cache is shutting down.
192    pub shutdown: bool,
193}
194
195/// Represents a single cache event with an ID and associated data.
196#[derive(Serialize, Deserialize, Debug, Clone)]
197pub struct KvCacheEvent {
198    /// The unique identifier of the event.
199    pub event_id: u64,
200    /// The data associated with the event.
201    pub data: KvCacheEventData,
202}
203
204/// Represents the data associated with a cache event.
205///
206/// Data is either stored or removed.
207#[derive(Serialize, Deserialize, Debug, Clone)]
208#[serde(rename_all = "snake_case")]
209pub enum KvCacheEventData {
210    Stored(KvCacheStoreData),
211    Removed(KvCacheRemoveData),
212    Cleared,
213}
214
215/// Represents the data associated with a stored cache event.
216#[derive(Serialize, Deserialize, Debug, Clone)]
217pub struct KvCacheStoreData {
218    /// The optional hash of the parent block.
219    pub parent_hash: Option<ExternalSequenceBlockHash>,
220    /// A list of stored blocked data.
221    pub blocks: Vec<KvCacheStoredBlockData>,
222}
223
224/// Represents data for a stored block.
225#[derive(Serialize, Deserialize, Debug, Clone)]
226pub struct KvCacheStoredBlockData {
227    /// The hash of the block.
228    pub block_hash: ExternalSequenceBlockHash,
229    /// The hash of the tokens in the block.
230    pub tokens_hash: LocalBlockHash,
231}
232
233/// Represents the data associated with a removed cache event.
234#[derive(Serialize, Deserialize, Debug, Clone)]
235pub struct KvCacheRemoveData {
236    /// A list of block hashes to remove.
237    pub block_hashes: Vec<ExternalSequenceBlockHash>,
238}
239
240impl Serialize for LocalBlockHash {
241    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
242    where
243        S: serde::Serializer,
244    {
245        serializer.serialize_u64(self.0)
246    }
247}
248
249impl<'de> Deserialize<'de> for LocalBlockHash {
250    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
251    where
252        D: serde::Deserializer<'de>,
253    {
254        let value = u64::deserialize(deserializer)?;
255        Ok(LocalBlockHash(value))
256    }
257}
258
259impl Serialize for ExternalSequenceBlockHash {
260    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
261    where
262        S: serde::Serializer,
263    {
264        serializer.serialize_u64(self.0)
265    }
266}
267
268impl<'de> Deserialize<'de> for ExternalSequenceBlockHash {
269    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
270    where
271        D: serde::Deserializer<'de>,
272    {
273        let value = u64::deserialize(deserializer)?;
274        Ok(ExternalSequenceBlockHash(value))
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use serde_json;
282
283    #[test]
284    fn test_local_block_hash_serialization() {
285        let hash = LocalBlockHash(12345);
286        let serialized = serde_json::to_string(&hash).unwrap();
287        assert_eq!(serialized, "12345");
288
289        let deserialized: LocalBlockHash = serde_json::from_str(&serialized).unwrap();
290        assert_eq!(deserialized, hash);
291    }
292
293    #[test]
294    fn test_external_sequence_block_hash_serialization() {
295        let hash = ExternalSequenceBlockHash(67890);
296        let serialized = serde_json::to_string(&hash).unwrap();
297        assert_eq!(serialized, "67890");
298
299        let deserialized: ExternalSequenceBlockHash = serde_json::from_str(&serialized).unwrap();
300        assert_eq!(deserialized, hash);
301    }
302
303    #[test]
304    fn test_kv_cache_events_serialization() {
305        let event_data = KvCacheEventData::Stored(KvCacheStoreData {
306            parent_hash: Some(ExternalSequenceBlockHash(1)),
307            blocks: vec![KvCacheStoredBlockData {
308                block_hash: ExternalSequenceBlockHash(2),
309                tokens_hash: LocalBlockHash(3),
310            }],
311        });
312
313        let event = KvCacheEvent {
314            event_id: 1,
315            data: event_data,
316        };
317
318        let events = KvCacheEvents {
319            events: vec![event],
320            shutdown: false,
321        };
322
323        let serialized = serde_json::to_string(&events).unwrap();
324        let deserialized: KvCacheEvents = serde_json::from_str(&serialized).unwrap();
325
326        assert_eq!(deserialized.events.len(), 1);
327        assert_eq!(deserialized.events[0].event_id, 1);
328        if let KvCacheEventData::Stored(store_data) = &deserialized.events[0].data {
329            assert_eq!(store_data.parent_hash.unwrap().0, 1);
330            assert_eq!(store_data.blocks.len(), 1);
331            assert_eq!(store_data.blocks[0].block_hash.0, 2);
332            assert_eq!(store_data.blocks[0].tokens_hash.0, 3);
333        } else {
334            panic!("Expected KvCacheEventData::Stored variant");
335        }
336        assert!(!deserialized.shutdown);
337    }
338
339    #[test]
340    fn test_kv_cache_remove_data_serialization() {
341        let remove_data = KvCacheRemoveData {
342            block_hashes: vec![ExternalSequenceBlockHash(4), ExternalSequenceBlockHash(5)],
343        };
344
345        let serialized = serde_json::to_string(&remove_data).unwrap();
346        let deserialized: KvCacheRemoveData = serde_json::from_str(&serialized).unwrap();
347
348        assert_eq!(deserialized.block_hashes.len(), 2);
349        assert_eq!(deserialized.block_hashes[0].0, 4);
350        assert_eq!(deserialized.block_hashes[1].0, 5);
351    }
352}