use std::sync::Arc;
use tokio::sync::watch;
use tracing::debug;
use crate::error::{ClusterError, Result};
use crate::rpc_codec::{self, MAX_RPC_PAYLOAD_SIZE, RaftRpc, auth_envelope};
use crate::transport::auth_context::AuthContext;
pub trait RaftRpcHandler: Send + Sync + 'static {
fn handle_rpc(&self, rpc: RaftRpc)
-> impl std::future::Future<Output = Result<RaftRpc>> + Send;
}
pub(crate) async fn handle_connection<H: RaftRpcHandler>(
conn: quinn::Connection,
handler: Arc<H>,
auth: Arc<AuthContext>,
mut shutdown: watch::Receiver<bool>,
) -> Result<()> {
loop {
let accepted = tokio::select! {
biased;
_ = shutdown.changed() => {
if *shutdown.borrow() {
return Ok(());
}
continue;
}
result = conn.accept_bi() => result,
};
let (send, recv) = match accepted {
Ok(streams) => streams,
Err(quinn::ConnectionError::ApplicationClosed(_)) => return Ok(()),
Err(quinn::ConnectionError::LocallyClosed) => return Ok(()),
Err(e) => {
return Err(ClusterError::Transport {
detail: format!("accept_bi: {e}"),
});
}
};
let h = handler.clone();
let stream_shutdown = shutdown.clone();
let stream_auth = auth.clone();
tokio::spawn(async move {
if let Err(e) = handle_stream(h, stream_auth, send, recv, stream_shutdown).await {
debug!(error = %e, "raft RPC stream error");
}
});
}
}
async fn handle_stream<H: RaftRpcHandler>(
handler: Arc<H>,
auth: Arc<AuthContext>,
mut send: quinn::SendStream,
mut recv: quinn::RecvStream,
mut shutdown: watch::Receiver<bool>,
) -> Result<()> {
let work = async {
let envelope = read_envelope(&mut recv).await?;
let (fields, inner_frame) = auth_envelope::parse_envelope(&envelope, &auth.mac_key)?;
if fields.from_node_id != auth.local_node_id {
auth.peer_seq_in.accept(fields.from_node_id, fields.seq)?;
}
let request = rpc_codec::decode(inner_frame)?;
let response = handler.handle_rpc(request).await?;
let response_inner = rpc_codec::encode(&response)?;
let response_seq = auth.peer_seq_out.next();
let mut response_envelope =
Vec::with_capacity(auth_envelope::ENVELOPE_OVERHEAD + response_inner.len());
auth_envelope::write_envelope(
auth.local_node_id,
response_seq,
&response_inner,
&auth.mac_key,
&mut response_envelope,
)?;
send.write_all(&response_envelope)
.await
.map_err(|e| ClusterError::Transport {
detail: format!("write response: {e}"),
})?;
send.finish().map_err(|e| ClusterError::Transport {
detail: format!("finish response: {e}"),
})?;
Ok::<(), ClusterError>(())
};
tokio::select! {
biased;
_ = shutdown.changed() => Ok(()),
result = work => result,
}
}
pub(crate) async fn read_envelope(recv: &mut quinn::RecvStream) -> Result<Vec<u8>> {
const ENV_HDR_LEN: usize = 21;
let mut hdr = [0u8; ENV_HDR_LEN];
recv.read_exact(&mut hdr)
.await
.map_err(|e| ClusterError::Transport {
detail: format!("read envelope header: {e}"),
})?;
let inner_len = u32::from_le_bytes([hdr[17], hdr[18], hdr[19], hdr[20]]);
if inner_len > MAX_RPC_PAYLOAD_SIZE {
return Err(ClusterError::Codec {
detail: format!(
"envelope inner length {inner_len} exceeds maximum {MAX_RPC_PAYLOAD_SIZE}"
),
});
}
let total = ENV_HDR_LEN + inner_len as usize + rpc_codec::MAC_LEN;
let mut buf = vec![0u8; total];
buf[..ENV_HDR_LEN].copy_from_slice(&hdr);
if total > ENV_HDR_LEN {
recv.read_exact(&mut buf[ENV_HDR_LEN..])
.await
.map_err(|e| ClusterError::Transport {
detail: format!("read envelope payload+mac: {e}"),
})?;
}
Ok(buf)
}