solana-leader 0.4.0

solana leader library
Documentation
use crate::url::ParsedUrl;
use rustls::crypto::aws_lc_rs;
use rustls::pki_types::ServerName;
use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned};
use std::io::{self, Read, Write};
use std::net::{IpAddr, TcpStream, ToSocketAddrs};
use std::sync::{Arc, OnceLock};
use std::time::Duration;

pub(crate) type TlsClientConfig = Arc<ClientConfig>;

pub(crate) fn shared_crypto_provider() -> &'static Arc<rustls::crypto::CryptoProvider> {
    static PROVIDER: OnceLock<Arc<rustls::crypto::CryptoProvider>> = OnceLock::new();
    PROVIDER.get_or_init(|| Arc::new(aws_lc_rs::default_provider()))
}

pub(crate) enum ClientStream {
    Plain(TcpStream),
    Tls(Box<StreamOwned<ClientConnection, TcpStream>>),
}

impl Read for ClientStream {
    fn read(&mut self, buffer: &mut [u8]) -> io::Result<usize> {
        match self {
            Self::Plain(stream) => stream.read(buffer),
            Self::Tls(stream) => stream.read(buffer),
        }
    }
}

impl Write for ClientStream {
    fn write(&mut self, buffer: &[u8]) -> io::Result<usize> {
        match self {
            Self::Plain(stream) => stream.write(buffer),
            Self::Tls(stream) => stream.write(buffer),
        }
    }

    fn flush(&mut self) -> io::Result<()> {
        match self {
            Self::Plain(stream) => stream.flush(),
            Self::Tls(stream) => stream.flush(),
        }
    }
}

pub(crate) fn load_tls_client_config() -> io::Result<TlsClientConfig> {
    let cert_result = rustls_native_certs::load_native_certs();
    let mut roots = RootCertStore::empty();
    let (added, _) = roots.add_parsable_certificates(cert_result.certs);
    if added == 0 {
        return Err(io::Error::other(
            "failed to load any native tls root certificates",
        ));
    }

    Ok(Arc::new(
        ClientConfig::builder_with_provider(Arc::clone(shared_crypto_provider()))
            .with_safe_default_protocol_versions()
            .map_err(io::Error::other)?
            .with_root_certificates(roots)
            .with_no_client_auth(),
    ))
}

pub(crate) fn connect_tcp_stream(
    host: &str,
    port: u16,
    connect_timeout: Duration,
    read_timeout: Duration,
    write_timeout: Duration,
) -> io::Result<Option<TcpStream>> {
    let mut addrs = (host, port).to_socket_addrs()?;
    let Some(addr) = addrs.next() else {
        return Ok(None);
    };

    let stream = TcpStream::connect_timeout(&addr, connect_timeout)?;
    stream.set_nodelay(true)?;
    stream.set_read_timeout(Some(read_timeout))?;
    stream.set_write_timeout(Some(write_timeout))?;
    Ok(Some(stream))
}

pub(crate) fn wrap_client_stream(
    url: &ParsedUrl,
    stream: TcpStream,
    tls_config: Option<&TlsClientConfig>,
) -> io::Result<ClientStream> {
    if !url.uses_tls() {
        return Ok(ClientStream::Plain(stream));
    }

    let server_name = server_name(&url.host)?;
    let connection = ClientConnection::new(
        Arc::clone(tls_config.ok_or_else(|| io::Error::other("missing tls client config"))?),
        server_name,
    )
    .map_err(io::Error::other)?;

    Ok(ClientStream::Tls(Box::new(StreamOwned::new(
        connection, stream,
    ))))
}

fn server_name(host: &str) -> io::Result<ServerName<'static>> {
    if let Ok(ip) = host.parse::<IpAddr>() {
        return Ok(ServerName::IpAddress(ip.into()));
    }

    ServerName::try_from(host.to_string()).map_err(|_| {
        io::Error::new(
            io::ErrorKind::InvalidInput,
            "invalid tls server name in url host",
        )
    })
}