dynamo_mocker/common/
bootstrap.rs1use std::sync::Arc;
17use std::time::Duration;
18
19use anyhow::{Result, bail};
20use dashmap::DashMap;
21use dashmap::mapref::entry::Entry;
22use tokio::io::{AsyncReadExt, AsyncWriteExt};
23use tokio::net::{TcpListener, TcpStream};
24use tokio::sync::oneshot;
25use tokio_util::sync::CancellationToken;
26
27const RENDEZVOUS_TIMEOUT: Duration = Duration::from_secs(30);
29
30const ACK_BYTE: u8 = 0x01;
32
33struct RoomState {
35 prefill_completed: bool,
37 decode_waiting: Option<oneshot::Sender<()>>,
39}
40
41pub struct BootstrapServer {
44 port: u16,
45 rooms: Arc<DashMap<u64, RoomState>>,
46}
47
48impl BootstrapServer {
49 pub async fn start(port: u16, cancel_token: CancellationToken) -> Result<Arc<Self>> {
51 let listener = TcpListener::bind(format!("0.0.0.0:{port}")).await?;
52 let actual_port = listener.local_addr()?.port();
53
54 tracing::info!("Bootstrap server started on port {actual_port}");
55
56 let rooms: Arc<DashMap<u64, RoomState>> = Arc::new(DashMap::new());
57 let server = Arc::new(Self {
58 port: actual_port,
59 rooms: rooms.clone(),
60 });
61
62 tokio::spawn(async move {
64 loop {
65 tokio::select! {
66 result = listener.accept() => {
67 match result {
68 Ok((stream, addr)) => {
69 tracing::debug!("Bootstrap: accepted connection from {addr}");
70 let rooms_clone = rooms.clone();
71 tokio::spawn(async move {
72 if let Err(e) = Self::handle_connection(stream, rooms_clone).await {
73 tracing::warn!("Bootstrap: connection error: {e}");
74 }
75 });
76 }
77 Err(e) => {
78 tracing::warn!("Bootstrap: accept failed: {e}");
79 }
80 }
81 }
82 _ = cancel_token.cancelled() => {
83 tracing::debug!("Bootstrap server shutting down");
84 break;
85 }
86 }
87 }
88 });
89
90 Ok(server)
91 }
92
93 async fn handle_connection(
95 mut stream: TcpStream,
96 rooms: Arc<DashMap<u64, RoomState>>,
97 ) -> Result<()> {
98 let mut buf = [0u8; 8];
100 stream.read_exact(&mut buf).await?;
101 let room_id = u64::from_le_bytes(buf);
102
103 tracing::debug!("Bootstrap: decode connected for room {room_id}");
104
105 let rx = match rooms.entry(room_id) {
107 Entry::Occupied(mut entry) => {
108 if entry.get().prefill_completed {
109 entry.remove();
111 tracing::debug!("Bootstrap: room {room_id} already completed, immediate ACK");
112 None
113 } else {
114 let (tx, rx) = oneshot::channel();
116 entry.get_mut().decode_waiting = Some(tx);
117 tracing::debug!("Bootstrap: room {room_id} waiting for prefill to complete");
118 Some(rx)
119 }
120 }
121 Entry::Vacant(entry) => {
122 let (tx, rx) = oneshot::channel();
124 entry.insert(RoomState {
125 prefill_completed: false,
126 decode_waiting: Some(tx),
127 });
128 tracing::debug!("Bootstrap: room {room_id} decode arrived first, waiting");
129 Some(rx)
130 }
131 };
132
133 if let Some(rx) = rx {
135 match tokio::time::timeout(RENDEZVOUS_TIMEOUT, rx).await {
136 Ok(Ok(())) => {
137 tracing::debug!("Bootstrap: room {room_id} prefill completed, sending ACK");
138 }
139 Ok(Err(_)) => {
140 bail!("Bootstrap: room {room_id} sender dropped");
141 }
142 Err(_) => {
143 rooms.remove(&room_id);
144 bail!("Bootstrap: room {room_id} timeout waiting for prefill");
145 }
146 }
147 }
148
149 stream.write_all(&[ACK_BYTE]).await?;
151 Ok(())
152 }
153
154 pub fn complete_room(&self, room_id: u64) {
157 match self.rooms.entry(room_id) {
158 Entry::Occupied(mut entry) => {
159 if let Some(sender) = entry.get_mut().decode_waiting.take() {
160 let _ = sender.send(());
162 entry.remove();
163 tracing::debug!("Bootstrap: room {room_id} completed, decode unblocked");
164 } else {
165 entry.get_mut().prefill_completed = true;
167 tracing::debug!("Bootstrap: room {room_id} completed, awaiting decode");
168 }
169 }
170 Entry::Vacant(entry) => {
171 entry.insert(RoomState {
173 prefill_completed: true,
174 decode_waiting: None,
175 });
176 tracing::debug!("Bootstrap: room {room_id} completed (no decode yet)");
177 }
178 }
179 }
180
181 pub fn port(&self) -> u16 {
183 self.port
184 }
185}
186
187pub async fn connect_to_prefill(host: &str, port: u16, room_id: u64) -> Result<()> {
189 let host = host.trim_matches(|c| c == '[' || c == ']');
190 let addr = format!("{host}:{port}");
191
192 tracing::debug!("Bootstrap: decode connecting to {addr} for room {room_id}");
193
194 let mut stream = tokio::time::timeout(RENDEZVOUS_TIMEOUT, TcpStream::connect(&addr))
196 .await
197 .map_err(|_| anyhow::anyhow!("Bootstrap: connect timeout to {addr}"))?
198 .map_err(|e| anyhow::anyhow!("Bootstrap: connect failed to {addr}: {e}"))?;
199
200 stream.write_all(&room_id.to_le_bytes()).await?;
202
203 let mut ack = [0u8; 1];
205 tokio::time::timeout(RENDEZVOUS_TIMEOUT, stream.read_exact(&mut ack))
206 .await
207 .map_err(|_| anyhow::anyhow!("Bootstrap: ACK timeout for room {room_id}"))?
208 .map_err(|e| anyhow::anyhow!("Bootstrap: read ACK failed: {e}"))?;
209
210 if ack[0] != ACK_BYTE {
211 bail!(
212 "Bootstrap: invalid ACK byte {:02x} for room {room_id}",
213 ack[0]
214 );
215 }
216
217 tracing::debug!("Bootstrap: decode received ACK for room {room_id}");
218 Ok(())
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[tokio::test]
226 async fn test_prefill_completes_first() {
227 let cancel_token = CancellationToken::new();
228 let server = BootstrapServer::start(0, cancel_token.clone())
229 .await
230 .unwrap();
231
232 let port = server.port();
233 let room_id = 1001u64;
234
235 server.complete_room(room_id);
237
238 let result = connect_to_prefill("127.0.0.1", port, room_id).await;
240 assert!(result.is_ok(), "Decode should succeed: {result:?}");
241
242 cancel_token.cancel();
243 }
244
245 #[tokio::test]
246 async fn test_decode_connects_first() {
247 let cancel_token = CancellationToken::new();
248 let server = BootstrapServer::start(0, cancel_token.clone())
249 .await
250 .unwrap();
251
252 let port = server.port();
253 let room_id = 1002u64;
254
255 let decode_handle =
257 tokio::spawn(async move { connect_to_prefill("127.0.0.1", port, room_id).await });
258
259 tokio::time::sleep(Duration::from_millis(50)).await;
261
262 server.complete_room(room_id);
264
265 let result = decode_handle.await.unwrap();
266 assert!(result.is_ok(), "Decode should succeed: {result:?}");
267
268 cancel_token.cancel();
269 }
270
271 #[tokio::test]
272 async fn test_interleaved_ordering() {
273 let cancel_token = CancellationToken::new();
274 let server = BootstrapServer::start(0, cancel_token.clone())
275 .await
276 .unwrap();
277
278 let port = server.port();
279 let room_id = 1003u64;
280
281 let server_clone = server.clone();
283 let decode_handle = tokio::spawn(async move {
284 tokio::time::sleep(Duration::from_millis(10)).await;
286 connect_to_prefill("127.0.0.1", port, room_id).await
287 });
288
289 tokio::time::sleep(Duration::from_millis(50)).await;
291 server_clone.complete_room(room_id);
292
293 let result = decode_handle.await.unwrap();
294 assert!(result.is_ok(), "Decode should succeed: {result:?}");
295
296 cancel_token.cancel();
297 }
298
299 #[tokio::test]
300 async fn test_multiple_rooms_concurrent() {
301 let cancel_token = CancellationToken::new();
302 let server = BootstrapServer::start(0, cancel_token.clone())
303 .await
304 .unwrap();
305
306 let port = server.port();
307
308 let mut handles = vec![];
309
310 let server1 = server.clone();
312 handles.push(tokio::spawn(async move {
313 server1.complete_room(2001);
314 tokio::time::sleep(Duration::from_millis(10)).await;
315 connect_to_prefill("127.0.0.1", port, 2001).await
316 }));
317
318 let server2 = server.clone();
320 handles.push(tokio::spawn(async move {
321 let decode = tokio::spawn(connect_to_prefill("127.0.0.1", port, 2002));
322 tokio::time::sleep(Duration::from_millis(50)).await;
323 server2.complete_room(2002);
324 decode.await.unwrap()
325 }));
326
327 let server3 = server.clone();
329 handles.push(tokio::spawn(async move {
330 let decode = tokio::spawn(connect_to_prefill("127.0.0.1", port, 2003));
331 server3.complete_room(2003);
332 decode.await.unwrap()
333 }));
334
335 for (i, handle) in handles.into_iter().enumerate() {
336 let result = handle.await.unwrap();
337 assert!(
338 result.is_ok(),
339 "Room {} should succeed: {result:?}",
340 2001 + i
341 );
342 }
343
344 cancel_token.cancel();
345 }
346
347 #[tokio::test]
348 async fn test_decode_timeout_no_prefill() {
349 let cancel_token = CancellationToken::new();
350 let server = BootstrapServer::start(0, cancel_token.clone())
351 .await
352 .unwrap();
353
354 let port = server.port();
355 let room_id = 9999u64;
356
357 let result = tokio::time::timeout(
359 Duration::from_millis(100),
360 connect_to_prefill("127.0.0.1", port, room_id),
361 )
362 .await;
363
364 assert!(result.is_err(), "Should timeout waiting for prefill");
366
367 cancel_token.cancel();
368 }
369}