use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use std::{
collections::hash_map::DefaultHasher,
future::Future,
hash::{Hash, Hasher},
sync::atomic::{AtomicU64, Ordering},
};
use dashmap::DashMap;
use futures::{stream, StreamExt};
use quinn::Endpoint;
use tokio::sync::{broadcast, mpsc, watch, RwLock};
use tokio::task::JoinHandle;
use tracing::{debug, error, info};
use ulid::Ulid;
use crate::config::PeerConfig;
use crate::error::{QrpcError, QrpcResult};
use crate::message::{
Ctx, FromRef, FromRefCallback, QrpcCallback, QrpcDispatcher, QrpcMessage, State,
};
use crate::protocol::{PacketKind, WirePacket, BROADCAST_TARGET, MAX_PACKET_SIZE};
use crate::tls::{
build_client_config, build_server_config, ensure_rustls_provider, TransportOptions,
};
#[derive(Clone)]
struct PeerSession {
connection: quinn::Connection,
session_id: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionFailureKind {
Transient,
Permanent,
}
#[derive(Debug, Clone)]
pub enum PeerConnectionEvent {
Connecting { peer: PeerConfig, attempt: u32 },
Connected {
peer_id: String,
remote_addr: SocketAddr,
},
Disconnected { peer_id: String },
ConnectFailed {
peer: PeerConfig,
kind: ConnectionFailureKind,
detail: String,
retry_in: Option<Duration>,
},
}
#[derive(Debug)]
pub enum OutboundCmd<M> {
SendTo { peer_id: String, message: M },
Broadcast { message: M },
}
struct Inner<S, M, H> {
id: String,
state: S,
handler: Arc<H>,
endpoint: Endpoint,
peers: DashMap<String, PeerSession>,
bootstrap_peers: Vec<PeerConfig>,
local_client_cert_path: String,
local_client_key_path: String,
transport_options: TransportOptions,
shutdown_tx: watch::Sender<bool>,
peer_event_tx: broadcast::Sender<PeerConnectionEvent>,
next_session_id: AtomicU64,
tasks: RwLock<Vec<JoinHandle<()>>>,
_marker: std::marker::PhantomData<M>,
}
pub struct QrpcInstance<S, M, H> {
inner: Arc<Inner<S, M, H>>,
}
pub struct WithState;
pub struct WithoutState;
pub struct QrpcInstanceBuilder<StateStatus, S, M, H> {
state: Option<S>,
handler: H,
id: Option<String>,
ca_cert_path: Option<String>,
cert_path: Option<String>,
key_path: Option<String>,
client_cert_path: Option<String>,
client_key_path: Option<String>,
keep_alive_interval: Option<Duration>,
max_idle_timeout: Option<Duration>,
port: Option<u16>,
listen_ip: IpAddr,
peers: Vec<PeerConfig>,
_state_status: std::marker::PhantomData<StateStatus>,
_marker: std::marker::PhantomData<M>,
}
impl<S, M, H> QrpcInstanceBuilder<WithoutState, S, M, H>
where
S: Send + Sync + 'static,
M: QrpcMessage,
H: QrpcCallback<S, M>,
{
pub fn new(handler: H) -> Self {
Self {
state: None,
handler,
id: None,
ca_cert_path: None,
cert_path: None,
key_path: None,
client_cert_path: None,
client_key_path: None,
keep_alive_interval: Some(crate::tls::DEFAULT_KEEP_ALIVE_INTERVAL),
max_idle_timeout: Some(crate::tls::DEFAULT_MAX_IDLE_TIMEOUT),
port: None,
listen_ip: IpAddr::V4(Ipv4Addr::UNSPECIFIED),
peers: Vec::new(),
_state_status: std::marker::PhantomData,
_marker: std::marker::PhantomData,
}
}
pub fn with_state(mut self, state: S) -> QrpcInstanceBuilder<WithState, S, M, H> {
self.state = Some(state);
QrpcInstanceBuilder {
state: self.state,
handler: self.handler,
id: self.id,
ca_cert_path: self.ca_cert_path,
cert_path: self.cert_path,
key_path: self.key_path,
client_cert_path: self.client_cert_path,
client_key_path: self.client_key_path,
keep_alive_interval: self.keep_alive_interval,
max_idle_timeout: self.max_idle_timeout,
port: self.port,
listen_ip: self.listen_ip,
peers: self.peers,
_state_status: std::marker::PhantomData,
_marker: std::marker::PhantomData,
}
}
}
impl<StateStatus, S, M, H> QrpcInstanceBuilder<StateStatus, S, M, H>
where
S: Send + Sync + 'static,
M: QrpcMessage,
H: QrpcCallback<S, M>,
{
pub fn with_id(mut self, id: impl Into<String>) -> Self {
self.id = Some(id.into());
self
}
pub fn with_ca_cert(mut self, path: impl Into<String>) -> Self {
self.ca_cert_path = Some(path.into());
self
}
pub fn with_identity(
mut self,
cert_path: impl Into<String>,
key_path: impl Into<String>,
) -> Self {
self.cert_path = Some(cert_path.into());
self.key_path = Some(key_path.into());
self
}
pub fn with_client_identity(
mut self,
cert_path: impl Into<String>,
key_path: impl Into<String>,
) -> Self {
self.client_cert_path = Some(cert_path.into());
self.client_key_path = Some(key_path.into());
self
}
pub fn with_keep_alive_interval(mut self, interval: Option<Duration>) -> Self {
self.keep_alive_interval = interval;
self
}
pub fn with_max_idle_timeout(mut self, timeout: Option<Duration>) -> Self {
self.max_idle_timeout = timeout;
self
}
pub fn with_port(mut self, port: u16) -> Self {
self.port = Some(port);
self
}
pub fn with_listen_ip(mut self, ip: IpAddr) -> Self {
self.listen_ip = ip;
self
}
pub fn add_peer(mut self, peer: PeerConfig) -> Self {
self.peers.push(peer);
self
}
pub fn add_peers(mut self, peers: impl IntoIterator<Item=PeerConfig>) -> Self {
self.peers.extend(peers);
self
}
fn build_inner(self, state: S) -> QrpcResult<QrpcInstance<S, M, H>> {
ensure_rustls_provider()?;
let ca = self
.ca_cert_path
.ok_or(QrpcError::MissingField("ca_cert_path"))?;
let cert = self.cert_path.ok_or(QrpcError::MissingField("cert_path"))?;
let key = self.key_path.ok_or(QrpcError::MissingField("key_path"))?;
let client_cert = self.client_cert_path.unwrap_or_else(|| cert.clone());
let client_key = self.client_key_path.unwrap_or_else(|| key.clone());
let transport_options = TransportOptions {
keep_alive_interval: self.keep_alive_interval,
max_idle_timeout: self.max_idle_timeout,
};
let port = self.port.ok_or(QrpcError::MissingField("port"))?;
let id = self.id.unwrap_or_else(|| Ulid::new().to_string());
info!(
instance_id = %id,
listen_ip = %self.listen_ip,
port = port,
peers = self.peers.len(),
"building qrpc instance"
);
let server_config = build_server_config(&ca, &cert, &key, transport_options)?;
let listen_addr = SocketAddr::new(self.listen_ip, port);
let mut endpoint = Endpoint::server(server_config, listen_addr).map_err(|e| {
QrpcError::MessageDecode(format!("failed to create server endpoint: {e}"))
})?;
endpoint.set_default_client_config(build_client_config(
&ca,
&client_cert,
&client_key,
transport_options,
)?);
let (shutdown_tx, _) = watch::channel(false);
let (peer_event_tx, _) = broadcast::channel(256);
let inner = Arc::new(Inner {
id,
state,
handler: Arc::new(self.handler),
endpoint,
peers: DashMap::new(),
bootstrap_peers: self.peers,
local_client_cert_path: client_cert,
local_client_key_path: client_key,
transport_options,
shutdown_tx,
peer_event_tx,
next_session_id: AtomicU64::new(1),
tasks: RwLock::new(Vec::new()),
_marker: std::marker::PhantomData,
});
info!(instance_id = %inner.id, "qrpc instance built");
Ok(QrpcInstance { inner })
}
}
impl<S, M, H> QrpcInstanceBuilder<WithState, S, M, H>
where
S: Send + Sync + 'static,
M: QrpcMessage,
H: QrpcCallback<S, M>,
{
pub fn build(mut self) -> QrpcResult<QrpcInstance<S, M, H>> {
let state = self.state.take().ok_or(QrpcError::MissingField("state"))?;
self.build_inner(state)
}
}
impl<M, H> QrpcInstanceBuilder<WithoutState, (), M, H>
where
M: QrpcMessage,
H: QrpcCallback<(), M>,
{
pub fn build(self) -> QrpcResult<QrpcInstance<(), M, H>> {
self.build_inner(())
}
}
impl<S, M> QrpcInstance<S, M, ()>
where
S: Send + Sync + 'static,
M: QrpcMessage,
{
pub fn builder<T, F, Fut>(
handler: F,
) -> QrpcInstanceBuilder<WithoutState, S, M, impl QrpcCallback<S, M>>
where
T: FromRef<S> + Send + 'static,
F: Fn(State<T>, Ctx<M>, String, M) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output=QrpcResult<()>> + Send + 'static,
{
QrpcInstanceBuilder::new(FromRefCallback::<T, F, Fut>::new(handler))
}
}
impl<S, M, H> QrpcInstance<S, M, H>
where
S: Send + Sync + 'static,
M: QrpcMessage,
H: QrpcCallback<S, M>,
{
pub fn id(&self) -> &str {
&self.inner.id
}
pub fn local_addr(&self) -> QrpcResult<SocketAddr> {
self.inner
.endpoint
.local_addr()
.map_err(|e| QrpcError::MessageDecode(format!("failed to read local addr: {e}")))
}
pub async fn start(&self) {
let mut tasks = self.inner.tasks.write().await;
if !tasks.is_empty() {
debug!(instance_id = %self.inner.id, "instance already started");
return;
}
info!(instance_id = %self.inner.id, "starting qrpc instance");
let accept_inner = Arc::clone(&self.inner);
tasks.push(tokio::spawn(async move {
accept_loop(accept_inner).await;
}));
for peer in self.inner.bootstrap_peers.clone() {
let connect_inner = Arc::clone(&self.inner);
tasks.push(tokio::spawn(async move {
connect_loop(connect_inner, peer).await;
}));
}
}
pub async fn serve(&self) {
self.start().await;
wait_for_shutdown_signal(self.inner.shutdown_tx.subscribe()).await;
}
pub async fn serve_with<F, Fut>(&self, worker: F) -> QrpcResult<()>
where
F: FnOnce(Ctx<M>) -> Fut + Send,
Fut: Future<Output=QrpcResult<()>> + Send,
{
self.start().await;
let ctx = Ctx::new(Arc::new(InstanceDispatcher {
inner: Arc::clone(&self.inner),
}));
let worker_fut = worker(ctx);
let shutdown_fut = wait_for_shutdown_signal(self.inner.shutdown_tx.subscribe());
tokio::pin!(worker_fut);
tokio::pin!(shutdown_fut);
tokio::select! {
_ = &mut shutdown_fut => Ok(()),
result = &mut worker_fut => {
match result {
Ok(()) => {
info!(
instance_id = %self.inner.id,
"serve_with worker finished; continue serving until shutdown"
);
shutdown_fut.await;
Ok(())
}
Err(err) => {
error!(
instance_id = %self.inner.id,
error = %err,
"serve_with worker returned error; keep serving until shutdown"
);
shutdown_fut.await;
Ok(())
}
}
}
}
}
pub async fn serve_with_rx(&self, mut rx: mpsc::Receiver<OutboundCmd<M>>) -> QrpcResult<()> {
self.start().await;
let mut shutdown_rx = self.inner.shutdown_tx.subscribe();
loop {
tokio::select! {
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
return Ok(());
}
}
cmd = rx.recv() => {
match cmd {
Some(OutboundCmd::SendTo { peer_id, message }) => {
if let Err(err) = wait_for_peer_until_shutdown(&self.inner, &peer_id).await {
error!(
instance_id = %self.inner.id,
peer_id = %peer_id,
error = %err,
"serve_with_rx send_to wait failed"
);
continue;
}
if let Err(err) = self.send_to(&peer_id, &message).await {
error!(
instance_id = %self.inner.id,
peer_id = %peer_id,
error = %err,
"serve_with_rx send_to failed"
);
}
}
Some(OutboundCmd::Broadcast { message }) => {
if let Err(err) = self.broadcast(&message).await {
error!(
instance_id = %self.inner.id,
error = %err,
"serve_with_rx broadcast failed"
);
}
}
None => {
wait_for_shutdown_signal(self.inner.shutdown_tx.subscribe()).await;
return Ok(());
}
}
}
}
}
}
pub async fn shutdown(&self) {
shutdown_inner(&self.inner).await;
}
pub async fn send_to(&self, target_id: &str, message: &M) -> QrpcResult<()> {
send_to_inner(&self.inner, target_id, message).await
}
pub async fn broadcast(&self, message: &M) -> QrpcResult<usize> {
broadcast_inner(&self.inner, message).await
}
pub async fn peer_ids(&self) -> Vec<String> {
peer_ids_inner(&self.inner).await
}
pub fn subscribe_peer_events(&self) -> broadcast::Receiver<PeerConnectionEvent> {
self.inner.peer_event_tx.subscribe()
}
pub async fn wait_for_peer(&self, peer_id: &str, timeout: Duration) -> QrpcResult<()> {
wait_for_peer_inner(&self.inner, peer_id, timeout).await
}
}
async fn peer_ids_inner<S, M, H>(inner: &Arc<Inner<S, M, H>>) -> Vec<String>
where
S: Send + Sync + 'static,
M: QrpcMessage,
H: QrpcCallback<S, M>,
{
inner
.peers
.iter()
.map(|entry| entry.key().clone())
.collect()
}
async fn wait_for_peer_inner<S, M, H>(
inner: &Arc<Inner<S, M, H>>,
peer_id: &str,
timeout: Duration,
) -> QrpcResult<()>
where
S: Send + Sync + 'static,
M: QrpcMessage,
H: QrpcCallback<S, M>,
{
let deadline = tokio::time::Instant::now() + timeout;
while tokio::time::Instant::now() < deadline {
if inner.peers.contains_key(peer_id) {
return Ok(());
}
let remain = deadline.saturating_duration_since(tokio::time::Instant::now());
tokio::time::sleep(remain.min(Duration::from_millis(50))).await;
}
Err(QrpcError::PeerWaitTimeout {
peer_id: peer_id.to_string(),
timeout,
})
}
async fn wait_for_peer_until_shutdown<S, M, H>(
inner: &Arc<Inner<S, M, H>>,
peer_id: &str,
) -> QrpcResult<()>
where
S: Send + Sync + 'static,
M: QrpcMessage,
H: QrpcCallback<S, M>,
{
let mut shutdown_rx = inner.shutdown_tx.subscribe();
loop {
if inner.peers.contains_key(peer_id) {
return Ok(());
}
tokio::select! {
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
return Err(QrpcError::PeerNotFound(peer_id.to_string()));
}
}
_ = tokio::time::sleep(Duration::from_millis(50)) => {}
}
}
}
async fn shutdown_inner<S, M, H>(inner: &Arc<Inner<S, M, H>>)
where
S: Send + Sync + 'static,
M: QrpcMessage,
H: QrpcCallback<S, M>,
{
info!(instance_id = %inner.id, "shutting down qrpc instance");
let _ = inner.shutdown_tx.send(true);
let snapshot: Vec<(String, PeerSession)> = inner
.peers
.iter()
.map(|entry| (entry.key().clone(), entry.value().clone()))
.collect();
for (peer_id, peer) in snapshot {
let _ = send_packet_over_connection(
&peer.connection,
WirePacket::disconnect(inner.id.clone(), peer_id.clone()),
)
.await;
debug!(instance_id = %inner.id, peer_id = %peer_id, "closing peer connection");
peer.connection.close(0u32.into(), b"shutdown");
}
let mut tasks = inner.tasks.write().await;
for handle in tasks.drain(..) {
handle.abort();
}
inner.endpoint.close(0u32.into(), b"shutdown");
info!(instance_id = %inner.id, "qrpc instance shutdown completed");
}
#[derive(Clone)]
struct InstanceDispatcher<S, M, H> {
inner: Arc<Inner<S, M, H>>,
}
impl<S, M, H> QrpcDispatcher<M> for InstanceDispatcher<S, M, H>
where
S: Send + Sync + 'static,
M: QrpcMessage,
H: QrpcCallback<S, M>,
{
fn instance_id(&self) -> &str {
&self.inner.id
}
fn send_to<'a>(
&'a self,
target_id: &'a str,
message: &'a M,
) -> std::pin::Pin<Box<dyn std::future::Future<Output=QrpcResult<()>> + Send + 'a>> {
Box::pin(async move { send_to_inner(&self.inner, target_id, message).await })
}
fn broadcast<'a>(
&'a self,
message: &'a M,
) -> std::pin::Pin<Box<dyn std::future::Future<Output=QrpcResult<usize>> + Send + 'a>> {
Box::pin(async move { broadcast_inner(&self.inner, message).await })
}
fn peer_ids<'a>(
&'a self,
) -> std::pin::Pin<Box<dyn std::future::Future<Output=Vec<String>> + Send + 'a>> {
Box::pin(async move { peer_ids_inner(&self.inner).await })
}
fn wait_for_peer<'a>(
&'a self,
peer_id: &'a str,
timeout: Duration,
) -> std::pin::Pin<Box<dyn std::future::Future<Output=QrpcResult<()>> + Send + 'a>> {
Box::pin(async move { wait_for_peer_inner(&self.inner, peer_id, timeout).await })
}
fn shutdown<'a>(
&'a self,
) -> std::pin::Pin<Box<dyn std::future::Future<Output=()> + Send + 'a>> {
Box::pin(async move {
shutdown_inner(&self.inner).await;
})
}
}
async fn send_to_inner<S, M, H>(
inner: &Arc<Inner<S, M, H>>,
target_id: &str,
message: &M,
) -> QrpcResult<()>
where
S: Send + Sync + 'static,
M: QrpcMessage,
H: QrpcCallback<S, M>,
{
debug!(
instance_id = %inner.id,
target_id = target_id,
cmd_id = message.cmd_id(),
"sending message to peer"
);
let peer = inner
.peers
.get(target_id)
.map(|entry| entry.value().clone())
.ok_or_else(|| QrpcError::PeerNotFound(target_id.to_string()))?;
let packet = WirePacket::data(
inner.id.clone(),
target_id.to_string(),
message.cmd_id(),
message.encode_vec(),
);
send_packet_over_connection(&peer.connection, packet).await
}
async fn broadcast_inner<S, M, H>(inner: &Arc<Inner<S, M, H>>, message: &M) -> QrpcResult<usize>
where
S: Send + Sync + 'static,
M: QrpcMessage,
H: QrpcCallback<S, M>,
{
debug!(
instance_id = %inner.id,
cmd_id = message.cmd_id(),
"broadcasting message"
);
let snapshot: Vec<(String, PeerSession)> = inner
.peers
.iter()
.map(|entry| (entry.key().clone(), entry.value().clone()))
.collect();
let mut sent = 0usize;
for (peer_id, peer) in snapshot {
let packet = WirePacket::data(
inner.id.clone(),
BROADCAST_TARGET,
message.cmd_id(),
message.encode_vec(),
);
if send_packet_over_connection(&peer.connection, packet)
.await
.is_ok()
{
sent += 1;
} else {
error!(
instance_id = %inner.id,
peer_id = %peer_id,
"broadcast send failed, removing peer"
);
inner.peers.remove(&peer_id);
}
}
info!(instance_id = %inner.id, sent = sent, "broadcast finished");
Ok(sent)
}
async fn accept_loop<S, M, H>(inner: Arc<Inner<S, M, H>>)
where
S: Send + Sync + 'static,
M: QrpcMessage,
H: QrpcCallback<S, M>,
{
let mut shutdown_rx = inner.shutdown_tx.subscribe();
info!(instance_id = %inner.id, "accept loop started");
loop {
tokio::select! {
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
break;
}
}
incoming = inner.endpoint.accept() => {
match incoming {
Some(connecting) => {
let child_inner = Arc::clone(&inner);
tokio::spawn(async move {
match connecting.await {
Ok(connection) => {
let remote = connection.remote_address();
debug!(instance_id = %child_inner.id, remote = %remote, "accepted incoming quic connection");
if let Err(err) = handle_incoming_connection(child_inner, connection).await {
error!("incoming connection handling failed: {err}");
}
}
Err(err) => error!("incoming quic handshake failed: {err}"),
}
});
}
None => break,
}
}
}
}
info!(instance_id = %inner.id, "accept loop stopped");
}
async fn connect_loop<S, M, H>(inner: Arc<Inner<S, M, H>>, peer: PeerConfig)
where
S: Send + Sync + 'static,
M: QrpcMessage,
H: QrpcCallback<S, M>,
{
let mut shutdown_rx = inner.shutdown_tx.subscribe();
info!(instance_id = %inner.id, peer = ?peer, "connect loop started");
let mut failed_attempts = 0u32;
let jitter_seed = backoff_seed(&inner.id, &peer);
loop {
if *shutdown_rx.borrow() {
break;
}
if let Some(expected_id) = peer.expected_id.as_ref() {
if inner.peers.contains_key(expected_id) {
if sleep_or_shutdown(&mut shutdown_rx, Duration::from_millis(300)).await {
break;
}
continue;
}
}
emit_peer_event(
&inner,
PeerConnectionEvent::Connecting {
peer: peer.clone(),
attempt: failed_attempts.saturating_add(1),
},
);
let addr = match peer.socket_addr() {
Ok(addr) => addr,
Err(err) => {
emit_peer_event(
&inner,
PeerConnectionEvent::ConnectFailed {
peer: peer.clone(),
kind: ConnectionFailureKind::Permanent,
detail: err.to_string(),
retry_in: None,
},
);
error!(instance_id = %inner.id, "invalid peer socket addr: {err}");
return;
}
};
let connecting = if let Some(peer_ca) = &peer.ca_cert_path {
match build_client_config(
peer_ca,
&inner.local_client_cert_path,
&inner.local_client_key_path,
inner.transport_options,
) {
Ok(cfg) => match inner.endpoint.connect_with(cfg, addr, &peer.server_name) {
Ok(conn) => conn,
Err(err) => {
let kind = classify_connect_init_error(&err);
if kind == ConnectionFailureKind::Permanent {
emit_peer_event(
&inner,
PeerConnectionEvent::ConnectFailed {
peer: peer.clone(),
kind,
detail: err.to_string(),
retry_in: None,
},
);
error!(instance_id = %inner.id, peer_addr = %addr, "connect_with build failed permanently: {err}");
return;
}
let retry_in = retry_delay(failed_attempts, jitter_seed);
failed_attempts = failed_attempts.saturating_add(1);
emit_peer_event(
&inner,
PeerConnectionEvent::ConnectFailed {
peer: peer.clone(),
kind,
detail: err.to_string(),
retry_in: Some(retry_in),
},
);
debug!(instance_id = %inner.id, peer_addr = %addr, retry_ms = retry_in.as_millis(), "connect_with build failed, retrying");
if sleep_or_shutdown(&mut shutdown_rx, retry_in).await {
break;
}
continue;
}
},
Err(err) => {
emit_peer_event(
&inner,
PeerConnectionEvent::ConnectFailed {
peer: peer.clone(),
kind: ConnectionFailureKind::Permanent,
detail: err.to_string(),
retry_in: None,
},
);
error!(instance_id = %inner.id, peer_addr = %addr, "build client config failed permanently: {err}");
return;
}
}
} else {
match inner.endpoint.connect(addr, &peer.server_name) {
Ok(conn) => conn,
Err(err) => {
let kind = classify_connect_init_error(&err);
if kind == ConnectionFailureKind::Permanent {
emit_peer_event(
&inner,
PeerConnectionEvent::ConnectFailed {
peer: peer.clone(),
kind,
detail: err.to_string(),
retry_in: None,
},
);
error!(instance_id = %inner.id, peer_addr = %addr, "connect build failed permanently: {err}");
return;
}
let retry_in = retry_delay(failed_attempts, jitter_seed);
failed_attempts = failed_attempts.saturating_add(1);
emit_peer_event(
&inner,
PeerConnectionEvent::ConnectFailed {
peer: peer.clone(),
kind,
detail: err.to_string(),
retry_in: Some(retry_in),
},
);
debug!(instance_id = %inner.id, peer_addr = %addr, retry_ms = retry_in.as_millis(), "connect build failed, retrying");
if sleep_or_shutdown(&mut shutdown_rx, retry_in).await {
break;
}
continue;
}
}
};
let connection = tokio::select! {
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
break;
}
continue;
}
result = connecting => {
match result {
Ok(conn) => {
failed_attempts = 0;
conn
}
Err(err) => {
let retry_in = retry_delay(failed_attempts, jitter_seed);
failed_attempts = failed_attempts.saturating_add(1);
emit_peer_event(
&inner,
PeerConnectionEvent::ConnectFailed {
peer: peer.clone(),
kind: ConnectionFailureKind::Transient,
detail: err.to_string(),
retry_in: Some(retry_in),
},
);
debug!(instance_id = %inner.id, peer_addr = %addr, retry_ms = retry_in.as_millis(), "connect failed, retrying");
if sleep_or_shutdown(&mut shutdown_rx, retry_in).await {
break;
}
continue;
}
}
}
};
let result =
handle_outgoing_connection(Arc::clone(&inner), connection, peer.expected_id.clone())
.await;
if let Err(err) = result {
let kind = classify_outgoing_error(&err);
if kind == ConnectionFailureKind::Permanent {
emit_peer_event(
&inner,
PeerConnectionEvent::ConnectFailed {
peer: peer.clone(),
kind,
detail: err.to_string(),
retry_in: None,
},
);
error!(instance_id = %inner.id, peer_addr = %addr, "outgoing connection handling failed permanently: {err}");
return;
}
let retry_in = retry_delay(failed_attempts, jitter_seed);
failed_attempts = failed_attempts.saturating_add(1);
emit_peer_event(
&inner,
PeerConnectionEvent::ConnectFailed {
peer: peer.clone(),
kind,
detail: err.to_string(),
retry_in: Some(retry_in),
},
);
error!(
instance_id = %inner.id,
peer_addr = %addr,
retry_ms = retry_in.as_millis(),
"outgoing connection handling failed, retrying: {err}"
);
if sleep_or_shutdown(&mut shutdown_rx, retry_in).await {
break;
}
}
}
info!(instance_id = %inner.id, "connect loop stopped");
}
async fn handle_incoming_connection<S, M, H>(
inner: Arc<Inner<S, M, H>>,
connection: quinn::Connection,
) -> QrpcResult<()>
where
S: Send + Sync + 'static,
M: QrpcMessage,
H: QrpcCallback<S, M>,
{
let (mut send, mut recv) = connection.accept_bi().await?;
let packet = read_packet_from_stream(&mut recv).await?;
if packet.kind != PacketKind::Register {
return Err(QrpcError::MessageDecode(
"incoming connection first packet must be register".to_string(),
));
}
let peer_id = packet.source_id;
info!(instance_id = %inner.id, peer_id = %peer_id, "incoming peer registered");
send.write_all(&WirePacket::register(inner.id.clone()).encode_frame())
.await?;
send.finish()?;
let session_id = register_peer(Arc::clone(&inner), peer_id.clone(), connection.clone());
run_connection_machine(inner, connection, peer_id, session_id).await
}
async fn handle_outgoing_connection<S, M, H>(
inner: Arc<Inner<S, M, H>>,
connection: quinn::Connection,
expected_id: Option<String>,
) -> QrpcResult<()>
where
S: Send + Sync + 'static,
M: QrpcMessage,
H: QrpcCallback<S, M>,
{
let (mut send, mut recv) = connection.open_bi().await?;
send.write_all(&WirePacket::register(inner.id.clone()).encode_frame())
.await?;
send.finish()?;
let response = read_packet_from_stream(&mut recv).await?;
if response.kind != PacketKind::Register {
return Err(QrpcError::MessageDecode(
"outgoing connection register response must be register".to_string(),
));
}
if let Some(expected) = expected_id {
if expected != response.source_id {
return Err(QrpcError::PeerIdMismatch {
expected,
actual: response.source_id,
});
}
}
let peer_id = response.source_id;
info!(instance_id = %inner.id, peer_id = %peer_id, "outgoing peer registered");
let session_id = register_peer(Arc::clone(&inner), peer_id.clone(), connection.clone());
run_connection_machine(inner, connection, peer_id, session_id).await
}
async fn run_connection_machine<S, M, H>(
inner: Arc<Inner<S, M, H>>,
connection: quinn::Connection,
peer_id: String,
session_id: u64,
) -> QrpcResult<()>
where
S: Send + Sync + 'static,
M: QrpcMessage,
H: QrpcCallback<S, M>,
{
info!(instance_id = %inner.id, peer_id = %peer_id, "connection state machine started");
#[derive(Clone, Copy)]
enum ConnState {
Running,
Closing,
}
struct ConnCtx<S, M, H>
where
S: Send + Sync + 'static,
M: QrpcMessage,
H: QrpcCallback<S, M>,
{
state: ConnState,
inner: Arc<Inner<S, M, H>>,
connection: quinn::Connection,
peer_id: String,
}
let ctx = ConnCtx {
state: ConnState::Running,
inner: Arc::clone(&inner),
connection,
peer_id: peer_id.clone(),
};
let machine = stream::unfold(ctx, |mut ctx| async move {
match ctx.state {
ConnState::Running => match ctx.connection.accept_bi().await {
Ok((_send, mut recv)) => {
match read_packet_from_stream(&mut recv).await {
Ok(packet) => {
if !on_packet(&ctx.inner, &ctx.peer_id, packet).await {
ctx.state = ConnState::Closing;
}
}
Err(_) => ctx.state = ConnState::Closing,
}
Some(((), ctx))
}
Err(_) => {
ctx.state = ConnState::Closing;
Some(((), ctx))
}
},
ConnState::Closing => None,
}
});
futures::pin_mut!(machine);
while machine.next().await.is_some() {}
unregister_peer_if_current(&inner, &peer_id, session_id);
info!(instance_id = %inner.id, peer_id = %peer_id, "connection state machine stopped");
Ok(())
}
async fn on_packet<S, M, H>(inner: &Arc<Inner<S, M, H>>, peer_id: &str, packet: WirePacket) -> bool
where
S: Send + Sync + 'static,
M: QrpcMessage,
H: QrpcCallback<S, M>,
{
match packet.kind {
PacketKind::Register => true,
PacketKind::Disconnect => {
info!(instance_id = %inner.id, peer_id = peer_id, "received disconnect packet");
false
}
PacketKind::Data => {
if packet.target_id != BROADCAST_TARGET && packet.target_id != inner.id {
debug!(
instance_id = %inner.id,
peer_id = peer_id,
target_id = %packet.target_id,
"skipping packet not targeted to current instance"
);
return true;
}
if let Ok(message) = M::decode_vec(packet.cmd_id, &packet.payload) {
let ctx = Ctx::new(Arc::new(InstanceDispatcher {
inner: Arc::clone(inner),
}));
if let Err(err) = QrpcCallback::call(
&*inner.handler,
&inner.state,
ctx,
peer_id.to_string(),
message,
)
.await
{
error!(
instance_id = %inner.id,
peer_id = peer_id,
"callback returned error: {err}"
);
}
} else {
error!(
instance_id = %inner.id,
peer_id = peer_id,
cmd_id = packet.cmd_id,
"message decode failed"
);
}
true
}
}
}
fn register_peer<S, M, H>(
inner: Arc<Inner<S, M, H>>,
peer_id: String,
connection: quinn::Connection,
) -> u64
where
S: Send + Sync + 'static,
M: QrpcMessage,
H: QrpcCallback<S, M>,
{
let session_id = inner.next_session_id.fetch_add(1, Ordering::Relaxed);
let remote_addr = connection.remote_address();
inner.peers.insert(
peer_id.clone(),
PeerSession {
connection,
session_id,
},
);
emit_peer_event(
&inner,
PeerConnectionEvent::Connected {
peer_id: peer_id.clone(),
remote_addr,
},
);
debug!(instance_id = %inner.id, peers = inner.peers.len(), "peer registered");
session_id
}
fn unregister_peer_if_current<S, M, H>(inner: &Arc<Inner<S, M, H>>, peer_id: &str, session_id: u64)
where
S: Send + Sync + 'static,
M: QrpcMessage,
H: QrpcCallback<S, M>,
{
let removed = inner
.peers
.remove_if(peer_id, |_, session| session.session_id == session_id)
.is_some();
if !removed {
return;
}
emit_peer_event(
inner,
PeerConnectionEvent::Disconnected {
peer_id: peer_id.to_string(),
},
);
debug!(
instance_id = %inner.id,
peer_id = peer_id,
peers = inner.peers.len(),
"peer unregistered"
);
}
async fn send_packet_over_connection(
connection: &quinn::Connection,
packet: WirePacket,
) -> QrpcResult<()> {
let (mut send, _) = connection.open_bi().await?;
send.write_all(&packet.encode_frame()).await?;
send.finish()?;
Ok(())
}
async fn read_packet_from_stream(recv: &mut quinn::RecvStream) -> QrpcResult<WirePacket> {
let bytes = recv.read_to_end(MAX_PACKET_SIZE).await?;
WirePacket::decode_frame(&bytes)
}
fn emit_peer_event<S, M, H>(inner: &Arc<Inner<S, M, H>>, event: PeerConnectionEvent)
where
S: Send + Sync + 'static,
M: QrpcMessage,
H: QrpcCallback<S, M>,
{
let _ = inner.peer_event_tx.send(event);
}
fn classify_connect_init_error(err: &quinn::ConnectError) -> ConnectionFailureKind {
let msg = err.to_string().to_ascii_lowercase();
if msg.contains("invalid")
|| msg.contains("server name")
|| msg.contains("dns")
|| msg.contains("default client config")
{
ConnectionFailureKind::Permanent
} else {
ConnectionFailureKind::Transient
}
}
fn classify_outgoing_error(err: &QrpcError) -> ConnectionFailureKind {
match err {
QrpcError::PeerIdMismatch { .. } => ConnectionFailureKind::Permanent,
QrpcError::MessageDecode(_) => ConnectionFailureKind::Permanent,
QrpcError::Rustls(_) => ConnectionFailureKind::Permanent,
QrpcError::QuinnConnection(conn) => {
let msg = conn.to_string().to_ascii_lowercase();
if msg.contains("certificate") || msg.contains("tls") || msg.contains("crypto") {
ConnectionFailureKind::Permanent
} else {
ConnectionFailureKind::Transient
}
}
_ => ConnectionFailureKind::Transient,
}
}
fn backoff_seed(instance_id: &str, peer: &PeerConfig) -> u64 {
let mut hasher = DefaultHasher::new();
instance_id.hash(&mut hasher);
peer.address.hash(&mut hasher);
peer.port.hash(&mut hasher);
peer.server_name.hash(&mut hasher);
hasher.finish()
}
fn retry_delay(failed_attempts: u32, seed: u64) -> Duration {
let exp = failed_attempts.min(5);
let base_ms = (1u64 << exp) * 1000;
let mixed = seed
.wrapping_add((failed_attempts as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15))
.rotate_left(17);
let jitter_ms = mixed % 250;
Duration::from_millis((base_ms + jitter_ms).min(30_000))
}
async fn sleep_or_shutdown(shutdown_rx: &mut watch::Receiver<bool>, delay: Duration) -> bool {
tokio::select! {
_ = shutdown_rx.changed() => *shutdown_rx.borrow(),
_ = tokio::time::sleep(delay) => false,
}
}
async fn wait_for_shutdown_signal(mut shutdown_rx: watch::Receiver<bool>) {
if *shutdown_rx.borrow() {
return;
}
while shutdown_rx.changed().await.is_ok() {
if *shutdown_rx.borrow() {
break;
}
}
}
#[cfg(test)]
mod tests {
use crate::{Ctx, State};
use super::*;
#[derive(Clone)]
struct DummyMessage;
impl QrpcMessage for DummyMessage {
fn cmd_id(&self) -> u32 {
1
}
fn encode_vec(&self) -> Vec<u8> {
vec![]
}
fn decode_vec(_cmd_id: u32, _data: &[u8]) -> QrpcResult<Self> {
Ok(Self)
}
}
fn cb(
_state: State<usize>,
_ctx: Ctx<DummyMessage>,
_source_peer_id: String,
_msg: DummyMessage,
) -> impl std::future::Future<Output=QrpcResult<()>> {
async { Ok(()) }
}
#[test]
fn builder_missing_required_field() {
let ca = "tests/certs/ca.crt";
let cert = "tests/certs/server.crt";
let key = "tests/certs/server.key";
let builder = QrpcInstance::<usize, DummyMessage, _>::builder(cb)
.with_state(1usize)
.with_ca_cert(ca)
.with_identity(cert, key);
let err = match builder.build() {
Ok(_) => panic!("must fail without port"),
Err(err) => err,
};
assert!(matches!(err, QrpcError::MissingField("port")));
}
#[tokio::test]
async fn builder_auto_generates_id() {
let ca = "tests/certs/ca.crt";
let cert = "tests/certs/server.crt";
let key = "tests/certs/server.key";
let instance = QrpcInstance::<usize, DummyMessage, _>::builder(cb)
.with_state(1usize)
.with_ca_cert(ca)
.with_identity(cert, key)
.with_port(0)
.build()
.expect("build must succeed");
assert!(!instance.id().is_empty());
}
#[tokio::test]
async fn builder_custom_id() {
let ca = "tests/certs/ca.crt";
let cert = "tests/certs/server.crt";
let key = "tests/certs/server.key";
let instance = QrpcInstance::<usize, DummyMessage, _>::builder(cb)
.with_state(1usize)
.with_id("node-a")
.with_ca_cert(ca)
.with_identity(cert, key)
.with_port(0)
.build()
.expect("build must succeed");
assert_eq!(instance.id(), "node-a");
}
#[tokio::test]
async fn builder_add_peers() {
let ca = "tests/certs/ca.crt";
let cert = "tests/certs/server.crt";
let key = "tests/certs/server.key";
let peer = PeerConfig {
address: "127.0.0.1".to_string(),
port: 1234,
server_name: "localhost".to_string(),
ca_cert_path: None,
expected_id: Some("target".to_string()),
};
let instance = QrpcInstance::<usize, DummyMessage, _>::builder(cb)
.with_state(1usize)
.with_ca_cert(ca)
.with_identity(cert, key)
.with_port(0)
.add_peer(peer)
.build()
.expect("build must succeed");
let _ = instance;
}
#[tokio::test]
async fn builder_without_state_defaults_to_unit() {
let ca = "tests/certs/ca.crt";
let cert = "tests/certs/server.crt";
let key = "tests/certs/server.key";
let instance = QrpcInstance::<(), DummyMessage, _>::builder(
|_state: State<()>,
_ctx: Ctx<DummyMessage>,
_source_peer_id: String,
_msg: DummyMessage| async move { Ok(()) },
)
.with_ca_cert(ca)
.with_identity(cert, key)
.with_port(0)
.build()
.expect("build must succeed without explicit state");
assert!(!instance.id().is_empty());
}
#[tokio::test]
async fn builder_supports_separate_client_identity() {
let ca = "tests/certs/ca.crt";
let server_cert = "tests/certs/server.crt";
let server_key = "tests/certs/server.key";
let client_cert = "tests/certs/client.crt";
let client_key = "tests/certs/client.key";
let instance = QrpcInstance::<(), DummyMessage, _>::builder(
|_state: State<()>,
_ctx: Ctx<DummyMessage>,
_source_peer_id: String,
_msg: DummyMessage| async move { Ok(()) },
)
.with_ca_cert(ca)
.with_identity(server_cert, server_key)
.with_client_identity(client_cert, client_key)
.with_port(0)
.build()
.expect("build must succeed with separate client identity");
assert!(!instance.id().is_empty());
}
#[tokio::test]
async fn wait_for_peer_timeout_returns_error() {
let ca = "tests/certs/ca.crt";
let cert = "tests/certs/server.crt";
let key = "tests/certs/server.key";
let instance = QrpcInstance::<(), DummyMessage, _>::builder(
|_state: State<()>,
_ctx: Ctx<DummyMessage>,
_source_peer_id: String,
_msg: DummyMessage| async move { Ok(()) },
)
.with_ca_cert(ca)
.with_identity(cert, key)
.with_port(0)
.build()
.expect("build must succeed");
let err = instance
.wait_for_peer("missing-peer", Duration::from_millis(50))
.await
.expect_err("must timeout");
assert!(matches!(
err,
QrpcError::PeerWaitTimeout { ref peer_id, .. } if peer_id == "missing-peer"
));
}
#[tokio::test]
async fn emits_permanent_connect_error_for_invalid_peer_addr() {
let ca = "tests/certs/ca.crt";
let cert = "tests/certs/client.crt";
let key = "tests/certs/client.key";
let instance = QrpcInstance::<(), DummyMessage, _>::builder(
|_state: State<()>,
_ctx: Ctx<DummyMessage>,
_source_peer_id: String,
_msg: DummyMessage| async move { Ok(()) },
)
.with_id("node-a")
.with_ca_cert(ca)
.with_identity(cert, key)
.with_port(0)
.add_peer(PeerConfig {
address: "invalid-host".to_string(),
port: 12345,
server_name: "localhost".to_string(),
ca_cert_path: Some(ca.to_string()),
expected_id: None,
})
.build()
.expect("build must succeed");
let mut rx = instance.subscribe_peer_events();
instance.start().await;
let event = tokio::time::timeout(Duration::from_secs(1), rx.recv())
.await
.expect("event must arrive")
.expect("event channel open");
match event {
PeerConnectionEvent::Connecting { .. } => {}
other => panic!("first event must be connecting, got: {other:?}"),
}
let event = tokio::time::timeout(Duration::from_secs(1), rx.recv())
.await
.expect("event must arrive")
.expect("event channel open");
match event {
PeerConnectionEvent::ConnectFailed { kind, retry_in, .. } => {
assert_eq!(kind, ConnectionFailureKind::Permanent);
assert!(retry_in.is_none());
}
other => panic!("unexpected event: {other:?}"),
}
instance.shutdown().await;
}
#[tokio::test]
async fn serve_with_keeps_serving_after_worker_finishes() {
let ca = "tests/certs/ca.crt";
let cert = "tests/certs/server.crt";
let key = "tests/certs/server.key";
let instance = Arc::new(QrpcInstance::<(), DummyMessage, _>::builder(
|_state: State<()>,
_ctx: Ctx<DummyMessage>,
_source_peer_id: String,
_msg: DummyMessage| async move { Ok(()) },
)
.with_ca_cert(ca)
.with_identity(cert, key)
.with_port(0)
.build()
.expect("build must succeed"));
let (worker_done_tx, worker_done_rx) = tokio::sync::oneshot::channel();
let serve_instance = Arc::clone(&instance);
let serve_task = tokio::spawn(async move {
serve_instance
.serve_with(|_ctx| async {
tokio::time::sleep(Duration::from_millis(30)).await;
let _ = worker_done_tx.send(());
Ok(())
})
.await
});
tokio::time::timeout(Duration::from_secs(1), worker_done_rx)
.await
.expect("worker must finish")
.expect("worker signal channel open");
assert!(
!*instance.inner.shutdown_tx.borrow(),
"serve_with should keep serving after worker returns Ok"
);
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(
!serve_task.is_finished(),
"serve_with should still wait for shutdown"
);
instance.shutdown().await;
tokio::time::timeout(Duration::from_secs(1), serve_task)
.await
.expect("serve task should finish after shutdown")
.expect("serve task join should succeed")
.expect("serve_with should return Ok after shutdown");
}
#[tokio::test]
async fn serve_with_keeps_serving_after_worker_error() {
let ca = "tests/certs/ca.crt";
let cert = "tests/certs/server.crt";
let key = "tests/certs/server.key";
let instance = Arc::new(QrpcInstance::<(), DummyMessage, _>::builder(
|_state: State<()>,
_ctx: Ctx<DummyMessage>,
_source_peer_id: String,
_msg: DummyMessage| async move { Ok(()) },
)
.with_ca_cert(ca)
.with_identity(cert, key)
.with_port(0)
.build()
.expect("build must succeed"));
let (worker_done_tx, worker_done_rx) = tokio::sync::oneshot::channel();
let serve_instance = Arc::clone(&instance);
let serve_task = tokio::spawn(async move {
serve_instance
.serve_with(|_ctx| async {
let _ = worker_done_tx.send(());
Err(QrpcError::MessageDecode("worker failed".to_string()))
})
.await
});
tokio::time::timeout(Duration::from_secs(1), worker_done_rx)
.await
.expect("worker must finish")
.expect("worker signal channel open");
assert!(
!*instance.inner.shutdown_tx.borrow(),
"serve_with should keep serving after worker returns Err"
);
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(
!serve_task.is_finished(),
"serve_with should still wait for shutdown after worker Err"
);
instance.shutdown().await;
tokio::time::timeout(Duration::from_secs(1), serve_task)
.await
.expect("serve task should finish after shutdown")
.expect("serve task join should succeed")
.expect("serve_with should return Ok after shutdown");
}
}