use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use futures::future::select_ok;
use quinn::rustls::RootCertStore;
use quinn::{rustls, ClientConfig, ConnectError, Connection, ConnectionError, Endpoint};
use crate::common::client::connection::init::initialize_connection_private;
use crate::common::connection::get_default_transport_config;
use crate::common::constants::{LOCALHOST, ZERO_ADDR4, ZERO_ADDR6};
use crate::common::dns::{resolve_to_ips, DomainResolveError};
use crate::common::tls_debug::get_public_key;
use crate::common::{
connection::{Reader, Writer},
result::DynResult,
};
use super::ClientOptions;
pub async fn initialize_connection(
options: &ClientOptions,
) -> DynResult<(Connection, impl Writer, impl Reader)> {
let connection =
create_quic_connection(options.domain.as_str(), options.port, options.ip).await?;
initialize_connection_private(connection, &options.domain, options.port, true).await
}
#[derive(thiserror::Error, Debug)]
#[allow(clippy::enum_variant_names)]
pub enum CreateQuicConnectionError {
#[error("RustlsError {}", .0)]
RustlsError(#[from] rustls::Error),
#[error("Could not resolve domain: {}", .0)]
DomainResolveError(#[from] DomainResolveError),
#[error("Could not establish connection: {}", .0)]
QuinnConnectionError(#[from] ConnectionError),
#[error("Errors in the parameters being used to create a new connection: {}", .0)]
QuinnConnectError(#[from] ConnectError),
#[error("Io error: {}", .0)]
IoError(#[from] std::io::Error),
#[error("Tls config error: {}", .0)]
RustlsVerifierBuilderError(#[from] rustls::client::VerifierBuilderError),
}
pub async fn create_quic_connection(
domain: &str,
port: u16,
ip_addr: Option<IpAddr>,
) -> Result<Connection, CreateQuicConnectionError> {
let mut client_config = if domain == LOCALHOST {
let mut certs = RootCertStore::empty();
certs.add(get_public_key())?;
ClientConfig::with_root_certificates(Arc::new(certs))?
} else if cfg!(target_os = "android") {
let mut certs = RootCertStore::empty();
certs.add_parsable_certificates(rustls_native_certs::load_native_certs()?.into_iter());
ClientConfig::with_root_certificates(Arc::new(certs))?
} else {
ClientConfig::with_platform_verifier()
};
client_config.transport_config(get_default_transport_config());
enum Error {
Fatal1(std::io::Error),
Fatal2(ConnectError),
Conn(ConnectionError),
}
let socket_addrs = if let Some(addr) = ip_addr {
vec![SocketAddr::new(addr, port)].into_boxed_slice()
} else {
resolve_to_ips(domain, port).await?
};
if socket_addrs.is_empty() {
return Err(CreateQuicConnectionError::DomainResolveError(
DomainResolveError::NoIpsFound,
));
}
let futures = Box::into_iter(socket_addrs).enumerate().map(|(i, addr)| {
let client_config = client_config.clone();
Box::pin(async move {
tokio::time::sleep(Duration::from_secs(i as u64)).await;
let endpoint = Endpoint::client(SocketAddr::new(
if addr.is_ipv4() {
ZERO_ADDR4
} else {
ZERO_ADDR6
},
0,
))
.map_err(Error::Fatal1)?;
let connecting = endpoint
.connect_with(client_config.clone(), addr, domain)
.map_err(Error::Fatal2)?;
connecting.await.map_err(Error::Conn)
})
});
match select_ok(futures).await {
Ok((connection, _)) => Ok(connection),
Err(err) => match err {
Error::Fatal1(err) => Err(err.into()),
Error::Fatal2(err) => Err(err.into()),
Error::Conn(err) => Err(err.into()),
},
}
}