use crate::error::{NexarError, Result};
use crate::protocol::NexarMessage;
use crate::protocol::codec::decode_message;
use crate::transport::buffer_pool::{BufferPool, PooledBuf};
use crate::transport::connection::{
STREAM_TAG_FRAMED, STREAM_TAG_RAW, STREAM_TAG_RAW_COMM, STREAM_TAG_RAW_TAGGED,
};
use crate::types::Rank;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{Mutex, Semaphore, mpsc, oneshot};
const LANE_CAPACITY: usize = 256;
const MAX_CONCURRENT_STREAMS: usize = 512;
const MAX_MESSAGE_SIZE: u64 = 4 * 1024 * 1024 * 1024;
pub struct PeerRouter {
pub rpc_requests: Mutex<mpsc::Receiver<NexarMessage>>,
rpc_waiters: Arc<Mutex<HashMap<u64, oneshot::Sender<NexarMessage>>>>,
pub control: Mutex<mpsc::Receiver<NexarMessage>>,
pub data: Mutex<mpsc::Receiver<NexarMessage>>,
pub raw: Mutex<mpsc::Receiver<PooledBuf>>,
raw_comms: Arc<Mutex<HashMap<u64, CommChannel>>>,
tagged: Arc<Mutex<HashMap<u64, TaggedChannel>>>,
pub relay: Mutex<mpsc::Receiver<NexarMessage>>,
}
struct CommChannel {
tx: mpsc::Sender<PooledBuf>,
rx: Option<mpsc::Receiver<PooledBuf>>,
}
struct TaggedChannel {
tx: mpsc::Sender<PooledBuf>,
rx: Option<mpsc::Receiver<PooledBuf>>,
}
#[derive(Clone)]
struct RouterSenders {
rank: Rank,
rpc_requests: mpsc::Sender<NexarMessage>,
rpc_waiters: Arc<Mutex<HashMap<u64, oneshot::Sender<NexarMessage>>>>,
control: mpsc::Sender<NexarMessage>,
data: mpsc::Sender<NexarMessage>,
raw: mpsc::Sender<PooledBuf>,
raw_comms: Arc<Mutex<HashMap<u64, CommChannel>>>,
tagged: Arc<Mutex<HashMap<u64, TaggedChannel>>>,
relay: mpsc::Sender<NexarMessage>,
pool: Arc<BufferPool>,
}
impl PeerRouter {
pub fn spawn(
rank: Rank,
conn: quinn::Connection,
pool: Arc<BufferPool>,
) -> (Self, tokio::task::JoinHandle<Result<()>>) {
let (rpc_req_tx, rpc_req_rx) = mpsc::channel(LANE_CAPACITY);
let (ctrl_tx, ctrl_rx) = mpsc::channel(LANE_CAPACITY);
let (data_tx, data_rx) = mpsc::channel(LANE_CAPACITY);
let (raw_tx, raw_rx) = mpsc::channel(LANE_CAPACITY);
let (relay_tx, relay_rx) = mpsc::channel(LANE_CAPACITY);
let rpc_waiters: Arc<Mutex<HashMap<u64, oneshot::Sender<NexarMessage>>>> =
Arc::new(Mutex::new(HashMap::new()));
let raw_comms: Arc<Mutex<HashMap<u64, CommChannel>>> = Arc::new(Mutex::new(HashMap::new()));
let tagged: Arc<Mutex<HashMap<u64, TaggedChannel>>> = Arc::new(Mutex::new(HashMap::new()));
let senders = RouterSenders {
rank,
rpc_requests: rpc_req_tx,
rpc_waiters: Arc::clone(&rpc_waiters),
control: ctrl_tx,
data: data_tx,
raw: raw_tx,
raw_comms: Arc::clone(&raw_comms),
tagged: Arc::clone(&tagged),
relay: relay_tx,
pool,
};
let handle = tokio::spawn(accept_loop(conn, senders));
let router = Self {
rpc_requests: Mutex::new(rpc_req_rx),
rpc_waiters,
control: Mutex::new(ctrl_rx),
data: Mutex::new(data_rx),
raw: Mutex::new(raw_rx),
raw_comms,
tagged,
relay: Mutex::new(relay_rx),
};
(router, handle)
}
pub async fn register_rpc_waiter(&self, req_id: u64) -> oneshot::Receiver<NexarMessage> {
let (tx, rx) = oneshot::channel();
self.rpc_waiters.lock().await.insert(req_id, tx);
rx
}
pub async fn remove_rpc_waiter(&self, req_id: u64) {
self.rpc_waiters.lock().await.remove(&req_id);
}
pub async fn register_comm(&self, comm_id: u64) -> mpsc::Receiver<PooledBuf> {
let mut comms = self.raw_comms.lock().await;
if let Some(ch) = comms.get_mut(&comm_id)
&& let Some(rx) = ch.rx.take()
{
return rx;
}
let (tx, rx) = mpsc::channel(LANE_CAPACITY);
comms.insert(comm_id, CommChannel { tx, rx: None });
rx
}
pub async fn recv_control(&self, rank: Rank) -> Result<NexarMessage> {
self.control
.lock()
.await
.recv()
.await
.ok_or(NexarError::PeerDisconnected { rank })
}
pub async fn recv_rpc_request(&self, rank: Rank) -> Result<NexarMessage> {
self.rpc_requests
.lock()
.await
.recv()
.await
.ok_or(NexarError::PeerDisconnected { rank })
}
pub async fn recv_data(&self, rank: Rank) -> Result<NexarMessage> {
self.data
.lock()
.await
.recv()
.await
.ok_or(NexarError::PeerDisconnected { rank })
}
pub async fn register_tag(&self, tag: u64) -> mpsc::Receiver<PooledBuf> {
let mut tags = self.tagged.lock().await;
if let Some(ch) = tags.get_mut(&tag) {
if let Some(rx) = ch.rx.take() {
return rx;
}
}
let (tx, rx) = mpsc::channel(LANE_CAPACITY);
tags.insert(tag, TaggedChannel { tx, rx: None });
rx
}
pub async fn remove_tag(&self, tag: u64) {
self.tagged.lock().await.remove(&tag);
}
pub async fn take_relay_rx(&self) -> Option<mpsc::Receiver<NexarMessage>> {
let mut rx = self.relay.lock().await;
let (_, dummy) = mpsc::channel(1);
Some(std::mem::replace(&mut *rx, dummy))
}
pub async fn recv_raw(&self, rank: Rank) -> Result<PooledBuf> {
self.raw
.lock()
.await
.recv()
.await
.ok_or(NexarError::PeerDisconnected { rank })
}
}
async fn accept_loop(conn: quinn::Connection, tx: RouterSenders) -> Result<()> {
let semaphore = Arc::new(Semaphore::new(MAX_CONCURRENT_STREAMS));
loop {
let stream = match conn.accept_uni().await {
Ok(s) => s,
Err(_) => {
tx.rpc_waiters.lock().await.clear();
return Ok(());
}
};
let Ok(permit) = Arc::clone(&semaphore).acquire_owned().await else {
return Ok(()); };
let tx = tx.clone();
tokio::spawn(async move {
if let Err(e) = handle_stream(stream, &tx).await {
tracing::error!(
rank = tx.rank,
"router: local receiver dropped, messages will be lost: {e}"
);
}
drop(permit);
});
}
}
async fn handle_stream(mut stream: quinn::RecvStream, tx: &RouterSenders) -> Result<()> {
let mut tag_buf = [0u8; 1];
if stream.read_exact(&mut tag_buf).await.is_err() {
tracing::warn!(
rank = tx.rank,
"router: failed to read stream tag, skipping stream"
);
return Ok(());
}
match tag_buf[0] {
STREAM_TAG_FRAMED => {
let msg = match read_framed(&mut stream, tx.rank, &tx.pool).await {
Some(m) => m,
None => return Ok(()),
};
dispatch_framed(msg, tx).await?;
}
STREAM_TAG_RAW => {
let buf = match read_raw(&mut stream, tx.rank, &tx.pool).await {
Some(b) => b,
None => return Ok(()),
};
if tx.raw.send(buf).await.is_err() {
return Err(NexarError::PeerDisconnected { rank: tx.rank });
}
}
STREAM_TAG_RAW_TAGGED => {
let mut tag_bytes = [0u8; 8];
if stream.read_exact(&mut tag_bytes).await.is_err() {
tracing::warn!(rank = tx.rank, "router: failed to read tagged tag");
return Ok(());
}
let tag = u64::from_le_bytes(tag_bytes);
let buf = match read_raw(&mut stream, tx.rank, &tx.pool).await {
Some(b) => b,
None => return Ok(()),
};
let mut tags = tx.tagged.lock().await;
let ch = tags.entry(tag).or_insert_with(|| {
let (tx, rx) = mpsc::channel(LANE_CAPACITY);
TaggedChannel { tx, rx: Some(rx) }
});
if ch.tx.send(buf).await.is_err() {
return Err(NexarError::PeerDisconnected { rank: tx.rank });
}
}
STREAM_TAG_RAW_COMM => {
let mut comm_id_buf = [0u8; 8];
if stream.read_exact(&mut comm_id_buf).await.is_err() {
tracing::warn!(rank = tx.rank, "router: failed to read comm_id");
return Ok(());
}
let comm_id = u64::from_le_bytes(comm_id_buf);
let buf = match read_raw(&mut stream, tx.rank, &tx.pool).await {
Some(b) => b,
None => return Ok(()),
};
let mut comms = tx.raw_comms.lock().await;
let ch = comms.entry(comm_id).or_insert_with(|| {
let (tx, rx) = mpsc::channel(LANE_CAPACITY);
CommChannel { tx, rx: Some(rx) }
});
if ch.tx.send(buf).await.is_err() {
return Err(NexarError::PeerDisconnected { rank: tx.rank });
}
}
other => {
tracing::warn!(
rank = tx.rank,
"router: unknown stream tag 0x{:02x}, skipping stream",
other
);
}
}
Ok(())
}
async fn dispatch_framed(msg: NexarMessage, tx: &RouterSenders) -> Result<()> {
match msg {
NexarMessage::Rpc { .. } => {
if tx.rpc_requests.send(msg).await.is_err() {
return Err(NexarError::PeerDisconnected { rank: tx.rank });
}
}
NexarMessage::RpcResponse { req_id, .. } => {
let mut waiters = tx.rpc_waiters.lock().await;
if let Some(waiter) = waiters.remove(&req_id) {
let _ = waiter.send(msg);
} else {
tracing::warn!(
rank = tx.rank,
req_id,
"router: RpcResponse with no registered waiter, discarding"
);
}
}
NexarMessage::Barrier { .. }
| NexarMessage::BarrierAck { .. }
| NexarMessage::Heartbeat { .. }
| NexarMessage::NodeJoined { .. }
| NexarMessage::NodeLeft { .. }
| NexarMessage::Hello { .. }
| NexarMessage::Welcome { .. }
| NexarMessage::RdmaEndpoint { .. }
| NexarMessage::SplitRequest { .. }
| NexarMessage::RecoveryVote { .. }
| NexarMessage::RecoveryAgreement { .. }
| NexarMessage::ElasticCheckpoint { .. }
| NexarMessage::ElasticCheckpointAck { .. } => {
if tx.control.send(msg).await.is_err() {
return Err(NexarError::PeerDisconnected { rank: tx.rank });
}
}
NexarMessage::Data { .. } => {
if tx.data.send(msg).await.is_err() {
return Err(NexarError::PeerDisconnected { rank: tx.rank });
}
}
NexarMessage::Relay { .. } => {
if tx.relay.send(msg).await.is_err() {
return Err(NexarError::PeerDisconnected { rank: tx.rank });
}
}
}
Ok(())
}
async fn read_length_prefixed(
stream: &mut quinn::RecvStream,
rank: Rank,
pool: &Arc<BufferPool>,
label: &str,
) -> Option<PooledBuf> {
let mut len_buf = [0u8; 8];
if let Err(e) = stream.read_exact(&mut len_buf).await {
tracing::warn!(rank, "router: {label} length read failed: {e}");
return None;
}
let len = u64::from_le_bytes(len_buf);
if len > MAX_MESSAGE_SIZE {
tracing::warn!(
rank,
"router: {label} message too large ({len} bytes), skipping"
);
return None;
}
let mut buf = pool.checkout(len as usize);
if let Err(e) = stream.read_exact(&mut buf).await {
tracing::warn!(rank, "router: {label} payload read failed: {e}");
return None;
}
Some(buf)
}
async fn read_framed(
stream: &mut quinn::RecvStream,
rank: Rank,
pool: &Arc<BufferPool>,
) -> Option<NexarMessage> {
let buf = read_length_prefixed(stream, rank, pool, "framed").await?;
match decode_message(&buf) {
Ok((_, msg)) => Some(msg),
Err(e) => {
tracing::warn!(rank, "router: framed decode failed: {e}");
None
}
}
}
async fn read_raw(
stream: &mut quinn::RecvStream,
rank: Rank,
pool: &Arc<BufferPool>,
) -> Option<PooledBuf> {
read_length_prefixed(stream, rank, pool, "raw").await
}