Skip to main content

dynamo_mocker/common/
bootstrap.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Bootstrap rendezvous for disaggregated mocker testing.
5//!
6//! Simulates the SGLang disaggregated serving handshake for KV transfer coordination.
7//! Either prefill or decode can arrive first; the rendezvous completes when both are ready.
8//!
9//! - Prefill: calls `complete_room(room_id)` after first token (KV cache ready)
10//! - Decode: connects to prefill's bootstrap server, blocks until prefill completes
11//!
12//! Wire protocol:
13//! - Decode -> Prefill: room_id (8 bytes, little-endian u64)
14//! - Prefill -> Decode: ACK (1 byte, 0x01) after prefill completes
15
16use std::sync::Arc;
17use std::time::Duration;
18
19use anyhow::{Result, bail};
20use dashmap::DashMap;
21use dashmap::mapref::entry::Entry;
22use tokio::io::{AsyncReadExt, AsyncWriteExt};
23use tokio::net::{TcpListener, TcpStream};
24use tokio::sync::oneshot;
25use tokio_util::sync::CancellationToken;
26
27/// Timeout for bootstrap rendezvous operations.
28const RENDEZVOUS_TIMEOUT: Duration = Duration::from_secs(30);
29
30/// ACK byte sent from server to decode after prefill completes.
31const ACK_BYTE: u8 = 0x01;
32
33/// State for a room in the rendezvous.
34struct RoomState {
35    /// True if prefill has completed (KV cache ready)
36    prefill_completed: bool,
37    /// Channel to notify decode when prefill completes (if decode is waiting)
38    decode_waiting: Option<oneshot::Sender<()>>,
39}
40
41/// Bootstrap server for prefill mockers.
42/// Handles rendezvous between prefill and decode for KV transfer coordination.
43pub struct BootstrapServer {
44    port: u16,
45    rooms: Arc<DashMap<u64, RoomState>>,
46}
47
48impl BootstrapServer {
49    /// Start the bootstrap server on the specified port.
50    pub async fn start(port: u16, cancel_token: CancellationToken) -> Result<Arc<Self>> {
51        let listener = TcpListener::bind(format!("0.0.0.0:{port}")).await?;
52        let actual_port = listener.local_addr()?.port();
53
54        tracing::info!("Bootstrap server started on port {actual_port}");
55
56        let rooms: Arc<DashMap<u64, RoomState>> = Arc::new(DashMap::new());
57        let server = Arc::new(Self {
58            port: actual_port,
59            rooms: rooms.clone(),
60        });
61
62        // Spawn accept loop
63        tokio::spawn(async move {
64            loop {
65                tokio::select! {
66                    result = listener.accept() => {
67                        match result {
68                            Ok((stream, addr)) => {
69                                tracing::debug!("Bootstrap: accepted connection from {addr}");
70                                let rooms_clone = rooms.clone();
71                                tokio::spawn(async move {
72                                    if let Err(e) = Self::handle_connection(stream, rooms_clone).await {
73                                        tracing::warn!("Bootstrap: connection error: {e}");
74                                    }
75                                });
76                            }
77                            Err(e) => {
78                                tracing::warn!("Bootstrap: accept failed: {e}");
79                            }
80                        }
81                    }
82                    _ = cancel_token.cancelled() => {
83                        tracing::debug!("Bootstrap server shutting down");
84                        break;
85                    }
86                }
87            }
88        });
89
90        Ok(server)
91    }
92
93    /// Handle a connection from decode. Blocks until prefill completes for this room.
94    async fn handle_connection(
95        mut stream: TcpStream,
96        rooms: Arc<DashMap<u64, RoomState>>,
97    ) -> Result<()> {
98        // Read room_id (8 bytes, little-endian)
99        let mut buf = [0u8; 8];
100        stream.read_exact(&mut buf).await?;
101        let room_id = u64::from_le_bytes(buf);
102
103        tracing::debug!("Bootstrap: decode connected for room {room_id}");
104
105        // Check room state and wait if needed
106        let rx = match rooms.entry(room_id) {
107            Entry::Occupied(mut entry) => {
108                if entry.get().prefill_completed {
109                    // Prefill already done, immediate ACK
110                    entry.remove();
111                    tracing::debug!("Bootstrap: room {room_id} already completed, immediate ACK");
112                    None
113                } else {
114                    // Prefill registered but not completed, wait
115                    let (tx, rx) = oneshot::channel();
116                    entry.get_mut().decode_waiting = Some(tx);
117                    tracing::debug!("Bootstrap: room {room_id} waiting for prefill to complete");
118                    Some(rx)
119                }
120            }
121            Entry::Vacant(entry) => {
122                // Decode arrived first, create entry and wait
123                let (tx, rx) = oneshot::channel();
124                entry.insert(RoomState {
125                    prefill_completed: false,
126                    decode_waiting: Some(tx),
127                });
128                tracing::debug!("Bootstrap: room {room_id} decode arrived first, waiting");
129                Some(rx)
130            }
131        };
132
133        // Wait for prefill to complete if needed
134        if let Some(rx) = rx {
135            match tokio::time::timeout(RENDEZVOUS_TIMEOUT, rx).await {
136                Ok(Ok(())) => {
137                    tracing::debug!("Bootstrap: room {room_id} prefill completed, sending ACK");
138                }
139                Ok(Err(_)) => {
140                    bail!("Bootstrap: room {room_id} sender dropped");
141                }
142                Err(_) => {
143                    rooms.remove(&room_id);
144                    bail!("Bootstrap: room {room_id} timeout waiting for prefill");
145                }
146            }
147        }
148
149        // Send ACK
150        stream.write_all(&[ACK_BYTE]).await?;
151        Ok(())
152    }
153
154    /// Mark a room as completed (prefill finished, KV cache ready).
155    /// If decode is already waiting, unblocks it.
156    pub fn complete_room(&self, room_id: u64) {
157        match self.rooms.entry(room_id) {
158            Entry::Occupied(mut entry) => {
159                if let Some(sender) = entry.get_mut().decode_waiting.take() {
160                    // Decode is waiting, unblock it
161                    let _ = sender.send(());
162                    entry.remove();
163                    tracing::debug!("Bootstrap: room {room_id} completed, decode unblocked");
164                } else {
165                    // Decode not connected yet, mark completed
166                    entry.get_mut().prefill_completed = true;
167                    tracing::debug!("Bootstrap: room {room_id} completed, awaiting decode");
168                }
169            }
170            Entry::Vacant(entry) => {
171                // Decode hasn't connected yet
172                entry.insert(RoomState {
173                    prefill_completed: true,
174                    decode_waiting: None,
175                });
176                tracing::debug!("Bootstrap: room {room_id} completed (no decode yet)");
177            }
178        }
179    }
180
181    /// Get the port the server is listening on.
182    pub fn port(&self) -> u16 {
183        self.port
184    }
185}
186
187/// Connect to a prefill worker's bootstrap server and wait for KV to be ready.
188pub async fn connect_to_prefill(host: &str, port: u16, room_id: u64) -> Result<()> {
189    let host = host.trim_matches(|c| c == '[' || c == ']');
190    let addr = format!("{host}:{port}");
191
192    tracing::debug!("Bootstrap: decode connecting to {addr} for room {room_id}");
193
194    // Connect with timeout
195    let mut stream = tokio::time::timeout(RENDEZVOUS_TIMEOUT, TcpStream::connect(&addr))
196        .await
197        .map_err(|_| anyhow::anyhow!("Bootstrap: connect timeout to {addr}"))?
198        .map_err(|e| anyhow::anyhow!("Bootstrap: connect failed to {addr}: {e}"))?;
199
200    // Send room_id
201    stream.write_all(&room_id.to_le_bytes()).await?;
202
203    // Wait for ACK (blocks until prefill completes)
204    let mut ack = [0u8; 1];
205    tokio::time::timeout(RENDEZVOUS_TIMEOUT, stream.read_exact(&mut ack))
206        .await
207        .map_err(|_| anyhow::anyhow!("Bootstrap: ACK timeout for room {room_id}"))?
208        .map_err(|e| anyhow::anyhow!("Bootstrap: read ACK failed: {e}"))?;
209
210    if ack[0] != ACK_BYTE {
211        bail!(
212            "Bootstrap: invalid ACK byte {:02x} for room {room_id}",
213            ack[0]
214        );
215    }
216
217    tracing::debug!("Bootstrap: decode received ACK for room {room_id}");
218    Ok(())
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[tokio::test]
226    async fn test_prefill_completes_first() {
227        let cancel_token = CancellationToken::new();
228        let server = BootstrapServer::start(0, cancel_token.clone())
229            .await
230            .unwrap();
231
232        let port = server.port();
233        let room_id = 1001u64;
234
235        // Prefill completes first
236        server.complete_room(room_id);
237
238        // Decode connects - should get immediate ACK
239        let result = connect_to_prefill("127.0.0.1", port, room_id).await;
240        assert!(result.is_ok(), "Decode should succeed: {result:?}");
241
242        cancel_token.cancel();
243    }
244
245    #[tokio::test]
246    async fn test_decode_connects_first() {
247        let cancel_token = CancellationToken::new();
248        let server = BootstrapServer::start(0, cancel_token.clone())
249            .await
250            .unwrap();
251
252        let port = server.port();
253        let room_id = 1002u64;
254
255        // Spawn decode (will block waiting for prefill)
256        let decode_handle =
257            tokio::spawn(async move { connect_to_prefill("127.0.0.1", port, room_id).await });
258
259        // Give decode time to connect and register
260        tokio::time::sleep(Duration::from_millis(50)).await;
261
262        // Prefill completes - should unblock decode
263        server.complete_room(room_id);
264
265        let result = decode_handle.await.unwrap();
266        assert!(result.is_ok(), "Decode should succeed: {result:?}");
267
268        cancel_token.cancel();
269    }
270
271    #[tokio::test]
272    async fn test_interleaved_ordering() {
273        let cancel_token = CancellationToken::new();
274        let server = BootstrapServer::start(0, cancel_token.clone())
275            .await
276            .unwrap();
277
278        let port = server.port();
279        let room_id = 1003u64;
280
281        // Spawn decode
282        let server_clone = server.clone();
283        let decode_handle = tokio::spawn(async move {
284            // Small delay so prefill can "register" conceptually first
285            tokio::time::sleep(Duration::from_millis(10)).await;
286            connect_to_prefill("127.0.0.1", port, room_id).await
287        });
288
289        // Prefill completes after decode starts connecting
290        tokio::time::sleep(Duration::from_millis(50)).await;
291        server_clone.complete_room(room_id);
292
293        let result = decode_handle.await.unwrap();
294        assert!(result.is_ok(), "Decode should succeed: {result:?}");
295
296        cancel_token.cancel();
297    }
298
299    #[tokio::test]
300    async fn test_multiple_rooms_concurrent() {
301        let cancel_token = CancellationToken::new();
302        let server = BootstrapServer::start(0, cancel_token.clone())
303            .await
304            .unwrap();
305
306        let port = server.port();
307
308        let mut handles = vec![];
309
310        // Room 1: prefill first
311        let server1 = server.clone();
312        handles.push(tokio::spawn(async move {
313            server1.complete_room(2001);
314            tokio::time::sleep(Duration::from_millis(10)).await;
315            connect_to_prefill("127.0.0.1", port, 2001).await
316        }));
317
318        // Room 2: decode first
319        let server2 = server.clone();
320        handles.push(tokio::spawn(async move {
321            let decode = tokio::spawn(connect_to_prefill("127.0.0.1", port, 2002));
322            tokio::time::sleep(Duration::from_millis(50)).await;
323            server2.complete_room(2002);
324            decode.await.unwrap()
325        }));
326
327        // Room 3: simultaneous
328        let server3 = server.clone();
329        handles.push(tokio::spawn(async move {
330            let decode = tokio::spawn(connect_to_prefill("127.0.0.1", port, 2003));
331            server3.complete_room(2003);
332            decode.await.unwrap()
333        }));
334
335        for (i, handle) in handles.into_iter().enumerate() {
336            let result = handle.await.unwrap();
337            assert!(
338                result.is_ok(),
339                "Room {} should succeed: {result:?}",
340                2001 + i
341            );
342        }
343
344        cancel_token.cancel();
345    }
346
347    #[tokio::test]
348    async fn test_decode_timeout_no_prefill() {
349        let cancel_token = CancellationToken::new();
350        let server = BootstrapServer::start(0, cancel_token.clone())
351            .await
352            .unwrap();
353
354        let port = server.port();
355        let room_id = 9999u64;
356
357        // Decode connects but prefill never completes - use short timeout
358        let result = tokio::time::timeout(
359            Duration::from_millis(100),
360            connect_to_prefill("127.0.0.1", port, room_id),
361        )
362        .await;
363
364        // Should timeout (outer timeout, not inner RENDEZVOUS_TIMEOUT)
365        assert!(result.is_err(), "Should timeout waiting for prefill");
366
367        cancel_token.cancel();
368    }
369}