zero-trust-rps 0.0.5

Online Multiplayer Rock Paper Scissors
Documentation
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;

use futures::future::select_ok;
use quinn::rustls::RootCertStore;
use quinn::{rustls, ClientConfig, ConnectError, Connection, ConnectionError, Endpoint};

use crate::common::client::connection::init::initialize_connection_private;
use crate::common::connection::get_default_transport_config;
use crate::common::constants::{LOCALHOST, ZERO_ADDR4, ZERO_ADDR6};
use crate::common::dns::{resolve_to_ips, DomainResolveError};
use crate::common::tls_debug::get_public_key;
use crate::common::{
    connection::{Reader, Writer},
    result::DynResult,
};

use super::ClientOptions;

pub async fn initialize_connection(
    options: &ClientOptions,
) -> DynResult<(Connection, impl Writer, impl Reader)> {
    let connection =
        create_quic_connection(options.domain.as_str(), options.port, options.ip).await?;

    initialize_connection_private(connection, &options.domain, options.port, true).await
}

#[derive(thiserror::Error, Debug)]
#[allow(clippy::enum_variant_names)]
pub enum CreateQuicConnectionError {
    #[error("RustlsError {}", .0)]
    RustlsError(#[from] rustls::Error),
    #[error("Could not resolve domain: {}", .0)]
    DomainResolveError(#[from] DomainResolveError),
    #[error("Could not establish connection: {}", .0)]
    QuinnConnectionError(#[from] ConnectionError),
    #[error("Errors in the parameters being used to create a new connection: {}", .0)]
    QuinnConnectError(#[from] ConnectError),
    #[error("Io error: {}", .0)]
    IoError(#[from] std::io::Error),
    #[error("Tls config error: {}", .0)]
    RustlsVerifierBuilderError(#[from] rustls::client::VerifierBuilderError),
}

pub async fn create_quic_connection(
    domain: &str,
    port: u16,
    ip_addr: Option<IpAddr>,
) -> Result<Connection, CreateQuicConnectionError> {
    let mut client_config = if domain == LOCALHOST {
        let mut certs = RootCertStore::empty();
        certs.add(get_public_key())?;
        ClientConfig::with_root_certificates(Arc::new(certs))?
    } else if cfg!(target_os = "android") {
        // Do not use platform verifier as that requires java
        let mut certs = RootCertStore::empty();
        certs.add_parsable_certificates(rustls_native_certs::load_native_certs()?.into_iter());
        ClientConfig::with_root_certificates(Arc::new(certs))?
    } else {
        ClientConfig::with_platform_verifier()
    };

    client_config.transport_config(get_default_transport_config());

    enum Error {
        Fatal1(std::io::Error),
        Fatal2(ConnectError),
        Conn(ConnectionError),
    }

    let socket_addrs = if let Some(addr) = ip_addr {
        vec![SocketAddr::new(addr, port)].into_boxed_slice()
    } else {
        resolve_to_ips(domain, port).await?
    };

    if socket_addrs.is_empty() {
        return Err(CreateQuicConnectionError::DomainResolveError(
            DomainResolveError::NoIpsFound,
        ));
    }

    let futures = Box::into_iter(socket_addrs).enumerate().map(|(i, addr)| {
        let client_config = client_config.clone();
        Box::pin(async move {
            // wait to not spam requests
            tokio::time::sleep(Duration::from_secs(i as u64)).await;

            let endpoint = Endpoint::client(SocketAddr::new(
                if addr.is_ipv4() {
                    ZERO_ADDR4
                } else {
                    ZERO_ADDR6
                },
                0,
            ))
            .map_err(Error::Fatal1)?;

            let connecting = endpoint
                .connect_with(client_config.clone(), addr, domain)
                .map_err(Error::Fatal2)?;

            connecting.await.map_err(Error::Conn)
        })
    });

    match select_ok(futures).await {
        Ok((connection, _)) => Ok(connection),
        Err(err) => match err {
            Error::Fatal1(err) => Err(err.into()),
            Error::Fatal2(err) => Err(err.into()),
            Error::Conn(err) => Err(err.into()),
        },
    }
}