1#[allow(deprecated)]
2use crate::quic::QuicServerParams;
3use {
4 crate::{
5 nonblocking::{
6 connection_rate_limiter::{ConnectionRateLimiter, TotalConnectionRateLimiter},
7 qos::{ConnectionContext, QosController},
8 stream_throttle::ConnectionStreamCounter,
9 swqos::{SwQos, SwQosConfig},
10 },
11 quic::{configure_server, QuicServerError, QuicStreamerConfig, StreamerStats},
12 streamer::StakedNodes,
13 },
14 bytes::{BufMut, Bytes, BytesMut},
15 crossbeam_channel::{bounded, Receiver, Sender, TryRecvError, TrySendError},
16 futures::{stream::FuturesUnordered, Future, StreamExt as _},
17 indexmap::map::{Entry, IndexMap},
18 quinn::{Accept, Connecting, Connection, Endpoint, EndpointConfig, TokioRuntime},
19 rand::{thread_rng, Rng},
20 smallvec::SmallVec,
21 solana_keypair::Keypair,
22 solana_measure::measure::Measure,
23 solana_packet::{Meta, PACKET_DATA_SIZE},
24 solana_perf::packet::{BytesPacket, BytesPacketBatch, PacketBatch, PACKETS_PER_BATCH},
25 solana_pubkey::Pubkey,
26 solana_signature::Signature,
27 solana_tls_utils::get_pubkey_from_tls_certificate,
28 solana_transaction_metrics_tracker::signature_if_should_track_packet,
29 std::{
30 array,
31 fmt,
32 iter::repeat_with,
33 net::{IpAddr, SocketAddr, UdpSocket},
34 pin::Pin,
35 sync::{
37 atomic::{AtomicBool, AtomicU64, Ordering},
38 Arc, RwLock,
39 },
40 task::Poll,
41 time::{Duration, Instant},
42 },
43 tokio::{
44 select,
53 task::{self, JoinHandle},
54 time::{sleep, timeout},
55 },
56 tokio_util::{sync::CancellationToken, task::TaskTracker},
57};
58
59pub const DEFAULT_WAIT_FOR_CHUNK_TIMEOUT: Duration = Duration::from_secs(2);
60
61pub const ALPN_TPU_PROTOCOL_ID: &[u8] = b"solana-tpu";
62
63const CONNECTION_CLOSE_CODE_DROPPED_ENTRY: u32 = 1;
64const CONNECTION_CLOSE_REASON_DROPPED_ENTRY: &[u8] = b"dropped";
65
66pub(crate) const CONNECTION_CLOSE_CODE_DISALLOWED: u32 = 2;
67pub(crate) const CONNECTION_CLOSE_REASON_DISALLOWED: &[u8] = b"disallowed";
68
69pub(crate) const CONNECTION_CLOSE_CODE_EXCEED_MAX_STREAM_COUNT: u32 = 3;
70pub(crate) const CONNECTION_CLOSE_REASON_EXCEED_MAX_STREAM_COUNT: &[u8] =
71 b"exceed_max_stream_count";
72
73const CONNECTION_CLOSE_CODE_TOO_MANY: u32 = 4;
74const CONNECTION_CLOSE_REASON_TOO_MANY: &[u8] = b"too_many";
75
76const CONNECTION_CLOSE_CODE_INVALID_STREAM: u32 = 5;
77const CONNECTION_CLOSE_REASON_INVALID_STREAM: &[u8] = b"invalid_stream";
78
79const TOTAL_CONNECTIONS_PER_SECOND: u64 = 2500;
83
84const CONNECTION_RATE_LIMITER_CLEANUP_SIZE_THRESHOLD: usize = 100_000;
88
89const QUIC_CONNECTION_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(2);
92
93pub(crate) const MAX_RTT: Duration = Duration::from_millis(320);
96pub(crate) const MIN_RTT: Duration = Duration::from_millis(2);
99
100#[derive(Clone)]
106struct PacketAccumulator {
107 pub meta: Meta,
108 pub chunks: SmallVec<[Bytes; 2]>,
109 pub start_time: Instant,
110}
111
112impl PacketAccumulator {
113 fn new(meta: Meta) -> Self {
114 Self {
115 meta,
116 chunks: SmallVec::default(),
117 start_time: Instant::now(),
118 }
119 }
120}
121
122#[derive(Copy, Clone, Debug)]
123pub enum ConnectionPeerType {
124 Unstaked,
125 Staked(u64),
126}
127
128impl ConnectionPeerType {
129 pub(crate) fn is_staked(&self) -> bool {
130 matches!(self, ConnectionPeerType::Staked(_))
131 }
132}
133
134pub struct SpawnNonBlockingServerResult {
135 pub endpoints: Vec<Endpoint>,
136 pub stats: Arc<StreamerStats>,
137 pub thread: JoinHandle<()>,
138 pub max_concurrent_connections: usize,
139}
140
141#[deprecated(since = "3.0.0", note = "Use spawn_server instead")]
142#[allow(deprecated)]
143pub fn spawn_server_multi(
144 name: &'static str,
145 sockets: impl IntoIterator<Item = UdpSocket>,
146 keypair: &Keypair,
147 packet_sender: Sender<PacketBatch>,
148 exit: Arc<AtomicBool>,
149 staked_nodes: Arc<RwLock<StakedNodes>>,
150 quic_server_params: QuicServerParams,
151) -> Result<SpawnNonBlockingServerResult, QuicServerError> {
152 #[allow(deprecated)]
153 spawn_server(
154 name,
155 sockets,
156 keypair,
157 packet_sender,
158 exit,
159 staked_nodes,
160 quic_server_params,
161 )
162}
163
164#[deprecated(since = "3.1.0", note = "Use spawn_server_with_cancel instead")]
165#[allow(deprecated)]
166pub fn spawn_server(
167 name: &'static str,
168 sockets: impl IntoIterator<Item = UdpSocket>,
169 keypair: &Keypair,
170 packet_sender: Sender<PacketBatch>,
171 exit: Arc<AtomicBool>,
172 staked_nodes: Arc<RwLock<StakedNodes>>,
173 quic_server_params: QuicServerParams,
174) -> Result<SpawnNonBlockingServerResult, QuicServerError> {
175 let cancel = CancellationToken::new();
176 tokio::spawn({
177 let cancel = cancel.clone();
178 async move {
179 loop {
180 if exit.load(Ordering::Relaxed) {
181 cancel.cancel();
182 break;
183 }
184 sleep(Duration::from_millis(100)).await;
185 }
186 }
187 });
188 let quic_server_config = QuicStreamerConfig::from(&quic_server_params);
189 let qos_config = SwQosConfig {
190 max_streams_per_ms: quic_server_params.max_streams_per_ms,
191 };
192 spawn_server_with_cancel(
193 name,
194 sockets,
195 keypair,
196 packet_sender,
197 staked_nodes,
198 quic_server_config,
199 qos_config,
200 cancel,
201 )
202}
203
204pub fn spawn_server_with_cancel(
206 name: &'static str,
207 sockets: impl IntoIterator<Item = UdpSocket>,
208 keypair: &Keypair,
209 packet_sender: Sender<PacketBatch>,
210 staked_nodes: Arc<RwLock<StakedNodes>>,
211 quic_server_params: QuicStreamerConfig,
212 qos_config: SwQosConfig,
213 cancel: CancellationToken,
214) -> Result<SpawnNonBlockingServerResult, QuicServerError>
215where
216{
217 let stats = Arc::<StreamerStats>::default();
218
219 let swqos = Arc::new(SwQos::new(
220 qos_config,
221 quic_server_params.max_staked_connections,
222 quic_server_params.max_unstaked_connections,
223 quic_server_params.max_connections_per_peer,
224 stats.clone(),
225 staked_nodes,
226 cancel.clone(),
227 ));
228
229 spawn_server_with_cancel_and_qos(
230 name,
231 stats,
232 sockets,
233 keypair,
234 packet_sender,
235 quic_server_params,
236 swqos,
237 cancel,
238 )
239}
240
241pub(crate) fn spawn_server_with_cancel_and_qos<Q, C>(
243 name: &'static str,
244 stats: Arc<StreamerStats>,
245 sockets: impl IntoIterator<Item = UdpSocket>,
246 keypair: &Keypair,
247 packet_sender: Sender<PacketBatch>,
248 quic_server_params: QuicStreamerConfig,
249 qos: Arc<Q>,
250 cancel: CancellationToken,
251) -> Result<SpawnNonBlockingServerResult, QuicServerError>
252where
253 Q: QosController<C> + Send + Sync + 'static,
254 C: ConnectionContext + Send + Sync + 'static,
255{
256 let sockets: Vec<_> = sockets.into_iter().collect();
257 info!("Start {name} quic server on {sockets:?}");
258 let (config, _) = configure_server(keypair)?;
259
260 let endpoints = sockets
261 .into_iter()
262 .map(|sock| {
263 Endpoint::new(
264 EndpointConfig::default(),
265 Some(config.clone()),
266 sock,
267 Arc::new(TokioRuntime),
268 )
269 .map_err(QuicServerError::EndpointFailed)
270 })
271 .collect::<Result<Vec<_>, _>>()?;
272 let (packet_batch_sender, packet_batch_receiver) =
273 bounded(quic_server_params.accumulator_channel_size);
274 task::spawn_blocking({
275 let cancel = cancel.clone();
276 let stats = stats.clone();
277 move || {
278 run_packet_batch_sender(packet_sender, packet_batch_receiver, stats, cancel);
279 }
280 });
281
282 let max_concurrent_connections = quic_server_params.max_concurrent_connections();
283 let handle = tokio::spawn({
284 let endpoints = endpoints.clone();
285 let stats = stats.clone();
286 async move {
287 let tasks = run_server(
288 name,
289 endpoints.clone(),
290 packet_batch_sender,
291 stats.clone(),
292 quic_server_params,
293 cancel,
294 qos,
295 )
296 .await;
297 tasks.close();
298 tasks.wait().await;
299 }
300 });
301
302 Ok(SpawnNonBlockingServerResult {
303 endpoints,
304 stats,
305 thread: handle,
306 max_concurrent_connections,
307 })
308}
309
310pub struct ClientConnectionTracker {
315 pub(crate) stats: Arc<StreamerStats>,
316}
317
318impl fmt::Debug for ClientConnectionTracker {
320 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
321 f.debug_struct("StreamerClientConnection")
322 .field(
323 "open_connections:",
324 &self.stats.open_connections.load(Ordering::Relaxed),
325 )
326 .finish()
327 }
328}
329
330impl Drop for ClientConnectionTracker {
331 fn drop(&mut self) {
333 self.stats.open_connections.fetch_sub(1, Ordering::Relaxed);
334 }
335}
336
337impl ClientConnectionTracker {
338 fn new(stats: Arc<StreamerStats>, max_concurrent_connections: usize) -> Result<Self, ()> {
341 let open_connections = stats.open_connections.fetch_add(1, Ordering::Relaxed);
342 if open_connections >= max_concurrent_connections {
343 stats.open_connections.fetch_sub(1, Ordering::Relaxed);
344 debug!(
345 "There are too many concurrent connections opened already: open: \
346 {open_connections}, max: {max_concurrent_connections}"
347 );
348 return Err(());
349 }
350
351 Ok(Self { stats })
352 }
353}
354
355#[allow(clippy::too_many_arguments)]
356async fn run_server<Q, C>(
357 name: &'static str,
358 endpoints: Vec<Endpoint>,
359 packet_batch_sender: Sender<PacketAccumulator>,
360 stats: Arc<StreamerStats>,
361 quic_server_params: QuicStreamerConfig,
362 cancel: CancellationToken,
363 qos: Arc<Q>,
364) -> TaskTracker
365where
366 Q: QosController<C> + Send + Sync + 'static,
367 C: ConnectionContext + Send + Sync + 'static,
368{
369 let quic_server_params = Arc::new(quic_server_params);
370 let rate_limiter = Arc::new(ConnectionRateLimiter::new(
371 quic_server_params.max_connections_per_ipaddr_per_min,
372 ));
373 let overall_connection_rate_limiter = Arc::new(TotalConnectionRateLimiter::new(
374 TOTAL_CONNECTIONS_PER_SECOND,
375 ));
376
377 const WAIT_FOR_CONNECTION_TIMEOUT: Duration = Duration::from_secs(1);
378 debug!("spawn quic server");
379 let mut last_datapoint = Instant::now();
380 stats
381 .quic_endpoints_count
382 .store(endpoints.len(), Ordering::Relaxed);
383
384 let mut accepts = endpoints
385 .iter()
386 .enumerate()
387 .map(|(i, incoming)| {
388 Box::pin(EndpointAccept {
389 accept: incoming.accept(),
390 endpoint: i,
391 })
392 })
393 .collect::<FuturesUnordered<_>>();
394
395 let tasks = TaskTracker::new();
396 loop {
397 let timeout_connection = select! {
398 ready = accepts.next() => {
399 if let Some((connecting, i)) = ready {
400 accepts.push(
401 Box::pin(EndpointAccept {
402 accept: endpoints[i].accept(),
403 endpoint: i,
404 }
405 ));
406 Ok(connecting)
407 } else {
408 continue
410 }
411 }
412 _ = tokio::time::sleep(WAIT_FOR_CONNECTION_TIMEOUT) => {
413 Err(())
414 }
415 _ = cancel.cancelled() => break,
416 };
417
418 if last_datapoint.elapsed().as_secs() >= 5 {
419 stats.report(name);
420 last_datapoint = Instant::now();
421 }
422
423 if let Ok(Some(incoming)) = timeout_connection {
424 stats
425 .total_incoming_connection_attempts
426 .fetch_add(1, Ordering::Relaxed);
427
428 if rate_limiter.len() > CONNECTION_RATE_LIMITER_CLEANUP_SIZE_THRESHOLD {
430 rate_limiter.retain_recent();
431 }
432 stats
433 .connection_rate_limiter_length
434 .store(rate_limiter.len(), Ordering::Relaxed);
435
436 let Ok(client_connection_tracker) = ClientConnectionTracker::new(
437 stats.clone(),
438 quic_server_params.max_concurrent_connections(),
439 ) else {
440 stats
441 .refused_connections_too_many_open_connections
442 .fetch_add(1, Ordering::Relaxed);
443 incoming.refuse();
444 continue;
445 };
446
447 stats
448 .outstanding_incoming_connection_attempts
449 .fetch_add(1, Ordering::Relaxed);
450 let connecting = incoming.accept();
451 match connecting {
452 Ok(connecting) => {
453 let rate_limiter = rate_limiter.clone();
454 let overall_connection_rate_limiter = overall_connection_rate_limiter.clone();
455 tasks.spawn(setup_connection(
456 connecting,
457 rate_limiter,
458 overall_connection_rate_limiter,
459 client_connection_tracker,
460 packet_batch_sender.clone(),
461 stats.clone(),
462 quic_server_params.clone(),
463 qos.clone(),
464 tasks.clone(),
465 ));
466 }
467 Err(err) => {
468 stats
469 .outstanding_incoming_connection_attempts
470 .fetch_sub(1, Ordering::Relaxed);
471 debug!("Incoming::accept(): error {err:?}");
472 }
473 }
474 } else {
475 debug!("accept(): Timed out waiting for connection");
476 }
477 }
478 tasks
479}
480
481pub fn get_remote_pubkey(connection: &Connection) -> Option<Pubkey> {
482 connection
484 .peer_identity()?
485 .downcast::<Vec<rustls::pki_types::CertificateDer>>()
486 .ok()
487 .filter(|certs| certs.len() == 1)?
488 .first()
489 .and_then(get_pubkey_from_tls_certificate)
490}
491
492pub fn get_connection_stake(
493 connection: &Connection,
494 staked_nodes: &RwLock<StakedNodes>,
495) -> Option<(Pubkey, u64, u64)> {
496 let pubkey = get_remote_pubkey(connection)?;
497 debug!("Peer public key is {pubkey:?}");
498 let staked_nodes = staked_nodes.read().unwrap();
499 Some((
500 pubkey,
501 staked_nodes.get_node_stake(&pubkey)?,
502 staked_nodes.total_stake(),
503 ))
504}
505
506#[derive(Debug)]
507pub(crate) enum ConnectionHandlerError {
508 ConnectionAddError,
509 MaxStreamError,
510}
511
512pub(crate) fn update_open_connections_stat(
513 stats: &StreamerStats,
514 connection_table: &ConnectionTable,
515) {
516 if connection_table.is_staked() {
517 stats
518 .open_staked_connections
519 .store(connection_table.table_size(), Ordering::Relaxed);
520 } else {
521 stats
522 .open_unstaked_connections
523 .store(connection_table.table_size(), Ordering::Relaxed);
524 }
525}
526
527#[allow(clippy::too_many_arguments)]
528async fn setup_connection<Q, C>(
529 connecting: Connecting,
530 rate_limiter: Arc<ConnectionRateLimiter>,
531 overall_connection_rate_limiter: Arc<TotalConnectionRateLimiter>,
532 client_connection_tracker: ClientConnectionTracker,
533 packet_sender: Sender<PacketAccumulator>,
534 stats: Arc<StreamerStats>,
535 server_params: Arc<QuicStreamerConfig>,
536 qos: Arc<Q>,
537 tasks: TaskTracker,
538) where
539 Q: QosController<C> + Send + Sync + 'static,
540 C: ConnectionContext + Send + Sync + 'static,
541{
542 let from = connecting.remote_address();
543 let res = timeout(QUIC_CONNECTION_HANDSHAKE_TIMEOUT, connecting).await;
544 stats
545 .outstanding_incoming_connection_attempts
546 .fetch_sub(1, Ordering::Relaxed);
547 if let Ok(connecting_result) = res {
548 match connecting_result {
549 Ok(new_connection) => {
550 debug!("Got a connection {from:?}");
551 if !rate_limiter.is_allowed(&from.ip()) {
552 debug!("Reject connection from {from:?} -- rate limiting exceeded");
553 stats
554 .connection_rate_limited_per_ipaddr
555 .fetch_add(1, Ordering::Relaxed);
556 new_connection.close(
557 CONNECTION_CLOSE_CODE_DISALLOWED.into(),
558 CONNECTION_CLOSE_REASON_DISALLOWED,
559 );
560 return;
561 }
562
563 if !overall_connection_rate_limiter.is_allowed() {
564 debug!(
565 "Reject connection from {:?} -- total rate limiting exceeded",
566 from.ip()
567 );
568 stats
569 .connection_rate_limited_across_all
570 .fetch_add(1, Ordering::Relaxed);
571 new_connection.close(
572 CONNECTION_CLOSE_CODE_DISALLOWED.into(),
573 CONNECTION_CLOSE_REASON_DISALLOWED,
574 );
575 return;
576 }
577
578 stats.total_new_connections.fetch_add(1, Ordering::Relaxed);
579
580 let mut conn_context = qos.build_connection_context(&new_connection);
581 if let Some(cancel_connection) = qos
582 .try_add_connection(
583 client_connection_tracker,
584 &new_connection,
585 &mut conn_context,
586 )
587 .await
588 {
589 tasks.spawn(handle_connection(
590 packet_sender.clone(),
591 from,
592 new_connection,
593 stats,
594 server_params.wait_for_chunk_timeout,
595 conn_context.clone(),
596 qos,
597 cancel_connection,
598 ));
599 }
600 }
601 Err(e) => {
602 handle_connection_error(e, &stats, from);
603 }
604 }
605 } else {
606 stats
607 .connection_setup_timeout
608 .fetch_add(1, Ordering::Relaxed);
609 }
610}
611
612fn handle_connection_error(e: quinn::ConnectionError, stats: &StreamerStats, from: SocketAddr) {
613 debug!("error: {e:?} from: {from:?}");
614 stats.connection_setup_error.fetch_add(1, Ordering::Relaxed);
615 match e {
616 quinn::ConnectionError::TimedOut => {
617 stats
618 .connection_setup_error_timed_out
619 .fetch_add(1, Ordering::Relaxed);
620 }
621 quinn::ConnectionError::ConnectionClosed(_) => {
622 stats
623 .connection_setup_error_closed
624 .fetch_add(1, Ordering::Relaxed);
625 }
626 quinn::ConnectionError::TransportError(_) => {
627 stats
628 .connection_setup_error_transport
629 .fetch_add(1, Ordering::Relaxed);
630 }
631 quinn::ConnectionError::ApplicationClosed(_) => {
632 stats
633 .connection_setup_error_app_closed
634 .fetch_add(1, Ordering::Relaxed);
635 }
636 quinn::ConnectionError::Reset => {
637 stats
638 .connection_setup_error_reset
639 .fetch_add(1, Ordering::Relaxed);
640 }
641 quinn::ConnectionError::LocallyClosed => {
642 stats
643 .connection_setup_error_locally_closed
644 .fetch_add(1, Ordering::Relaxed);
645 }
646 _ => {}
647 }
648}
649
650fn run_packet_batch_sender(
653 packet_sender: Sender<PacketBatch>,
654 packet_receiver: Receiver<PacketAccumulator>,
655 stats: Arc<StreamerStats>,
656 cancel: CancellationToken,
657) {
658 let mut channel_disconnected = false;
659 trace!("enter packet_batch_sender");
660 loop {
661 let mut packet_perf_measure: Vec<([u8; 64], Instant)> = Vec::default();
662 let mut packet_batch = BytesPacketBatch::with_capacity(PACKETS_PER_BATCH);
663 let mut total_bytes: usize = 0;
664
665 stats
666 .total_packet_batches_allocated
667 .fetch_add(1, Ordering::Relaxed);
668 stats
669 .total_packets_allocated
670 .fetch_add(PACKETS_PER_BATCH, Ordering::Relaxed);
671
672 loop {
673 if cancel.is_cancelled() || channel_disconnected {
674 return;
675 }
676 if !packet_batch.is_empty() {
677 let len = packet_batch.len();
678 track_streamer_fetch_packet_performance(&packet_perf_measure, &stats);
679
680 if let Err(e) = packet_sender.try_send(packet_batch.into()) {
681 stats
682 .total_packet_batch_send_err
683 .fetch_add(1, Ordering::Relaxed);
684 trace!("Send error: {e}");
685
686 if matches!(e, TrySendError::Disconnected(_)) {
688 cancel.cancel();
689 return;
690 }
691 } else {
692 stats
693 .total_packet_batches_sent
694 .fetch_add(1, Ordering::Relaxed);
695
696 stats
697 .total_packets_sent_to_consumer
698 .fetch_add(len, Ordering::Relaxed);
699
700 stats
701 .total_bytes_sent_to_consumer
702 .fetch_add(total_bytes, Ordering::Relaxed);
703
704 trace!("Sent {len} packet batch");
705 }
706 break;
707 }
708
709 let mut first = true;
718 let mut recv = || {
719 if first {
720 first = false;
721 packet_receiver.recv().map_err(|_| {
723 channel_disconnected = true;
724 TryRecvError::Disconnected
725 })
726 } else {
727 packet_receiver.try_recv()
728 }
729 };
730 while let Ok(mut packet_accumulator) = recv() {
731 let num_chunks = packet_accumulator.chunks.len();
738 let mut packet = if packet_accumulator.chunks.len() == 1 {
739 BytesPacket::new(
740 packet_accumulator.chunks.pop().expect("expected one chunk"),
741 packet_accumulator.meta,
742 )
743 } else {
744 let size: usize = packet_accumulator.chunks.iter().map(Bytes::len).sum();
745 let mut buf = BytesMut::with_capacity(size);
746 for chunk in packet_accumulator.chunks {
747 buf.put_slice(&chunk);
748 }
749 BytesPacket::new(buf.freeze(), packet_accumulator.meta)
750 };
751
752 total_bytes += packet.meta().size;
753
754 if let Some(signature) = signature_if_should_track_packet(&packet).ok().flatten() {
755 packet_perf_measure.push((*signature, packet_accumulator.start_time));
756 packet.meta_mut().set_track_performance(true);
758 }
759 packet_batch.push(packet);
760 stats
761 .total_chunks_processed_by_batcher
762 .fetch_add(num_chunks, Ordering::Relaxed);
763
764 if packet_batch.len() >= PACKETS_PER_BATCH {
766 break;
767 }
768 }
769 }
770 }
771}
772
773fn track_streamer_fetch_packet_performance(
774 packet_perf_measure: &[([u8; 64], Instant)],
775 stats: &StreamerStats,
776) {
777 if packet_perf_measure.is_empty() {
778 return;
779 }
780 let mut measure = Measure::start("track_perf");
781 let mut process_sampled_packets_us_hist = stats.process_sampled_packets_us_hist.lock().unwrap();
782
783 let now = Instant::now();
784 for (signature, start_time) in packet_perf_measure {
785 let duration = now.duration_since(*start_time);
786 debug!(
787 "QUIC streamer fetch stage took {duration:?} for transaction {:?}",
788 Signature::from(*signature)
789 );
790 process_sampled_packets_us_hist
791 .increment(duration.as_micros() as u64)
792 .unwrap();
793 }
794
795 drop(process_sampled_packets_us_hist);
796 measure.stop();
797 stats
798 .perf_track_overhead_us
799 .fetch_add(measure.as_us(), Ordering::Relaxed);
800}
801
802async fn handle_connection<Q, C>(
803 packet_sender: Sender<PacketAccumulator>,
804 remote_address: SocketAddr,
805 connection: Connection,
806 stats: Arc<StreamerStats>,
807 wait_for_chunk_timeout: Duration,
808 context: C,
809 qos: Arc<Q>,
810 cancel: CancellationToken,
811) where
812 Q: QosController<C> + Send + Sync + 'static,
813 C: ConnectionContext + Send + Sync + 'static,
814{
815 let peer_type = context.peer_type();
816 debug!(
817 "quic new connection {} streams: {} connections: {}",
818 remote_address,
819 stats.active_streams.load(Ordering::Relaxed),
820 stats.total_connections.load(Ordering::Relaxed),
821 );
822 stats.total_connections.fetch_add(1, Ordering::Relaxed);
823
824 'conn: loop {
825 let mut stream = select! {
828 stream = connection.accept_uni() => match stream {
829 Ok(stream) => stream,
830 Err(e) => {
831 debug!("stream error: {e:?}");
832 break;
833 }
834 },
835 _ = cancel.cancelled() => break,
836 };
837
838 qos.on_new_stream(&context).await;
839 qos.on_stream_accepted(&context);
840 stats.active_streams.fetch_add(1, Ordering::Relaxed);
841 stats.total_new_streams.fetch_add(1, Ordering::Relaxed);
842
843 let mut meta = Meta::default();
844 meta.set_socket_addr(&remote_address);
845 meta.set_from_staked_node(matches!(peer_type, ConnectionPeerType::Staked(_)));
846 let mut accum = PacketAccumulator::new(meta);
847
848 let mut chunks: [Bytes; 4] = array::from_fn(|_| Bytes::new());
857
858 loop {
859 let n_chunks = match tokio::select! {
863 chunk = tokio::time::timeout(
864 wait_for_chunk_timeout,
865 stream.read_chunks(&mut chunks)) => chunk,
866
867 _ = cancel.cancelled() => break,
869 } {
870 Ok(Ok(chunk)) => chunk.unwrap_or(0),
872 Ok(Err(e)) => {
874 debug!("Received stream error: {e:?}");
875 stats
876 .total_stream_read_errors
877 .fetch_add(1, Ordering::Relaxed);
878 break;
879 }
880 Err(_) => {
882 debug!("Timeout in receiving on stream");
883 stats
884 .total_stream_read_timeouts
885 .fetch_add(1, Ordering::Relaxed);
886 break;
887 }
888 };
889
890 match handle_chunks(
891 chunks.iter().take(n_chunks).cloned(),
893 &mut accum,
894 &packet_sender,
895 &stats,
896 peer_type,
897 ) {
898 Ok(StreamState::Finished) => {
900 qos.on_stream_finished(&context);
901 break;
902 }
903 Ok(StreamState::Receiving) => {}
905 Err(_) => {
906 connection.close(
908 CONNECTION_CLOSE_CODE_INVALID_STREAM.into(),
909 CONNECTION_CLOSE_REASON_INVALID_STREAM,
910 );
911 stats.active_streams.fetch_sub(1, Ordering::Relaxed);
912 qos.on_stream_error(&context);
913 break 'conn;
914 }
915 }
916 }
917
918 stats.active_streams.fetch_sub(1, Ordering::Relaxed);
919 qos.on_stream_closed(&context);
920 }
921
922 let removed_connection_count = qos.remove_connection(&context, connection).await;
923 if removed_connection_count > 0 {
924 stats
925 .connection_removed
926 .fetch_add(removed_connection_count, Ordering::Relaxed);
927 } else {
928 stats
929 .connection_remove_failed
930 .fetch_add(1, Ordering::Relaxed);
931 }
932 stats.total_connections.fetch_sub(1, Ordering::Relaxed);
933}
934
935enum StreamState {
936 Receiving,
938 Finished,
940}
941
942fn handle_chunks(
947 chunks: impl ExactSizeIterator<Item = Bytes>,
948 accum: &mut PacketAccumulator,
949 packet_sender: &Sender<PacketAccumulator>,
950 stats: &StreamerStats,
951 peer_type: ConnectionPeerType,
952) -> Result<StreamState, ()> {
953 let n_chunks = chunks.len();
954 for chunk in chunks {
955 accum.meta.size += chunk.len();
956 if accum.meta.size > PACKET_DATA_SIZE {
957 stats.invalid_stream_size.fetch_add(1, Ordering::Relaxed);
961 debug!("invalid stream size {}", accum.meta.size);
962 return Err(());
963 }
964 accum.chunks.push(chunk);
965 if peer_type.is_staked() {
966 stats
967 .total_staked_chunks_received
968 .fetch_add(1, Ordering::Relaxed);
969 } else {
970 stats
971 .total_unstaked_chunks_received
972 .fetch_add(1, Ordering::Relaxed);
973 }
974 }
975
976 if n_chunks != 0 {
978 return Ok(StreamState::Receiving);
979 }
980
981 if accum.chunks.is_empty() {
982 debug!("stream is empty");
983 stats
984 .total_packet_batches_none
985 .fetch_add(1, Ordering::Relaxed);
986 return Err(());
987 }
988
989 let bytes_sent = accum.meta.size;
991 let chunks_sent = accum.chunks.len();
992
993 if let Err(err) = packet_sender.try_send(accum.clone()) {
994 stats
995 .total_handle_chunk_to_packet_batcher_send_err
996 .fetch_add(1, Ordering::Relaxed);
997 match err {
998 TrySendError::Full(_) => {
999 stats
1000 .total_handle_chunk_to_packet_batcher_send_full_err
1001 .fetch_add(1, Ordering::Relaxed);
1002 }
1003 TrySendError::Disconnected(_) => {
1004 stats
1005 .total_handle_chunk_to_packet_batcher_send_disconnected_err
1006 .fetch_add(1, Ordering::Relaxed);
1007 }
1008 }
1009 trace!("packet batch send error {err:?}");
1010 } else {
1011 stats
1012 .total_packets_sent_for_batching
1013 .fetch_add(1, Ordering::Relaxed);
1014 stats
1015 .total_bytes_sent_for_batching
1016 .fetch_add(bytes_sent, Ordering::Relaxed);
1017 stats
1018 .total_chunks_sent_for_batching
1019 .fetch_add(chunks_sent, Ordering::Relaxed);
1020
1021 match peer_type {
1022 ConnectionPeerType::Unstaked => {
1023 stats
1024 .total_unstaked_packets_sent_for_batching
1025 .fetch_add(1, Ordering::Relaxed);
1026 }
1027 ConnectionPeerType::Staked(_) => {
1028 stats
1029 .total_staked_packets_sent_for_batching
1030 .fetch_add(1, Ordering::Relaxed);
1031 }
1032 }
1033
1034 trace!("sent {bytes_sent} byte packet for batching");
1035 }
1036
1037 Ok(StreamState::Finished)
1038}
1039
1040#[derive(Debug)]
1041struct ConnectionEntry {
1042 cancel: CancellationToken,
1043 peer_type: ConnectionPeerType,
1044 last_update: Arc<AtomicU64>,
1045 port: u16,
1046 _client_connection_tracker: ClientConnectionTracker,
1048 connection: Option<Connection>,
1049 stream_counter: Arc<ConnectionStreamCounter>,
1050}
1051
1052impl ConnectionEntry {
1053 fn new(
1054 cancel: CancellationToken,
1055 peer_type: ConnectionPeerType,
1056 last_update: Arc<AtomicU64>,
1057 port: u16,
1058 client_connection_tracker: ClientConnectionTracker,
1059 connection: Option<Connection>,
1060 stream_counter: Arc<ConnectionStreamCounter>,
1061 ) -> Self {
1062 Self {
1063 cancel,
1064 peer_type,
1065 last_update,
1066 port,
1067 _client_connection_tracker: client_connection_tracker,
1068 connection,
1069 stream_counter,
1070 }
1071 }
1072
1073 fn last_update(&self) -> u64 {
1074 self.last_update.load(Ordering::Relaxed)
1075 }
1076
1077 fn stake(&self) -> u64 {
1078 match self.peer_type {
1079 ConnectionPeerType::Unstaked => 0,
1080 ConnectionPeerType::Staked(stake) => stake,
1081 }
1082 }
1083}
1084
1085impl Drop for ConnectionEntry {
1086 fn drop(&mut self) {
1087 if let Some(conn) = self.connection.take() {
1088 conn.close(
1089 CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(),
1090 CONNECTION_CLOSE_REASON_DROPPED_ENTRY,
1091 );
1092 }
1093 self.cancel.cancel();
1094 }
1095}
1096
1097#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
1098pub(crate) enum ConnectionTableKey {
1099 IP(IpAddr),
1100 Pubkey(Pubkey),
1101}
1102
1103impl ConnectionTableKey {
1104 pub(crate) fn new(ip: IpAddr, maybe_pubkey: Option<Pubkey>) -> Self {
1105 maybe_pubkey.map_or(ConnectionTableKey::IP(ip), |pubkey| {
1106 ConnectionTableKey::Pubkey(pubkey)
1107 })
1108 }
1109}
1110
1111pub(crate) enum ConnectionTableType {
1112 Staked,
1113 Unstaked,
1114}
1115
1116pub(crate) struct ConnectionTable {
1118 table: IndexMap<ConnectionTableKey, Vec<ConnectionEntry>>,
1119 pub(crate) total_size: usize,
1120 table_type: ConnectionTableType,
1121 cancel: CancellationToken,
1122}
1123
1124impl ConnectionTable {
1128 pub(crate) fn new(table_type: ConnectionTableType, cancel: CancellationToken) -> Self {
1129 Self {
1130 table: IndexMap::default(),
1131 total_size: 0,
1132 table_type,
1133 cancel,
1134 }
1135 }
1136
1137 fn table_size(&self) -> usize {
1138 self.total_size
1139 }
1140
1141 fn is_staked(&self) -> bool {
1142 matches!(self.table_type, ConnectionTableType::Staked)
1143 }
1144
1145 pub(crate) fn prune_oldest(&mut self, max_size: usize) -> usize {
1146 let mut num_pruned = 0;
1147 let key = |(_, connections): &(_, &Vec<_>)| {
1148 connections.iter().map(ConnectionEntry::last_update).min()
1149 };
1150 while self.total_size.saturating_sub(num_pruned) > max_size {
1151 match self.table.values().enumerate().min_by_key(key) {
1152 None => break,
1153 Some((index, connections)) => {
1154 num_pruned += connections.len();
1155 self.table.swap_remove_index(index);
1156 }
1157 }
1158 }
1159 self.total_size = self.total_size.saturating_sub(num_pruned);
1160 num_pruned
1161 }
1162
1163 pub(crate) fn prune_random(&mut self, sample_size: usize, threshold_stake: u64) -> usize {
1168 let num_pruned = std::iter::once(self.table.len())
1169 .filter(|&size| size > 0)
1170 .flat_map(|size| {
1171 let mut rng = thread_rng();
1172 repeat_with(move || rng.gen_range(0..size))
1173 })
1174 .map(|index| {
1175 let connection = self.table[index].first();
1176 let stake = connection.map(|connection: &ConnectionEntry| connection.stake());
1177 (index, stake)
1178 })
1179 .take(sample_size)
1180 .min_by_key(|&(_, stake)| stake)
1181 .filter(|&(_, stake)| stake < Some(threshold_stake))
1182 .and_then(|(index, _)| self.table.swap_remove_index(index))
1183 .map(|(_, connections)| connections.len())
1184 .unwrap_or_default();
1185 self.total_size = self.total_size.saturating_sub(num_pruned);
1186 num_pruned
1187 }
1188
1189 pub(crate) fn try_add_connection(
1190 &mut self,
1191 key: ConnectionTableKey,
1192 port: u16,
1193 client_connection_tracker: ClientConnectionTracker,
1194 connection: Option<Connection>,
1195 peer_type: ConnectionPeerType,
1196 last_update: Arc<AtomicU64>,
1197 max_connections_per_peer: usize,
1198 ) -> Option<(
1199 Arc<AtomicU64>,
1200 CancellationToken,
1201 Arc<ConnectionStreamCounter>,
1202 )> {
1203 let connection_entry = self.table.entry(key).or_default();
1204 let has_connection_capacity = connection_entry
1205 .len()
1206 .checked_add(1)
1207 .map(|c| c <= max_connections_per_peer)
1208 .unwrap_or(false);
1209 if has_connection_capacity {
1210 let cancel = self.cancel.child_token();
1211 let stream_counter = connection_entry
1212 .first()
1213 .map(|entry| entry.stream_counter.clone())
1214 .unwrap_or(Arc::new(ConnectionStreamCounter::new()));
1215 connection_entry.push(ConnectionEntry::new(
1216 cancel.clone(),
1217 peer_type,
1218 last_update.clone(),
1219 port,
1220 client_connection_tracker,
1221 connection,
1222 stream_counter.clone(),
1223 ));
1224 self.total_size += 1;
1225 Some((last_update, cancel, stream_counter))
1226 } else {
1227 if let Some(connection) = connection {
1228 connection.close(
1229 CONNECTION_CLOSE_CODE_TOO_MANY.into(),
1230 CONNECTION_CLOSE_REASON_TOO_MANY,
1231 );
1232 }
1233 None
1234 }
1235 }
1236
1237 pub(crate) fn remove_connection(
1239 &mut self,
1240 key: ConnectionTableKey,
1241 port: u16,
1242 stable_id: usize,
1243 ) -> usize {
1244 if let Entry::Occupied(mut e) = self.table.entry(key) {
1245 let e_ref = e.get_mut();
1246 let old_size = e_ref.len();
1247
1248 e_ref.retain(|connection_entry| {
1249 connection_entry.port != port
1255 || connection_entry
1256 .connection
1257 .as_ref()
1258 .and_then(|connection| (connection.stable_id() != stable_id).then_some(0))
1259 .is_some()
1260 });
1261 let new_size = e_ref.len();
1262 if e_ref.is_empty() {
1263 e.swap_remove_entry();
1264 }
1265 let connections_removed = old_size.saturating_sub(new_size);
1266 self.total_size = self.total_size.saturating_sub(connections_removed);
1267 connections_removed
1268 } else {
1269 0
1270 }
1271 }
1272}
1273
1274struct EndpointAccept<'a> {
1275 endpoint: usize,
1276 accept: Accept<'a>,
1277}
1278
1279impl Future for EndpointAccept<'_> {
1280 type Output = (Option<quinn::Incoming>, usize);
1281
1282 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll<Self::Output> {
1283 let i = self.endpoint;
1284 unsafe { self.map_unchecked_mut(|this| &mut this.accept) }
1288 .poll(cx)
1289 .map(|r| (r, i))
1290 }
1291}
1292
1293#[cfg(test)]
1294pub mod test {
1295 use {
1296 super::*,
1297 crate::nonblocking::testing_utilities::{
1298 check_multiple_streams, get_client_config, make_client_endpoint, setup_quic_server,
1299 SpawnTestServerResult,
1300 },
1301 assert_matches::assert_matches,
1302 crossbeam_channel::{unbounded, Receiver},
1303 quinn::{ApplicationClose, ConnectionError},
1304 solana_keypair::Keypair,
1305 solana_net_utils::sockets::bind_to_localhost_unique,
1306 solana_signer::Signer,
1307 std::collections::HashMap,
1308 tokio::time::sleep,
1309 };
1310
1311 pub async fn check_timeout(receiver: Receiver<PacketBatch>, server_address: SocketAddr) {
1312 let conn1 = make_client_endpoint(&server_address, None).await;
1313 let total = 30;
1314 for i in 0..total {
1315 let mut s1 = conn1.open_uni().await.unwrap();
1316 s1.write_all(&[0u8]).await.unwrap();
1317 s1.finish().unwrap();
1318 info!("done {i}");
1319 sleep(Duration::from_millis(1000)).await;
1320 }
1321 let mut received = 0;
1322 loop {
1323 if let Ok(_x) = receiver.try_recv() {
1324 received += 1;
1325 info!("got {received}");
1326 } else {
1327 sleep(Duration::from_millis(500)).await;
1328 }
1329 if received >= total {
1330 break;
1331 }
1332 }
1333 }
1334
1335 pub async fn check_block_multiple_connections(server_address: SocketAddr) {
1336 let conn1 = make_client_endpoint(&server_address, None).await;
1337 let conn2 = make_client_endpoint(&server_address, None).await;
1338 let mut s1 = conn1.open_uni().await.unwrap();
1339 let s2 = conn2.open_uni().await;
1340 if let Ok(mut s2) = s2 {
1341 s1.write_all(&[0u8]).await.unwrap();
1342 s1.finish().unwrap();
1343 let data = vec![1u8; PACKET_DATA_SIZE * 2];
1347 s2.write_all(&data)
1348 .await
1349 .expect_err("shouldn't be able to open 2 connections");
1350 } else {
1351 assert_matches!(s2, Err(quinn::ConnectionError::ApplicationClosed(_)));
1355 }
1356 }
1357
1358 pub async fn check_multiple_writes(
1359 receiver: Receiver<PacketBatch>,
1360 server_address: SocketAddr,
1361 client_keypair: Option<&Keypair>,
1362 ) {
1363 let conn1 = Arc::new(make_client_endpoint(&server_address, client_keypair).await);
1364
1365 let num_bytes = PACKET_DATA_SIZE;
1367 let num_expected_packets = 1;
1368 let mut s1 = conn1.open_uni().await.unwrap();
1369 for _ in 0..num_bytes {
1370 s1.write_all(&[0u8]).await.unwrap();
1371 }
1372 s1.finish().unwrap();
1373
1374 check_received_packets(receiver, num_expected_packets, num_bytes).await;
1375 }
1376
1377 pub async fn check_multiple_packets(
1378 receiver: Receiver<PacketBatch>,
1379 server_address: SocketAddr,
1380 client_keypair: Option<&Keypair>,
1381 num_expected_packets: usize,
1382 ) {
1383 let conn1 = Arc::new(make_client_endpoint(&server_address, client_keypair).await);
1384
1385 let num_bytes = PACKET_DATA_SIZE;
1387 let packet = vec![1u8; num_bytes];
1388 for _ in 0..num_expected_packets {
1389 let mut s1 = conn1.open_uni().await.unwrap();
1390 s1.write_all(&packet).await.unwrap();
1391 s1.finish().unwrap();
1392 }
1393
1394 check_received_packets(receiver, num_expected_packets, num_bytes).await;
1395 }
1396
1397 async fn check_received_packets(
1398 receiver: Receiver<PacketBatch>,
1399 num_expected_packets: usize,
1400 num_bytes: usize,
1401 ) {
1402 let mut all_packets = vec![];
1403 let now = Instant::now();
1404 let mut total_packets = 0;
1405 while now.elapsed().as_secs() < 5 {
1406 if let Ok(packets) = receiver.try_recv() {
1409 total_packets += packets.len();
1410 all_packets.push(packets)
1411 } else {
1412 sleep(Duration::from_secs(1)).await;
1413 }
1414 if total_packets >= num_expected_packets {
1415 break;
1416 }
1417 }
1418 for batch in all_packets {
1419 for p in batch.iter() {
1420 assert_eq!(p.meta().size, num_bytes);
1421 }
1422 }
1423 assert_eq!(total_packets, num_expected_packets);
1424 }
1425
1426 pub async fn check_unstaked_node_connect_failure(server_address: SocketAddr) {
1427 let conn1 = Arc::new(make_client_endpoint(&server_address, None).await);
1428
1429 if let Ok(mut s1) = conn1.open_uni().await {
1431 for _ in 0..PACKET_DATA_SIZE {
1432 s1.write_all(&[0u8]).await.unwrap_or_default();
1434 }
1435 s1.finish().unwrap_or_default();
1436 s1.stopped().await.unwrap_err();
1437 }
1438 }
1439
1440 #[tokio::test(flavor = "multi_thread")]
1441 async fn test_quic_server_exit_on_cancel() {
1442 let SpawnTestServerResult {
1443 join_handle,
1444 receiver,
1445 server_address: _,
1446 stats: _,
1447 cancel,
1448 } = setup_quic_server(
1449 None,
1450 QuicStreamerConfig::default_for_tests(),
1451 SwQosConfig::default(),
1452 );
1453 cancel.cancel();
1454 join_handle.await.unwrap();
1455 drop(receiver);
1458 }
1459
1460 #[tokio::test(flavor = "multi_thread")]
1461 async fn test_quic_timeout() {
1462 agave_logger::setup();
1463 let SpawnTestServerResult {
1464 join_handle,
1465 receiver,
1466 server_address,
1467 stats: _,
1468 cancel,
1469 } = setup_quic_server(
1470 None,
1471 QuicStreamerConfig::default_for_tests(),
1472 SwQosConfig::default(),
1473 );
1474
1475 check_timeout(receiver, server_address).await;
1476 cancel.cancel();
1477 join_handle.await.unwrap();
1478 }
1479
1480 #[tokio::test(flavor = "multi_thread")]
1481 async fn test_packet_batcher() {
1482 agave_logger::setup();
1483 let (pkt_batch_sender, pkt_batch_receiver) = unbounded();
1484 let (ptk_sender, pkt_receiver) = unbounded();
1485 let cancel = CancellationToken::new();
1486 let stats = Arc::new(StreamerStats::default());
1487
1488 let handle = task::spawn_blocking({
1489 let cancel = cancel.clone();
1490 move || {
1491 run_packet_batch_sender(pkt_batch_sender, pkt_receiver, stats, cancel);
1492 }
1493 });
1494
1495 let num_packets = 1000;
1496
1497 for _i in 0..num_packets {
1498 let mut meta = Meta::default();
1499 let bytes = Bytes::from("Hello world");
1500 let size = bytes.len();
1501 meta.size = size;
1502 let packet_accum = PacketAccumulator {
1503 meta,
1504 chunks: smallvec::smallvec![bytes],
1505 start_time: Instant::now(),
1506 };
1507 ptk_sender.send(packet_accum).unwrap();
1508 }
1509 let mut i = 0;
1510 let start = Instant::now();
1511 while i < num_packets && start.elapsed().as_secs() < 2 {
1512 if let Ok(batch) = pkt_batch_receiver.try_recv() {
1513 i += batch.len();
1514 } else {
1515 sleep(Duration::from_millis(1)).await;
1516 }
1517 }
1518 assert_eq!(i, num_packets);
1519 cancel.cancel();
1520 drop(ptk_sender);
1522 handle.await.unwrap();
1523 }
1524
1525 #[tokio::test(flavor = "multi_thread")]
1526 async fn test_quic_stream_timeout() {
1527 agave_logger::setup();
1528 let SpawnTestServerResult {
1529 join_handle,
1530 receiver,
1531 server_address,
1532 stats,
1533 cancel,
1534 } = setup_quic_server(
1535 None,
1536 QuicStreamerConfig::default_for_tests(),
1537 SwQosConfig::default(),
1538 );
1539
1540 let conn1 = make_client_endpoint(&server_address, None).await;
1541 assert_eq!(stats.active_streams.load(Ordering::Relaxed), 0);
1542 assert_eq!(stats.total_stream_read_timeouts.load(Ordering::Relaxed), 0);
1543
1544 let mut s1 = conn1.open_uni().await.unwrap();
1546 s1.write_all(&[0u8]).await.unwrap_or_default();
1547
1548 let sleep_time = DEFAULT_WAIT_FOR_CHUNK_TIMEOUT * 2;
1550 sleep(sleep_time).await;
1551
1552 assert_eq!(stats.active_streams.load(Ordering::Relaxed), 0);
1554 assert_ne!(stats.total_stream_read_timeouts.load(Ordering::Relaxed), 0);
1555
1556 assert!(s1.write_all(&[0u8]).await.is_err());
1559
1560 cancel.cancel();
1561 drop(receiver);
1562 join_handle.await.unwrap();
1563 }
1564
1565 #[tokio::test(flavor = "multi_thread")]
1566 async fn test_quic_server_block_multiple_connections() {
1567 agave_logger::setup();
1568 let SpawnTestServerResult {
1569 join_handle,
1570 receiver,
1571 server_address,
1572 stats: _,
1573 cancel,
1574 } = setup_quic_server(
1575 None,
1576 QuicStreamerConfig::default_for_tests(),
1577 SwQosConfig::default(),
1578 );
1579 check_block_multiple_connections(server_address).await;
1580 cancel.cancel();
1581 drop(receiver);
1582 join_handle.await.unwrap();
1583 }
1584
1585 #[tokio::test(flavor = "multi_thread")]
1586 async fn test_quic_server_multiple_connections_on_single_client_endpoint() {
1587 agave_logger::setup();
1588
1589 let SpawnTestServerResult {
1590 join_handle,
1591 receiver,
1592 server_address,
1593 stats,
1594 cancel,
1595 } = setup_quic_server(
1596 None,
1597 QuicStreamerConfig {
1598 max_connections_per_peer: 2,
1599 ..QuicStreamerConfig::default_for_tests()
1600 },
1601 SwQosConfig::default(),
1602 );
1603
1604 let client_socket = bind_to_localhost_unique().expect("should bind - client");
1605 let mut endpoint = quinn::Endpoint::new(
1606 EndpointConfig::default(),
1607 None,
1608 client_socket,
1609 Arc::new(TokioRuntime),
1610 )
1611 .unwrap();
1612 let default_keypair = Keypair::new();
1613 endpoint.set_default_client_config(get_client_config(&default_keypair));
1614 let conn1 = endpoint
1615 .connect(server_address, "localhost")
1616 .expect("Failed in connecting")
1617 .await
1618 .expect("Failed in waiting");
1619
1620 let conn2 = endpoint
1621 .connect(server_address, "localhost")
1622 .expect("Failed in connecting")
1623 .await
1624 .expect("Failed in waiting");
1625
1626 let mut s1 = conn1.open_uni().await.unwrap();
1627 s1.write_all(&[0u8]).await.unwrap();
1628 s1.finish().unwrap();
1629
1630 let mut s2 = conn2.open_uni().await.unwrap();
1631 conn1.close(
1632 CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(),
1633 CONNECTION_CLOSE_REASON_DROPPED_ENTRY,
1634 );
1635
1636 let start = Instant::now();
1637 while stats.connection_removed.load(Ordering::Relaxed) != 1 && start.elapsed().as_secs() < 1
1638 {
1639 debug!("First connection not removed yet");
1640 sleep(Duration::from_millis(10)).await;
1641 }
1642 assert!(start.elapsed().as_secs() < 1);
1643
1644 s2.write_all(&[0u8]).await.unwrap();
1645 s2.finish().unwrap();
1646
1647 conn2.close(
1648 CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(),
1649 CONNECTION_CLOSE_REASON_DROPPED_ENTRY,
1650 );
1651
1652 let start = Instant::now();
1653 while stats.connection_removed.load(Ordering::Relaxed) != 2 && start.elapsed().as_secs() < 1
1654 {
1655 debug!("Second connection not removed yet");
1656 sleep(Duration::from_millis(10)).await;
1657 }
1658 assert!(start.elapsed().as_secs() < 1);
1659
1660 cancel.cancel();
1661 drop(receiver);
1666 join_handle.await.unwrap();
1667 }
1668
1669 #[tokio::test(flavor = "multi_thread")]
1670 async fn test_quic_server_multiple_writes() {
1671 agave_logger::setup();
1672 let SpawnTestServerResult {
1673 join_handle,
1674 receiver,
1675 server_address,
1676 stats: _,
1677 cancel,
1678 } = setup_quic_server(
1679 None,
1680 QuicStreamerConfig::default_for_tests(),
1681 SwQosConfig::default(),
1682 );
1683 check_multiple_writes(receiver, server_address, None).await;
1684 cancel.cancel();
1685 join_handle.await.unwrap();
1686 }
1687
1688 #[tokio::test(flavor = "multi_thread")]
1689 async fn test_quic_server_staked_connection_removal() {
1690 agave_logger::setup();
1691
1692 let client_keypair = Keypair::new();
1693 let stakes = HashMap::from([(client_keypair.pubkey(), 100_000)]);
1694 let staked_nodes = StakedNodes::new(
1695 Arc::new(stakes),
1696 HashMap::<Pubkey, u64>::default(), );
1698 let SpawnTestServerResult {
1699 join_handle,
1700 receiver,
1701 server_address,
1702 stats,
1703 cancel,
1704 } = setup_quic_server(
1705 Some(staked_nodes),
1706 QuicStreamerConfig::default_for_tests(),
1707 SwQosConfig::default(),
1708 );
1709 check_multiple_writes(receiver, server_address, Some(&client_keypair)).await;
1710 cancel.cancel();
1711 join_handle.await.unwrap();
1712
1713 assert_eq!(
1714 stats
1715 .connection_added_from_staked_peer
1716 .load(Ordering::Relaxed),
1717 1
1718 );
1719 assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1);
1720 assert_eq!(stats.connection_remove_failed.load(Ordering::Relaxed), 0);
1721 }
1722
1723 #[tokio::test(flavor = "multi_thread")]
1724 async fn test_quic_server_zero_staked_connection_removal() {
1725 agave_logger::setup();
1727
1728 let client_keypair = Keypair::new();
1729 let stakes = HashMap::from([(client_keypair.pubkey(), 0)]);
1730 let staked_nodes = StakedNodes::new(
1731 Arc::new(stakes),
1732 HashMap::<Pubkey, u64>::default(), );
1734 let SpawnTestServerResult {
1735 join_handle,
1736 receiver,
1737 server_address,
1738 stats,
1739 cancel,
1740 } = setup_quic_server(
1741 Some(staked_nodes),
1742 QuicStreamerConfig::default_for_tests(),
1743 SwQosConfig::default(),
1744 );
1745 check_multiple_writes(receiver, server_address, Some(&client_keypair)).await;
1746 cancel.cancel();
1747 join_handle.await.unwrap();
1748
1749 assert_eq!(
1750 stats
1751 .connection_added_from_staked_peer
1752 .load(Ordering::Relaxed),
1753 0
1754 );
1755 assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1);
1756 assert_eq!(stats.connection_remove_failed.load(Ordering::Relaxed), 0);
1757 }
1758
1759 #[tokio::test(flavor = "multi_thread")]
1760 async fn test_quic_server_unstaked_connection_removal() {
1761 agave_logger::setup();
1762 let SpawnTestServerResult {
1763 join_handle,
1764 receiver,
1765 server_address,
1766 stats,
1767 cancel,
1768 } = setup_quic_server(
1769 None,
1770 QuicStreamerConfig::default_for_tests(),
1771 SwQosConfig::default(),
1772 );
1773 check_multiple_writes(receiver, server_address, None).await;
1774 cancel.cancel();
1775 join_handle.await.unwrap();
1776
1777 assert_eq!(
1778 stats
1779 .connection_added_from_staked_peer
1780 .load(Ordering::Relaxed),
1781 0
1782 );
1783 assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1);
1784 assert_eq!(stats.connection_remove_failed.load(Ordering::Relaxed), 0);
1785 }
1786
1787 #[tokio::test(flavor = "multi_thread")]
1788 async fn test_quic_server_unstaked_node_connect_failure() {
1789 agave_logger::setup();
1790 let s = bind_to_localhost_unique().expect("should bind");
1791 let (sender, _) = unbounded();
1792 let keypair = Keypair::new();
1793 let server_address = s.local_addr().unwrap();
1794 let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
1795 let cancel = CancellationToken::new();
1796 let SpawnNonBlockingServerResult {
1797 endpoints: _,
1798 stats: _,
1799 thread: t,
1800 max_concurrent_connections: _,
1801 } = spawn_server_with_cancel(
1802 "quic_streamer_test",
1803 [s],
1804 &keypair,
1805 sender,
1806 staked_nodes,
1807 QuicStreamerConfig {
1808 max_unstaked_connections: 0, ..QuicStreamerConfig::default_for_tests()
1810 },
1811 SwQosConfig::default(),
1812 cancel.clone(),
1813 )
1814 .unwrap();
1815
1816 check_unstaked_node_connect_failure(server_address).await;
1817 cancel.cancel();
1818 t.await.unwrap();
1819 }
1820
1821 #[tokio::test(flavor = "multi_thread")]
1822 async fn test_quic_server_multiple_streams() {
1823 agave_logger::setup();
1824 let s = bind_to_localhost_unique().expect("should bind");
1825 let (sender, receiver) = unbounded();
1826 let keypair = Keypair::new();
1827 let server_address = s.local_addr().unwrap();
1828 let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
1829 let cancel = CancellationToken::new();
1830 let SpawnNonBlockingServerResult {
1831 endpoints: _,
1832 stats,
1833 thread: t,
1834 max_concurrent_connections: _,
1835 } = spawn_server_with_cancel(
1836 "quic_streamer_test",
1837 [s],
1838 &keypair,
1839 sender,
1840 staked_nodes,
1841 QuicStreamerConfig {
1842 max_connections_per_peer: 2,
1843 ..QuicStreamerConfig::default_for_tests()
1844 },
1845 SwQosConfig::default(),
1846 cancel.clone(),
1847 )
1848 .unwrap();
1849
1850 check_multiple_streams(receiver, server_address, None).await;
1851 assert_eq!(stats.active_streams.load(Ordering::Relaxed), 0);
1852 assert_eq!(stats.total_new_streams.load(Ordering::Relaxed), 20);
1853 assert_eq!(stats.total_connections.load(Ordering::Relaxed), 2);
1854 assert_eq!(stats.total_new_connections.load(Ordering::Relaxed), 2);
1855 cancel.cancel();
1856 t.await.unwrap();
1857
1858 assert_eq!(stats.total_connections.load(Ordering::Relaxed), 0);
1859 assert_eq!(stats.total_new_connections.load(Ordering::Relaxed), 2);
1860 }
1861
1862 #[test]
1863 fn test_prune_table_with_ip() {
1864 use std::net::Ipv4Addr;
1865 agave_logger::setup();
1866 let cancel = CancellationToken::new();
1867 let mut table = ConnectionTable::new(ConnectionTableType::Unstaked, cancel);
1868 let mut num_entries = 5;
1869 let max_connections_per_peer = 10;
1870 let sockets: Vec<_> = (0..num_entries)
1871 .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
1872 .collect();
1873 let stats = Arc::new(StreamerStats::default());
1874 for (i, socket) in sockets.iter().enumerate() {
1875 table
1876 .try_add_connection(
1877 ConnectionTableKey::IP(socket.ip()),
1878 socket.port(),
1879 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
1880 None,
1881 ConnectionPeerType::Unstaked,
1882 Arc::new(AtomicU64::new(i as u64)),
1883 max_connections_per_peer,
1884 )
1885 .unwrap();
1886 }
1887 num_entries += 1;
1888 table
1889 .try_add_connection(
1890 ConnectionTableKey::IP(sockets[0].ip()),
1891 sockets[0].port(),
1892 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
1893 None,
1894 ConnectionPeerType::Unstaked,
1895 Arc::new(AtomicU64::new(5)),
1896 max_connections_per_peer,
1897 )
1898 .unwrap();
1899
1900 let new_size = 3;
1901 let pruned = table.prune_oldest(new_size);
1902 assert_eq!(pruned, num_entries as usize - new_size);
1903 for v in table.table.values() {
1904 for x in v {
1905 assert!((x.last_update() + 1) >= (num_entries as u64 - new_size as u64));
1906 }
1907 }
1908 assert_eq!(table.table.len(), new_size);
1909 assert_eq!(table.total_size, new_size);
1910 for socket in sockets.iter().take(num_entries as usize).skip(new_size - 1) {
1911 table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port(), 0);
1912 }
1913 assert_eq!(table.total_size, 0);
1914 assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
1915 }
1916
1917 #[test]
1918 fn test_prune_table_with_unique_pubkeys() {
1919 agave_logger::setup();
1920 let cancel = CancellationToken::new();
1921 let mut table = ConnectionTable::new(ConnectionTableType::Unstaked, cancel);
1922
1923 let num_entries = 15;
1926 let max_connections_per_peer = 10;
1927 let stats = Arc::new(StreamerStats::default());
1928
1929 let pubkeys: Vec<_> = (0..num_entries).map(|_| Pubkey::new_unique()).collect();
1930 for (i, pubkey) in pubkeys.iter().enumerate() {
1931 table
1932 .try_add_connection(
1933 ConnectionTableKey::Pubkey(*pubkey),
1934 0,
1935 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
1936 None,
1937 ConnectionPeerType::Unstaked,
1938 Arc::new(AtomicU64::new(i as u64)),
1939 max_connections_per_peer,
1940 )
1941 .unwrap();
1942 }
1943
1944 let new_size = 3;
1945 let pruned = table.prune_oldest(new_size);
1946 assert_eq!(pruned, num_entries as usize - new_size);
1947 assert_eq!(table.table.len(), new_size);
1948 assert_eq!(table.total_size, new_size);
1949 for pubkey in pubkeys.iter().take(num_entries as usize).skip(new_size - 1) {
1950 table.remove_connection(ConnectionTableKey::Pubkey(*pubkey), 0, 0);
1951 }
1952 assert_eq!(table.total_size, 0);
1953 assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
1954 }
1955
1956 #[test]
1957 fn test_prune_table_with_non_unique_pubkeys() {
1958 agave_logger::setup();
1959 let cancel = CancellationToken::new();
1960 let mut table = ConnectionTable::new(ConnectionTableType::Unstaked, cancel);
1961
1962 let max_connections_per_peer = 10;
1963 let pubkey = Pubkey::new_unique();
1964 let stats: Arc<StreamerStats> = Arc::new(StreamerStats::default());
1965
1966 (0..max_connections_per_peer).for_each(|i| {
1967 table
1968 .try_add_connection(
1969 ConnectionTableKey::Pubkey(pubkey),
1970 0,
1971 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
1972 None,
1973 ConnectionPeerType::Unstaked,
1974 Arc::new(AtomicU64::new(i as u64)),
1975 max_connections_per_peer,
1976 )
1977 .unwrap();
1978 });
1979
1980 assert!(table
1983 .try_add_connection(
1984 ConnectionTableKey::Pubkey(pubkey),
1985 0,
1986 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
1987 None,
1988 ConnectionPeerType::Unstaked,
1989 Arc::new(AtomicU64::new(10)),
1990 max_connections_per_peer,
1991 )
1992 .is_none());
1993
1994 let num_entries = max_connections_per_peer + 1;
1996 let pubkey2 = Pubkey::new_unique();
1997 assert!(table
1998 .try_add_connection(
1999 ConnectionTableKey::Pubkey(pubkey2),
2000 0,
2001 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2002 None,
2003 ConnectionPeerType::Unstaked,
2004 Arc::new(AtomicU64::new(10)),
2005 max_connections_per_peer,
2006 )
2007 .is_some());
2008
2009 assert_eq!(table.total_size, num_entries);
2010
2011 let new_max_size = 3;
2012 let pruned = table.prune_oldest(new_max_size);
2013 assert!(pruned >= num_entries - new_max_size);
2014 assert!(table.table.len() <= new_max_size);
2015 assert!(table.total_size <= new_max_size);
2016
2017 table.remove_connection(ConnectionTableKey::Pubkey(pubkey2), 0, 0);
2018 assert_eq!(table.total_size, 0);
2019 assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2020 }
2021
2022 #[test]
2023 fn test_prune_table_random() {
2024 use std::net::Ipv4Addr;
2025 agave_logger::setup();
2026 let cancel = CancellationToken::new();
2027 let mut table = ConnectionTable::new(ConnectionTableType::Unstaked, cancel);
2028
2029 let num_entries = 5;
2030 let max_connections_per_peer = 10;
2031 let sockets: Vec<_> = (0..num_entries)
2032 .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
2033 .collect();
2034 let stats: Arc<StreamerStats> = Arc::new(StreamerStats::default());
2035
2036 for (i, socket) in sockets.iter().enumerate() {
2037 table
2038 .try_add_connection(
2039 ConnectionTableKey::IP(socket.ip()),
2040 socket.port(),
2041 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2042 None,
2043 ConnectionPeerType::Staked((i + 1) as u64),
2044 Arc::new(AtomicU64::new(i as u64)),
2045 max_connections_per_peer,
2046 )
2047 .unwrap();
2048 }
2049
2050 let pruned = table.prune_random(2, 0);
2053 assert_eq!(pruned, 0);
2054
2055 let pruned = table.prune_random(
2058 2, num_entries as u64 + 1, );
2061 assert_eq!(pruned, 1);
2062 assert_eq!(stats.open_connections.load(Ordering::Relaxed), 4);
2064 }
2065
2066 #[test]
2067 fn test_remove_connections() {
2068 use std::net::Ipv4Addr;
2069 agave_logger::setup();
2070 let cancel = CancellationToken::new();
2071 let mut table = ConnectionTable::new(ConnectionTableType::Unstaked, cancel);
2072
2073 let num_ips = 5;
2074 let max_connections_per_peer = 10;
2075 let mut sockets: Vec<_> = (0..num_ips)
2076 .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
2077 .collect();
2078 let stats: Arc<StreamerStats> = Arc::new(StreamerStats::default());
2079
2080 for (i, socket) in sockets.iter().enumerate() {
2081 table
2082 .try_add_connection(
2083 ConnectionTableKey::IP(socket.ip()),
2084 socket.port(),
2085 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2086 None,
2087 ConnectionPeerType::Unstaked,
2088 Arc::new(AtomicU64::new((i * 2) as u64)),
2089 max_connections_per_peer,
2090 )
2091 .unwrap();
2092
2093 table
2094 .try_add_connection(
2095 ConnectionTableKey::IP(socket.ip()),
2096 socket.port(),
2097 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2098 None,
2099 ConnectionPeerType::Unstaked,
2100 Arc::new(AtomicU64::new((i * 2 + 1) as u64)),
2101 max_connections_per_peer,
2102 )
2103 .unwrap();
2104 }
2105
2106 let single_connection_addr =
2107 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(num_ips, 0, 0, 0)), 0);
2108 table
2109 .try_add_connection(
2110 ConnectionTableKey::IP(single_connection_addr.ip()),
2111 single_connection_addr.port(),
2112 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2113 None,
2114 ConnectionPeerType::Unstaked,
2115 Arc::new(AtomicU64::new((num_ips * 2) as u64)),
2116 max_connections_per_peer,
2117 )
2118 .unwrap();
2119
2120 let zero_connection_addr =
2121 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(num_ips + 1, 0, 0, 0)), 0);
2122
2123 sockets.push(single_connection_addr);
2124 sockets.push(zero_connection_addr);
2125
2126 for socket in sockets.iter() {
2127 table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port(), 0);
2128 }
2129 assert_eq!(table.total_size, 0);
2130 assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2131 }
2132
2133 #[tokio::test(flavor = "multi_thread")]
2134 async fn test_throttling_check_no_packet_drop() {
2135 agave_logger::setup_with_default_filter();
2136
2137 let SpawnTestServerResult {
2138 join_handle,
2139 receiver,
2140 server_address,
2141 stats,
2142 cancel,
2143 } = setup_quic_server(
2144 None,
2145 QuicStreamerConfig::default_for_tests(),
2146 SwQosConfig::default(),
2147 );
2148
2149 let client_connection = make_client_endpoint(&server_address, None).await;
2150
2151 let expected_num_txs = 100;
2153 let start_time = tokio::time::Instant::now();
2154 for i in 0..expected_num_txs {
2155 let mut send_stream = client_connection.open_uni().await.unwrap();
2156 let data = format!("{i}").into_bytes();
2157 send_stream.write_all(&data).await.unwrap();
2158 send_stream.finish().unwrap();
2159 }
2160 let elapsed_sending: f64 = start_time.elapsed().as_secs_f64();
2161 info!("Elapsed sending: {elapsed_sending}");
2162
2163 let start_time = tokio::time::Instant::now();
2165 let mut num_txs_received = 0;
2166 while num_txs_received < expected_num_txs && start_time.elapsed() < Duration::from_secs(2) {
2167 if let Ok(packets) = receiver.try_recv() {
2168 num_txs_received += packets.len();
2169 } else {
2170 sleep(Duration::from_millis(100)).await;
2171 }
2172 }
2173 assert_eq!(expected_num_txs, num_txs_received);
2174
2175 cancel.cancel();
2176 join_handle.await.unwrap();
2177
2178 assert_eq!(
2179 stats.total_new_streams.load(Ordering::Relaxed),
2180 expected_num_txs
2181 );
2182 assert!(stats.throttled_unstaked_streams.load(Ordering::Relaxed) > 0);
2183 }
2184
2185 #[test]
2186 fn test_client_connection_tracker() {
2187 let stats = Arc::new(StreamerStats::default());
2188 let tracker_1 = ClientConnectionTracker::new(stats.clone(), 1);
2189 assert!(tracker_1.is_ok());
2190 assert!(ClientConnectionTracker::new(stats.clone(), 1).is_err());
2191 assert_eq!(stats.open_connections.load(Ordering::Relaxed), 1);
2192 drop(tracker_1);
2194 assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2195 }
2196
2197 #[tokio::test(flavor = "multi_thread")]
2198 async fn test_client_connection_close_invalid_stream() {
2199 let SpawnTestServerResult {
2200 join_handle,
2201 server_address,
2202 stats,
2203 cancel,
2204 ..
2205 } = setup_quic_server(
2206 None,
2207 QuicStreamerConfig::default_for_tests(),
2208 SwQosConfig::default(),
2209 );
2210
2211 let client_connection = make_client_endpoint(&server_address, None).await;
2212
2213 let mut send_stream = client_connection.open_uni().await.unwrap();
2214 send_stream
2215 .write_all(&[42; PACKET_DATA_SIZE + 1])
2216 .await
2217 .unwrap();
2218 match client_connection.closed().await {
2219 ConnectionError::ApplicationClosed(ApplicationClose { error_code, reason }) => {
2220 assert_eq!(error_code, CONNECTION_CLOSE_CODE_INVALID_STREAM.into());
2221 assert_eq!(reason, CONNECTION_CLOSE_REASON_INVALID_STREAM);
2222 }
2223 _ => panic!("unexpected close"),
2224 }
2225 assert_eq!(stats.invalid_stream_size.load(Ordering::Relaxed), 1);
2226 cancel.cancel();
2227 join_handle.await.unwrap();
2228 }
2229}