use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::{future::Future, marker::PhantomData};
use futures_util::TryFutureExt;
use native_tls::Protocol::Tlsv12;
use native_tls::{Certificate, Identity, TlsConnector};
use tokio_native_tls::{TlsConnector as TokioTlsConnector, TlsStream as TokioTlsStream};
use crate::iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd};
use crate::tcp::Connect;
use crate::tcp::TcpStream;
use crate::xfer::{BufDnsStreamHandle, StreamReceiver};
pub type TlsStream<S> = TcpStream<AsyncIoTokioAsStd<TokioTlsStream<AsyncIoStdAsTokio<S>>>>;
fn tls_new(certs: Vec<Certificate>, pkcs12: Option<Identity>) -> io::Result<TlsConnector> {
let mut builder = TlsConnector::builder();
builder.min_protocol_version(Some(Tlsv12));
for cert in certs {
builder.add_root_certificate(cert);
}
if let Some(pkcs12) = pkcs12 {
builder.identity(pkcs12);
}
builder.build().map_err(|e| {
io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("tls error: {}", e),
)
})
}
pub fn tls_from_stream<S: Connect>(
stream: TokioTlsStream<AsyncIoStdAsTokio<S>>,
peer_addr: SocketAddr,
) -> (TlsStream<S>, BufDnsStreamHandle) {
let (message_sender, outbound_messages) = BufDnsStreamHandle::new(peer_addr);
let stream = TcpStream::from_stream_with_receiver(
AsyncIoTokioAsStd(stream),
peer_addr,
outbound_messages,
);
(stream, message_sender)
}
#[derive(Default)]
pub struct TlsStreamBuilder<S> {
ca_chain: Vec<Certificate>,
identity: Option<Identity>,
bind_addr: Option<SocketAddr>,
marker: PhantomData<S>,
}
impl<S: Connect> TlsStreamBuilder<S> {
pub fn new() -> Self {
Self {
ca_chain: vec![],
identity: None,
bind_addr: None,
marker: PhantomData,
}
}
pub fn add_ca(&mut self, ca: Certificate) {
self.ca_chain.push(ca);
}
#[cfg(feature = "mtls")]
pub fn identity(&mut self, identity: Identity) {
self.identity = Some(identity);
}
pub fn bind_addr(&mut self, bind_addr: SocketAddr) {
self.bind_addr = Some(bind_addr);
}
#[allow(clippy::type_complexity)]
pub fn build(
self,
name_server: SocketAddr,
dns_name: String,
) -> (
// TODO: change to impl?
Pin<Box<dyn Future<Output = Result<TlsStream<S>, io::Error>> + Send>>,
BufDnsStreamHandle,
) {
let (message_sender, outbound_messages) = BufDnsStreamHandle::new(name_server);
let stream = self.inner_build(name_server, dns_name, outbound_messages);
(Box::pin(stream), message_sender)
}
async fn inner_build(
self,
name_server: SocketAddr,
dns_name: String,
outbound_messages: StreamReceiver,
) -> Result<TlsStream<S>, io::Error> {
use crate::native_tls::tls_stream;
let ca_chain = self.ca_chain.clone();
let identity = self.identity;
let tcp_stream = S::connect_with_bind(name_server, self.bind_addr).await;
let tcp_stream = match tcp_stream {
Ok(tcp_stream) => AsyncIoStdAsTokio(tcp_stream),
Err(err) => return Err(err),
};
let tls_connector = tls_stream::tls_new(ca_chain, identity)
.map(TokioTlsConnector::from)
.map_err(|e| {
io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("tls error: {}", e),
)
})?;
let tls_connected = tls_connector
.connect(&dns_name, tcp_stream)
.map_err(|e| {
io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("tls error: {}", e),
)
})
.await?;
Ok(TcpStream::from_stream_with_receiver(
AsyncIoTokioAsStd(tls_connected),
name_server,
outbound_messages,
))
}
}