use crate::error::{ClusterError, Result};
use crate::node::NodeId;
use crate::protocol::{
decode_request, decode_response, encode_request, encode_response, ClusterRequest,
ClusterResponse,
};
use bytes::{BufMut, BytesMut};
use dashmap::DashMap;
use parking_lot::RwLock;
use quinn::{
congestion, ClientConfig, Connection, Endpoint, RecvStream, SendStream, ServerConfig,
TransportConfig as QuinnTransportConfig, VarInt,
};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Semaphore;
use tokio::time::timeout;
use tracing::{debug, info, instrument, warn};
#[derive(Debug, Clone)]
pub struct QuicConfig {
pub idle_timeout: Duration,
pub keep_alive_interval: Duration,
pub max_concurrent_streams: u32,
pub initial_window: u32,
pub max_udp_payload_size: u16,
pub stream_receive_window: u32,
pub connection_receive_window: u32,
pub enable_0rtt: bool,
pub max_connections_per_peer: usize,
pub request_timeout: Duration,
pub use_bbr: bool,
}
impl Default for QuicConfig {
fn default() -> Self {
Self {
idle_timeout: Duration::from_secs(30),
keep_alive_interval: Duration::from_secs(10),
max_concurrent_streams: 256,
initial_window: 14720,
max_udp_payload_size: 1350,
stream_receive_window: 1024 * 1024, connection_receive_window: 8 * 1024 * 1024, enable_0rtt: true,
max_connections_per_peer: 2,
request_timeout: Duration::from_secs(30),
use_bbr: true,
}
}
}
impl QuicConfig {
pub fn high_throughput() -> Self {
Self {
idle_timeout: Duration::from_secs(60),
keep_alive_interval: Duration::from_secs(15),
max_concurrent_streams: 512,
initial_window: 65535,
max_udp_payload_size: 1452,
stream_receive_window: 4 * 1024 * 1024, connection_receive_window: 32 * 1024 * 1024, enable_0rtt: true,
max_connections_per_peer: 4,
request_timeout: Duration::from_secs(60),
use_bbr: true,
}
}
pub fn low_latency() -> Self {
Self {
idle_timeout: Duration::from_secs(10),
keep_alive_interval: Duration::from_secs(3),
max_concurrent_streams: 64,
initial_window: 14720,
max_udp_payload_size: 1350,
stream_receive_window: 256 * 1024, connection_receive_window: 2 * 1024 * 1024, enable_0rtt: true,
max_connections_per_peer: 1,
request_timeout: Duration::from_secs(5),
use_bbr: false, }
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum MtlsMode {
#[default]
Disabled,
Optional,
Required,
}
pub struct TlsConfig {
pub cert_chain: Vec<CertificateDer<'static>>,
pub private_key: PrivateKeyDer<'static>,
pub ca_certs: Vec<CertificateDer<'static>>,
pub mtls_mode: MtlsMode,
pub skip_verification: bool,
}
impl TlsConfig {
pub fn self_signed(common_name: &str) -> Result<Self> {
let cert = rcgen::generate_simple_self_signed(vec![common_name.to_string()])
.map_err(|e| ClusterError::CryptoError(format!("Failed to generate cert: {}", e)))?;
let cert_der = CertificateDer::from(cert.cert.der().to_vec());
let key_der = PrivatePkcs8KeyDer::from(cert.key_pair.serialize_der());
Ok(Self {
cert_chain: vec![cert_der],
private_key: PrivateKeyDer::Pkcs8(key_der),
ca_certs: vec![],
mtls_mode: MtlsMode::Disabled,
skip_verification: false,
})
}
pub fn from_pem_files(cert_path: &str, key_path: &str) -> Result<Self> {
let cert_pem = std::fs::read(cert_path).map_err(ClusterError::Io)?;
let key_pem = std::fs::read(key_path).map_err(ClusterError::Io)?;
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_pem.as_slice())
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| ClusterError::CryptoError(format!("Failed to parse cert: {}", e)))?;
let key = rustls_pemfile::private_key(&mut key_pem.as_slice())
.map_err(|e| ClusterError::CryptoError(format!("Failed to parse key: {}", e)))?
.ok_or_else(|| ClusterError::CryptoError("No private key found".to_string()))?;
Ok(Self {
cert_chain: certs,
private_key: key,
ca_certs: vec![],
mtls_mode: MtlsMode::Disabled,
skip_verification: false,
})
}
pub fn mtls_from_pem_files(cert_path: &str, key_path: &str, ca_path: &str) -> Result<Self> {
let cert_pem = std::fs::read(cert_path).map_err(ClusterError::Io)?;
let key_pem = std::fs::read(key_path).map_err(ClusterError::Io)?;
let ca_pem = std::fs::read(ca_path).map_err(ClusterError::Io)?;
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_pem.as_slice())
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| ClusterError::CryptoError(format!("Failed to parse cert: {}", e)))?;
let key = rustls_pemfile::private_key(&mut key_pem.as_slice())
.map_err(|e| ClusterError::CryptoError(format!("Failed to parse key: {}", e)))?
.ok_or_else(|| ClusterError::CryptoError("No private key found".to_string()))?;
let ca_certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut ca_pem.as_slice())
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| ClusterError::CryptoError(format!("Failed to parse CA cert: {}", e)))?;
Ok(Self {
cert_chain: certs,
private_key: key,
ca_certs,
mtls_mode: MtlsMode::Required,
skip_verification: false,
})
}
pub fn with_mtls_mode(mut self, mode: MtlsMode) -> Self {
self.mtls_mode = mode;
self
}
}
#[derive(Debug, Clone)]
pub struct PeerIdentity {
pub common_name: Option<String>,
pub dns_names: Vec<String>,
pub fingerprint: String,
}
impl PeerIdentity {
pub fn from_connection(connection: &Connection) -> Option<Self> {
let peer_certs = connection.peer_identity()?;
let certs: &Vec<CertificateDer<'static>> = peer_certs.downcast_ref()?;
if certs.is_empty() {
return None;
}
let cert_der = &certs[0];
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(cert_der.as_ref());
let fingerprint = hex::encode(hasher.finalize());
let (common_name, dns_names) = match x509_parser::parse_x509_certificate(cert_der.as_ref())
{
Ok((_, cert)) => {
let cn = cert
.subject()
.iter_common_name()
.next()
.and_then(|attr| attr.as_str().ok())
.map(|s| s.to_string());
let sans: Vec<String> = cert
.subject_alternative_name()
.ok()
.flatten()
.map(|san| {
san.value
.general_names
.iter()
.filter_map(|gn| match gn {
x509_parser::prelude::GeneralName::DNSName(name) => {
Some(name.to_string())
}
_ => None,
})
.collect()
})
.unwrap_or_default();
(cn, sans)
}
Err(_) => (None, vec![]),
};
Some(Self {
common_name,
dns_names,
fingerprint,
})
}
}
struct ManagedConnection {
connection: Connection,
stream_semaphore: Arc<Semaphore>,
active_streams: AtomicU64,
#[allow(dead_code)] created_at: std::time::Instant,
last_used: RwLock<std::time::Instant>,
healthy: AtomicBool,
}
impl ManagedConnection {
fn new(connection: Connection, max_streams: u32) -> Self {
let now = std::time::Instant::now();
Self {
connection,
stream_semaphore: Arc::new(Semaphore::new(max_streams as usize)),
active_streams: AtomicU64::new(0),
created_at: now,
last_used: RwLock::new(now),
healthy: AtomicBool::new(true),
}
}
fn is_healthy(&self) -> bool {
self.healthy.load(Ordering::Acquire) && self.connection.close_reason().is_none()
}
fn mark_unhealthy(&self) {
self.healthy.store(false, Ordering::Release);
}
fn touch(&self) {
*self.last_used.write() = std::time::Instant::now();
}
async fn open_bi_stream(&self) -> Result<(SendStream, RecvStream)> {
let _permit = self
.stream_semaphore
.clone()
.acquire_owned()
.await
.map_err(|_| ClusterError::ConnectionClosed)?;
self.active_streams.fetch_add(1, Ordering::SeqCst);
let stream =
self.connection.open_bi().await.map_err(|e| {
ClusterError::ConnectionFailed(format!("Failed to open stream: {}", e))
})?;
Ok(stream)
}
}
struct PeerConnectionPool {
connections: RwLock<Vec<Arc<ManagedConnection>>>,
max_connections: usize,
#[allow(dead_code)] connecting: AtomicBool,
}
impl PeerConnectionPool {
fn new(max_connections: usize) -> Self {
Self {
connections: RwLock::new(Vec::with_capacity(max_connections)),
max_connections,
connecting: AtomicBool::new(false),
}
}
fn get_connection(&self) -> Option<Arc<ManagedConnection>> {
let conns = self.connections.read();
conns
.iter()
.filter(|c| c.is_healthy())
.min_by_key(|c| c.active_streams.load(Ordering::Relaxed))
.cloned()
}
fn add_connection(&self, conn: Arc<ManagedConnection>) {
let mut conns = self.connections.write();
conns.retain(|c| c.is_healthy());
if conns.len() < self.max_connections {
conns.push(conn);
}
}
#[allow(dead_code)] fn cleanup(&self) {
let mut conns = self.connections.write();
conns.retain(|c| c.is_healthy());
}
}
#[derive(Debug, Default)]
pub struct QuicStats {
pub connections_established: AtomicU64,
pub connections_failed: AtomicU64,
pub streams_opened: AtomicU64,
pub streams_failed: AtomicU64,
pub bytes_sent: AtomicU64,
pub bytes_received: AtomicU64,
pub requests_sent: AtomicU64,
pub responses_received: AtomicU64,
pub request_timeouts: AtomicU64,
pub zero_rtt_connections: AtomicU64,
}
impl QuicStats {
pub fn snapshot(&self) -> QuicStatsSnapshot {
QuicStatsSnapshot {
connections_established: self.connections_established.load(Ordering::Relaxed),
connections_failed: self.connections_failed.load(Ordering::Relaxed),
streams_opened: self.streams_opened.load(Ordering::Relaxed),
streams_failed: self.streams_failed.load(Ordering::Relaxed),
bytes_sent: self.bytes_sent.load(Ordering::Relaxed),
bytes_received: self.bytes_received.load(Ordering::Relaxed),
requests_sent: self.requests_sent.load(Ordering::Relaxed),
responses_received: self.responses_received.load(Ordering::Relaxed),
request_timeouts: self.request_timeouts.load(Ordering::Relaxed),
zero_rtt_connections: self.zero_rtt_connections.load(Ordering::Relaxed),
}
}
}
#[derive(Debug, Clone)]
pub struct QuicStatsSnapshot {
pub connections_established: u64,
pub connections_failed: u64,
pub streams_opened: u64,
pub streams_failed: u64,
pub bytes_sent: u64,
pub bytes_received: u64,
pub requests_sent: u64,
pub responses_received: u64,
pub request_timeouts: u64,
pub zero_rtt_connections: u64,
}
pub type QuicRequestHandler = Arc<dyn Fn(ClusterRequest) -> ClusterResponse + Send + Sync>;
pub struct QuicTransport {
local_node: NodeId,
endpoint: Endpoint,
connections: Arc<DashMap<NodeId, PeerConnectionPool>>,
peer_addrs: Arc<DashMap<NodeId, SocketAddr>>,
config: QuicConfig,
handler: Option<QuicRequestHandler>,
#[allow(dead_code)] correlation_id: AtomicU64,
stats: Arc<QuicStats>,
shutdown: Arc<AtomicBool>,
}
impl QuicTransport {
#[instrument(skip(tls_config, config))]
pub fn new(
local_node: NodeId,
bind_addr: SocketAddr,
tls_config: TlsConfig,
config: QuicConfig,
) -> Result<Self> {
let server_crypto = Self::build_server_crypto(&tls_config)?;
let mut server_config = ServerConfig::with_crypto(Arc::new(server_crypto));
let transport_config = Self::build_transport_config(&config);
server_config.transport_config(Arc::new(transport_config));
let client_config = Self::build_client_config(&tls_config, &config)?;
let mut endpoint = Endpoint::server(server_config, bind_addr)
.map_err(|e| ClusterError::Network(format!("Failed to create endpoint: {}", e)))?;
endpoint.set_default_client_config(client_config);
info!(
node = %local_node,
addr = %bind_addr,
"QUIC transport initialized"
);
Ok(Self {
local_node,
endpoint,
connections: Arc::new(DashMap::new()),
peer_addrs: Arc::new(DashMap::new()),
config,
handler: None,
correlation_id: AtomicU64::new(1),
stats: Arc::new(QuicStats::default()),
shutdown: Arc::new(AtomicBool::new(false)),
})
}
fn build_server_crypto(tls: &TlsConfig) -> Result<quinn::crypto::rustls::QuicServerConfig> {
let server_config = match tls.mtls_mode {
MtlsMode::Disabled => {
rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(tls.cert_chain.clone(), tls.private_key.clone_key())
.map_err(|e| ClusterError::CryptoError(format!("TLS config error: {}", e)))?
}
MtlsMode::Optional | MtlsMode::Required => {
let mut roots = rustls::RootCertStore::empty();
for cert in &tls.ca_certs {
roots.add(cert.clone()).map_err(|e| {
ClusterError::CryptoError(format!("Failed to add CA: {:?}", e))
})?;
}
if roots.is_empty() {
return Err(ClusterError::CryptoError(
"mTLS enabled but no CA certificates provided".to_string(),
));
}
let verifier = if tls.mtls_mode == MtlsMode::Required {
rustls::server::WebPkiClientVerifier::builder(Arc::new(roots))
.build()
.map_err(|e| {
ClusterError::CryptoError(format!("Failed to build verifier: {}", e))
})?
} else {
rustls::server::WebPkiClientVerifier::builder(Arc::new(roots))
.allow_unauthenticated()
.build()
.map_err(|e| {
ClusterError::CryptoError(format!("Failed to build verifier: {}", e))
})?
};
rustls::ServerConfig::builder()
.with_client_cert_verifier(verifier)
.with_single_cert(tls.cert_chain.clone(), tls.private_key.clone_key())
.map_err(|e| ClusterError::CryptoError(format!("TLS config error: {}", e)))?
}
};
let mut server_config = server_config;
server_config.max_early_data_size = u32::MAX;
server_config.alpn_protocols = vec![b"rivven".to_vec()];
quinn::crypto::rustls::QuicServerConfig::try_from(server_config)
.map_err(|e| ClusterError::CryptoError(format!("QUIC server config error: {}", e)))
}
fn build_client_config(tls: &TlsConfig, config: &QuicConfig) -> Result<ClientConfig> {
let mut roots = rustls::RootCertStore::empty();
for cert in &tls.ca_certs {
roots
.add(cert.clone())
.map_err(|e| ClusterError::CryptoError(format!("Failed to add CA: {:?}", e)))?;
}
let native_result = rustls_native_certs::load_native_certs();
for cert in native_result.certs {
let _ = roots.add(cert);
}
let crypto = if tls.skip_verification {
#[cfg(any(test, feature = "dangerous-skip-verify"))]
{
if tls.mtls_mode != MtlsMode::Disabled {
rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(SkipServerVerification))
.with_client_auth_cert(tls.cert_chain.clone(), tls.private_key.clone_key())
.map_err(|e| {
ClusterError::CryptoError(format!("Client cert error: {}", e))
})?
} else {
rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(SkipServerVerification))
.with_no_client_auth()
}
}
#[cfg(not(any(test, feature = "dangerous-skip-verify")))]
{
return Err(ClusterError::CryptoError(
"skip_verification requires the 'dangerous-skip-verify' feature".into(),
));
}
} else if tls.mtls_mode != MtlsMode::Disabled {
rustls::ClientConfig::builder()
.with_root_certificates(roots)
.with_client_auth_cert(tls.cert_chain.clone(), tls.private_key.clone_key())
.map_err(|e| ClusterError::CryptoError(format!("Client cert error: {}", e)))?
} else {
rustls::ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth()
};
let mut client_config = ClientConfig::new(Arc::new(
quinn::crypto::rustls::QuicClientConfig::try_from(crypto).map_err(|e| {
ClusterError::CryptoError(format!("QUIC crypto config error: {}", e))
})?,
));
let transport = Self::build_transport_config(config);
client_config.transport_config(Arc::new(transport));
Ok(client_config)
}
fn build_transport_config(config: &QuicConfig) -> QuinnTransportConfig {
let mut transport = QuinnTransportConfig::default();
transport.max_concurrent_bidi_streams(VarInt::from_u32(config.max_concurrent_streams));
transport.initial_mtu(config.max_udp_payload_size);
transport.stream_receive_window(VarInt::from_u32(config.stream_receive_window));
transport.receive_window(VarInt::from_u32(config.connection_receive_window));
transport.keep_alive_interval(Some(config.keep_alive_interval));
transport.max_idle_timeout(Some(config.idle_timeout.try_into().unwrap()));
if config.use_bbr {
transport.congestion_controller_factory(Arc::new(congestion::BbrConfig::default()));
} else {
transport.congestion_controller_factory(Arc::new(congestion::CubicConfig::default()));
}
transport
}
pub fn set_handler(&mut self, handler: QuicRequestHandler) {
self.handler = Some(handler);
}
pub fn add_peer(&self, node_id: NodeId, addr: SocketAddr) {
self.peer_addrs.insert(node_id.clone(), addr);
self.connections.insert(
node_id,
PeerConnectionPool::new(self.config.max_connections_per_peer),
);
}
pub fn remove_peer(&self, node_id: &NodeId) {
self.peer_addrs.remove(node_id);
self.connections.remove(node_id);
}
#[instrument(skip(self))]
pub async fn start(&self) -> Result<()> {
let endpoint = self.endpoint.clone();
let connections = self.connections.clone();
let handler = self.handler.clone();
let config = self.config.clone();
let stats = self.stats.clone();
let shutdown = self.shutdown.clone();
tokio::spawn(async move {
while !shutdown.load(Ordering::Relaxed) {
tokio::select! {
incoming = endpoint.accept() => {
if let Some(incoming) = incoming {
let handler = handler.clone();
let config = config.clone();
let stats = stats.clone();
let connections = connections.clone();
tokio::spawn(async move {
if let Err(e) = Self::handle_incoming(
incoming, handler, config, stats, connections
).await {
debug!(error = %e, "Incoming connection error");
}
});
}
}
_ = tokio::time::sleep(Duration::from_millis(100)) => {
}
}
}
info!("QUIC transport acceptor shutting down");
});
Ok(())
}
async fn handle_incoming(
incoming: quinn::Incoming,
handler: Option<QuicRequestHandler>,
config: QuicConfig,
stats: Arc<QuicStats>,
_connections: Arc<DashMap<NodeId, PeerConnectionPool>>,
) -> Result<()> {
let connection = incoming
.await
.map_err(|e| ClusterError::ConnectionFailed(format!("Accept failed: {}", e)))?;
stats
.connections_established
.fetch_add(1, Ordering::Relaxed);
let remote = connection.remote_address();
debug!(peer = %remote, "Accepted QUIC connection");
loop {
match connection.accept_bi().await {
Ok((send, recv)) => {
let handler = handler.clone();
let config = config.clone();
let stats = stats.clone();
tokio::spawn(async move {
if let Err(e) =
Self::handle_stream(send, recv, handler, config, stats).await
{
debug!(error = %e, "Stream handling error");
}
});
}
Err(quinn::ConnectionError::ApplicationClosed(_)) => {
debug!(peer = %remote, "Connection closed by peer");
break;
}
Err(e) => {
warn!(peer = %remote, error = %e, "Connection error");
break;
}
}
}
Ok(())
}
async fn handle_stream(
mut send: SendStream,
mut recv: RecvStream,
handler: Option<QuicRequestHandler>,
config: QuicConfig,
stats: Arc<QuicStats>,
) -> Result<()> {
stats.streams_opened.fetch_add(1, Ordering::Relaxed);
let request_bytes = Self::read_message(&mut recv, &config).await?;
stats
.bytes_received
.fetch_add(request_bytes.len() as u64, Ordering::Relaxed);
let request = decode_request(&request_bytes)?;
if let Some(handler) = handler {
let response = handler(request);
let response_bytes = encode_response(&response)?;
Self::write_message(&mut send, &response_bytes, &config).await?;
stats
.bytes_sent
.fetch_add(response_bytes.len() as u64, Ordering::Relaxed);
}
send.finish()
.map_err(|e| ClusterError::Network(format!("Failed to finish stream: {}", e)))?;
Ok(())
}
async fn read_message(recv: &mut RecvStream, config: &QuicConfig) -> Result<Vec<u8>> {
let mut len_buf = [0u8; 4];
timeout(config.request_timeout, recv.read_exact(&mut len_buf))
.await
.map_err(|_| ClusterError::Timeout)?
.map_err(|e| ClusterError::Network(format!("Failed to read length: {}", e)))?;
let len = u32::from_be_bytes(len_buf) as usize;
if len > crate::protocol::MAX_MESSAGE_SIZE {
return Err(ClusterError::MessageTooLarge {
size: len,
max: crate::protocol::MAX_MESSAGE_SIZE,
});
}
let mut body = vec![0u8; len];
timeout(config.request_timeout, recv.read_exact(&mut body))
.await
.map_err(|_| ClusterError::Timeout)?
.map_err(|e| ClusterError::Network(format!("Failed to read body: {}", e)))?;
Ok(body)
}
async fn write_message(send: &mut SendStream, data: &[u8], config: &QuicConfig) -> Result<()> {
let mut buf = BytesMut::with_capacity(4 + data.len());
buf.put_u32(data.len() as u32);
buf.put_slice(data);
timeout(config.request_timeout, send.write_all(&buf))
.await
.map_err(|_| ClusterError::Timeout)?
.map_err(|e| ClusterError::Network(format!("Failed to write: {}", e)))?;
Ok(())
}
#[instrument(skip(self, request), fields(node = %node_id))]
pub async fn send(&self, node_id: &NodeId, request: ClusterRequest) -> Result<ClusterResponse> {
let conn = self.get_or_create_connection(node_id).await?;
self.stats.requests_sent.fetch_add(1, Ordering::Relaxed);
let (mut send, mut recv) = conn.open_bi_stream().await?;
self.stats.streams_opened.fetch_add(1, Ordering::Relaxed);
conn.touch();
let request_bytes = encode_request(&request)?;
Self::write_message(&mut send, &request_bytes, &self.config).await?;
self.stats
.bytes_sent
.fetch_add(request_bytes.len() as u64, Ordering::Relaxed);
send.finish()
.map_err(|e| ClusterError::Network(format!("Failed to finish send: {}", e)))?;
let response_bytes = match Self::read_message(&mut recv, &self.config).await {
Ok(bytes) => bytes,
Err(ClusterError::Timeout) => {
self.stats.request_timeouts.fetch_add(1, Ordering::Relaxed);
conn.active_streams.fetch_sub(1, Ordering::SeqCst);
conn.mark_unhealthy();
return Err(ClusterError::Timeout);
}
Err(e) => {
conn.active_streams.fetch_sub(1, Ordering::SeqCst);
conn.mark_unhealthy();
return Err(e);
}
};
self.stats
.bytes_received
.fetch_add(response_bytes.len() as u64, Ordering::Relaxed);
self.stats
.responses_received
.fetch_add(1, Ordering::Relaxed);
let response = decode_response(&response_bytes)?;
conn.active_streams.fetch_sub(1, Ordering::SeqCst);
Ok(response)
}
pub async fn send_async(&self, node_id: &NodeId, request: ClusterRequest) -> Result<()> {
let conn = self.get_or_create_connection(node_id).await?;
let (mut send, _recv) = conn.open_bi_stream().await?;
let request_bytes = encode_request(&request)?;
Self::write_message(&mut send, &request_bytes, &self.config).await?;
send.finish()
.map_err(|e| ClusterError::Network(format!("Failed to finish: {}", e)))?;
conn.active_streams.fetch_sub(1, Ordering::SeqCst);
Ok(())
}
pub async fn broadcast(
&self,
request: ClusterRequest,
) -> Vec<(NodeId, Result<ClusterResponse>)> {
let peers: Vec<_> = self.peer_addrs.iter().map(|e| e.key().clone()).collect();
let mut futures = Vec::with_capacity(peers.len());
for peer in peers {
let request = request.clone();
let this = self;
futures.push(async move {
let result = this.send(&peer, request).await;
(peer, result)
});
}
futures::future::join_all(futures).await
}
async fn get_or_create_connection(&self, node_id: &NodeId) -> Result<Arc<ManagedConnection>> {
if let Some(pool) = self.connections.get(node_id) {
if let Some(conn) = pool.get_connection() {
return Ok(conn);
}
}
let addr = *self
.peer_addrs
.get(node_id)
.ok_or_else(|| ClusterError::NodeNotFound(node_id.clone()))?;
let sni_name = addr.ip().to_string();
let connection = self
.endpoint
.connect(addr, &sni_name) .map_err(|e| ClusterError::ConnectionFailed(format!("Connect error: {}", e)))?
.await
.map_err(|e| {
self.stats
.connections_failed
.fetch_add(1, Ordering::Relaxed);
ClusterError::ConnectionFailed(format!("Connection failed: {}", e))
})?;
self.stats
.connections_established
.fetch_add(1, Ordering::Relaxed);
let managed = Arc::new(ManagedConnection::new(
connection,
self.config.max_concurrent_streams,
));
if let Some(pool) = self.connections.get(node_id) {
pool.add_connection(managed.clone());
} else {
let pool = PeerConnectionPool::new(self.config.max_connections_per_peer);
pool.add_connection(managed.clone());
self.connections.insert(node_id.clone(), pool);
}
Ok(managed)
}
fn _next_correlation_id(&self) -> u64 {
self.correlation_id.fetch_add(1, Ordering::SeqCst)
}
pub fn stats(&self) -> &QuicStats {
&self.stats
}
pub async fn shutdown(&self) {
self.shutdown.store(true, Ordering::Release);
for pool in self.connections.iter() {
let conns = pool.connections.read();
for conn in conns.iter() {
conn.connection.close(VarInt::from_u32(0), b"shutdown");
}
}
self.endpoint.close(VarInt::from_u32(0), b"shutdown");
self.endpoint.wait_idle().await;
info!(node = %self.local_node, "QUIC transport shutdown complete");
}
}
#[cfg(any(test, feature = "dangerous-skip-verify"))]
#[derive(Debug)]
struct SkipServerVerification;
#[cfg(any(test, feature = "dangerous-skip-verify"))]
impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quic_config_defaults() {
let config = QuicConfig::default();
assert_eq!(config.max_concurrent_streams, 256);
assert!(config.enable_0rtt);
assert!(config.use_bbr);
}
#[test]
fn test_quic_config_high_throughput() {
let config = QuicConfig::high_throughput();
assert_eq!(config.max_concurrent_streams, 512);
assert_eq!(config.connection_receive_window, 32 * 1024 * 1024);
}
#[test]
fn test_quic_config_low_latency() {
let config = QuicConfig::low_latency();
assert_eq!(config.max_concurrent_streams, 64);
assert!(!config.use_bbr); }
#[test]
fn test_tls_self_signed() {
let tls = TlsConfig::self_signed("test.rivven.local").unwrap();
assert!(!tls.cert_chain.is_empty());
}
#[test]
fn test_stats_snapshot() {
let stats = QuicStats::default();
stats
.connections_established
.fetch_add(5, Ordering::Relaxed);
stats.bytes_sent.fetch_add(1000, Ordering::Relaxed);
let snap = stats.snapshot();
assert_eq!(snap.connections_established, 5);
assert_eq!(snap.bytes_sent, 1000);
}
}