use std::{future::Future, net::SocketAddr, sync::Arc, time::Duration};
use futures_util::future::BoxFuture;
#[cfg(not(feature = "rustls-platform-verifier"))]
use rustls::RootCertStore;
use rustls::{
ClientConfig,
crypto::{self, CryptoProvider},
pki_types::ServerName,
};
#[cfg(feature = "rustls-platform-verifier")]
use rustls_platform_verifier::BuilderVerifierExt;
use tokio::time::timeout;
use tokio_rustls::TlsConnector;
use tracing::debug;
use crate::{
error::NetError,
runtime::{
DnsTcpStream, RuntimeProvider, Spawn,
iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd},
},
tcp::{TcpClientStream, TcpStream},
xfer::{BufDnsStreamHandle, CONNECT_TIMEOUT, DnsExchange, DnsMultiplexer, StreamReceiver},
};
pub type TlsClientStream<S> =
TcpClientStream<AsyncIoTokioAsStd<tokio_rustls::client::TlsStream<AsyncIoStdAsTokio<S>>>>;
pub async fn tls_exchange<P: RuntimeProvider<Tcp = S>, S: DnsTcpStream>(
remote_addr: SocketAddr,
server_name: ServerName<'static>,
mut config: ClientConfig,
timeout: Duration,
max_active_requests: Option<usize>,
provider: P,
) -> Result<DnsExchange<P>, NetError> {
config.enable_sni = false;
let stream = provider.connect_tcp(remote_addr, None, None).await?;
let (future, sender) = tls_client_connect_with_future(
stream,
remote_addr,
server_name.to_owned(),
Arc::new(config),
);
let mut multiplexer = DnsMultiplexer::new(future.await?, sender).with_timeout(timeout);
if let Some(max) = max_active_requests {
multiplexer = multiplexer.with_max_active_requests(max);
}
let (exchange, bg) = DnsExchange::<P>::from_stream(multiplexer);
provider.create_handle().spawn_bg(bg);
Ok(exchange)
}
#[allow(clippy::type_complexity)]
pub fn tls_client_connect<P: RuntimeProvider>(
name_server: SocketAddr,
server_name: ServerName<'static>,
client_config: Arc<ClientConfig>,
provider: P,
) -> (
BoxFuture<'static, Result<TlsClientStream<P::Tcp>, NetError>>,
BufDnsStreamHandle,
) {
tls_client_connect_with_bind_addr(name_server, None, server_name, client_config, provider)
}
#[allow(clippy::type_complexity)]
pub fn tls_client_connect_with_bind_addr<P: RuntimeProvider>(
name_server: SocketAddr,
bind_addr: Option<SocketAddr>,
server_name: ServerName<'static>,
client_config: Arc<ClientConfig>,
provider: P,
) -> (
BoxFuture<'static, Result<TlsClientStream<P::Tcp>, NetError>>,
BufDnsStreamHandle,
) {
let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
let early_data_enabled = client_config.enable_early_data;
let tls_connector = TlsConnector::from(client_config).early_data(early_data_enabled);
let stream = async move {
let tcp = provider.connect_tcp(name_server, bind_addr, None).await?;
connect_tls_stream(
tls_connector,
tcp,
name_server,
server_name,
outbound_messages,
)
.await
};
let new_future = Box::pin(async { Ok(TcpClientStream::from_stream(stream.await?)) });
(new_future, message_sender)
}
fn tls_client_connect_with_future<S: DnsTcpStream>(
stream: S,
socket_addr: SocketAddr,
server_name: ServerName<'static>,
client_config: Arc<ClientConfig>,
) -> (
impl Future<Output = Result<TlsClientStream<S>, NetError>> + Send + 'static,
BufDnsStreamHandle,
) {
let (message_sender, outbound_messages) = BufDnsStreamHandle::new(socket_addr);
let early_data_enabled = client_config.enable_early_data;
let tls_connector = TlsConnector::from(client_config).early_data(early_data_enabled);
let stream = async move {
connect_tls_stream(
tls_connector,
stream,
socket_addr,
server_name,
outbound_messages,
)
.await
};
(
async move { Ok(TcpClientStream::from_stream(stream.await?)) },
message_sender,
)
}
pub(super) async fn connect_tls_stream<S: DnsTcpStream>(
tls_connector: TlsConnector,
stream: S,
name_server: SocketAddr,
server_name: ServerName<'static>,
outbound_messages: StreamReceiver,
) -> Result<TcpStream<AsyncIoTokioAsStd<TokioTlsClientStream<S>>>, NetError> {
let stream = AsyncIoStdAsTokio(stream);
let s = match timeout(CONNECT_TIMEOUT, tls_connector.connect(server_name, stream)).await {
Ok(Ok(s)) => s,
Ok(Err(e)) => return Err(NetError::from(e)),
Err(_) => {
debug!(%name_server, "TLS connect timeout");
return Err(NetError::Timeout);
}
};
Ok(TcpStream::from_stream_with_receiver(
AsyncIoTokioAsStd(s),
name_server,
outbound_messages,
))
}
pub fn tls_from_stream<S: DnsTcpStream>(
stream: S,
peer_addr: SocketAddr,
) -> (TlsStream<S>, BufDnsStreamHandle) {
let (message_sender, outbound_messages) = BufDnsStreamHandle::new(peer_addr);
let stream = TcpStream::from_stream_with_receiver(stream, peer_addr, outbound_messages);
(stream, message_sender)
}
pub fn client_config() -> Result<ClientConfig, rustls::Error> {
let builder = ClientConfig::builder_with_provider(Arc::new(default_provider()))
.with_safe_default_protocol_versions()
.unwrap();
#[cfg(feature = "rustls-platform-verifier")]
let builder = builder.with_platform_verifier()?;
#[cfg(not(feature = "rustls-platform-verifier"))]
let builder = builder.with_root_certificates({
#[cfg_attr(not(feature = "webpki-roots"), allow(unused_mut))]
let mut root_store = RootCertStore::empty();
#[cfg(feature = "webpki-roots")]
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
root_store
});
Ok(builder.with_no_client_auth())
}
#[cfg(all(feature = "tls-aws-lc-rs", not(feature = "tls-ring")))]
pub fn default_provider() -> CryptoProvider {
crypto::aws_lc_rs::default_provider()
}
#[cfg(feature = "tls-ring")]
pub fn default_provider() -> CryptoProvider {
crypto::ring::default_provider()
}
pub type TlsStream<S> = TcpStream<S>;
pub type TokioTlsClientStream<S> = tokio_rustls::client::TlsStream<AsyncIoStdAsTokio<S>>;