use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use futures::channel::mpsc::{unbounded, UnboundedReceiver};
use futures::{Future, TryFutureExt};
use rustls::ClientConfig;
use tokio;
use tokio::net::TcpStream as TokioTcpStream;
use tokio_rustls::TlsConnector;
use webpki::{DNSName, DNSNameRef};
use trust_dns_proto::iocompat::AsyncIo02As03;
use trust_dns_proto::tcp::TcpStream;
use trust_dns_proto::xfer::{BufStreamHandle, SerialMessage};
pub type TokioTlsClientStream = tokio_rustls::client::TlsStream<TokioTcpStream>;
pub type TokioTlsServerStream = tokio_rustls::server::TlsStream<TokioTcpStream>;
pub type TlsStream<S> = TcpStream<S>;
pub fn tls_from_stream<S: futures::io::AsyncRead + futures::io::AsyncWrite>(
stream: S,
peer_addr: SocketAddr,
) -> (TlsStream<S>, BufStreamHandle) {
let (message_sender, outbound_messages) = unbounded();
let message_sender = BufStreamHandle::new(message_sender);
let stream = TcpStream::from_stream_with_receiver(stream, peer_addr, outbound_messages);
(stream, message_sender)
}
#[allow(clippy::type_complexity)]
pub fn tls_connect(
name_server: SocketAddr,
dns_name: String,
client_config: Arc<ClientConfig>,
) -> (
Pin<
Box<
dyn Future<Output = Result<TlsStream<AsyncIo02As03<TokioTlsClientStream>>, io::Error>>
+ Send,
>,
>,
BufStreamHandle,
) {
let (message_sender, outbound_messages) = unbounded();
let message_sender = BufStreamHandle::new(message_sender);
let early_data_enabled = client_config.enable_early_data;
let tls_connector = TlsConnector::from(client_config).early_data(early_data_enabled);
let stream = Box::pin(connect_tls(
tls_connector,
name_server,
dns_name,
outbound_messages,
));
(stream, message_sender)
}
async fn connect_tls(
tls_connector: TlsConnector,
name_server: SocketAddr,
dns_name: String,
outbound_messages: UnboundedReceiver<SerialMessage>,
) -> io::Result<TcpStream<AsyncIo02As03<TokioTlsClientStream>>> {
let tcp = TokioTcpStream::connect(&name_server).await?;
let dns_name = DNSNameRef::try_from_ascii_str(&dns_name)
.map(DNSName::from)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "bad dns_name"))?;
let s = tls_connector
.connect(dns_name.as_ref(), tcp)
.map_err(|e| {
io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("tls error: {}", e),
)
})
.await?;
Ok(TcpStream::from_stream_with_receiver(
AsyncIo02As03(s),
name_server,
outbound_messages,
))
}