use std::sync::Arc;
use rustls::pki_types::CertificateDer;
use tokio::sync::watch;
use tracing::{debug, warn};
use crate::error::{ClusterError, Result};
use crate::rpc_codec::{self, MAX_RPC_PAYLOAD_SIZE, RaftRpc, auth_envelope};
use crate::topology::NodeInfo;
use crate::transport::auth_context::AuthContext;
use crate::transport::peer_identity_verifier::{
IDENTITY_MISMATCH_QUIC_ERROR, VerifyOutcome, verify_peer_identity,
};
use crate::wire_version::handshake_io::{local_version_range, perform_version_handshake_server};
pub trait PeerIdentityStore: Send + Sync + 'static {
fn get_node_info(&self, node_id: u64) -> Option<NodeInfo>;
fn find_by_spki(&self, spki: &[u8; 32]) -> Option<NodeInfo>;
fn find_by_spiffe(&self, spiffe_id: &str) -> Option<NodeInfo>;
}
pub struct NoopIdentityStore;
impl PeerIdentityStore for NoopIdentityStore {
fn get_node_info(&self, _node_id: u64) -> Option<NodeInfo> {
None
}
fn find_by_spki(&self, _spki: &[u8; 32]) -> Option<NodeInfo> {
None
}
fn find_by_spiffe(&self, _spiffe_id: &str) -> Option<NodeInfo> {
None
}
}
pub trait RaftRpcHandler: Send + Sync + 'static {
fn handle_rpc(&self, rpc: RaftRpc)
-> impl std::future::Future<Output = Result<RaftRpc>> + Send;
}
fn peer_leaf_cert_der(conn: &quinn::Connection) -> Option<Vec<u8>> {
let identity = conn.peer_identity()?;
let certs: &Vec<CertificateDer<'static>> = identity.downcast_ref()?;
certs.first().map(|c| c.as_ref().to_vec())
}
pub(crate) async fn handle_connection<H: RaftRpcHandler, S: PeerIdentityStore>(
conn: quinn::Connection,
handler: Arc<H>,
auth: Arc<AuthContext>,
identity_store: Arc<S>,
mut shutdown: watch::Receiver<bool>,
) -> Result<()> {
let peer_cert_der: Option<Vec<u8>> = peer_leaf_cert_der(&conn);
let peer_addr = conn.remote_address();
let agreed_version = {
let accepted = tokio::select! {
biased;
_ = shutdown.changed() => {
if *shutdown.borrow() {
return Ok(());
}
conn.accept_bi().await
}
result = conn.accept_bi() => result,
};
let (mut hs_send, mut hs_recv) = match accepted {
Ok(streams) => streams,
Err(quinn::ConnectionError::ApplicationClosed(_)) => return Ok(()),
Err(quinn::ConnectionError::LocallyClosed) => return Ok(()),
Err(e) => {
return Err(ClusterError::Transport {
detail: format!("accept handshake stream from {peer_addr}: {e}"),
});
}
};
let local = local_version_range();
match perform_version_handshake_server(&conn, &mut hs_send, &mut hs_recv).await {
Ok(v) => v,
Err(e) => {
warn!(
peer_addr = %peer_addr,
local_min = %local.min,
local_max = %local.max,
error = %e,
"wire version handshake failed; closing connection"
);
return Err(e);
}
}
};
debug!(
peer_addr = %peer_addr,
agreed_version = %agreed_version,
"wire version handshake complete"
);
loop {
let accepted = tokio::select! {
biased;
_ = shutdown.changed() => {
if *shutdown.borrow() {
return Ok(());
}
continue;
}
result = conn.accept_bi() => result,
};
let (send, recv) = match accepted {
Ok(streams) => streams,
Err(quinn::ConnectionError::ApplicationClosed(_)) => return Ok(()),
Err(quinn::ConnectionError::LocallyClosed) => return Ok(()),
Err(e) => {
return Err(ClusterError::Transport {
detail: format!("accept_bi: {e}"),
});
}
};
let h = handler.clone();
let stream_shutdown = shutdown.clone();
let stream_auth = auth.clone();
let stream_id_store = identity_store.clone();
let stream_cert = peer_cert_der.clone();
let conn_clone = conn.clone();
tokio::spawn(async move {
if let Err(e) = handle_stream(
h,
stream_auth,
stream_id_store,
stream_cert,
conn_clone,
send,
recv,
stream_shutdown,
)
.await
{
debug!(error = %e, "raft RPC stream error");
}
});
}
}
#[allow(clippy::too_many_arguments)]
async fn handle_stream<H: RaftRpcHandler, S: PeerIdentityStore>(
handler: Arc<H>,
auth: Arc<AuthContext>,
identity_store: Arc<S>,
peer_cert_der: Option<Vec<u8>>,
conn: quinn::Connection,
mut send: quinn::SendStream,
mut recv: quinn::RecvStream,
mut shutdown: watch::Receiver<bool>,
) -> Result<()> {
let work = async {
let envelope = read_envelope(&mut recv).await?;
let (fields, inner_frame) = auth_envelope::parse_envelope(&envelope, &auth.mac_key)?;
if fields.from_node_id != auth.local_node_id {
auth.peer_seq_in.accept(fields.from_node_id, fields.seq)?;
}
if fields.from_node_id != auth.local_node_id
&& let Some(cert_der) = &peer_cert_der
{
let node_info = identity_store.get_node_info(fields.from_node_id);
match node_info {
Some(ref info) => match verify_peer_identity(info, cert_der) {
VerifyOutcome::Accepted { method } => {
debug!(
node_id = fields.from_node_id,
?method,
"peer identity verified"
);
}
VerifyOutcome::BootstrapAccepted => {
warn!(
node_id = fields.from_node_id,
"peer identity not pinned — bootstrap window accepted"
);
}
VerifyOutcome::Rejected => {
warn!(
node_id = fields.from_node_id,
"peer identity mismatch — closing connection"
);
conn.close(IDENTITY_MISMATCH_QUIC_ERROR, b"peer identity mismatch");
return Err(ClusterError::Transport {
detail: format!(
"peer identity mismatch for node {}",
fields.from_node_id
),
});
}
},
None => {
warn!(
node_id = fields.from_node_id,
"node not in topology — bootstrap window accepted"
);
}
}
}
let request = rpc_codec::decode(inner_frame)?;
let response = handler.handle_rpc(request).await?;
let response_inner = rpc_codec::encode(&response)?;
let response_seq = auth.peer_seq_out.next();
let mut response_envelope =
Vec::with_capacity(auth_envelope::ENVELOPE_OVERHEAD + response_inner.len());
auth_envelope::write_envelope(
auth.local_node_id,
response_seq,
&response_inner,
&auth.mac_key,
&mut response_envelope,
)?;
send.write_all(&response_envelope)
.await
.map_err(|e| ClusterError::Transport {
detail: format!("write response: {e}"),
})?;
send.finish().map_err(|e| ClusterError::Transport {
detail: format!("finish response: {e}"),
})?;
Ok::<(), ClusterError>(())
};
tokio::select! {
biased;
_ = shutdown.changed() => Ok(()),
result = work => result,
}
}
pub(crate) async fn read_envelope(recv: &mut quinn::RecvStream) -> Result<Vec<u8>> {
const ENV_HDR_LEN: usize = 21;
let mut hdr = [0u8; ENV_HDR_LEN];
recv.read_exact(&mut hdr)
.await
.map_err(|e| ClusterError::Transport {
detail: format!("read envelope header: {e}"),
})?;
let inner_len = u32::from_le_bytes([hdr[17], hdr[18], hdr[19], hdr[20]]);
if inner_len > MAX_RPC_PAYLOAD_SIZE {
return Err(ClusterError::Codec {
detail: format!(
"envelope inner length {inner_len} exceeds maximum {MAX_RPC_PAYLOAD_SIZE}"
),
});
}
let total = ENV_HDR_LEN + inner_len as usize + rpc_codec::MAC_LEN;
let mut buf = vec![0u8; total];
buf[..ENV_HDR_LEN].copy_from_slice(&hdr);
if total > ENV_HDR_LEN {
recv.read_exact(&mut buf[ENV_HDR_LEN..])
.await
.map_err(|e| ClusterError::Transport {
detail: format!("read envelope payload+mac: {e}"),
})?;
}
Ok(buf)
}