solana_streamer/nonblocking/
quic.rs

1#[allow(deprecated)]
2use crate::quic::QuicServerParams;
3use {
4    crate::{
5        nonblocking::{
6            connection_rate_limiter::{ConnectionRateLimiter, TotalConnectionRateLimiter},
7            qos::{ConnectionContext, QosController},
8            stream_throttle::ConnectionStreamCounter,
9            swqos::{SwQos, SwQosConfig},
10        },
11        quic::{configure_server, QuicServerError, QuicStreamerConfig, StreamerStats},
12        streamer::StakedNodes,
13    },
14    bytes::{BufMut, Bytes, BytesMut},
15    crossbeam_channel::{bounded, Receiver, Sender, TryRecvError, TrySendError},
16    futures::{stream::FuturesUnordered, Future, StreamExt as _},
17    indexmap::map::{Entry, IndexMap},
18    quinn::{Accept, Connecting, Connection, Endpoint, EndpointConfig, TokioRuntime},
19    rand::{thread_rng, Rng},
20    smallvec::SmallVec,
21    solana_keypair::Keypair,
22    solana_measure::measure::Measure,
23    solana_packet::{Meta, PACKET_DATA_SIZE},
24    solana_perf::packet::{BytesPacket, BytesPacketBatch, PacketBatch, PACKETS_PER_BATCH},
25    solana_pubkey::Pubkey,
26    solana_signature::Signature,
27    solana_tls_utils::get_pubkey_from_tls_certificate,
28    solana_transaction_metrics_tracker::signature_if_should_track_packet,
29    std::{
30        array,
31        fmt,
32        iter::repeat_with,
33        net::{IpAddr, SocketAddr, UdpSocket},
34        pin::Pin,
35        // CAUTION: be careful not to introduce any awaits while holding an RwLock.
36        sync::{
37            atomic::{AtomicBool, AtomicU64, Ordering},
38            Arc, RwLock,
39        },
40        task::Poll,
41        time::{Duration, Instant},
42    },
43    tokio::{
44        // CAUTION: It's kind of sketch that we're mixing async and sync locks (see the RwLock above).
45        // This is done so that sync code can also access the stake table.
46        // Make sure we don't hold a sync lock across an await - including the await to
47        // lock an async Mutex. This does not happen now and should not happen as long as we
48        // don't hold an async Mutex and sync RwLock at the same time (currently true)
49        // but if we do, the scope of the RwLock must always be a subset of the async Mutex
50        // (i.e. lock order is always async Mutex -> RwLock). Also, be careful not to
51        // introduce any other awaits while holding the RwLock.
52        select,
53        task::{self, JoinHandle},
54        time::{sleep, timeout},
55    },
56    tokio_util::{sync::CancellationToken, task::TaskTracker},
57};
58
59pub const DEFAULT_WAIT_FOR_CHUNK_TIMEOUT: Duration = Duration::from_secs(2);
60
61pub const ALPN_TPU_PROTOCOL_ID: &[u8] = b"solana-tpu";
62
63const CONNECTION_CLOSE_CODE_DROPPED_ENTRY: u32 = 1;
64const CONNECTION_CLOSE_REASON_DROPPED_ENTRY: &[u8] = b"dropped";
65
66pub(crate) const CONNECTION_CLOSE_CODE_DISALLOWED: u32 = 2;
67pub(crate) const CONNECTION_CLOSE_REASON_DISALLOWED: &[u8] = b"disallowed";
68
69pub(crate) const CONNECTION_CLOSE_CODE_EXCEED_MAX_STREAM_COUNT: u32 = 3;
70pub(crate) const CONNECTION_CLOSE_REASON_EXCEED_MAX_STREAM_COUNT: &[u8] =
71    b"exceed_max_stream_count";
72
73const CONNECTION_CLOSE_CODE_TOO_MANY: u32 = 4;
74const CONNECTION_CLOSE_REASON_TOO_MANY: &[u8] = b"too_many";
75
76const CONNECTION_CLOSE_CODE_INVALID_STREAM: u32 = 5;
77const CONNECTION_CLOSE_REASON_INVALID_STREAM: &[u8] = b"invalid_stream";
78
79/// Total new connection counts per second. Heuristically taken from
80/// the default staked and unstaked connection limits. Might be adjusted
81/// later.
82const TOTAL_CONNECTIONS_PER_SECOND: u64 = 2500;
83
84/// The threshold of the size of the connection rate limiter map. When
85/// the map size is above this, we will trigger a cleanup of older
86/// entries used by past requests.
87const CONNECTION_RATE_LIMITER_CLEANUP_SIZE_THRESHOLD: usize = 100_000;
88
89/// Timeout for connection handshake. Timer starts once we get Initial from the
90/// peer, and is canceled when we get a Handshake packet from them.
91const QUIC_CONNECTION_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(2);
92
93/// Absolute max RTT to allow for a legitimate connection.
94/// Enough to cover any non-malicious link on Earth.
95pub(crate) const MAX_RTT: Duration = Duration::from_millis(320);
96/// Prevent connections from having 0 RTT when RTT is too small,
97/// as this would break some BDP calculations and assign zero bandwidth
98pub(crate) const MIN_RTT: Duration = Duration::from_millis(2);
99
100// A struct to accumulate the bytes making up
101// a packet, along with their offsets, and the
102// packet metadata. We use this accumulator to avoid
103// multiple copies of the Bytes (when building up
104// the Packet and then when copying the Packet into a PacketBatch)
105#[derive(Clone)]
106struct PacketAccumulator {
107    pub meta: Meta,
108    pub chunks: SmallVec<[Bytes; 2]>,
109    pub start_time: Instant,
110}
111
112impl PacketAccumulator {
113    fn new(meta: Meta) -> Self {
114        Self {
115            meta,
116            chunks: SmallVec::default(),
117            start_time: Instant::now(),
118        }
119    }
120}
121
122#[derive(Copy, Clone, Debug)]
123pub enum ConnectionPeerType {
124    Unstaked,
125    Staked(u64),
126}
127
128impl ConnectionPeerType {
129    pub(crate) fn is_staked(&self) -> bool {
130        matches!(self, ConnectionPeerType::Staked(_))
131    }
132}
133
134pub struct SpawnNonBlockingServerResult {
135    pub endpoints: Vec<Endpoint>,
136    pub stats: Arc<StreamerStats>,
137    pub thread: JoinHandle<()>,
138    pub max_concurrent_connections: usize,
139}
140
141#[deprecated(since = "3.0.0", note = "Use spawn_server instead")]
142#[allow(deprecated)]
143pub fn spawn_server_multi(
144    name: &'static str,
145    sockets: impl IntoIterator<Item = UdpSocket>,
146    keypair: &Keypair,
147    packet_sender: Sender<PacketBatch>,
148    exit: Arc<AtomicBool>,
149    staked_nodes: Arc<RwLock<StakedNodes>>,
150    quic_server_params: QuicServerParams,
151) -> Result<SpawnNonBlockingServerResult, QuicServerError> {
152    #[allow(deprecated)]
153    spawn_server(
154        name,
155        sockets,
156        keypair,
157        packet_sender,
158        exit,
159        staked_nodes,
160        quic_server_params,
161    )
162}
163
164#[deprecated(since = "3.1.0", note = "Use spawn_server_with_cancel instead")]
165#[allow(deprecated)]
166pub fn spawn_server(
167    name: &'static str,
168    sockets: impl IntoIterator<Item = UdpSocket>,
169    keypair: &Keypair,
170    packet_sender: Sender<PacketBatch>,
171    exit: Arc<AtomicBool>,
172    staked_nodes: Arc<RwLock<StakedNodes>>,
173    quic_server_params: QuicServerParams,
174) -> Result<SpawnNonBlockingServerResult, QuicServerError> {
175    let cancel = CancellationToken::new();
176    tokio::spawn({
177        let cancel = cancel.clone();
178        async move {
179            loop {
180                if exit.load(Ordering::Relaxed) {
181                    cancel.cancel();
182                    break;
183                }
184                sleep(Duration::from_millis(100)).await;
185            }
186        }
187    });
188    let quic_server_config = QuicStreamerConfig::from(&quic_server_params);
189    let qos_config = SwQosConfig {
190        max_streams_per_ms: quic_server_params.max_streams_per_ms,
191    };
192    spawn_server_with_cancel(
193        name,
194        sockets,
195        keypair,
196        packet_sender,
197        staked_nodes,
198        quic_server_config,
199        qos_config,
200        cancel,
201    )
202}
203
204/// Spawn a streamer instance in the current tokio runtime.
205pub fn spawn_server_with_cancel(
206    name: &'static str,
207    sockets: impl IntoIterator<Item = UdpSocket>,
208    keypair: &Keypair,
209    packet_sender: Sender<PacketBatch>,
210    staked_nodes: Arc<RwLock<StakedNodes>>,
211    quic_server_params: QuicStreamerConfig,
212    qos_config: SwQosConfig,
213    cancel: CancellationToken,
214) -> Result<SpawnNonBlockingServerResult, QuicServerError>
215where
216{
217    let stats = Arc::<StreamerStats>::default();
218
219    let swqos = Arc::new(SwQos::new(
220        qos_config,
221        quic_server_params.max_staked_connections,
222        quic_server_params.max_unstaked_connections,
223        quic_server_params.max_connections_per_peer,
224        stats.clone(),
225        staked_nodes,
226        cancel.clone(),
227    ));
228
229    spawn_server_with_cancel_and_qos(
230        name,
231        stats,
232        sockets,
233        keypair,
234        packet_sender,
235        quic_server_params,
236        swqos,
237        cancel,
238    )
239}
240
241/// Spawn a streamer instance in the current tokio runtime.
242pub(crate) fn spawn_server_with_cancel_and_qos<Q, C>(
243    name: &'static str,
244    stats: Arc<StreamerStats>,
245    sockets: impl IntoIterator<Item = UdpSocket>,
246    keypair: &Keypair,
247    packet_sender: Sender<PacketBatch>,
248    quic_server_params: QuicStreamerConfig,
249    qos: Arc<Q>,
250    cancel: CancellationToken,
251) -> Result<SpawnNonBlockingServerResult, QuicServerError>
252where
253    Q: QosController<C> + Send + Sync + 'static,
254    C: ConnectionContext + Send + Sync + 'static,
255{
256    let sockets: Vec<_> = sockets.into_iter().collect();
257    info!("Start {name} quic server on {sockets:?}");
258    let (config, _) = configure_server(keypair)?;
259
260    let endpoints = sockets
261        .into_iter()
262        .map(|sock| {
263            Endpoint::new(
264                EndpointConfig::default(),
265                Some(config.clone()),
266                sock,
267                Arc::new(TokioRuntime),
268            )
269            .map_err(QuicServerError::EndpointFailed)
270        })
271        .collect::<Result<Vec<_>, _>>()?;
272    let (packet_batch_sender, packet_batch_receiver) =
273        bounded(quic_server_params.accumulator_channel_size);
274    task::spawn_blocking({
275        let cancel = cancel.clone();
276        let stats = stats.clone();
277        move || {
278            run_packet_batch_sender(packet_sender, packet_batch_receiver, stats, cancel);
279        }
280    });
281
282    let max_concurrent_connections = quic_server_params.max_concurrent_connections();
283    let handle = tokio::spawn({
284        let endpoints = endpoints.clone();
285        let stats = stats.clone();
286        async move {
287            let tasks = run_server(
288                name,
289                endpoints.clone(),
290                packet_batch_sender,
291                stats.clone(),
292                quic_server_params,
293                cancel,
294                qos,
295            )
296            .await;
297            tasks.close();
298            tasks.wait().await;
299        }
300    });
301
302    Ok(SpawnNonBlockingServerResult {
303        endpoints,
304        stats,
305        thread: handle,
306        max_concurrent_connections,
307    })
308}
309
310/// struct ease tracking connections of all stages, so that we do not have to
311/// litter the code with open connection tracking. This is added into the
312/// connection table as part of the ConnectionEntry. The reference is auto
313/// reduced when it is dropped.
314pub struct ClientConnectionTracker {
315    pub(crate) stats: Arc<StreamerStats>,
316}
317
318/// This is required by ConnectionEntry for supporting debug format.
319impl fmt::Debug for ClientConnectionTracker {
320    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
321        f.debug_struct("StreamerClientConnection")
322            .field(
323                "open_connections:",
324                &self.stats.open_connections.load(Ordering::Relaxed),
325            )
326            .finish()
327    }
328}
329
330impl Drop for ClientConnectionTracker {
331    /// When this is dropped, reduce the open connection count.
332    fn drop(&mut self) {
333        self.stats.open_connections.fetch_sub(1, Ordering::Relaxed);
334    }
335}
336
337impl ClientConnectionTracker {
338    /// Check the max_concurrent_connections limit and if it is within the limit
339    /// create ClientConnectionTracker and increment open connection count. Otherwise returns Err
340    fn new(stats: Arc<StreamerStats>, max_concurrent_connections: usize) -> Result<Self, ()> {
341        let open_connections = stats.open_connections.fetch_add(1, Ordering::Relaxed);
342        if open_connections >= max_concurrent_connections {
343            stats.open_connections.fetch_sub(1, Ordering::Relaxed);
344            debug!(
345                "There are too many concurrent connections opened already: open: \
346                 {open_connections}, max: {max_concurrent_connections}"
347            );
348            return Err(());
349        }
350
351        Ok(Self { stats })
352    }
353}
354
355#[allow(clippy::too_many_arguments)]
356async fn run_server<Q, C>(
357    name: &'static str,
358    endpoints: Vec<Endpoint>,
359    packet_batch_sender: Sender<PacketAccumulator>,
360    stats: Arc<StreamerStats>,
361    quic_server_params: QuicStreamerConfig,
362    cancel: CancellationToken,
363    qos: Arc<Q>,
364) -> TaskTracker
365where
366    Q: QosController<C> + Send + Sync + 'static,
367    C: ConnectionContext + Send + Sync + 'static,
368{
369    let quic_server_params = Arc::new(quic_server_params);
370    let rate_limiter = Arc::new(ConnectionRateLimiter::new(
371        quic_server_params.max_connections_per_ipaddr_per_min,
372    ));
373    let overall_connection_rate_limiter = Arc::new(TotalConnectionRateLimiter::new(
374        TOTAL_CONNECTIONS_PER_SECOND,
375    ));
376
377    const WAIT_FOR_CONNECTION_TIMEOUT: Duration = Duration::from_secs(1);
378    debug!("spawn quic server");
379    let mut last_datapoint = Instant::now();
380    stats
381        .quic_endpoints_count
382        .store(endpoints.len(), Ordering::Relaxed);
383
384    let mut accepts = endpoints
385        .iter()
386        .enumerate()
387        .map(|(i, incoming)| {
388            Box::pin(EndpointAccept {
389                accept: incoming.accept(),
390                endpoint: i,
391            })
392        })
393        .collect::<FuturesUnordered<_>>();
394
395    let tasks = TaskTracker::new();
396    loop {
397        let timeout_connection = select! {
398            ready = accepts.next() => {
399                if let Some((connecting, i)) = ready {
400                    accepts.push(
401                        Box::pin(EndpointAccept {
402                            accept: endpoints[i].accept(),
403                            endpoint: i,
404                        }
405                    ));
406                    Ok(connecting)
407                } else {
408                    // we can't really get here - we never poll an empty FuturesUnordered
409                    continue
410                }
411            }
412            _ = tokio::time::sleep(WAIT_FOR_CONNECTION_TIMEOUT) => {
413                Err(())
414            }
415            _ = cancel.cancelled() => break,
416        };
417
418        if last_datapoint.elapsed().as_secs() >= 5 {
419            stats.report(name);
420            last_datapoint = Instant::now();
421        }
422
423        if let Ok(Some(incoming)) = timeout_connection {
424            stats
425                .total_incoming_connection_attempts
426                .fetch_add(1, Ordering::Relaxed);
427
428            // first do per IpAddr rate limiting
429            if rate_limiter.len() > CONNECTION_RATE_LIMITER_CLEANUP_SIZE_THRESHOLD {
430                rate_limiter.retain_recent();
431            }
432            stats
433                .connection_rate_limiter_length
434                .store(rate_limiter.len(), Ordering::Relaxed);
435
436            let Ok(client_connection_tracker) = ClientConnectionTracker::new(
437                stats.clone(),
438                quic_server_params.max_concurrent_connections(),
439            ) else {
440                stats
441                    .refused_connections_too_many_open_connections
442                    .fetch_add(1, Ordering::Relaxed);
443                incoming.refuse();
444                continue;
445            };
446
447            stats
448                .outstanding_incoming_connection_attempts
449                .fetch_add(1, Ordering::Relaxed);
450            let connecting = incoming.accept();
451            match connecting {
452                Ok(connecting) => {
453                    let rate_limiter = rate_limiter.clone();
454                    let overall_connection_rate_limiter = overall_connection_rate_limiter.clone();
455                    tasks.spawn(setup_connection(
456                        connecting,
457                        rate_limiter,
458                        overall_connection_rate_limiter,
459                        client_connection_tracker,
460                        packet_batch_sender.clone(),
461                        stats.clone(),
462                        quic_server_params.clone(),
463                        qos.clone(),
464                        tasks.clone(),
465                    ));
466                }
467                Err(err) => {
468                    stats
469                        .outstanding_incoming_connection_attempts
470                        .fetch_sub(1, Ordering::Relaxed);
471                    debug!("Incoming::accept(): error {err:?}");
472                }
473            }
474        } else {
475            debug!("accept(): Timed out waiting for connection");
476        }
477    }
478    tasks
479}
480
481pub fn get_remote_pubkey(connection: &Connection) -> Option<Pubkey> {
482    // Use the client cert only if it is self signed and the chain length is 1.
483    connection
484        .peer_identity()?
485        .downcast::<Vec<rustls::pki_types::CertificateDer>>()
486        .ok()
487        .filter(|certs| certs.len() == 1)?
488        .first()
489        .and_then(get_pubkey_from_tls_certificate)
490}
491
492pub fn get_connection_stake(
493    connection: &Connection,
494    staked_nodes: &RwLock<StakedNodes>,
495) -> Option<(Pubkey, u64, u64)> {
496    let pubkey = get_remote_pubkey(connection)?;
497    debug!("Peer public key is {pubkey:?}");
498    let staked_nodes = staked_nodes.read().unwrap();
499    Some((
500        pubkey,
501        staked_nodes.get_node_stake(&pubkey)?,
502        staked_nodes.total_stake(),
503    ))
504}
505
506#[derive(Debug)]
507pub(crate) enum ConnectionHandlerError {
508    ConnectionAddError,
509    MaxStreamError,
510}
511
512pub(crate) fn update_open_connections_stat(
513    stats: &StreamerStats,
514    connection_table: &ConnectionTable,
515) {
516    if connection_table.is_staked() {
517        stats
518            .open_staked_connections
519            .store(connection_table.table_size(), Ordering::Relaxed);
520    } else {
521        stats
522            .open_unstaked_connections
523            .store(connection_table.table_size(), Ordering::Relaxed);
524    }
525}
526
527#[allow(clippy::too_many_arguments)]
528async fn setup_connection<Q, C>(
529    connecting: Connecting,
530    rate_limiter: Arc<ConnectionRateLimiter>,
531    overall_connection_rate_limiter: Arc<TotalConnectionRateLimiter>,
532    client_connection_tracker: ClientConnectionTracker,
533    packet_sender: Sender<PacketAccumulator>,
534    stats: Arc<StreamerStats>,
535    server_params: Arc<QuicStreamerConfig>,
536    qos: Arc<Q>,
537    tasks: TaskTracker,
538) where
539    Q: QosController<C> + Send + Sync + 'static,
540    C: ConnectionContext + Send + Sync + 'static,
541{
542    let from = connecting.remote_address();
543    let res = timeout(QUIC_CONNECTION_HANDSHAKE_TIMEOUT, connecting).await;
544    stats
545        .outstanding_incoming_connection_attempts
546        .fetch_sub(1, Ordering::Relaxed);
547    if let Ok(connecting_result) = res {
548        match connecting_result {
549            Ok(new_connection) => {
550                debug!("Got a connection {from:?}");
551                if !rate_limiter.is_allowed(&from.ip()) {
552                    debug!("Reject connection from {from:?} -- rate limiting exceeded");
553                    stats
554                        .connection_rate_limited_per_ipaddr
555                        .fetch_add(1, Ordering::Relaxed);
556                    new_connection.close(
557                        CONNECTION_CLOSE_CODE_DISALLOWED.into(),
558                        CONNECTION_CLOSE_REASON_DISALLOWED,
559                    );
560                    return;
561                }
562
563                if !overall_connection_rate_limiter.is_allowed() {
564                    debug!(
565                        "Reject connection from {:?} -- total rate limiting exceeded",
566                        from.ip()
567                    );
568                    stats
569                        .connection_rate_limited_across_all
570                        .fetch_add(1, Ordering::Relaxed);
571                    new_connection.close(
572                        CONNECTION_CLOSE_CODE_DISALLOWED.into(),
573                        CONNECTION_CLOSE_REASON_DISALLOWED,
574                    );
575                    return;
576                }
577
578                stats.total_new_connections.fetch_add(1, Ordering::Relaxed);
579
580                let mut conn_context = qos.build_connection_context(&new_connection);
581                if let Some(cancel_connection) = qos
582                    .try_add_connection(
583                        client_connection_tracker,
584                        &new_connection,
585                        &mut conn_context,
586                    )
587                    .await
588                {
589                    tasks.spawn(handle_connection(
590                        packet_sender.clone(),
591                        from,
592                        new_connection,
593                        stats,
594                        server_params.wait_for_chunk_timeout,
595                        conn_context.clone(),
596                        qos,
597                        cancel_connection,
598                    ));
599                }
600            }
601            Err(e) => {
602                handle_connection_error(e, &stats, from);
603            }
604        }
605    } else {
606        stats
607            .connection_setup_timeout
608            .fetch_add(1, Ordering::Relaxed);
609    }
610}
611
612fn handle_connection_error(e: quinn::ConnectionError, stats: &StreamerStats, from: SocketAddr) {
613    debug!("error: {e:?} from: {from:?}");
614    stats.connection_setup_error.fetch_add(1, Ordering::Relaxed);
615    match e {
616        quinn::ConnectionError::TimedOut => {
617            stats
618                .connection_setup_error_timed_out
619                .fetch_add(1, Ordering::Relaxed);
620        }
621        quinn::ConnectionError::ConnectionClosed(_) => {
622            stats
623                .connection_setup_error_closed
624                .fetch_add(1, Ordering::Relaxed);
625        }
626        quinn::ConnectionError::TransportError(_) => {
627            stats
628                .connection_setup_error_transport
629                .fetch_add(1, Ordering::Relaxed);
630        }
631        quinn::ConnectionError::ApplicationClosed(_) => {
632            stats
633                .connection_setup_error_app_closed
634                .fetch_add(1, Ordering::Relaxed);
635        }
636        quinn::ConnectionError::Reset => {
637            stats
638                .connection_setup_error_reset
639                .fetch_add(1, Ordering::Relaxed);
640        }
641        quinn::ConnectionError::LocallyClosed => {
642            stats
643                .connection_setup_error_locally_closed
644                .fetch_add(1, Ordering::Relaxed);
645        }
646        _ => {}
647    }
648}
649
650// Holder(s) of the Sender<PacketAccumulator> on the other end should not
651// wait for this function to exit
652fn run_packet_batch_sender(
653    packet_sender: Sender<PacketBatch>,
654    packet_receiver: Receiver<PacketAccumulator>,
655    stats: Arc<StreamerStats>,
656    cancel: CancellationToken,
657) {
658    let mut channel_disconnected = false;
659    trace!("enter packet_batch_sender");
660    loop {
661        let mut packet_perf_measure: Vec<([u8; 64], Instant)> = Vec::default();
662        let mut packet_batch = BytesPacketBatch::with_capacity(PACKETS_PER_BATCH);
663        let mut total_bytes: usize = 0;
664
665        stats
666            .total_packet_batches_allocated
667            .fetch_add(1, Ordering::Relaxed);
668        stats
669            .total_packets_allocated
670            .fetch_add(PACKETS_PER_BATCH, Ordering::Relaxed);
671
672        loop {
673            if cancel.is_cancelled() || channel_disconnected {
674                return;
675            }
676            if !packet_batch.is_empty() {
677                let len = packet_batch.len();
678                track_streamer_fetch_packet_performance(&packet_perf_measure, &stats);
679
680                if let Err(e) = packet_sender.try_send(packet_batch.into()) {
681                    stats
682                        .total_packet_batch_send_err
683                        .fetch_add(1, Ordering::Relaxed);
684                    trace!("Send error: {e}");
685
686                    // The downstream channel is disconnected, this error is not recoverable.
687                    if matches!(e, TrySendError::Disconnected(_)) {
688                        cancel.cancel();
689                        return;
690                    }
691                } else {
692                    stats
693                        .total_packet_batches_sent
694                        .fetch_add(1, Ordering::Relaxed);
695
696                    stats
697                        .total_packets_sent_to_consumer
698                        .fetch_add(len, Ordering::Relaxed);
699
700                    stats
701                        .total_bytes_sent_to_consumer
702                        .fetch_add(total_bytes, Ordering::Relaxed);
703
704                    trace!("Sent {len} packet batch");
705                }
706                break;
707            }
708
709            // On the first receive, we block on recv not to use excessive CPU when the channel is idle.
710            // This will not block the exit as the channel will be dropped on the sender's side.
711            //
712            // On subsequent receives, we call try_recv, so that we do not get blocked waiting for packets
713            // when we already have something in the batch.
714            //
715            // For setting channel_disconnected, we can ignore TryRecvError::Disconnected on try_recv and
716            // set it on the next iteration (if we don't exit early anyway from cancel token)
717            let mut first = true;
718            let mut recv = || {
719                if first {
720                    first = false;
721                    // recv is only an error if empty and disconnected
722                    packet_receiver.recv().map_err(|_| {
723                        channel_disconnected = true;
724                        TryRecvError::Disconnected
725                    })
726                } else {
727                    packet_receiver.try_recv()
728                }
729            };
730            while let Ok(mut packet_accumulator) = recv() {
731                // 86% of transactions/packets come in one chunk. In that case,
732                // we can just move the chunk to the `Packet` and no copy is
733                // made.
734                // 14% of them come in multiple chunks. In that case, we copy
735                // them into one `Bytes` buffer. We make a copy once, with
736                // intention to not do it again.
737                let num_chunks = packet_accumulator.chunks.len();
738                let mut packet = if packet_accumulator.chunks.len() == 1 {
739                    BytesPacket::new(
740                        packet_accumulator.chunks.pop().expect("expected one chunk"),
741                        packet_accumulator.meta,
742                    )
743                } else {
744                    let size: usize = packet_accumulator.chunks.iter().map(Bytes::len).sum();
745                    let mut buf = BytesMut::with_capacity(size);
746                    for chunk in packet_accumulator.chunks {
747                        buf.put_slice(&chunk);
748                    }
749                    BytesPacket::new(buf.freeze(), packet_accumulator.meta)
750                };
751
752                total_bytes += packet.meta().size;
753
754                if let Some(signature) = signature_if_should_track_packet(&packet).ok().flatten() {
755                    packet_perf_measure.push((*signature, packet_accumulator.start_time));
756                    // we set the PERF_TRACK_PACKET on
757                    packet.meta_mut().set_track_performance(true);
758                }
759                packet_batch.push(packet);
760                stats
761                    .total_chunks_processed_by_batcher
762                    .fetch_add(num_chunks, Ordering::Relaxed);
763
764                // prevent getting stuck in loop
765                if packet_batch.len() >= PACKETS_PER_BATCH {
766                    break;
767                }
768            }
769        }
770    }
771}
772
773fn track_streamer_fetch_packet_performance(
774    packet_perf_measure: &[([u8; 64], Instant)],
775    stats: &StreamerStats,
776) {
777    if packet_perf_measure.is_empty() {
778        return;
779    }
780    let mut measure = Measure::start("track_perf");
781    let mut process_sampled_packets_us_hist = stats.process_sampled_packets_us_hist.lock().unwrap();
782
783    let now = Instant::now();
784    for (signature, start_time) in packet_perf_measure {
785        let duration = now.duration_since(*start_time);
786        debug!(
787            "QUIC streamer fetch stage took {duration:?} for transaction {:?}",
788            Signature::from(*signature)
789        );
790        process_sampled_packets_us_hist
791            .increment(duration.as_micros() as u64)
792            .unwrap();
793    }
794
795    drop(process_sampled_packets_us_hist);
796    measure.stop();
797    stats
798        .perf_track_overhead_us
799        .fetch_add(measure.as_us(), Ordering::Relaxed);
800}
801
802async fn handle_connection<Q, C>(
803    packet_sender: Sender<PacketAccumulator>,
804    remote_address: SocketAddr,
805    connection: Connection,
806    stats: Arc<StreamerStats>,
807    wait_for_chunk_timeout: Duration,
808    context: C,
809    qos: Arc<Q>,
810    cancel: CancellationToken,
811) where
812    Q: QosController<C> + Send + Sync + 'static,
813    C: ConnectionContext + Send + Sync + 'static,
814{
815    let peer_type = context.peer_type();
816    debug!(
817        "quic new connection {} streams: {} connections: {}",
818        remote_address,
819        stats.active_streams.load(Ordering::Relaxed),
820        stats.total_connections.load(Ordering::Relaxed),
821    );
822    stats.total_connections.fetch_add(1, Ordering::Relaxed);
823
824    'conn: loop {
825        // Wait for new streams. If the peer is disconnected we get a cancellation signal and stop
826        // the connection task.
827        let mut stream = select! {
828            stream = connection.accept_uni() => match stream {
829                Ok(stream) => stream,
830                Err(e) => {
831                    debug!("stream error: {e:?}");
832                    break;
833                }
834            },
835            _ = cancel.cancelled() => break,
836        };
837
838        qos.on_new_stream(&context).await;
839        qos.on_stream_accepted(&context);
840        stats.active_streams.fetch_add(1, Ordering::Relaxed);
841        stats.total_new_streams.fetch_add(1, Ordering::Relaxed);
842
843        let mut meta = Meta::default();
844        meta.set_socket_addr(&remote_address);
845        meta.set_from_staked_node(matches!(peer_type, ConnectionPeerType::Staked(_)));
846        let mut accum = PacketAccumulator::new(meta);
847
848        // Virtually all small transactions will fit in 1 chunk. Larger transactions will fit in 1
849        // or 2 chunks if the first chunk starts towards the end of a datagram. A small number of
850        // transaction will have other protocol frames inserted in the middle. Empirically it's been
851        // observed that 4 is the maximum number of chunks txs get split into.
852        //
853        // Bytes values are small, so overall the array takes only 128 bytes, and the "cost" of
854        // overallocating a few bytes is negligible compared to the cost of having to do multiple
855        // read_chunks() calls.
856        let mut chunks: [Bytes; 4] = array::from_fn(|_| Bytes::new());
857
858        loop {
859            // Read the next chunks, waiting up to `wait_for_chunk_timeout`. If we don't get chunks
860            // before then, we assume the stream is dead. This can only happen if there's severe
861            // packet loss or the peer stops sending for whatever reason.
862            let n_chunks = match tokio::select! {
863                chunk = tokio::time::timeout(
864                    wait_for_chunk_timeout,
865                    stream.read_chunks(&mut chunks)) => chunk,
866
867                // If the peer gets disconnected stop the task right away.
868                _ = cancel.cancelled() => break,
869            } {
870                // read_chunk returned success
871                Ok(Ok(chunk)) => chunk.unwrap_or(0),
872                // read_chunk returned error
873                Ok(Err(e)) => {
874                    debug!("Received stream error: {e:?}");
875                    stats
876                        .total_stream_read_errors
877                        .fetch_add(1, Ordering::Relaxed);
878                    break;
879                }
880                // timeout elapsed
881                Err(_) => {
882                    debug!("Timeout in receiving on stream");
883                    stats
884                        .total_stream_read_timeouts
885                        .fetch_add(1, Ordering::Relaxed);
886                    break;
887                }
888            };
889
890            match handle_chunks(
891                // Bytes::clone() is a cheap atomic inc
892                chunks.iter().take(n_chunks).cloned(),
893                &mut accum,
894                &packet_sender,
895                &stats,
896                peer_type,
897            ) {
898                // The stream is finished, break out of the loop and close the stream.
899                Ok(StreamState::Finished) => {
900                    qos.on_stream_finished(&context);
901                    break;
902                }
903                // The stream is still active, continue reading.
904                Ok(StreamState::Receiving) => {}
905                Err(_) => {
906                    // Disconnect peers that send invalid streams.
907                    connection.close(
908                        CONNECTION_CLOSE_CODE_INVALID_STREAM.into(),
909                        CONNECTION_CLOSE_REASON_INVALID_STREAM,
910                    );
911                    stats.active_streams.fetch_sub(1, Ordering::Relaxed);
912                    qos.on_stream_error(&context);
913                    break 'conn;
914                }
915            }
916        }
917
918        stats.active_streams.fetch_sub(1, Ordering::Relaxed);
919        qos.on_stream_closed(&context);
920    }
921
922    let removed_connection_count = qos.remove_connection(&context, connection).await;
923    if removed_connection_count > 0 {
924        stats
925            .connection_removed
926            .fetch_add(removed_connection_count, Ordering::Relaxed);
927    } else {
928        stats
929            .connection_remove_failed
930            .fetch_add(1, Ordering::Relaxed);
931    }
932    stats.total_connections.fetch_sub(1, Ordering::Relaxed);
933}
934
935enum StreamState {
936    // Stream is not finished, keep receiving chunks
937    Receiving,
938    // Stream is finished
939    Finished,
940}
941
942// Handle the chunks received from the stream. If the stream is finished, send the packet to the
943// packet sender.
944//
945// Returns Err(()) if the stream is invalid.
946fn handle_chunks(
947    chunks: impl ExactSizeIterator<Item = Bytes>,
948    accum: &mut PacketAccumulator,
949    packet_sender: &Sender<PacketAccumulator>,
950    stats: &StreamerStats,
951    peer_type: ConnectionPeerType,
952) -> Result<StreamState, ()> {
953    let n_chunks = chunks.len();
954    for chunk in chunks {
955        accum.meta.size += chunk.len();
956        if accum.meta.size > PACKET_DATA_SIZE {
957            // The stream window size is set to PACKET_DATA_SIZE, so one individual chunk can
958            // never exceed this size. A peer can send two chunks that together exceed the size
959            // tho, in which case we report the error.
960            stats.invalid_stream_size.fetch_add(1, Ordering::Relaxed);
961            debug!("invalid stream size {}", accum.meta.size);
962            return Err(());
963        }
964        accum.chunks.push(chunk);
965        if peer_type.is_staked() {
966            stats
967                .total_staked_chunks_received
968                .fetch_add(1, Ordering::Relaxed);
969        } else {
970            stats
971                .total_unstaked_chunks_received
972                .fetch_add(1, Ordering::Relaxed);
973        }
974    }
975
976    // n_chunks == 0 marks the end of a stream
977    if n_chunks != 0 {
978        return Ok(StreamState::Receiving);
979    }
980
981    if accum.chunks.is_empty() {
982        debug!("stream is empty");
983        stats
984            .total_packet_batches_none
985            .fetch_add(1, Ordering::Relaxed);
986        return Err(());
987    }
988
989    // done receiving chunks
990    let bytes_sent = accum.meta.size;
991    let chunks_sent = accum.chunks.len();
992
993    if let Err(err) = packet_sender.try_send(accum.clone()) {
994        stats
995            .total_handle_chunk_to_packet_batcher_send_err
996            .fetch_add(1, Ordering::Relaxed);
997        match err {
998            TrySendError::Full(_) => {
999                stats
1000                    .total_handle_chunk_to_packet_batcher_send_full_err
1001                    .fetch_add(1, Ordering::Relaxed);
1002            }
1003            TrySendError::Disconnected(_) => {
1004                stats
1005                    .total_handle_chunk_to_packet_batcher_send_disconnected_err
1006                    .fetch_add(1, Ordering::Relaxed);
1007            }
1008        }
1009        trace!("packet batch send error {err:?}");
1010    } else {
1011        stats
1012            .total_packets_sent_for_batching
1013            .fetch_add(1, Ordering::Relaxed);
1014        stats
1015            .total_bytes_sent_for_batching
1016            .fetch_add(bytes_sent, Ordering::Relaxed);
1017        stats
1018            .total_chunks_sent_for_batching
1019            .fetch_add(chunks_sent, Ordering::Relaxed);
1020
1021        match peer_type {
1022            ConnectionPeerType::Unstaked => {
1023                stats
1024                    .total_unstaked_packets_sent_for_batching
1025                    .fetch_add(1, Ordering::Relaxed);
1026            }
1027            ConnectionPeerType::Staked(_) => {
1028                stats
1029                    .total_staked_packets_sent_for_batching
1030                    .fetch_add(1, Ordering::Relaxed);
1031            }
1032        }
1033
1034        trace!("sent {bytes_sent} byte packet for batching");
1035    }
1036
1037    Ok(StreamState::Finished)
1038}
1039
1040#[derive(Debug)]
1041struct ConnectionEntry {
1042    cancel: CancellationToken,
1043    peer_type: ConnectionPeerType,
1044    last_update: Arc<AtomicU64>,
1045    port: u16,
1046    // We do not explicitly use it, but its drop is triggered when ConnectionEntry is dropped.
1047    _client_connection_tracker: ClientConnectionTracker,
1048    connection: Option<Connection>,
1049    stream_counter: Arc<ConnectionStreamCounter>,
1050}
1051
1052impl ConnectionEntry {
1053    fn new(
1054        cancel: CancellationToken,
1055        peer_type: ConnectionPeerType,
1056        last_update: Arc<AtomicU64>,
1057        port: u16,
1058        client_connection_tracker: ClientConnectionTracker,
1059        connection: Option<Connection>,
1060        stream_counter: Arc<ConnectionStreamCounter>,
1061    ) -> Self {
1062        Self {
1063            cancel,
1064            peer_type,
1065            last_update,
1066            port,
1067            _client_connection_tracker: client_connection_tracker,
1068            connection,
1069            stream_counter,
1070        }
1071    }
1072
1073    fn last_update(&self) -> u64 {
1074        self.last_update.load(Ordering::Relaxed)
1075    }
1076
1077    fn stake(&self) -> u64 {
1078        match self.peer_type {
1079            ConnectionPeerType::Unstaked => 0,
1080            ConnectionPeerType::Staked(stake) => stake,
1081        }
1082    }
1083}
1084
1085impl Drop for ConnectionEntry {
1086    fn drop(&mut self) {
1087        if let Some(conn) = self.connection.take() {
1088            conn.close(
1089                CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(),
1090                CONNECTION_CLOSE_REASON_DROPPED_ENTRY,
1091            );
1092        }
1093        self.cancel.cancel();
1094    }
1095}
1096
1097#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
1098pub(crate) enum ConnectionTableKey {
1099    IP(IpAddr),
1100    Pubkey(Pubkey),
1101}
1102
1103impl ConnectionTableKey {
1104    pub(crate) fn new(ip: IpAddr, maybe_pubkey: Option<Pubkey>) -> Self {
1105        maybe_pubkey.map_or(ConnectionTableKey::IP(ip), |pubkey| {
1106            ConnectionTableKey::Pubkey(pubkey)
1107        })
1108    }
1109}
1110
1111pub(crate) enum ConnectionTableType {
1112    Staked,
1113    Unstaked,
1114}
1115
1116// Map of IP to list of connection entries
1117pub(crate) struct ConnectionTable {
1118    table: IndexMap<ConnectionTableKey, Vec<ConnectionEntry>>,
1119    pub(crate) total_size: usize,
1120    table_type: ConnectionTableType,
1121    cancel: CancellationToken,
1122}
1123
1124/// Prune the connection which has the oldest update
1125///
1126/// Return number pruned
1127impl ConnectionTable {
1128    pub(crate) fn new(table_type: ConnectionTableType, cancel: CancellationToken) -> Self {
1129        Self {
1130            table: IndexMap::default(),
1131            total_size: 0,
1132            table_type,
1133            cancel,
1134        }
1135    }
1136
1137    fn table_size(&self) -> usize {
1138        self.total_size
1139    }
1140
1141    fn is_staked(&self) -> bool {
1142        matches!(self.table_type, ConnectionTableType::Staked)
1143    }
1144
1145    pub(crate) fn prune_oldest(&mut self, max_size: usize) -> usize {
1146        let mut num_pruned = 0;
1147        let key = |(_, connections): &(_, &Vec<_>)| {
1148            connections.iter().map(ConnectionEntry::last_update).min()
1149        };
1150        while self.total_size.saturating_sub(num_pruned) > max_size {
1151            match self.table.values().enumerate().min_by_key(key) {
1152                None => break,
1153                Some((index, connections)) => {
1154                    num_pruned += connections.len();
1155                    self.table.swap_remove_index(index);
1156                }
1157            }
1158        }
1159        self.total_size = self.total_size.saturating_sub(num_pruned);
1160        num_pruned
1161    }
1162
1163    // Randomly selects sample_size many connections, evicts the one with the
1164    // lowest stake, and returns the number of pruned connections.
1165    // If the stakes of all the sampled connections are higher than the
1166    // threshold_stake, rejects the pruning attempt, and returns 0.
1167    pub(crate) fn prune_random(&mut self, sample_size: usize, threshold_stake: u64) -> usize {
1168        let num_pruned = std::iter::once(self.table.len())
1169            .filter(|&size| size > 0)
1170            .flat_map(|size| {
1171                let mut rng = thread_rng();
1172                repeat_with(move || rng.gen_range(0..size))
1173            })
1174            .map(|index| {
1175                let connection = self.table[index].first();
1176                let stake = connection.map(|connection: &ConnectionEntry| connection.stake());
1177                (index, stake)
1178            })
1179            .take(sample_size)
1180            .min_by_key(|&(_, stake)| stake)
1181            .filter(|&(_, stake)| stake < Some(threshold_stake))
1182            .and_then(|(index, _)| self.table.swap_remove_index(index))
1183            .map(|(_, connections)| connections.len())
1184            .unwrap_or_default();
1185        self.total_size = self.total_size.saturating_sub(num_pruned);
1186        num_pruned
1187    }
1188
1189    pub(crate) fn try_add_connection(
1190        &mut self,
1191        key: ConnectionTableKey,
1192        port: u16,
1193        client_connection_tracker: ClientConnectionTracker,
1194        connection: Option<Connection>,
1195        peer_type: ConnectionPeerType,
1196        last_update: Arc<AtomicU64>,
1197        max_connections_per_peer: usize,
1198    ) -> Option<(
1199        Arc<AtomicU64>,
1200        CancellationToken,
1201        Arc<ConnectionStreamCounter>,
1202    )> {
1203        let connection_entry = self.table.entry(key).or_default();
1204        let has_connection_capacity = connection_entry
1205            .len()
1206            .checked_add(1)
1207            .map(|c| c <= max_connections_per_peer)
1208            .unwrap_or(false);
1209        if has_connection_capacity {
1210            let cancel = self.cancel.child_token();
1211            let stream_counter = connection_entry
1212                .first()
1213                .map(|entry| entry.stream_counter.clone())
1214                .unwrap_or(Arc::new(ConnectionStreamCounter::new()));
1215            connection_entry.push(ConnectionEntry::new(
1216                cancel.clone(),
1217                peer_type,
1218                last_update.clone(),
1219                port,
1220                client_connection_tracker,
1221                connection,
1222                stream_counter.clone(),
1223            ));
1224            self.total_size += 1;
1225            Some((last_update, cancel, stream_counter))
1226        } else {
1227            if let Some(connection) = connection {
1228                connection.close(
1229                    CONNECTION_CLOSE_CODE_TOO_MANY.into(),
1230                    CONNECTION_CLOSE_REASON_TOO_MANY,
1231                );
1232            }
1233            None
1234        }
1235    }
1236
1237    // Returns number of connections that were removed
1238    pub(crate) fn remove_connection(
1239        &mut self,
1240        key: ConnectionTableKey,
1241        port: u16,
1242        stable_id: usize,
1243    ) -> usize {
1244        if let Entry::Occupied(mut e) = self.table.entry(key) {
1245            let e_ref = e.get_mut();
1246            let old_size = e_ref.len();
1247
1248            e_ref.retain(|connection_entry| {
1249                // Retain the connection entry if the port is different, or if the connection's
1250                // stable_id doesn't match the provided stable_id.
1251                // (Some unit tests do not fill in a valid connection in the table. To support that,
1252                // if the connection is none, the stable_id check is ignored. i.e. if the port matches,
1253                // the connection gets removed)
1254                connection_entry.port != port
1255                    || connection_entry
1256                        .connection
1257                        .as_ref()
1258                        .and_then(|connection| (connection.stable_id() != stable_id).then_some(0))
1259                        .is_some()
1260            });
1261            let new_size = e_ref.len();
1262            if e_ref.is_empty() {
1263                e.swap_remove_entry();
1264            }
1265            let connections_removed = old_size.saturating_sub(new_size);
1266            self.total_size = self.total_size.saturating_sub(connections_removed);
1267            connections_removed
1268        } else {
1269            0
1270        }
1271    }
1272}
1273
1274struct EndpointAccept<'a> {
1275    endpoint: usize,
1276    accept: Accept<'a>,
1277}
1278
1279impl Future for EndpointAccept<'_> {
1280    type Output = (Option<quinn::Incoming>, usize);
1281
1282    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll<Self::Output> {
1283        let i = self.endpoint;
1284        // Safety:
1285        // self is pinned and accept is a field so it can't get moved out. See safety docs of
1286        // map_unchecked_mut.
1287        unsafe { self.map_unchecked_mut(|this| &mut this.accept) }
1288            .poll(cx)
1289            .map(|r| (r, i))
1290    }
1291}
1292
1293#[cfg(test)]
1294pub mod test {
1295    use {
1296        super::*,
1297        crate::nonblocking::testing_utilities::{
1298            check_multiple_streams, get_client_config, make_client_endpoint, setup_quic_server,
1299            SpawnTestServerResult,
1300        },
1301        assert_matches::assert_matches,
1302        crossbeam_channel::{unbounded, Receiver},
1303        quinn::{ApplicationClose, ConnectionError},
1304        solana_keypair::Keypair,
1305        solana_net_utils::sockets::bind_to_localhost_unique,
1306        solana_signer::Signer,
1307        std::collections::HashMap,
1308        tokio::time::sleep,
1309    };
1310
1311    pub async fn check_timeout(receiver: Receiver<PacketBatch>, server_address: SocketAddr) {
1312        let conn1 = make_client_endpoint(&server_address, None).await;
1313        let total = 30;
1314        for i in 0..total {
1315            let mut s1 = conn1.open_uni().await.unwrap();
1316            s1.write_all(&[0u8]).await.unwrap();
1317            s1.finish().unwrap();
1318            info!("done {i}");
1319            sleep(Duration::from_millis(1000)).await;
1320        }
1321        let mut received = 0;
1322        loop {
1323            if let Ok(_x) = receiver.try_recv() {
1324                received += 1;
1325                info!("got {received}");
1326            } else {
1327                sleep(Duration::from_millis(500)).await;
1328            }
1329            if received >= total {
1330                break;
1331            }
1332        }
1333    }
1334
1335    pub async fn check_block_multiple_connections(server_address: SocketAddr) {
1336        let conn1 = make_client_endpoint(&server_address, None).await;
1337        let conn2 = make_client_endpoint(&server_address, None).await;
1338        let mut s1 = conn1.open_uni().await.unwrap();
1339        let s2 = conn2.open_uni().await;
1340        if let Ok(mut s2) = s2 {
1341            s1.write_all(&[0u8]).await.unwrap();
1342            s1.finish().unwrap();
1343            // Send enough data to create more than 1 chunks.
1344            // The first will try to open the connection (which should fail).
1345            // The following chunks will enable the detection of connection failure.
1346            let data = vec![1u8; PACKET_DATA_SIZE * 2];
1347            s2.write_all(&data)
1348                .await
1349                .expect_err("shouldn't be able to open 2 connections");
1350        } else {
1351            // It has been noticed if there is already connection open against the server, this open_uni can fail
1352            // with ApplicationClosed(ApplicationClose) error due to CONNECTION_CLOSE_CODE_TOO_MANY before writing to
1353            // the stream -- expect it.
1354            assert_matches!(s2, Err(quinn::ConnectionError::ApplicationClosed(_)));
1355        }
1356    }
1357
1358    pub async fn check_multiple_writes(
1359        receiver: Receiver<PacketBatch>,
1360        server_address: SocketAddr,
1361        client_keypair: Option<&Keypair>,
1362    ) {
1363        let conn1 = Arc::new(make_client_endpoint(&server_address, client_keypair).await);
1364
1365        // Send a full size packet with single byte writes.
1366        let num_bytes = PACKET_DATA_SIZE;
1367        let num_expected_packets = 1;
1368        let mut s1 = conn1.open_uni().await.unwrap();
1369        for _ in 0..num_bytes {
1370            s1.write_all(&[0u8]).await.unwrap();
1371        }
1372        s1.finish().unwrap();
1373
1374        check_received_packets(receiver, num_expected_packets, num_bytes).await;
1375    }
1376
1377    pub async fn check_multiple_packets(
1378        receiver: Receiver<PacketBatch>,
1379        server_address: SocketAddr,
1380        client_keypair: Option<&Keypair>,
1381        num_expected_packets: usize,
1382    ) {
1383        let conn1 = Arc::new(make_client_endpoint(&server_address, client_keypair).await);
1384
1385        // Send a full size packet with single byte writes.
1386        let num_bytes = PACKET_DATA_SIZE;
1387        let packet = vec![1u8; num_bytes];
1388        for _ in 0..num_expected_packets {
1389            let mut s1 = conn1.open_uni().await.unwrap();
1390            s1.write_all(&packet).await.unwrap();
1391            s1.finish().unwrap();
1392        }
1393
1394        check_received_packets(receiver, num_expected_packets, num_bytes).await;
1395    }
1396
1397    async fn check_received_packets(
1398        receiver: Receiver<PacketBatch>,
1399        num_expected_packets: usize,
1400        num_bytes: usize,
1401    ) {
1402        let mut all_packets = vec![];
1403        let now = Instant::now();
1404        let mut total_packets = 0;
1405        while now.elapsed().as_secs() < 5 {
1406            // We're running in an async environment, we (almost) never
1407            // want to block
1408            if let Ok(packets) = receiver.try_recv() {
1409                total_packets += packets.len();
1410                all_packets.push(packets)
1411            } else {
1412                sleep(Duration::from_secs(1)).await;
1413            }
1414            if total_packets >= num_expected_packets {
1415                break;
1416            }
1417        }
1418        for batch in all_packets {
1419            for p in batch.iter() {
1420                assert_eq!(p.meta().size, num_bytes);
1421            }
1422        }
1423        assert_eq!(total_packets, num_expected_packets);
1424    }
1425
1426    pub async fn check_unstaked_node_connect_failure(server_address: SocketAddr) {
1427        let conn1 = Arc::new(make_client_endpoint(&server_address, None).await);
1428
1429        // Send a full size packet with single byte writes.
1430        if let Ok(mut s1) = conn1.open_uni().await {
1431            for _ in 0..PACKET_DATA_SIZE {
1432                // Ignoring any errors here. s1.finish() will test the error condition
1433                s1.write_all(&[0u8]).await.unwrap_or_default();
1434            }
1435            s1.finish().unwrap_or_default();
1436            s1.stopped().await.unwrap_err();
1437        }
1438    }
1439
1440    #[tokio::test(flavor = "multi_thread")]
1441    async fn test_quic_server_exit_on_cancel() {
1442        let SpawnTestServerResult {
1443            join_handle,
1444            receiver,
1445            server_address: _,
1446            stats: _,
1447            cancel,
1448        } = setup_quic_server(
1449            None,
1450            QuicStreamerConfig::default_for_tests(),
1451            SwQosConfig::default(),
1452        );
1453        cancel.cancel();
1454        join_handle.await.unwrap();
1455        // test that it is stopped by cancel, not due to receiver
1456        // dropped.
1457        drop(receiver);
1458    }
1459
1460    #[tokio::test(flavor = "multi_thread")]
1461    async fn test_quic_timeout() {
1462        agave_logger::setup();
1463        let SpawnTestServerResult {
1464            join_handle,
1465            receiver,
1466            server_address,
1467            stats: _,
1468            cancel,
1469        } = setup_quic_server(
1470            None,
1471            QuicStreamerConfig::default_for_tests(),
1472            SwQosConfig::default(),
1473        );
1474
1475        check_timeout(receiver, server_address).await;
1476        cancel.cancel();
1477        join_handle.await.unwrap();
1478    }
1479
1480    #[tokio::test(flavor = "multi_thread")]
1481    async fn test_packet_batcher() {
1482        agave_logger::setup();
1483        let (pkt_batch_sender, pkt_batch_receiver) = unbounded();
1484        let (ptk_sender, pkt_receiver) = unbounded();
1485        let cancel = CancellationToken::new();
1486        let stats = Arc::new(StreamerStats::default());
1487
1488        let handle = task::spawn_blocking({
1489            let cancel = cancel.clone();
1490            move || {
1491                run_packet_batch_sender(pkt_batch_sender, pkt_receiver, stats, cancel);
1492            }
1493        });
1494
1495        let num_packets = 1000;
1496
1497        for _i in 0..num_packets {
1498            let mut meta = Meta::default();
1499            let bytes = Bytes::from("Hello world");
1500            let size = bytes.len();
1501            meta.size = size;
1502            let packet_accum = PacketAccumulator {
1503                meta,
1504                chunks: smallvec::smallvec![bytes],
1505                start_time: Instant::now(),
1506            };
1507            ptk_sender.send(packet_accum).unwrap();
1508        }
1509        let mut i = 0;
1510        let start = Instant::now();
1511        while i < num_packets && start.elapsed().as_secs() < 2 {
1512            if let Ok(batch) = pkt_batch_receiver.try_recv() {
1513                i += batch.len();
1514            } else {
1515                sleep(Duration::from_millis(1)).await;
1516            }
1517        }
1518        assert_eq!(i, num_packets);
1519        cancel.cancel();
1520        // Explicit drop to wake up packet_batch_sender
1521        drop(ptk_sender);
1522        handle.await.unwrap();
1523    }
1524
1525    #[tokio::test(flavor = "multi_thread")]
1526    async fn test_quic_stream_timeout() {
1527        agave_logger::setup();
1528        let SpawnTestServerResult {
1529            join_handle,
1530            receiver,
1531            server_address,
1532            stats,
1533            cancel,
1534        } = setup_quic_server(
1535            None,
1536            QuicStreamerConfig::default_for_tests(),
1537            SwQosConfig::default(),
1538        );
1539
1540        let conn1 = make_client_endpoint(&server_address, None).await;
1541        assert_eq!(stats.active_streams.load(Ordering::Relaxed), 0);
1542        assert_eq!(stats.total_stream_read_timeouts.load(Ordering::Relaxed), 0);
1543
1544        // Send one byte to start the stream
1545        let mut s1 = conn1.open_uni().await.unwrap();
1546        s1.write_all(&[0u8]).await.unwrap_or_default();
1547
1548        // Wait long enough for the stream to timeout in receiving chunks
1549        let sleep_time = DEFAULT_WAIT_FOR_CHUNK_TIMEOUT * 2;
1550        sleep(sleep_time).await;
1551
1552        // Test that the stream was created, but timed out in read
1553        assert_eq!(stats.active_streams.load(Ordering::Relaxed), 0);
1554        assert_ne!(stats.total_stream_read_timeouts.load(Ordering::Relaxed), 0);
1555
1556        // Test that more writes to the stream will fail (i.e. the stream is no longer writable
1557        // after the timeouts)
1558        assert!(s1.write_all(&[0u8]).await.is_err());
1559
1560        cancel.cancel();
1561        drop(receiver);
1562        join_handle.await.unwrap();
1563    }
1564
1565    #[tokio::test(flavor = "multi_thread")]
1566    async fn test_quic_server_block_multiple_connections() {
1567        agave_logger::setup();
1568        let SpawnTestServerResult {
1569            join_handle,
1570            receiver,
1571            server_address,
1572            stats: _,
1573            cancel,
1574        } = setup_quic_server(
1575            None,
1576            QuicStreamerConfig::default_for_tests(),
1577            SwQosConfig::default(),
1578        );
1579        check_block_multiple_connections(server_address).await;
1580        cancel.cancel();
1581        drop(receiver);
1582        join_handle.await.unwrap();
1583    }
1584
1585    #[tokio::test(flavor = "multi_thread")]
1586    async fn test_quic_server_multiple_connections_on_single_client_endpoint() {
1587        agave_logger::setup();
1588
1589        let SpawnTestServerResult {
1590            join_handle,
1591            receiver,
1592            server_address,
1593            stats,
1594            cancel,
1595        } = setup_quic_server(
1596            None,
1597            QuicStreamerConfig {
1598                max_connections_per_peer: 2,
1599                ..QuicStreamerConfig::default_for_tests()
1600            },
1601            SwQosConfig::default(),
1602        );
1603
1604        let client_socket = bind_to_localhost_unique().expect("should bind - client");
1605        let mut endpoint = quinn::Endpoint::new(
1606            EndpointConfig::default(),
1607            None,
1608            client_socket,
1609            Arc::new(TokioRuntime),
1610        )
1611        .unwrap();
1612        let default_keypair = Keypair::new();
1613        endpoint.set_default_client_config(get_client_config(&default_keypair));
1614        let conn1 = endpoint
1615            .connect(server_address, "localhost")
1616            .expect("Failed in connecting")
1617            .await
1618            .expect("Failed in waiting");
1619
1620        let conn2 = endpoint
1621            .connect(server_address, "localhost")
1622            .expect("Failed in connecting")
1623            .await
1624            .expect("Failed in waiting");
1625
1626        let mut s1 = conn1.open_uni().await.unwrap();
1627        s1.write_all(&[0u8]).await.unwrap();
1628        s1.finish().unwrap();
1629
1630        let mut s2 = conn2.open_uni().await.unwrap();
1631        conn1.close(
1632            CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(),
1633            CONNECTION_CLOSE_REASON_DROPPED_ENTRY,
1634        );
1635
1636        let start = Instant::now();
1637        while stats.connection_removed.load(Ordering::Relaxed) != 1 && start.elapsed().as_secs() < 1
1638        {
1639            debug!("First connection not removed yet");
1640            sleep(Duration::from_millis(10)).await;
1641        }
1642        assert!(start.elapsed().as_secs() < 1);
1643
1644        s2.write_all(&[0u8]).await.unwrap();
1645        s2.finish().unwrap();
1646
1647        conn2.close(
1648            CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(),
1649            CONNECTION_CLOSE_REASON_DROPPED_ENTRY,
1650        );
1651
1652        let start = Instant::now();
1653        while stats.connection_removed.load(Ordering::Relaxed) != 2 && start.elapsed().as_secs() < 1
1654        {
1655            debug!("Second connection not removed yet");
1656            sleep(Duration::from_millis(10)).await;
1657        }
1658        assert!(start.elapsed().as_secs() < 1);
1659
1660        cancel.cancel();
1661        // Explicitly drop receiver here so that it doesn't get implicitly
1662        // dropped earlier. This is necessary to ensure the server stays alive
1663        // and doesn't issue a cancel to kill the connection earlier than
1664        // expected.
1665        drop(receiver);
1666        join_handle.await.unwrap();
1667    }
1668
1669    #[tokio::test(flavor = "multi_thread")]
1670    async fn test_quic_server_multiple_writes() {
1671        agave_logger::setup();
1672        let SpawnTestServerResult {
1673            join_handle,
1674            receiver,
1675            server_address,
1676            stats: _,
1677            cancel,
1678        } = setup_quic_server(
1679            None,
1680            QuicStreamerConfig::default_for_tests(),
1681            SwQosConfig::default(),
1682        );
1683        check_multiple_writes(receiver, server_address, None).await;
1684        cancel.cancel();
1685        join_handle.await.unwrap();
1686    }
1687
1688    #[tokio::test(flavor = "multi_thread")]
1689    async fn test_quic_server_staked_connection_removal() {
1690        agave_logger::setup();
1691
1692        let client_keypair = Keypair::new();
1693        let stakes = HashMap::from([(client_keypair.pubkey(), 100_000)]);
1694        let staked_nodes = StakedNodes::new(
1695            Arc::new(stakes),
1696            HashMap::<Pubkey, u64>::default(), // overrides
1697        );
1698        let SpawnTestServerResult {
1699            join_handle,
1700            receiver,
1701            server_address,
1702            stats,
1703            cancel,
1704        } = setup_quic_server(
1705            Some(staked_nodes),
1706            QuicStreamerConfig::default_for_tests(),
1707            SwQosConfig::default(),
1708        );
1709        check_multiple_writes(receiver, server_address, Some(&client_keypair)).await;
1710        cancel.cancel();
1711        join_handle.await.unwrap();
1712
1713        assert_eq!(
1714            stats
1715                .connection_added_from_staked_peer
1716                .load(Ordering::Relaxed),
1717            1
1718        );
1719        assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1);
1720        assert_eq!(stats.connection_remove_failed.load(Ordering::Relaxed), 0);
1721    }
1722
1723    #[tokio::test(flavor = "multi_thread")]
1724    async fn test_quic_server_zero_staked_connection_removal() {
1725        // In this test, the client has a pubkey, but is not in stake table.
1726        agave_logger::setup();
1727
1728        let client_keypair = Keypair::new();
1729        let stakes = HashMap::from([(client_keypair.pubkey(), 0)]);
1730        let staked_nodes = StakedNodes::new(
1731            Arc::new(stakes),
1732            HashMap::<Pubkey, u64>::default(), // overrides
1733        );
1734        let SpawnTestServerResult {
1735            join_handle,
1736            receiver,
1737            server_address,
1738            stats,
1739            cancel,
1740        } = setup_quic_server(
1741            Some(staked_nodes),
1742            QuicStreamerConfig::default_for_tests(),
1743            SwQosConfig::default(),
1744        );
1745        check_multiple_writes(receiver, server_address, Some(&client_keypair)).await;
1746        cancel.cancel();
1747        join_handle.await.unwrap();
1748
1749        assert_eq!(
1750            stats
1751                .connection_added_from_staked_peer
1752                .load(Ordering::Relaxed),
1753            0
1754        );
1755        assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1);
1756        assert_eq!(stats.connection_remove_failed.load(Ordering::Relaxed), 0);
1757    }
1758
1759    #[tokio::test(flavor = "multi_thread")]
1760    async fn test_quic_server_unstaked_connection_removal() {
1761        agave_logger::setup();
1762        let SpawnTestServerResult {
1763            join_handle,
1764            receiver,
1765            server_address,
1766            stats,
1767            cancel,
1768        } = setup_quic_server(
1769            None,
1770            QuicStreamerConfig::default_for_tests(),
1771            SwQosConfig::default(),
1772        );
1773        check_multiple_writes(receiver, server_address, None).await;
1774        cancel.cancel();
1775        join_handle.await.unwrap();
1776
1777        assert_eq!(
1778            stats
1779                .connection_added_from_staked_peer
1780                .load(Ordering::Relaxed),
1781            0
1782        );
1783        assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1);
1784        assert_eq!(stats.connection_remove_failed.load(Ordering::Relaxed), 0);
1785    }
1786
1787    #[tokio::test(flavor = "multi_thread")]
1788    async fn test_quic_server_unstaked_node_connect_failure() {
1789        agave_logger::setup();
1790        let s = bind_to_localhost_unique().expect("should bind");
1791        let (sender, _) = unbounded();
1792        let keypair = Keypair::new();
1793        let server_address = s.local_addr().unwrap();
1794        let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
1795        let cancel = CancellationToken::new();
1796        let SpawnNonBlockingServerResult {
1797            endpoints: _,
1798            stats: _,
1799            thread: t,
1800            max_concurrent_connections: _,
1801        } = spawn_server_with_cancel(
1802            "quic_streamer_test",
1803            [s],
1804            &keypair,
1805            sender,
1806            staked_nodes,
1807            QuicStreamerConfig {
1808                max_unstaked_connections: 0, // Do not allow any connection from unstaked clients/nodes
1809                ..QuicStreamerConfig::default_for_tests()
1810            },
1811            SwQosConfig::default(),
1812            cancel.clone(),
1813        )
1814        .unwrap();
1815
1816        check_unstaked_node_connect_failure(server_address).await;
1817        cancel.cancel();
1818        t.await.unwrap();
1819    }
1820
1821    #[tokio::test(flavor = "multi_thread")]
1822    async fn test_quic_server_multiple_streams() {
1823        agave_logger::setup();
1824        let s = bind_to_localhost_unique().expect("should bind");
1825        let (sender, receiver) = unbounded();
1826        let keypair = Keypair::new();
1827        let server_address = s.local_addr().unwrap();
1828        let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
1829        let cancel = CancellationToken::new();
1830        let SpawnNonBlockingServerResult {
1831            endpoints: _,
1832            stats,
1833            thread: t,
1834            max_concurrent_connections: _,
1835        } = spawn_server_with_cancel(
1836            "quic_streamer_test",
1837            [s],
1838            &keypair,
1839            sender,
1840            staked_nodes,
1841            QuicStreamerConfig {
1842                max_connections_per_peer: 2,
1843                ..QuicStreamerConfig::default_for_tests()
1844            },
1845            SwQosConfig::default(),
1846            cancel.clone(),
1847        )
1848        .unwrap();
1849
1850        check_multiple_streams(receiver, server_address, None).await;
1851        assert_eq!(stats.active_streams.load(Ordering::Relaxed), 0);
1852        assert_eq!(stats.total_new_streams.load(Ordering::Relaxed), 20);
1853        assert_eq!(stats.total_connections.load(Ordering::Relaxed), 2);
1854        assert_eq!(stats.total_new_connections.load(Ordering::Relaxed), 2);
1855        cancel.cancel();
1856        t.await.unwrap();
1857
1858        assert_eq!(stats.total_connections.load(Ordering::Relaxed), 0);
1859        assert_eq!(stats.total_new_connections.load(Ordering::Relaxed), 2);
1860    }
1861
1862    #[test]
1863    fn test_prune_table_with_ip() {
1864        use std::net::Ipv4Addr;
1865        agave_logger::setup();
1866        let cancel = CancellationToken::new();
1867        let mut table = ConnectionTable::new(ConnectionTableType::Unstaked, cancel);
1868        let mut num_entries = 5;
1869        let max_connections_per_peer = 10;
1870        let sockets: Vec<_> = (0..num_entries)
1871            .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
1872            .collect();
1873        let stats = Arc::new(StreamerStats::default());
1874        for (i, socket) in sockets.iter().enumerate() {
1875            table
1876                .try_add_connection(
1877                    ConnectionTableKey::IP(socket.ip()),
1878                    socket.port(),
1879                    ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
1880                    None,
1881                    ConnectionPeerType::Unstaked,
1882                    Arc::new(AtomicU64::new(i as u64)),
1883                    max_connections_per_peer,
1884                )
1885                .unwrap();
1886        }
1887        num_entries += 1;
1888        table
1889            .try_add_connection(
1890                ConnectionTableKey::IP(sockets[0].ip()),
1891                sockets[0].port(),
1892                ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
1893                None,
1894                ConnectionPeerType::Unstaked,
1895                Arc::new(AtomicU64::new(5)),
1896                max_connections_per_peer,
1897            )
1898            .unwrap();
1899
1900        let new_size = 3;
1901        let pruned = table.prune_oldest(new_size);
1902        assert_eq!(pruned, num_entries as usize - new_size);
1903        for v in table.table.values() {
1904            for x in v {
1905                assert!((x.last_update() + 1) >= (num_entries as u64 - new_size as u64));
1906            }
1907        }
1908        assert_eq!(table.table.len(), new_size);
1909        assert_eq!(table.total_size, new_size);
1910        for socket in sockets.iter().take(num_entries as usize).skip(new_size - 1) {
1911            table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port(), 0);
1912        }
1913        assert_eq!(table.total_size, 0);
1914        assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
1915    }
1916
1917    #[test]
1918    fn test_prune_table_with_unique_pubkeys() {
1919        agave_logger::setup();
1920        let cancel = CancellationToken::new();
1921        let mut table = ConnectionTable::new(ConnectionTableType::Unstaked, cancel);
1922
1923        // We should be able to add more entries than max_connections_per_peer, since each entry is
1924        // from a different peer pubkey.
1925        let num_entries = 15;
1926        let max_connections_per_peer = 10;
1927        let stats = Arc::new(StreamerStats::default());
1928
1929        let pubkeys: Vec<_> = (0..num_entries).map(|_| Pubkey::new_unique()).collect();
1930        for (i, pubkey) in pubkeys.iter().enumerate() {
1931            table
1932                .try_add_connection(
1933                    ConnectionTableKey::Pubkey(*pubkey),
1934                    0,
1935                    ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
1936                    None,
1937                    ConnectionPeerType::Unstaked,
1938                    Arc::new(AtomicU64::new(i as u64)),
1939                    max_connections_per_peer,
1940                )
1941                .unwrap();
1942        }
1943
1944        let new_size = 3;
1945        let pruned = table.prune_oldest(new_size);
1946        assert_eq!(pruned, num_entries as usize - new_size);
1947        assert_eq!(table.table.len(), new_size);
1948        assert_eq!(table.total_size, new_size);
1949        for pubkey in pubkeys.iter().take(num_entries as usize).skip(new_size - 1) {
1950            table.remove_connection(ConnectionTableKey::Pubkey(*pubkey), 0, 0);
1951        }
1952        assert_eq!(table.total_size, 0);
1953        assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
1954    }
1955
1956    #[test]
1957    fn test_prune_table_with_non_unique_pubkeys() {
1958        agave_logger::setup();
1959        let cancel = CancellationToken::new();
1960        let mut table = ConnectionTable::new(ConnectionTableType::Unstaked, cancel);
1961
1962        let max_connections_per_peer = 10;
1963        let pubkey = Pubkey::new_unique();
1964        let stats: Arc<StreamerStats> = Arc::new(StreamerStats::default());
1965
1966        (0..max_connections_per_peer).for_each(|i| {
1967            table
1968                .try_add_connection(
1969                    ConnectionTableKey::Pubkey(pubkey),
1970                    0,
1971                    ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
1972                    None,
1973                    ConnectionPeerType::Unstaked,
1974                    Arc::new(AtomicU64::new(i as u64)),
1975                    max_connections_per_peer,
1976                )
1977                .unwrap();
1978        });
1979
1980        // We should NOT be able to add more entries than max_connections_per_peer, since we are
1981        // using the same peer pubkey.
1982        assert!(table
1983            .try_add_connection(
1984                ConnectionTableKey::Pubkey(pubkey),
1985                0,
1986                ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
1987                None,
1988                ConnectionPeerType::Unstaked,
1989                Arc::new(AtomicU64::new(10)),
1990                max_connections_per_peer,
1991            )
1992            .is_none());
1993
1994        // We should be able to add an entry from another peer pubkey
1995        let num_entries = max_connections_per_peer + 1;
1996        let pubkey2 = Pubkey::new_unique();
1997        assert!(table
1998            .try_add_connection(
1999                ConnectionTableKey::Pubkey(pubkey2),
2000                0,
2001                ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2002                None,
2003                ConnectionPeerType::Unstaked,
2004                Arc::new(AtomicU64::new(10)),
2005                max_connections_per_peer,
2006            )
2007            .is_some());
2008
2009        assert_eq!(table.total_size, num_entries);
2010
2011        let new_max_size = 3;
2012        let pruned = table.prune_oldest(new_max_size);
2013        assert!(pruned >= num_entries - new_max_size);
2014        assert!(table.table.len() <= new_max_size);
2015        assert!(table.total_size <= new_max_size);
2016
2017        table.remove_connection(ConnectionTableKey::Pubkey(pubkey2), 0, 0);
2018        assert_eq!(table.total_size, 0);
2019        assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2020    }
2021
2022    #[test]
2023    fn test_prune_table_random() {
2024        use std::net::Ipv4Addr;
2025        agave_logger::setup();
2026        let cancel = CancellationToken::new();
2027        let mut table = ConnectionTable::new(ConnectionTableType::Unstaked, cancel);
2028
2029        let num_entries = 5;
2030        let max_connections_per_peer = 10;
2031        let sockets: Vec<_> = (0..num_entries)
2032            .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
2033            .collect();
2034        let stats: Arc<StreamerStats> = Arc::new(StreamerStats::default());
2035
2036        for (i, socket) in sockets.iter().enumerate() {
2037            table
2038                .try_add_connection(
2039                    ConnectionTableKey::IP(socket.ip()),
2040                    socket.port(),
2041                    ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2042                    None,
2043                    ConnectionPeerType::Staked((i + 1) as u64),
2044                    Arc::new(AtomicU64::new(i as u64)),
2045                    max_connections_per_peer,
2046                )
2047                .unwrap();
2048        }
2049
2050        // Try pruninng with threshold stake less than all the entries in the table
2051        // It should fail to prune (i.e. return 0 number of pruned entries)
2052        let pruned = table.prune_random(/*sample_size:*/ 2, /*threshold_stake:*/ 0);
2053        assert_eq!(pruned, 0);
2054
2055        // Try pruninng with threshold stake higher than all the entries in the table
2056        // It should succeed to prune (i.e. return 1 number of pruned entries)
2057        let pruned = table.prune_random(
2058            2,                      // sample_size
2059            num_entries as u64 + 1, // threshold_stake
2060        );
2061        assert_eq!(pruned, 1);
2062        // We had 5 connections and pruned 1, we should have 4 left
2063        assert_eq!(stats.open_connections.load(Ordering::Relaxed), 4);
2064    }
2065
2066    #[test]
2067    fn test_remove_connections() {
2068        use std::net::Ipv4Addr;
2069        agave_logger::setup();
2070        let cancel = CancellationToken::new();
2071        let mut table = ConnectionTable::new(ConnectionTableType::Unstaked, cancel);
2072
2073        let num_ips = 5;
2074        let max_connections_per_peer = 10;
2075        let mut sockets: Vec<_> = (0..num_ips)
2076            .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
2077            .collect();
2078        let stats: Arc<StreamerStats> = Arc::new(StreamerStats::default());
2079
2080        for (i, socket) in sockets.iter().enumerate() {
2081            table
2082                .try_add_connection(
2083                    ConnectionTableKey::IP(socket.ip()),
2084                    socket.port(),
2085                    ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2086                    None,
2087                    ConnectionPeerType::Unstaked,
2088                    Arc::new(AtomicU64::new((i * 2) as u64)),
2089                    max_connections_per_peer,
2090                )
2091                .unwrap();
2092
2093            table
2094                .try_add_connection(
2095                    ConnectionTableKey::IP(socket.ip()),
2096                    socket.port(),
2097                    ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2098                    None,
2099                    ConnectionPeerType::Unstaked,
2100                    Arc::new(AtomicU64::new((i * 2 + 1) as u64)),
2101                    max_connections_per_peer,
2102                )
2103                .unwrap();
2104        }
2105
2106        let single_connection_addr =
2107            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(num_ips, 0, 0, 0)), 0);
2108        table
2109            .try_add_connection(
2110                ConnectionTableKey::IP(single_connection_addr.ip()),
2111                single_connection_addr.port(),
2112                ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2113                None,
2114                ConnectionPeerType::Unstaked,
2115                Arc::new(AtomicU64::new((num_ips * 2) as u64)),
2116                max_connections_per_peer,
2117            )
2118            .unwrap();
2119
2120        let zero_connection_addr =
2121            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(num_ips + 1, 0, 0, 0)), 0);
2122
2123        sockets.push(single_connection_addr);
2124        sockets.push(zero_connection_addr);
2125
2126        for socket in sockets.iter() {
2127            table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port(), 0);
2128        }
2129        assert_eq!(table.total_size, 0);
2130        assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2131    }
2132
2133    #[tokio::test(flavor = "multi_thread")]
2134    async fn test_throttling_check_no_packet_drop() {
2135        agave_logger::setup_with_default_filter();
2136
2137        let SpawnTestServerResult {
2138            join_handle,
2139            receiver,
2140            server_address,
2141            stats,
2142            cancel,
2143        } = setup_quic_server(
2144            None,
2145            QuicStreamerConfig::default_for_tests(),
2146            SwQosConfig::default(),
2147        );
2148
2149        let client_connection = make_client_endpoint(&server_address, None).await;
2150
2151        // unstaked connection can handle up to 100tps, so we should send in ~1s.
2152        let expected_num_txs = 100;
2153        let start_time = tokio::time::Instant::now();
2154        for i in 0..expected_num_txs {
2155            let mut send_stream = client_connection.open_uni().await.unwrap();
2156            let data = format!("{i}").into_bytes();
2157            send_stream.write_all(&data).await.unwrap();
2158            send_stream.finish().unwrap();
2159        }
2160        let elapsed_sending: f64 = start_time.elapsed().as_secs_f64();
2161        info!("Elapsed sending: {elapsed_sending}");
2162
2163        // check that delivered all of them
2164        let start_time = tokio::time::Instant::now();
2165        let mut num_txs_received = 0;
2166        while num_txs_received < expected_num_txs && start_time.elapsed() < Duration::from_secs(2) {
2167            if let Ok(packets) = receiver.try_recv() {
2168                num_txs_received += packets.len();
2169            } else {
2170                sleep(Duration::from_millis(100)).await;
2171            }
2172        }
2173        assert_eq!(expected_num_txs, num_txs_received);
2174
2175        cancel.cancel();
2176        join_handle.await.unwrap();
2177
2178        assert_eq!(
2179            stats.total_new_streams.load(Ordering::Relaxed),
2180            expected_num_txs
2181        );
2182        assert!(stats.throttled_unstaked_streams.load(Ordering::Relaxed) > 0);
2183    }
2184
2185    #[test]
2186    fn test_client_connection_tracker() {
2187        let stats = Arc::new(StreamerStats::default());
2188        let tracker_1 = ClientConnectionTracker::new(stats.clone(), 1);
2189        assert!(tracker_1.is_ok());
2190        assert!(ClientConnectionTracker::new(stats.clone(), 1).is_err());
2191        assert_eq!(stats.open_connections.load(Ordering::Relaxed), 1);
2192        // dropping the connection, concurrent connections should become 0
2193        drop(tracker_1);
2194        assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2195    }
2196
2197    #[tokio::test(flavor = "multi_thread")]
2198    async fn test_client_connection_close_invalid_stream() {
2199        let SpawnTestServerResult {
2200            join_handle,
2201            server_address,
2202            stats,
2203            cancel,
2204            ..
2205        } = setup_quic_server(
2206            None,
2207            QuicStreamerConfig::default_for_tests(),
2208            SwQosConfig::default(),
2209        );
2210
2211        let client_connection = make_client_endpoint(&server_address, None).await;
2212
2213        let mut send_stream = client_connection.open_uni().await.unwrap();
2214        send_stream
2215            .write_all(&[42; PACKET_DATA_SIZE + 1])
2216            .await
2217            .unwrap();
2218        match client_connection.closed().await {
2219            ConnectionError::ApplicationClosed(ApplicationClose { error_code, reason }) => {
2220                assert_eq!(error_code, CONNECTION_CLOSE_CODE_INVALID_STREAM.into());
2221                assert_eq!(reason, CONNECTION_CLOSE_REASON_INVALID_STREAM);
2222            }
2223            _ => panic!("unexpected close"),
2224        }
2225        assert_eq!(stats.invalid_stream_size.load(Ordering::Relaxed), 1);
2226        cancel.cancel();
2227        join_handle.await.unwrap();
2228    }
2229}