use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use futures::{future, Future, IntoFuture};
use futures::sync::mpsc::unbounded;
use rustls::{Certificate, ClientConfig, ClientSession};
use tokio_core::net::TcpStream as TokioTcpStream;
use tokio_core::reactor::Handle;
use tokio_rustls::{ClientConfigExt, TlsStream as TokioTlsStream};
use trust_dns::BufStreamHandle;
use trust_dns::tcp::TcpStream;
pub type TlsStream = TcpStream<TokioTlsStream<TokioTcpStream, ClientSession>>;
fn tls_new(certs: &[Certificate] ) -> io::Result<Arc<ClientConfig>> {
let mut builder = ClientConfig::new();
{
let mut trust_store = &mut builder.root_store;
for cert in certs {
try!(trust_store
.add(cert)
.map_err(|e| {
io::Error::new(io::ErrorKind::ConnectionRefused,
format!("tls error: {:?}", e))
}));
}
}
Ok(Arc::new(builder))
}
pub fn tls_from_stream(stream: TokioTlsStream<TokioTcpStream, ClientSession>,
peer_addr: SocketAddr)
-> (TlsStream, BufStreamHandle) {
let (message_sender, outbound_messages) = unbounded();
let stream = TcpStream::from_stream_with_receiver(stream, peer_addr, outbound_messages);
(stream, message_sender)
}
pub struct TlsStreamBuilder {
ca_chain: Vec<Certificate>,
}
impl TlsStreamBuilder {
pub fn new() -> TlsStreamBuilder {
TlsStreamBuilder {
ca_chain: vec![],
}
}
pub fn add_ca(&mut self, ca: Certificate) {
self.ca_chain.push(ca);
}
#[cfg(feature = "mtls")]
pub fn identity(&mut self, pkcs12: Pkcs12) {
self.identity = Some(pkcs12);
}
pub fn build(self,
name_server: SocketAddr,
subject_name: String,
loop_handle: &Handle)
-> (Box<Future<Item = TlsStream, Error = io::Error>>, BufStreamHandle) {
let (message_sender, outbound_messages) = unbounded();
let tls_connector =
match ::tls_stream::tls_new(&self.ca_chain ) {
Ok(c) => c,
Err(e) => {
return (Box::new(future::err(e).into_future().map_err(|e| {
io::Error::new(io::ErrorKind::ConnectionRefused,
format!("tls error: {}", e))
})),
message_sender)
}
};
let tcp = TokioTcpStream::connect(&name_server, &loop_handle);
let stream: Box<Future<Item = TlsStream, Error = io::Error>> =
Box::new(tcp.and_then(move |tcp_stream| {
tls_connector
.connect_async(&subject_name, tcp_stream)
.map(move |s| {
TcpStream::from_stream_with_receiver(s,
name_server,
outbound_messages)
})
.map_err(|e| {
io::Error::new(io::ErrorKind::ConnectionRefused,
format!("tls error: {}", e))
})
})
.map_err(|e| {
io::Error::new(io::ErrorKind::ConnectionRefused,
format!("tls error: {}", e))
}));
(stream, message_sender)
}
}