solana-client 1.10.20

Solana Client
Documentation
use {
    crate::{
        quic_client::QuicTpuConnection,
        tpu_connection::{ClientStats, TpuConnection},
        udp_client::UdpTpuConnection,
    },
    indexmap::map::IndexMap,
    lazy_static::lazy_static,
    rand::{thread_rng, Rng},
    solana_measure::measure::Measure,
    solana_sdk::{
        timing::AtomicInterval, transaction::VersionedTransaction, transport::TransportError,
    },
    std::{
        net::SocketAddr,
        sync::{
            atomic::{AtomicU64, Ordering},
            Arc, RwLock,
        },
    },
};

// Should be non-zero
static MAX_CONNECTIONS: usize = 1024;

#[derive(Clone)]
pub enum Connection {
    Udp(Arc<UdpTpuConnection>),
    Quic(Arc<QuicTpuConnection>),
}

#[derive(Default)]
pub struct ConnectionCacheStats {
    cache_hits: AtomicU64,
    cache_misses: AtomicU64,
    cache_evictions: AtomicU64,
    eviction_time_ms: AtomicU64,
    sent_packets: AtomicU64,
    total_batches: AtomicU64,
    batch_success: AtomicU64,
    batch_failure: AtomicU64,
    get_connection_ms: AtomicU64,
    get_connection_lock_ms: AtomicU64,
    get_connection_hit_ms: AtomicU64,
    get_connection_miss_ms: AtomicU64,

    // Need to track these separately per-connection
    // because we need to track the base stat value from quinn
    pub total_client_stats: ClientStats,
}

const CONNECTION_STAT_SUBMISSION_INTERVAL: u64 = 2000;

impl ConnectionCacheStats {
    fn add_client_stats(&self, client_stats: &ClientStats, num_packets: usize, is_success: bool) {
        self.total_client_stats.total_connections.fetch_add(
            client_stats.total_connections.load(Ordering::Relaxed),
            Ordering::Relaxed,
        );
        self.total_client_stats.connection_reuse.fetch_add(
            client_stats.connection_reuse.load(Ordering::Relaxed),
            Ordering::Relaxed,
        );
        self.total_client_stats.connection_errors.fetch_add(
            client_stats.connection_errors.load(Ordering::Relaxed),
            Ordering::Relaxed,
        );
        self.total_client_stats.zero_rtt_accepts.fetch_add(
            client_stats.zero_rtt_accepts.load(Ordering::Relaxed),
            Ordering::Relaxed,
        );
        self.total_client_stats.zero_rtt_rejects.fetch_add(
            client_stats.zero_rtt_rejects.load(Ordering::Relaxed),
            Ordering::Relaxed,
        );
        self.total_client_stats.make_connection_ms.fetch_add(
            client_stats.make_connection_ms.load(Ordering::Relaxed),
            Ordering::Relaxed,
        );
        self.sent_packets
            .fetch_add(num_packets as u64, Ordering::Relaxed);
        self.total_batches.fetch_add(1, Ordering::Relaxed);
        if is_success {
            self.batch_success.fetch_add(1, Ordering::Relaxed);
        } else {
            self.batch_failure.fetch_add(1, Ordering::Relaxed);
        }
    }

    fn report(&self) {
        datapoint_info!(
            "quic-client-connection-stats",
            (
                "cache_hits",
                self.cache_hits.swap(0, Ordering::Relaxed),
                i64
            ),
            (
                "cache_misses",
                self.cache_misses.swap(0, Ordering::Relaxed),
                i64
            ),
            (
                "cache_evictions",
                self.cache_evictions.swap(0, Ordering::Relaxed),
                i64
            ),
            (
                "eviction_time_ms",
                self.eviction_time_ms.swap(0, Ordering::Relaxed),
                i64
            ),
            (
                "get_connection_ms",
                self.get_connection_ms.swap(0, Ordering::Relaxed),
                i64
            ),
            (
                "get_connection_lock_ms",
                self.get_connection_lock_ms.swap(0, Ordering::Relaxed),
                i64
            ),
            (
                "get_connection_hit_ms",
                self.get_connection_hit_ms.swap(0, Ordering::Relaxed),
                i64
            ),
            (
                "get_connection_miss_ms",
                self.get_connection_miss_ms.swap(0, Ordering::Relaxed),
                i64
            ),
            (
                "make_connection_ms",
                self.total_client_stats
                    .make_connection_ms
                    .swap(0, Ordering::Relaxed),
                i64
            ),
            (
                "total_connections",
                self.total_client_stats
                    .total_connections
                    .swap(0, Ordering::Relaxed),
                i64
            ),
            (
                "connection_reuse",
                self.total_client_stats
                    .connection_reuse
                    .swap(0, Ordering::Relaxed),
                i64
            ),
            (
                "connection_errors",
                self.total_client_stats
                    .connection_errors
                    .swap(0, Ordering::Relaxed),
                i64
            ),
            (
                "zero_rtt_accepts",
                self.total_client_stats
                    .zero_rtt_accepts
                    .swap(0, Ordering::Relaxed),
                i64
            ),
            (
                "zero_rtt_rejects",
                self.total_client_stats
                    .zero_rtt_rejects
                    .swap(0, Ordering::Relaxed),
                i64
            ),
            (
                "congestion_events",
                self.total_client_stats.congestion_events.load_and_reset(),
                i64
            ),
            (
                "tx_streams_blocked_uni",
                self.total_client_stats
                    .tx_streams_blocked_uni
                    .load_and_reset(),
                i64
            ),
            (
                "tx_data_blocked",
                self.total_client_stats.tx_data_blocked.load_and_reset(),
                i64
            ),
            (
                "tx_acks",
                self.total_client_stats.tx_acks.load_and_reset(),
                i64
            ),
            (
                "num_packets",
                self.sent_packets.swap(0, Ordering::Relaxed),
                i64
            ),
            (
                "total_batches",
                self.total_batches.swap(0, Ordering::Relaxed),
                i64
            ),
            (
                "batch_failure",
                self.batch_failure.swap(0, Ordering::Relaxed),
                i64
            ),
        );
    }
}

struct ConnectionMap {
    map: IndexMap<SocketAddr, Connection>,
    stats: Arc<ConnectionCacheStats>,
    last_stats: AtomicInterval,
    use_quic: bool,
}

impl ConnectionMap {
    pub fn new() -> Self {
        Self {
            map: IndexMap::with_capacity(MAX_CONNECTIONS),
            stats: Arc::new(ConnectionCacheStats::default()),
            last_stats: AtomicInterval::default(),
            use_quic: false,
        }
    }

    pub fn set_use_quic(&mut self, use_quic: bool) {
        self.use_quic = use_quic;
    }
}

lazy_static! {
    static ref CONNECTION_MAP: RwLock<ConnectionMap> = RwLock::new(ConnectionMap::new());
}

pub fn set_use_quic(use_quic: bool) {
    let mut map = (*CONNECTION_MAP).write().unwrap();
    map.set_use_quic(use_quic);
}

struct GetConnectionResult {
    connection: Connection,
    cache_hit: bool,
    report_stats: bool,
    map_timing_ms: u64,
    lock_timing_ms: u64,
    connection_cache_stats: Arc<ConnectionCacheStats>,
    num_evictions: u64,
    eviction_timing_ms: u64,
}

fn get_or_add_connection(addr: &SocketAddr) -> GetConnectionResult {
    let mut get_connection_map_lock_measure = Measure::start("get_connection_map_lock_measure");
    let map = (*CONNECTION_MAP).read().unwrap();
    get_connection_map_lock_measure.stop();

    let mut lock_timing_ms = get_connection_map_lock_measure.as_ms();

    let report_stats = map
        .last_stats
        .should_update(CONNECTION_STAT_SUBMISSION_INTERVAL);

    let mut get_connection_map_measure = Measure::start("get_connection_hit_measure");

    let (connection, cache_hit, connection_cache_stats, num_evictions, eviction_timing_ms) =
        match map.map.get(addr) {
            Some(connection) => (connection.clone(), true, map.stats.clone(), 0, 0),
            None => {
                // Upgrade to write access by dropping read lock and acquire write lock
                drop(map);
                let mut get_connection_map_lock_measure =
                    Measure::start("get_connection_map_lock_measure");
                let mut map = (*CONNECTION_MAP).write().unwrap();
                get_connection_map_lock_measure.stop();

                lock_timing_ms =
                    lock_timing_ms.saturating_add(get_connection_map_lock_measure.as_ms());

                // Read again, as it is possible that between read lock dropped and the write lock acquired
                // another thread could have setup the connection.
                match map.map.get(addr) {
                    Some(connection) => (connection.clone(), true, map.stats.clone(), 0, 0),
                    None => {
                        let connection = if map.use_quic {
                            Connection::Quic(Arc::new(QuicTpuConnection::new(
                                *addr,
                                map.stats.clone(),
                            )))
                        } else {
                            Connection::Udp(Arc::new(UdpTpuConnection::new(
                                *addr,
                                map.stats.clone(),
                            )))
                        };

                        // evict a connection if the cache is reaching upper bounds
                        let mut num_evictions = 0;
                        let mut get_connection_cache_eviction_measure =
                            Measure::start("get_connection_cache_eviction_measure");
                        while map.map.len() >= MAX_CONNECTIONS {
                            let mut rng = thread_rng();
                            let n = rng.gen_range(0, MAX_CONNECTIONS);
                            map.map.swap_remove_index(n);
                            num_evictions += 1;
                        }
                        get_connection_cache_eviction_measure.stop();

                        map.map.insert(*addr, connection.clone());
                        (
                            connection,
                            false,
                            map.stats.clone(),
                            num_evictions,
                            get_connection_cache_eviction_measure.as_ms(),
                        )
                    }
                }
            }
        };

    get_connection_map_measure.stop();

    GetConnectionResult {
        connection,
        cache_hit,
        report_stats,
        map_timing_ms: get_connection_map_measure.as_ms(),
        lock_timing_ms,
        connection_cache_stats,
        num_evictions,
        eviction_timing_ms,
    }
}

// TODO: see https://github.com/solana-labs/solana/issues/23661
// remove lazy_static and optimize and refactor this
fn get_connection(addr: &SocketAddr) -> (Connection, Arc<ConnectionCacheStats>) {
    let mut get_connection_measure = Measure::start("get_connection_measure");
    let GetConnectionResult {
        connection,
        cache_hit,
        report_stats,
        map_timing_ms,
        lock_timing_ms,
        connection_cache_stats,
        num_evictions,
        eviction_timing_ms,
    } = get_or_add_connection(addr);

    if report_stats {
        connection_cache_stats.report();
    }

    if cache_hit {
        connection_cache_stats
            .cache_hits
            .fetch_add(1, Ordering::Relaxed);
        connection_cache_stats
            .get_connection_hit_ms
            .fetch_add(map_timing_ms, Ordering::Relaxed);
    } else {
        connection_cache_stats
            .cache_misses
            .fetch_add(1, Ordering::Relaxed);
        connection_cache_stats
            .get_connection_miss_ms
            .fetch_add(map_timing_ms, Ordering::Relaxed);
        connection_cache_stats
            .cache_evictions
            .fetch_add(num_evictions, Ordering::Relaxed);
        connection_cache_stats
            .eviction_time_ms
            .fetch_add(eviction_timing_ms, Ordering::Relaxed);
    }

    get_connection_measure.stop();
    connection_cache_stats
        .get_connection_lock_ms
        .fetch_add(lock_timing_ms, Ordering::Relaxed);
    connection_cache_stats
        .get_connection_ms
        .fetch_add(get_connection_measure.as_ms(), Ordering::Relaxed);
    (connection, connection_cache_stats)
}

// TODO: see https://github.com/solana-labs/solana/issues/23851
// use enum_dispatch and get rid of this tedious code.
// The main blocker to using enum_dispatch right now is that
// the it doesn't work with static methods like TpuConnection::new
// which is used by thin_client. This will be eliminated soon
// once thin_client is moved to using this connection cache.
// Once that is done, we will migrate to using enum_dispatch
// This will be done in a followup to
// https://github.com/solana-labs/solana/pull/23817
pub fn send_wire_transaction_batch(
    packets: &[&[u8]],
    addr: &SocketAddr,
) -> Result<(), TransportError> {
    let (conn, stats) = get_connection(addr);
    let client_stats = ClientStats::default();
    let r = match conn {
        Connection::Udp(conn) => conn.send_wire_transaction_batch(packets, &client_stats),
        Connection::Quic(conn) => conn.send_wire_transaction_batch(packets, &client_stats),
    };
    stats.add_client_stats(&client_stats, packets.len(), r.is_ok());
    r
}

pub fn send_wire_transaction_async(
    packets: Vec<u8>,
    addr: &SocketAddr,
) -> Result<(), TransportError> {
    let (conn, stats) = get_connection(addr);
    let client_stats = Arc::new(ClientStats::default());
    let r = match conn {
        Connection::Udp(conn) => conn.send_wire_transaction_async(packets, client_stats.clone()),
        Connection::Quic(conn) => conn.send_wire_transaction_async(packets, client_stats.clone()),
    };
    stats.add_client_stats(&client_stats, 1, r.is_ok());
    r
}

pub fn send_wire_transaction_batch_async(
    packets: Vec<Vec<u8>>,
    addr: &SocketAddr,
) -> Result<(), TransportError> {
    let (conn, stats) = get_connection(addr);
    let client_stats = Arc::new(ClientStats::default());
    let len = packets.len();
    let r = match conn {
        Connection::Udp(conn) => {
            conn.send_wire_transaction_batch_async(packets, client_stats.clone())
        }
        Connection::Quic(conn) => {
            conn.send_wire_transaction_batch_async(packets, client_stats.clone())
        }
    };
    stats.add_client_stats(&client_stats, len, r.is_ok());
    r
}

pub fn send_wire_transaction(
    wire_transaction: &[u8],
    addr: &SocketAddr,
) -> Result<(), TransportError> {
    send_wire_transaction_batch(&[wire_transaction], addr)
}

pub fn serialize_and_send_transaction(
    transaction: &VersionedTransaction,
    addr: &SocketAddr,
) -> Result<(), TransportError> {
    let (conn, stats) = get_connection(addr);
    let client_stats = ClientStats::default();
    let r = match conn {
        Connection::Udp(conn) => conn.serialize_and_send_transaction(transaction, &client_stats),
        Connection::Quic(conn) => conn.serialize_and_send_transaction(transaction, &client_stats),
    };
    stats.add_client_stats(&client_stats, 1, r.is_ok());
    r
}

pub fn par_serialize_and_send_transaction_batch(
    transactions: &[VersionedTransaction],
    addr: &SocketAddr,
) -> Result<(), TransportError> {
    let (conn, stats) = get_connection(addr);
    let client_stats = ClientStats::default();
    let r = match conn {
        Connection::Udp(conn) => {
            conn.par_serialize_and_send_transaction_batch(transactions, &client_stats)
        }
        Connection::Quic(conn) => {
            conn.par_serialize_and_send_transaction_batch(transactions, &client_stats)
        }
    };
    stats.add_client_stats(&client_stats, transactions.len(), r.is_ok());
    r
}

#[cfg(test)]
mod tests {
    use {
        crate::{
            connection_cache::{get_connection, Connection, CONNECTION_MAP, MAX_CONNECTIONS},
            tpu_connection::TpuConnection,
        },
        rand::{Rng, SeedableRng},
        rand_chacha::ChaChaRng,
        std::net::{IpAddr, SocketAddr},
    };

    fn get_addr(rng: &mut ChaChaRng) -> SocketAddr {
        let a = rng.gen_range(1, 255);
        let b = rng.gen_range(1, 255);
        let c = rng.gen_range(1, 255);
        let d = rng.gen_range(1, 255);

        let addr_str = format!("{}.{}.{}.{}:80", a, b, c, d);

        addr_str.parse().expect("Invalid address")
    }

    fn ip(conn: Connection) -> IpAddr {
        match conn {
            Connection::Udp(conn) => conn.tpu_addr().ip(),
            Connection::Quic(conn) => conn.tpu_addr().ip(),
        }
    }

    #[test]
    fn test_connection_cache() {
        solana_logger::setup();
        // Allow the test to run deterministically
        // with the same pseudorandom sequence between runs
        // and on different platforms - the cryptographic security
        // property isn't important here but ChaChaRng provides a way
        // to get the same pseudorandom sequence on different platforms
        let mut rng = ChaChaRng::seed_from_u64(42);

        // Generate a bunch of random addresses and create TPUConnections to them
        // Since TPUConnection::new is infallible, it should't matter whether or not
        // we can actually connect to those addresses - TPUConnection implementations should either
        // be lazy and not connect until first use or handle connection errors somehow
        // (without crashing, as would be required in a real practical validator)
        let addrs = (0..MAX_CONNECTIONS)
            .into_iter()
            .map(|_| {
                let addr = get_addr(&mut rng);
                get_connection(&addr);
                addr
            })
            .collect::<Vec<_>>();
        {
            let map = (*CONNECTION_MAP).read().unwrap();
            assert!(map.map.len() == MAX_CONNECTIONS);
            addrs.iter().for_each(|a| {
                let conn = map.map.get(a).expect("Address not found");
                assert!(a.ip() == ip(conn.clone()));
            });
        }

        let addr = get_addr(&mut rng);
        get_connection(&addr);

        let map = (*CONNECTION_MAP).read().unwrap();
        assert!(map.map.len() == MAX_CONNECTIONS);
        let _conn = map.map.get(&addr).expect("Address not found");
    }
}