use std::io;
use std::net::SocketAddr;
use futures::sync::mpsc::unbounded;
use futures::{future, Future, IntoFuture};
use openssl::pkcs12::ParsedPkcs12;
use openssl::pkey::{PKeyRef, Private};
use openssl::ssl::{SslConnector, SslContextBuilder, SslMethod, SslOptions};
use openssl::stack::Stack;
use openssl::x509::store::X509StoreBuilder;
use openssl::x509::{X509, X509Ref};
use tokio_openssl::{SslConnectorExt, SslStream as TokioTlsStream};
use tokio_tcp::TcpStream as TokioTcpStream;
use trust_dns_proto::tcp::TcpStream;
use trust_dns_proto::xfer::BufStreamHandle;
pub trait TlsIdentityExt {
fn identity(&mut self, pkcs12: &ParsedPkcs12) -> io::Result<()> {
self.identity_parts(&pkcs12.cert, &pkcs12.pkey, pkcs12.chain.as_ref())
}
fn identity_parts(
&mut self,
cert: &X509Ref,
pkey: &PKeyRef<Private>,
chain: Option<&Stack<X509>>,
) -> io::Result<()>;
}
impl TlsIdentityExt for SslContextBuilder {
fn identity_parts(
&mut self,
cert: &X509Ref,
pkey: &PKeyRef<Private>,
chain: Option<&Stack<X509>>,
) -> io::Result<()> {
self.set_certificate(cert)?;
self.set_private_key(pkey)?;
self.check_private_key()?;
if let Some(chain) = chain {
for cert in chain {
self.add_extra_chain_cert(cert.to_owned())?;
}
}
Ok(())
}
}
pub type TlsStream = TcpStream<TokioTlsStream<TokioTcpStream>>;
fn new(certs: Vec<X509>, pkcs12: Option<ParsedPkcs12>) -> io::Result<SslConnector> {
let mut tls = SslConnector::builder(SslMethod::tls()).map_err(|e| {
io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("tls error: {}", e),
)
})?;
{
let openssl_ctx_builder = &mut tls;
openssl_ctx_builder.set_options(
SslOptions::NO_SSLV2
| SslOptions::NO_SSLV3
| SslOptions::NO_TLSV1
| SslOptions::NO_TLSV1_1,
);
let mut store = X509StoreBuilder::new().map_err(|e| {
io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("tls error: {}", e),
)
})?;
for cert in certs {
store.add_cert(cert).map_err(|e| {
io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("tls error: {}", e),
)
})?;
}
openssl_ctx_builder
.set_verify_cert_store(store.build())
.map_err(|e| {
io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("tls error: {}", e),
)
})?;
if let Some(pkcs12) = pkcs12 {
openssl_ctx_builder.identity(&pkcs12)?;
}
}
Ok(tls.build())
}
pub fn tls_stream_from_existing_tls_stream(
stream: TokioTlsStream<TokioTcpStream>,
peer_addr: SocketAddr,
) -> (TlsStream, BufStreamHandle) {
let (message_sender, outbound_messages) = unbounded();
let message_sender = BufStreamHandle::new(message_sender);
let stream = TcpStream::from_stream_with_receiver(stream, peer_addr, outbound_messages);
(stream, message_sender)
}
pub struct TlsStreamBuilder {
ca_chain: Vec<X509>,
identity: Option<ParsedPkcs12>,
}
impl TlsStreamBuilder {
pub fn new() -> Self {
TlsStreamBuilder {
ca_chain: vec![],
identity: None,
}
}
pub fn add_ca(&mut self, ca: X509) {
self.ca_chain.push(ca);
}
#[cfg(feature = "mtls")]
pub fn identity(&mut self, pkcs12: ParsedPkcs12) {
self.identity = Some(pkcs12);
}
pub fn build(
self,
name_server: SocketAddr,
dns_name: String,
) -> (
Box<Future<Item = TlsStream, Error = io::Error> + Send>,
BufStreamHandle,
) {
let (message_sender, outbound_messages) = unbounded();
let message_sender = BufStreamHandle::new(message_sender);
let tls_connector = match new(self.ca_chain, self.identity) {
Ok(c) => c,
Err(e) => {
return (
Box::new(future::err(e).into_future().map_err(|e| {
io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("tls error: {}", e),
)
})),
message_sender,
)
}
};
let tcp = TokioTcpStream::connect(&name_server);
let stream = Box::new(
tcp.and_then(move |tcp_stream| {
tls_connector
.connect_async(&dns_name, tcp_stream)
.map(move |s| {
TcpStream::from_stream_with_receiver(s, name_server, outbound_messages)
})
.map_err(|e| {
io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("tls error: {}", e),
)
})
}).map_err(|e| {
io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("tls error: {}", e),
)
}),
);
(stream, message_sender)
}
}