use {
bytes::Bytes,
crossbeam_channel::Sender,
futures::future::{TryJoin, TryJoin3},
log::error,
quinn::{
ClientConfig, ConnectError, Connecting, Connection, ConnectionError, Endpoint,
EndpointConfig, IdleTimeout, SendDatagramError, ServerConfig, TokioRuntime,
TransportConfig, VarInt,
crypto::rustls::{QuicClientConfig, QuicServerConfig},
},
rustls::{
CertificateError, KeyLogFile,
pki_types::{CertificateDer, PrivateKeyDer},
},
solana_gossip::contact_info::Protocol,
solana_keypair::Keypair,
solana_pubkey::Pubkey,
solana_runtime::bank_forks::BankForks,
solana_tls_utils::{
new_dummy_x509_certificate, socket_addr_to_quic_server_name, tls_client_config_builder,
tls_server_config_builder,
},
std::{
cmp::Reverse,
collections::{HashMap, hash_map::Entry},
io::Error as IoError,
net::{SocketAddr, UdpSocket},
sync::{
Arc, RwLock,
atomic::{AtomicBool, AtomicU64, Ordering},
},
time::Duration,
},
thiserror::Error,
tokio::{
sync::{
Mutex, RwLock as AsyncRwLock,
mpsc::{Receiver as AsyncReceiver, Sender as AsyncSender, error::TrySendError},
},
task::JoinHandle,
},
};
const CLIENT_CHANNEL_BUFFER: usize = 1 << 14;
const ROUTER_CHANNEL_BUFFER: usize = 64;
const CONNECTION_CACHE_CAPACITY: usize = 3072;
const ALPN_REPAIR_PROTOCOL_ID: &[u8] = b"solana-repair";
const DATAGRAM_RECEIVE_BUFFER_SIZE: usize = 256 * 1024 * 1024;
const DATAGRAM_SEND_BUFFER_SIZE: usize = 128 * 1024 * 1024;
const INITIAL_MAXIMUM_TRANSMISSION_UNIT: u16 = MINIMUM_MAXIMUM_TRANSMISSION_UNIT;
const KEEP_ALIVE_INTERVAL: Duration = Duration::from_secs(4);
const MAX_IDLE_TIMEOUT: Duration = Duration::from_secs(10);
const MINIMUM_MAXIMUM_TRANSMISSION_UNIT: u16 = 1280;
const CONNECTION_CLOSE_ERROR_CODE_SHUTDOWN: VarInt = VarInt::from_u32(1);
const CONNECTION_CLOSE_ERROR_CODE_DROPPED: VarInt = VarInt::from_u32(2);
const CONNECTION_CLOSE_ERROR_CODE_INVALID_IDENTITY: VarInt = VarInt::from_u32(3);
const CONNECTION_CLOSE_ERROR_CODE_REPLACED: VarInt = VarInt::from_u32(4);
const CONNECTION_CLOSE_ERROR_CODE_PRUNED: VarInt = VarInt::from_u32(5);
const CONNECTION_CLOSE_REASON_SHUTDOWN: &[u8] = b"SHUTDOWN";
const CONNECTION_CLOSE_REASON_DROPPED: &[u8] = b"DROPPED";
const CONNECTION_CLOSE_REASON_INVALID_IDENTITY: &[u8] = b"INVALID_IDENTITY";
const CONNECTION_CLOSE_REASON_REPLACED: &[u8] = b"REPLACED";
const CONNECTION_CLOSE_REASON_PRUNED: &[u8] = b"PRUNED";
pub(crate) type AsyncTryJoinHandle = TryJoin3<
TryJoin<JoinHandle<()>, JoinHandle<()>>,
TryJoin<JoinHandle<()>, JoinHandle<()>>,
TryJoin<JoinHandle<()>, JoinHandle<()>>,
>;
pub(crate) struct RemoteRequest {
pub(crate) remote_pubkey: Option<Pubkey>,
pub(crate) remote_address: SocketAddr,
pub(crate) bytes: Bytes,
}
pub(crate) struct RepairQuicAsyncSenders {
pub(crate) repair_response_quic_sender: AsyncSender<(SocketAddr, Bytes)>,
pub(crate) repair_request_quic_sender: AsyncSender<(SocketAddr, Bytes)>,
pub(crate) ancestor_hashes_request_quic_sender: AsyncSender<(SocketAddr, Bytes)>,
}
pub(crate) struct RepairQuicSockets {
pub(crate) repair_server_quic_socket: UdpSocket,
pub(crate) repair_client_quic_socket: UdpSocket,
pub(crate) ancestor_hashes_quic_socket: UdpSocket,
}
pub(crate) struct RepairQuicSenders {
pub(crate) repair_request_quic_sender: Sender<RemoteRequest>,
pub(crate) repair_response_quic_sender: Sender<(Pubkey, SocketAddr, Bytes)>,
pub(crate) ancestor_hashes_response_quic_sender: Sender<(Pubkey, SocketAddr, Bytes)>,
}
#[derive(Error, Debug)]
#[allow(clippy::enum_variant_names)]
pub(crate) enum Error {
#[error("Channel Send Error")]
ChannelSendError,
#[error(transparent)]
ConnectError(#[from] ConnectError),
#[error(transparent)]
ConnectionError(#[from] ConnectionError),
#[error("Invalid Identity: {0:?}")]
InvalidIdentity(SocketAddr),
#[error(transparent)]
IoError(#[from] IoError),
#[error(transparent)]
SendDatagramError(#[from] SendDatagramError),
#[error(transparent)]
TlsError(#[from] rustls::Error),
}
macro_rules! add_metric {
($metric: expr) => {{
$metric.fetch_add(1, Ordering::Relaxed);
}};
}
pub(crate) fn new_quic_endpoints(
runtime: &tokio::runtime::Handle,
keypair: &Keypair,
sockets: RepairQuicSockets,
senders: RepairQuicSenders,
bank_forks: Arc<RwLock<BankForks>>,
) -> Result<([Endpoint; 3], RepairQuicAsyncSenders, AsyncTryJoinHandle), Error> {
let (repair_server_quic_endpoint, repair_response_quic_sender, repair_server_join_handle) =
new_quic_endpoint(
runtime,
"repair_server_quic_client",
"repair_server_quic_server",
keypair,
sockets.repair_server_quic_socket,
senders.repair_request_quic_sender,
bank_forks.clone(),
)?;
let (repair_client_quic_endpoint, repair_request_quic_sender, repair_client_join_handle) =
new_quic_endpoint(
runtime,
"repair_client_quic_client",
"repair_client_quic_server",
keypair,
sockets.repair_client_quic_socket,
senders.repair_response_quic_sender,
bank_forks.clone(),
)?;
let (
ancestor_hashes_quic_endpoint,
ancestor_hashes_request_quic_sender,
ancestor_hashes_join_handle,
) = new_quic_endpoint(
runtime,
"ancestor_hashes_quic_client",
"ancestor_hashes_quic_server",
keypair,
sockets.ancestor_hashes_quic_socket,
senders.ancestor_hashes_response_quic_sender,
bank_forks,
)?;
Ok((
[
repair_server_quic_endpoint,
repair_client_quic_endpoint,
ancestor_hashes_quic_endpoint,
],
RepairQuicAsyncSenders {
repair_response_quic_sender,
repair_request_quic_sender,
ancestor_hashes_request_quic_sender,
},
futures::future::try_join3(
repair_server_join_handle,
repair_client_join_handle,
ancestor_hashes_join_handle,
),
))
}
#[allow(clippy::type_complexity)]
fn new_quic_endpoint<T>(
runtime: &tokio::runtime::Handle,
client_name: &'static str,
server_name: &'static str,
keypair: &Keypair,
socket: UdpSocket,
sender: Sender<T>,
bank_forks: Arc<RwLock<BankForks>>,
) -> Result<
(
Endpoint,
AsyncSender<(SocketAddr, Bytes)>,
TryJoin<JoinHandle<()>, JoinHandle<()>>,
),
Error,
>
where
T: 'static + From<(Pubkey, SocketAddr, Bytes)> + Send,
{
let (cert, key) = new_dummy_x509_certificate(keypair);
let server_config = new_server_config(cert.clone(), key.clone_key())?;
let client_config = new_client_config(cert, key)?;
let mut endpoint = {
let _guard = runtime.enter();
Endpoint::new(
EndpointConfig::default(),
Some(server_config),
socket,
Arc::new(TokioRuntime),
)?
};
endpoint.set_default_client_config(client_config);
let prune_cache_pending = Arc::<AtomicBool>::default();
let cache = Arc::<Mutex<HashMap<Pubkey, Connection>>>::default();
let router = Arc::<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>::default();
let (client_sender, client_receiver) = tokio::sync::mpsc::channel(CLIENT_CHANNEL_BUFFER);
let server_task = runtime.spawn(run_server(
endpoint.clone(),
server_name,
sender.clone(),
bank_forks.clone(),
prune_cache_pending.clone(),
router.clone(),
cache.clone(),
));
let client_task = runtime.spawn(run_client(
endpoint.clone(),
client_name,
client_receiver,
sender,
bank_forks,
prune_cache_pending,
router,
cache,
));
let task = futures::future::try_join(server_task, client_task);
Ok((endpoint, client_sender, task))
}
pub(crate) fn close_quic_endpoint(endpoint: &Endpoint) {
endpoint.close(
CONNECTION_CLOSE_ERROR_CODE_SHUTDOWN,
CONNECTION_CLOSE_REASON_SHUTDOWN,
);
}
fn new_server_config(
cert: CertificateDer<'static>,
key: PrivateKeyDer<'static>,
) -> Result<ServerConfig, rustls::Error> {
let mut config = tls_server_config_builder().with_single_cert(vec![cert], key)?;
config.alpn_protocols = vec![ALPN_REPAIR_PROTOCOL_ID.to_vec()];
config.key_log = Arc::new(KeyLogFile::new());
let Ok(config) = QuicServerConfig::try_from(config) else {
return Err(rustls::Error::InvalidCertificate(
CertificateError::BadSignature,
));
};
let mut config = ServerConfig::with_crypto(Arc::new(config));
config
.transport_config(Arc::new(new_transport_config()))
.migration(false);
Ok(config)
}
fn new_client_config(
cert: CertificateDer<'static>,
key: PrivateKeyDer<'static>,
) -> Result<ClientConfig, rustls::Error> {
let mut config = tls_client_config_builder().with_client_auth_cert(vec![cert], key)?;
config.enable_early_data = true;
config.alpn_protocols = vec![ALPN_REPAIR_PROTOCOL_ID.to_vec()];
let mut config = ClientConfig::new(Arc::new(QuicClientConfig::try_from(config).unwrap()));
config.transport_config(Arc::new(new_transport_config()));
Ok(config)
}
fn new_transport_config() -> TransportConfig {
let max_idle_timeout = IdleTimeout::try_from(MAX_IDLE_TIMEOUT).unwrap();
let mut config = TransportConfig::default();
config
.datagram_receive_buffer_size(Some(DATAGRAM_RECEIVE_BUFFER_SIZE))
.datagram_send_buffer_size(DATAGRAM_SEND_BUFFER_SIZE)
.initial_mtu(INITIAL_MAXIMUM_TRANSMISSION_UNIT)
.keep_alive_interval(Some(KEEP_ALIVE_INTERVAL))
.max_concurrent_bidi_streams(VarInt::from(0u8))
.max_concurrent_uni_streams(VarInt::from(0u8))
.max_idle_timeout(Some(max_idle_timeout))
.min_mtu(MINIMUM_MAXIMUM_TRANSMISSION_UNIT)
.mtu_discovery_config(None);
config
}
async fn run_server<T>(
endpoint: Endpoint,
server_name: &'static str,
sender: Sender<T>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) where
T: 'static + From<(Pubkey, SocketAddr, Bytes)> + Send,
{
let stats = Arc::<RepairQuicStats>::default();
let report_metrics_task = tokio::task::spawn(report_metrics_task(server_name, stats.clone()));
while let Some(incoming) = endpoint.accept().await {
let remote_addr: SocketAddr = incoming.remote_address();
match incoming.accept() {
Ok(connecting) => {
tokio::task::spawn(handle_connecting_task(
endpoint.clone(),
connecting,
sender.clone(),
bank_forks.clone(),
prune_cache_pending.clone(),
router.clone(),
cache.clone(),
stats.clone(),
));
}
Err(err) => {
debug!("Error while accepting incoming connection: {err:?} from {remote_addr}");
record_error(&Error::from(err), &stats);
}
}
}
report_metrics_task.abort();
}
async fn run_client<T>(
endpoint: Endpoint,
client_name: &'static str,
mut receiver: AsyncReceiver<(SocketAddr, Bytes)>,
sender: Sender<T>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) where
T: 'static + From<(Pubkey, SocketAddr, Bytes)> + Send,
{
let stats = Arc::<RepairQuicStats>::default();
let report_metrics_task = tokio::task::spawn(report_metrics_task(client_name, stats.clone()));
while let Some((remote_address, bytes)) = receiver.recv().await {
let Some(bytes) = try_route_bytes(&remote_address, bytes, &*router.read().await, &stats)
else {
continue;
};
let receiver = {
let mut router = router.write().await;
let Some(bytes) = try_route_bytes(&remote_address, bytes, &router, &stats) else {
continue;
};
let (sender, receiver) = tokio::sync::mpsc::channel(ROUTER_CHANNEL_BUFFER);
sender.try_send(bytes).unwrap();
router.insert(remote_address, sender);
receiver
};
tokio::task::spawn(make_connection_task(
endpoint.clone(),
remote_address,
sender.clone(),
receiver,
bank_forks.clone(),
prune_cache_pending.clone(),
router.clone(),
cache.clone(),
stats.clone(),
));
}
close_quic_endpoint(&endpoint);
router.write().await.clear();
report_metrics_task.abort();
}
fn try_route_bytes(
remote_address: &SocketAddr,
bytes: Bytes,
router: &HashMap<SocketAddr, AsyncSender<Bytes>>,
stats: &RepairQuicStats,
) -> Option<Bytes> {
match router.get(remote_address) {
None => Some(bytes),
Some(sender) => match sender.try_send(bytes) {
Ok(()) => None,
Err(TrySendError::Full(_)) => {
debug!("TrySendError::Full {remote_address}");
add_metric!(stats.router_try_send_error_full);
None
}
Err(TrySendError::Closed(bytes)) => Some(bytes),
},
}
}
async fn handle_connecting_task<T>(
endpoint: Endpoint,
connecting: Connecting,
sender: Sender<T>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
stats: Arc<RepairQuicStats>,
) where
T: 'static + From<(Pubkey, SocketAddr, Bytes)> + Send,
{
if let Err(err) = handle_connecting(
endpoint,
connecting,
sender,
bank_forks,
prune_cache_pending,
router,
cache,
stats.clone(),
)
.await
{
debug!("handle_connecting: {err:?}");
record_error(&err, &stats);
}
}
async fn handle_connecting<T>(
endpoint: Endpoint,
connecting: Connecting,
sender: Sender<T>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
stats: Arc<RepairQuicStats>,
) -> Result<(), Error>
where
T: 'static + From<(Pubkey, SocketAddr, Bytes)> + Send,
{
let connection = connecting.await?;
let remote_address = connection.remote_address();
let remote_pubkey = get_remote_pubkey(&connection)?;
let receiver = {
let (sender, receiver) = tokio::sync::mpsc::channel(ROUTER_CHANNEL_BUFFER);
router.write().await.insert(remote_address, sender);
receiver
};
handle_connection(
endpoint,
remote_address,
remote_pubkey,
connection,
sender,
receiver,
bank_forks,
prune_cache_pending,
router,
cache,
stats,
)
.await;
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn handle_connection<T>(
endpoint: Endpoint,
remote_address: SocketAddr,
remote_pubkey: Pubkey,
connection: Connection,
sender: Sender<T>,
receiver: AsyncReceiver<Bytes>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
stats: Arc<RepairQuicStats>,
) where
T: 'static + From<(Pubkey, SocketAddr, Bytes)> + Send,
{
cache_connection(
remote_pubkey,
connection.clone(),
bank_forks,
prune_cache_pending,
router.clone(),
cache.clone(),
)
.await;
let send_datagram_task = tokio::task::spawn(send_datagram_task(connection.clone(), receiver));
let read_datagram_task = tokio::task::spawn(read_datagram_task(
endpoint,
remote_address,
remote_pubkey,
connection.clone(),
sender,
stats.clone(),
));
match futures::future::try_join(send_datagram_task, read_datagram_task).await {
Err(err) => error!("handle_connection: {remote_pubkey}, {remote_address}, {err:?}"),
Ok(out) => {
if let (Err(ref err), _) = out {
debug!("send_datagram_task: {remote_pubkey}, {remote_address}, {err:?}");
record_error(err, &stats);
}
if let (_, Err(ref err)) = out {
debug!("read_datagram_task: {remote_pubkey}, {remote_address}, {err:?}");
record_error(err, &stats);
}
}
}
drop_connection(remote_pubkey, &connection, &cache).await;
if let Entry::Occupied(entry) = router.write().await.entry(remote_address) {
if entry.get().is_closed() {
entry.remove();
}
}
}
async fn read_datagram_task<T>(
endpoint: Endpoint,
remote_address: SocketAddr,
remote_pubkey: Pubkey,
connection: Connection,
sender: Sender<T>,
stats: Arc<RepairQuicStats>,
) -> Result<(), Error>
where
T: From<(Pubkey, SocketAddr, Bytes)>,
{
debug_assert_eq!(sender.capacity(), None);
loop {
match connection.read_datagram().await {
Ok(bytes) => {
let value = T::from((remote_pubkey, remote_address, bytes));
if let Err(err) = sender.send(value) {
close_quic_endpoint(&endpoint);
return Err(Error::from(err));
}
}
Err(err) => {
if let Some(err) = connection.close_reason() {
return Err(Error::from(err));
}
debug!("connection.read_datagram: {remote_pubkey}, {remote_address}, {err:?}");
record_error(&Error::from(err), &stats);
}
};
}
}
async fn send_datagram_task(
connection: Connection,
mut receiver: AsyncReceiver<Bytes>,
) -> Result<(), Error> {
tokio::pin! {
let connection_closed = connection.closed();
}
loop {
tokio::select! {
biased;
bytes = receiver.recv() => {
match bytes {
None => return Ok(()),
Some(bytes) => connection.send_datagram(bytes)?,
}
}
err = &mut connection_closed => return Err(Error::from(err)),
}
}
}
async fn make_connection_task<T>(
endpoint: Endpoint,
remote_address: SocketAddr,
sender: Sender<T>,
receiver: AsyncReceiver<Bytes>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
stats: Arc<RepairQuicStats>,
) where
T: 'static + From<(Pubkey, SocketAddr, Bytes)> + Send,
{
if let Err(err) = make_connection(
endpoint,
remote_address,
sender,
receiver,
bank_forks,
prune_cache_pending,
router,
cache,
stats.clone(),
)
.await
{
debug!("make_connection: {remote_address}, {err:?}");
record_error(&err, &stats);
}
}
async fn make_connection<T>(
endpoint: Endpoint,
remote_address: SocketAddr,
sender: Sender<T>,
receiver: AsyncReceiver<Bytes>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
stats: Arc<RepairQuicStats>,
) -> Result<(), Error>
where
T: 'static + From<(Pubkey, SocketAddr, Bytes)> + Send,
{
let server_name = socket_addr_to_quic_server_name(remote_address);
let connection = endpoint.connect(remote_address, &server_name)?.await?;
handle_connection(
endpoint,
connection.remote_address(),
get_remote_pubkey(&connection)?,
connection,
sender,
receiver,
bank_forks,
prune_cache_pending,
router,
cache,
stats,
)
.await;
Ok(())
}
fn get_remote_pubkey(connection: &Connection) -> Result<Pubkey, Error> {
match solana_streamer::nonblocking::quic::get_remote_pubkey(connection) {
Some(remote_pubkey) => Ok(remote_pubkey),
None => {
connection.close(
CONNECTION_CLOSE_ERROR_CODE_INVALID_IDENTITY,
CONNECTION_CLOSE_REASON_INVALID_IDENTITY,
);
Err(Error::InvalidIdentity(connection.remote_address()))
}
}
}
async fn cache_connection(
remote_pubkey: Pubkey,
connection: Connection,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
let (old, should_prune_cache) = {
let mut cache = cache.lock().await;
(
cache.insert(remote_pubkey, connection),
cache.len() >= CONNECTION_CACHE_CAPACITY.saturating_mul(2),
)
};
if let Some(old) = old {
old.close(
CONNECTION_CLOSE_ERROR_CODE_REPLACED,
CONNECTION_CLOSE_REASON_REPLACED,
);
}
if should_prune_cache && !prune_cache_pending.swap(true, Ordering::Relaxed) {
tokio::task::spawn(prune_connection_cache(
bank_forks,
prune_cache_pending,
router,
cache,
));
}
}
async fn drop_connection(
remote_pubkey: Pubkey,
connection: &Connection,
cache: &Mutex<HashMap<Pubkey, Connection>>,
) {
connection.close(
CONNECTION_CLOSE_ERROR_CODE_DROPPED,
CONNECTION_CLOSE_REASON_DROPPED,
);
if let Entry::Occupied(entry) = cache.lock().await.entry(remote_pubkey) {
if entry.get().stable_id() == connection.stable_id() {
entry.remove();
}
}
}
async fn prune_connection_cache(
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
debug_assert!(prune_cache_pending.load(Ordering::Relaxed));
let staked_nodes = {
let root_bank = bank_forks.read().unwrap().root_bank();
root_bank.current_epoch_staked_nodes()
};
{
let mut cache = cache.lock().await;
if cache.len() < CONNECTION_CACHE_CAPACITY.saturating_mul(2) {
prune_cache_pending.store(false, Ordering::Relaxed);
return;
}
let mut connections: Vec<_> = cache
.drain()
.filter(|(_, connection)| connection.close_reason().is_none())
.map(|entry @ (pubkey, _)| {
let stake = staked_nodes.get(&pubkey).copied().unwrap_or_default();
(stake, entry)
})
.collect();
connections
.select_nth_unstable_by_key(CONNECTION_CACHE_CAPACITY, |&(stake, _)| Reverse(stake));
for (_, (_, connection)) in &connections[CONNECTION_CACHE_CAPACITY..] {
connection.close(
CONNECTION_CLOSE_ERROR_CODE_PRUNED,
CONNECTION_CLOSE_REASON_PRUNED,
);
}
cache.extend(
connections
.into_iter()
.take(CONNECTION_CACHE_CAPACITY)
.map(|(_, entry)| entry),
);
prune_cache_pending.store(false, Ordering::Relaxed);
}
router.write().await.retain(|_, sender| !sender.is_closed());
}
impl RemoteRequest {
#[inline]
pub(crate) fn protocol(&self) -> Protocol {
if self.remote_pubkey.is_some() {
Protocol::QUIC
} else {
Protocol::UDP
}
}
}
impl From<(Pubkey, SocketAddr, Bytes)> for RemoteRequest {
#[inline]
fn from((pubkey, remote_address, bytes): (Pubkey, SocketAddr, Bytes)) -> Self {
Self {
remote_pubkey: Some(pubkey),
remote_address,
bytes,
}
}
}
impl RepairQuicAsyncSenders {
pub(crate) fn new_dummy() -> Self {
Self {
repair_response_quic_sender: tokio::sync::mpsc::channel(1).0,
repair_request_quic_sender: tokio::sync::mpsc::channel(1).0,
ancestor_hashes_request_quic_sender: tokio::sync::mpsc::channel(1).0,
}
}
}
impl<T> From<crossbeam_channel::SendError<T>> for Error {
fn from(_: crossbeam_channel::SendError<T>) -> Self {
Error::ChannelSendError
}
}
#[derive(Default)]
struct RepairQuicStats {
connect_error_cids_exhausted: AtomicU64,
connect_error_invalid_remote_address: AtomicU64,
connect_error_other: AtomicU64,
connection_error_application_closed: AtomicU64,
connection_error_cids_exhausted: AtomicU64,
connection_error_connection_closed: AtomicU64,
connection_error_locally_closed: AtomicU64,
connection_error_reset: AtomicU64,
connection_error_timed_out: AtomicU64,
connection_error_transport_error: AtomicU64,
connection_error_version_mismatch: AtomicU64,
invalid_identity: AtomicU64,
router_try_send_error_full: AtomicU64,
send_datagram_error_connection_lost: AtomicU64,
send_datagram_error_too_large: AtomicU64,
send_datagram_error_unsupported_by_peer: AtomicU64,
}
async fn report_metrics_task(name: &'static str, stats: Arc<RepairQuicStats>) {
loop {
tokio::time::sleep(Duration::from_secs(2)).await;
report_metrics(name, &stats);
}
}
fn record_error(err: &Error, stats: &RepairQuicStats) {
match err {
Error::ChannelSendError => (),
Error::ConnectError(ConnectError::CidsExhausted) => {
add_metric!(stats.connect_error_cids_exhausted)
}
Error::ConnectError(ConnectError::EndpointStopping) => {
add_metric!(stats.connect_error_other)
}
Error::ConnectError(ConnectError::InvalidRemoteAddress(_)) => {
add_metric!(stats.connect_error_invalid_remote_address)
}
Error::ConnectError(ConnectError::InvalidServerName(_)) => {
add_metric!(stats.connect_error_other)
}
Error::ConnectError(ConnectError::NoDefaultClientConfig) => {
add_metric!(stats.connect_error_other)
}
Error::ConnectError(ConnectError::UnsupportedVersion) => {
add_metric!(stats.connect_error_other)
}
Error::ConnectionError(ConnectionError::CidsExhausted) => {
add_metric!(stats.connection_error_cids_exhausted)
}
Error::ConnectionError(ConnectionError::VersionMismatch) => {
add_metric!(stats.connection_error_version_mismatch)
}
Error::ConnectionError(ConnectionError::TransportError(_)) => {
add_metric!(stats.connection_error_transport_error)
}
Error::ConnectionError(ConnectionError::ConnectionClosed(_)) => {
add_metric!(stats.connection_error_connection_closed)
}
Error::ConnectionError(ConnectionError::ApplicationClosed(_)) => {
add_metric!(stats.connection_error_application_closed)
}
Error::ConnectionError(ConnectionError::Reset) => add_metric!(stats.connection_error_reset),
Error::ConnectionError(ConnectionError::TimedOut) => {
add_metric!(stats.connection_error_timed_out)
}
Error::ConnectionError(ConnectionError::LocallyClosed) => {
add_metric!(stats.connection_error_locally_closed)
}
Error::InvalidIdentity(_) => add_metric!(stats.invalid_identity),
Error::IoError(_) => (),
Error::SendDatagramError(SendDatagramError::UnsupportedByPeer) => {
add_metric!(stats.send_datagram_error_unsupported_by_peer)
}
Error::SendDatagramError(SendDatagramError::Disabled) => (),
Error::SendDatagramError(SendDatagramError::TooLarge) => {
add_metric!(stats.send_datagram_error_too_large)
}
Error::SendDatagramError(SendDatagramError::ConnectionLost(_)) => {
add_metric!(stats.send_datagram_error_connection_lost)
}
Error::TlsError(_) => (),
}
}
fn report_metrics(name: &'static str, stats: &RepairQuicStats) {
macro_rules! reset_metric {
($metric: expr) => {
$metric.swap(0, Ordering::Relaxed)
};
}
datapoint_info!(
name,
(
"connect_error_cids_exhausted",
reset_metric!(stats.connect_error_cids_exhausted),
i64
),
(
"connect_error_invalid_remote_address",
reset_metric!(stats.connect_error_invalid_remote_address),
i64
),
(
"connect_error_other",
reset_metric!(stats.connect_error_other),
i64
),
(
"connection_error_application_closed",
reset_metric!(stats.connection_error_application_closed),
i64
),
(
"connection_error_cids_exhausted",
reset_metric!(stats.connection_error_cids_exhausted),
i64
),
(
"connection_error_connection_closed",
reset_metric!(stats.connection_error_connection_closed),
i64
),
(
"connection_error_locally_closed",
reset_metric!(stats.connection_error_locally_closed),
i64
),
(
"connection_error_reset",
reset_metric!(stats.connection_error_reset),
i64
),
(
"connection_error_timed_out",
reset_metric!(stats.connection_error_timed_out),
i64
),
(
"connection_error_transport_error",
reset_metric!(stats.connection_error_transport_error),
i64
),
(
"connection_error_version_mismatch",
reset_metric!(stats.connection_error_version_mismatch),
i64
),
(
"invalid_identity",
reset_metric!(stats.invalid_identity),
i64
),
(
"router_try_send_error_full",
reset_metric!(stats.router_try_send_error_full),
i64
),
(
"send_datagram_error_connection_lost",
reset_metric!(stats.send_datagram_error_connection_lost),
i64
),
(
"send_datagram_error_too_large",
reset_metric!(stats.send_datagram_error_too_large),
i64
),
(
"send_datagram_error_unsupported_by_peer",
reset_metric!(stats.send_datagram_error_unsupported_by_peer),
i64
),
);
}
#[cfg(test)]
mod tests {
use {
super::*,
itertools::{izip, multiunzip},
solana_ledger::genesis_utils::{GenesisConfigInfo, create_genesis_config},
solana_net_utils::sockets::{bind_to, localhost_port_range_for_tests},
solana_runtime::bank::Bank,
solana_signer::Signer,
std::{
iter::repeat_with,
net::{IpAddr, Ipv4Addr},
time::Duration,
},
};
#[test]
fn test_quic_endpoint() {
const NUM_ENDPOINTS: usize = 3;
const RECV_TIMEOUT: Duration = Duration::from_secs(60);
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(8)
.enable_all()
.build()
.unwrap();
let keypairs: Vec<Keypair> = repeat_with(Keypair::new).take(NUM_ENDPOINTS).collect();
let port_range = localhost_port_range_for_tests();
let ip_addr = IpAddr::V4(Ipv4Addr::LOCALHOST);
let sockets: Vec<UdpSocket> = (port_range.0..port_range.1)
.map(|port| bind_to(ip_addr, port).unwrap())
.take(NUM_ENDPOINTS)
.collect();
let addresses: Vec<SocketAddr> = sockets
.iter()
.map(UdpSocket::local_addr)
.collect::<Result<_, _>>()
.unwrap();
let (senders, receivers): (Vec<_>, Vec<_>) =
repeat_with(crossbeam_channel::unbounded::<(Pubkey, SocketAddr, Bytes)>)
.take(NUM_ENDPOINTS)
.unzip();
let bank_forks = {
let GenesisConfigInfo { genesis_config, .. } =
create_genesis_config( 100_000);
let bank = Bank::new_for_tests(&genesis_config);
BankForks::new_rw_arc(bank)
};
let (endpoints, senders, tasks): (Vec<_>, Vec<_>, Vec<_>) =
multiunzip(keypairs.iter().zip(sockets).zip(senders).map(
|((keypair, socket), sender)| {
new_quic_endpoint(
runtime.handle(),
"test_quic_client",
"test_quic_server",
keypair,
socket,
sender,
bank_forks.clone(),
)
.unwrap()
},
));
for (i, (keypair, &address, sender)) in izip!(&keypairs, &addresses, &senders).enumerate() {
for (j, &address) in addresses.iter().enumerate() {
if i != j {
let bytes = Bytes::from(format!("{i}=>{j}"));
sender.blocking_send((address, bytes)).unwrap();
}
}
for (j, receiver) in receivers.iter().enumerate() {
if i != j {
let bytes = Bytes::from(format!("{i}=>{j}"));
let entry = (keypair.pubkey(), address, bytes);
assert_eq!(receiver.recv_timeout(RECV_TIMEOUT).unwrap(), entry);
}
}
}
drop(senders);
for endpoint in endpoints {
close_quic_endpoint(&endpoint);
}
for task in tasks {
runtime.block_on(task).unwrap();
}
}
}