solana-net-utils 3.1.13

Solana Network Utilities
Documentation
use {
    crate::{bind_to_unspecified, HEADER_LENGTH, IP_ECHO_SERVER_RESPONSE_LENGTH},
    log::*,
    serde::{Deserialize, Serialize},
    solana_serde::default_on_eof,
    std::{
        collections::HashSet,
        io,
        net::{IpAddr, SocketAddr},
        num::NonZeroUsize,
        sync::{Arc, Mutex},
        time::Duration,
    },
    tokio::{
        io::{AsyncReadExt, AsyncWriteExt},
        net::{TcpListener, TcpStream},
        runtime::{self, Runtime},
        time::{timeout_at, Instant},
    },
};

pub type IpEchoServer = Runtime;

// Enforce a minimum of two threads:
// - One thread to monitor the TcpListener and spawn async tasks
// - One thread to service the spawned tasks
pub const MINIMUM_IP_ECHO_SERVER_THREADS: NonZeroUsize = NonZeroUsize::new(2).unwrap();
// IP echo requests require little computation and come in fairly infrequently,
// so keep the number of server workers small to avoid overhead
pub const DEFAULT_IP_ECHO_SERVER_THREADS: NonZeroUsize = MINIMUM_IP_ECHO_SERVER_THREADS;
pub const MAX_PORT_COUNT_PER_MESSAGE: usize = 4;

const IO_TIMEOUT: Duration = Duration::from_secs(5);
// Non-loopback peers are limited to one active connection each; loopback is exempt.
const MAX_CONCURRENT_CONNECTIONS: usize = 2048;

struct ConnectionCleanup {
    active_ips: Arc<Mutex<HashSet<IpAddr>>>,
    ip: IpAddr,
}

impl ConnectionCleanup {
    fn new(active_ips: Arc<Mutex<HashSet<IpAddr>>>, ip: IpAddr) -> Self {
        Self { active_ips, ip }
    }
}

impl Drop for ConnectionCleanup {
    fn drop(&mut self) {
        let mut active_ips = self.active_ips.lock().expect("active_ips lock poisoned");
        release_active_ip(&mut active_ips, self.ip);
    }
}

#[derive(Serialize, Deserialize, Default, Debug)]
pub(crate) struct IpEchoServerMessage {
    tcp_ports: [u16; MAX_PORT_COUNT_PER_MESSAGE], // Fixed size list of ports to avoid vec serde
    udp_ports: [u16; MAX_PORT_COUNT_PER_MESSAGE], // Fixed size list of ports to avoid vec serde
}

#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct IpEchoServerResponse {
    // Public IP address of request echoed back to the node.
    pub(crate) address: IpAddr,
    // Cluster shred-version of the node running the server.
    #[serde(deserialize_with = "default_on_eof")]
    pub(crate) shred_version: Option<u16>,
}

impl IpEchoServerMessage {
    pub fn new(tcp_ports: &[u16], udp_ports: &[u16]) -> Self {
        let mut msg = Self::default();
        assert!(tcp_ports.len() <= msg.tcp_ports.len());
        assert!(udp_ports.len() <= msg.udp_ports.len());

        msg.tcp_ports[..tcp_ports.len()].copy_from_slice(tcp_ports);
        msg.udp_ports[..udp_ports.len()].copy_from_slice(udp_ports);
        msg
    }
}

pub(crate) fn ip_echo_server_request_length() -> usize {
    const REQUEST_TERMINUS_LENGTH: usize = 1;
    (HEADER_LENGTH + REQUEST_TERMINUS_LENGTH)
        .wrapping_add(bincode::serialized_size(&IpEchoServerMessage::default()).unwrap() as usize)
}

async fn process_connection(
    mut socket: TcpStream,
    peer_addr: SocketAddr,
    shred_version: Option<u16>,
) -> io::Result<()> {
    info!("connection from {peer_addr:?}");
    let deadline = Instant::now()
        .checked_add(IO_TIMEOUT)
        .ok_or_else(|| io::Error::other("failed to compute request deadline"))?;

    let mut data = vec![0u8; ip_echo_server_request_length()];

    let mut writer = {
        let (mut reader, writer) = socket.split();
        let _ = timeout_at(deadline, reader.read_exact(&mut data)).await??;
        writer
    };

    let request_header: String = data[0..HEADER_LENGTH].iter().map(|b| *b as char).collect();
    if request_header != "\0\0\0\0" {
        // Explicitly check for HTTP GET/POST requests to more gracefully handle
        // the case where a user accidentally tried to use a gossip entrypoint in
        // place of a JSON RPC URL:
        if request_header == "GET " || request_header == "POST" {
            // Send HTTP error response
            timeout_at(
                deadline,
                writer.write_all(b"HTTP/1.1 400 Bad Request\nContent-length: 0\n\n"),
            )
            .await??;
            return Ok(());
        }
        return Err(io::Error::other(format!(
            "Bad request header: {request_header}"
        )));
    }

    let msg =
        bincode::deserialize::<IpEchoServerMessage>(&data[HEADER_LENGTH..]).map_err(|err| {
            io::Error::other(format!(
                "Failed to deserialize IpEchoServerMessage: {err:?}"
            ))
        })?;

    trace!("request: {msg:?}");

    // Fire a datagram at each non-zero UDP port
    match bind_to_unspecified() {
        Ok(udp_socket) => {
            for udp_port in &msg.udp_ports {
                if *udp_port != 0 {
                    let result =
                        udp_socket.send_to(&[0], SocketAddr::from((peer_addr.ip(), *udp_port)));
                    match result {
                        Ok(_) => debug!("Successful send_to udp/{udp_port}"),
                        Err(err) => info!("Failed to send_to udp/{udp_port}: {err}"),
                    }
                }
            }
        }
        Err(err) => {
            warn!("Failed to bind local udp socket: {err}");
        }
    }

    // Try to connect to each non-zero TCP port
    for tcp_port in &msg.tcp_ports {
        if *tcp_port != 0 {
            debug!("Connecting to tcp/{tcp_port}");

            let mut tcp_stream = timeout_at(
                deadline,
                TcpStream::connect(&SocketAddr::new(peer_addr.ip(), *tcp_port)),
            )
            .await??;

            debug!("Connection established to tcp/{}", *tcp_port);
            tcp_stream.shutdown().await?;
        }
    }
    let response = IpEchoServerResponse {
        address: peer_addr.ip(),
        shred_version,
    };
    // "\0\0\0\0" header is added to ensure a valid response will never
    // conflict with the first four bytes of a valid HTTP response.
    let mut bytes = vec![0u8; IP_ECHO_SERVER_RESPONSE_LENGTH];
    bincode::serialize_into(&mut bytes[HEADER_LENGTH..], &response).unwrap();
    trace!("response: {bytes:?}");
    timeout_at(deadline, writer.write_all(&bytes)).await?
}

fn release_active_ip(active_ips: &mut HashSet<IpAddr>, ip: IpAddr) {
    let removed = active_ips.remove(&ip);
    debug_assert!(removed, "cleanup for unknown IP {ip}");
}

async fn run_echo_server(tcp_listener: std::net::TcpListener, shred_version: Option<u16>) {
    info!("bound to {:?}", tcp_listener.local_addr().unwrap());
    let tcp_listener =
        TcpListener::from_std(tcp_listener).expect("Failed to convert std::TcpListener");
    let active_ips = Arc::new(Mutex::new(HashSet::new()));

    loop {
        let connection = tcp_listener.accept().await;
        match connection {
            Ok((socket, peer_addr)) => {
                let tracked_ip = (!peer_addr.ip().is_loopback()).then_some(peer_addr.ip());
                if let Some(ip) = tracked_ip {
                    let mut active_ip_set = active_ips
                        .lock()
                        .expect("active_ips lock poisoned while admitting");
                    if active_ip_set.len() >= MAX_CONCURRENT_CONNECTIONS {
                        debug!(
                            "dropping connection from {peer_addr:?}: max concurrent connections \
                             ({MAX_CONCURRENT_CONNECTIONS}) reached",
                        );
                        continue;
                    }
                    if !active_ip_set.insert(ip) {
                        debug!(
                            "dropping connection from {peer_addr:?}: max concurrent connections \
                             per IP (1) reached"
                        );
                        continue;
                    }
                }
                let cleanup =
                    tracked_ip.map(|ip| ConnectionCleanup::new(Arc::clone(&active_ips), ip));
                runtime::Handle::current().spawn(async move {
                    let cleanup = cleanup;
                    if let Err(err) = process_connection(socket, peer_addr, shred_version).await {
                        info!("session failed: {err:?}");
                    }
                    drop(cleanup);
                });
            }
            Err(err) => warn!("listener accept failed: {err:?}"),
        }
    }
}

/// Starts a simple TCP server that echos the IP address of any peer that connects
/// Used by functions like |get_public_ip_addr| and |get_cluster_shred_version|
pub fn ip_echo_server(
    tcp_listener: std::net::TcpListener,
    num_server_threads: NonZeroUsize,
    // Cluster shred-version of the node running the server.
    shred_version: Option<u16>,
) -> IpEchoServer {
    tcp_listener.set_nonblocking(true).unwrap();

    let runtime = tokio::runtime::Builder::new_multi_thread()
        .thread_name("solIpEchoSrvrRt")
        .worker_threads(num_server_threads.get())
        .enable_all()
        .build()
        .expect("new tokio runtime");
    runtime.spawn(run_echo_server(tcp_listener, shred_version));
    runtime
}