use std::sync::Arc;
use std::time::Duration;
use anyhow::{Result, bail};
use dashmap::DashMap;
use dashmap::mapref::entry::Entry;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
const RENDEZVOUS_TIMEOUT: Duration = Duration::from_secs(30);
const ACK_BYTE: u8 = 0x01;
struct RoomState {
prefill_completed: bool,
decode_waiting: Option<oneshot::Sender<()>>,
}
pub struct BootstrapServer {
port: u16,
rooms: Arc<DashMap<u64, RoomState>>,
}
impl BootstrapServer {
pub async fn start(port: u16, cancel_token: CancellationToken) -> Result<Arc<Self>> {
let listener = TcpListener::bind(format!("0.0.0.0:{port}")).await?;
let actual_port = listener.local_addr()?.port();
tracing::info!("Bootstrap server started on port {actual_port}");
let rooms: Arc<DashMap<u64, RoomState>> = Arc::new(DashMap::new());
let server = Arc::new(Self {
port: actual_port,
rooms: rooms.clone(),
});
tokio::spawn(async move {
loop {
tokio::select! {
result = listener.accept() => {
match result {
Ok((stream, addr)) => {
tracing::debug!("Bootstrap: accepted connection from {addr}");
let rooms_clone = rooms.clone();
tokio::spawn(async move {
if let Err(e) = Self::handle_connection(stream, rooms_clone).await {
tracing::warn!("Bootstrap: connection error: {e}");
}
});
}
Err(e) => {
tracing::warn!("Bootstrap: accept failed: {e}");
}
}
}
_ = cancel_token.cancelled() => {
tracing::debug!("Bootstrap server shutting down");
break;
}
}
}
});
Ok(server)
}
async fn handle_connection(
mut stream: TcpStream,
rooms: Arc<DashMap<u64, RoomState>>,
) -> Result<()> {
let mut buf = [0u8; 8];
stream.read_exact(&mut buf).await?;
let room_id = u64::from_le_bytes(buf);
tracing::debug!("Bootstrap: decode connected for room {room_id}");
let rx = match rooms.entry(room_id) {
Entry::Occupied(mut entry) => {
if entry.get().prefill_completed {
entry.remove();
tracing::debug!("Bootstrap: room {room_id} already completed, immediate ACK");
None
} else {
let (tx, rx) = oneshot::channel();
entry.get_mut().decode_waiting = Some(tx);
tracing::debug!("Bootstrap: room {room_id} waiting for prefill to complete");
Some(rx)
}
}
Entry::Vacant(entry) => {
let (tx, rx) = oneshot::channel();
entry.insert(RoomState {
prefill_completed: false,
decode_waiting: Some(tx),
});
tracing::debug!("Bootstrap: room {room_id} decode arrived first, waiting");
Some(rx)
}
};
if let Some(rx) = rx {
match tokio::time::timeout(RENDEZVOUS_TIMEOUT, rx).await {
Ok(Ok(())) => {
tracing::debug!("Bootstrap: room {room_id} prefill completed, sending ACK");
}
Ok(Err(_)) => {
bail!("Bootstrap: room {room_id} sender dropped");
}
Err(_) => {
rooms.remove(&room_id);
bail!("Bootstrap: room {room_id} timeout waiting for prefill");
}
}
}
stream.write_all(&[ACK_BYTE]).await?;
Ok(())
}
pub fn complete_room(&self, room_id: u64) {
match self.rooms.entry(room_id) {
Entry::Occupied(mut entry) => {
if let Some(sender) = entry.get_mut().decode_waiting.take() {
let _ = sender.send(());
entry.remove();
tracing::debug!("Bootstrap: room {room_id} completed, decode unblocked");
} else {
entry.get_mut().prefill_completed = true;
tracing::debug!("Bootstrap: room {room_id} completed, awaiting decode");
}
}
Entry::Vacant(entry) => {
entry.insert(RoomState {
prefill_completed: true,
decode_waiting: None,
});
tracing::debug!("Bootstrap: room {room_id} completed (no decode yet)");
}
}
}
pub fn port(&self) -> u16 {
self.port
}
}
pub async fn connect_to_prefill(host: &str, port: u16, room_id: u64) -> Result<()> {
let host = host.trim_matches(|c| c == '[' || c == ']');
let addr = format!("{host}:{port}");
tracing::debug!("Bootstrap: decode connecting to {addr} for room {room_id}");
let mut stream = tokio::time::timeout(RENDEZVOUS_TIMEOUT, TcpStream::connect(&addr))
.await
.map_err(|_| anyhow::anyhow!("Bootstrap: connect timeout to {addr}"))?
.map_err(|e| anyhow::anyhow!("Bootstrap: connect failed to {addr}: {e}"))?;
stream.write_all(&room_id.to_le_bytes()).await?;
let mut ack = [0u8; 1];
tokio::time::timeout(RENDEZVOUS_TIMEOUT, stream.read_exact(&mut ack))
.await
.map_err(|_| anyhow::anyhow!("Bootstrap: ACK timeout for room {room_id}"))?
.map_err(|e| anyhow::anyhow!("Bootstrap: read ACK failed: {e}"))?;
if ack[0] != ACK_BYTE {
bail!(
"Bootstrap: invalid ACK byte {:02x} for room {room_id}",
ack[0]
);
}
tracing::debug!("Bootstrap: decode received ACK for room {room_id}");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_prefill_completes_first() {
let cancel_token = CancellationToken::new();
let server = BootstrapServer::start(0, cancel_token.clone())
.await
.unwrap();
let port = server.port();
let room_id = 1001u64;
server.complete_room(room_id);
let result = connect_to_prefill("127.0.0.1", port, room_id).await;
assert!(result.is_ok(), "Decode should succeed: {result:?}");
cancel_token.cancel();
}
#[tokio::test]
async fn test_decode_connects_first() {
let cancel_token = CancellationToken::new();
let server = BootstrapServer::start(0, cancel_token.clone())
.await
.unwrap();
let port = server.port();
let room_id = 1002u64;
let decode_handle =
tokio::spawn(async move { connect_to_prefill("127.0.0.1", port, room_id).await });
tokio::time::sleep(Duration::from_millis(50)).await;
server.complete_room(room_id);
let result = decode_handle.await.unwrap();
assert!(result.is_ok(), "Decode should succeed: {result:?}");
cancel_token.cancel();
}
#[tokio::test]
async fn test_interleaved_ordering() {
let cancel_token = CancellationToken::new();
let server = BootstrapServer::start(0, cancel_token.clone())
.await
.unwrap();
let port = server.port();
let room_id = 1003u64;
let server_clone = server.clone();
let decode_handle = tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
connect_to_prefill("127.0.0.1", port, room_id).await
});
tokio::time::sleep(Duration::from_millis(50)).await;
server_clone.complete_room(room_id);
let result = decode_handle.await.unwrap();
assert!(result.is_ok(), "Decode should succeed: {result:?}");
cancel_token.cancel();
}
#[tokio::test]
async fn test_multiple_rooms_concurrent() {
let cancel_token = CancellationToken::new();
let server = BootstrapServer::start(0, cancel_token.clone())
.await
.unwrap();
let port = server.port();
let mut handles = vec![];
let server1 = server.clone();
handles.push(tokio::spawn(async move {
server1.complete_room(2001);
tokio::time::sleep(Duration::from_millis(10)).await;
connect_to_prefill("127.0.0.1", port, 2001).await
}));
let server2 = server.clone();
handles.push(tokio::spawn(async move {
let decode = tokio::spawn(connect_to_prefill("127.0.0.1", port, 2002));
tokio::time::sleep(Duration::from_millis(50)).await;
server2.complete_room(2002);
decode.await.unwrap()
}));
let server3 = server.clone();
handles.push(tokio::spawn(async move {
let decode = tokio::spawn(connect_to_prefill("127.0.0.1", port, 2003));
server3.complete_room(2003);
decode.await.unwrap()
}));
for (i, handle) in handles.into_iter().enumerate() {
let result = handle.await.unwrap();
assert!(
result.is_ok(),
"Room {} should succeed: {result:?}",
2001 + i
);
}
cancel_token.cancel();
}
#[tokio::test]
async fn test_decode_timeout_no_prefill() {
let cancel_token = CancellationToken::new();
let server = BootstrapServer::start(0, cancel_token.clone())
.await
.unwrap();
let port = server.port();
let room_id = 9999u64;
let result = tokio::time::timeout(
Duration::from_millis(100),
connect_to_prefill("127.0.0.1", port, room_id),
)
.await;
assert!(result.is_err(), "Should timeout waiting for prefill");
cancel_token.cancel();
}
}