actori_connect/ssl/
rustls.rs

1use std::fmt;
2use std::future::Future;
3use std::marker::PhantomData;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8pub use rust_tls::Session;
9pub use tokio_rustls::{client::TlsStream, rustls::ClientConfig};
10
11use actori_codec::{AsyncRead, AsyncWrite};
12use actori_service::{Service, ServiceFactory};
13use futures::future::{ok, Ready};
14use tokio_rustls::{Connect, TlsConnector};
15use webpki::DNSNameRef;
16
17use crate::{Address, Connection};
18
19/// Rustls connector factory
20pub struct RustlsConnector<T, U> {
21    connector: Arc<ClientConfig>,
22    _t: PhantomData<(T, U)>,
23}
24
25impl<T, U> RustlsConnector<T, U> {
26    pub fn new(connector: Arc<ClientConfig>) -> Self {
27        RustlsConnector {
28            connector,
29            _t: PhantomData,
30        }
31    }
32}
33
34impl<T, U> RustlsConnector<T, U>
35where
36    T: Address,
37    U: AsyncRead + AsyncWrite + Unpin + fmt::Debug,
38{
39    pub fn service(connector: Arc<ClientConfig>) -> RustlsConnectorService<T, U> {
40        RustlsConnectorService {
41            connector: connector,
42            _t: PhantomData,
43        }
44    }
45}
46
47impl<T, U> Clone for RustlsConnector<T, U> {
48    fn clone(&self) -> Self {
49        Self {
50            connector: self.connector.clone(),
51            _t: PhantomData,
52        }
53    }
54}
55
56impl<T: Address, U> ServiceFactory for RustlsConnector<T, U>
57where
58    U: AsyncRead + AsyncWrite + Unpin + fmt::Debug,
59{
60    type Request = Connection<T, U>;
61    type Response = Connection<T, TlsStream<U>>;
62    type Error = std::io::Error;
63    type Config = ();
64    type Service = RustlsConnectorService<T, U>;
65    type InitError = ();
66    type Future = Ready<Result<Self::Service, Self::InitError>>;
67
68    fn new_service(&self, _: ()) -> Self::Future {
69        ok(RustlsConnectorService {
70            connector: self.connector.clone(),
71            _t: PhantomData,
72        })
73    }
74}
75
76pub struct RustlsConnectorService<T, U> {
77    connector: Arc<ClientConfig>,
78    _t: PhantomData<(T, U)>,
79}
80
81impl<T, U> Clone for RustlsConnectorService<T, U> {
82    fn clone(&self) -> Self {
83        Self {
84            connector: self.connector.clone(),
85            _t: PhantomData,
86        }
87    }
88}
89
90impl<T: Address, U> Service for RustlsConnectorService<T, U>
91where
92    U: AsyncRead + AsyncWrite + Unpin + fmt::Debug,
93{
94    type Request = Connection<T, U>;
95    type Response = Connection<T, TlsStream<U>>;
96    type Error = std::io::Error;
97    type Future = ConnectAsyncExt<T, U>;
98
99    fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
100        Poll::Ready(Ok(()))
101    }
102
103    fn call(&mut self, stream: Connection<T, U>) -> Self::Future {
104        trace!("SSL Handshake start for: {:?}", stream.host());
105        let (io, stream) = stream.replace(());
106        let host = DNSNameRef::try_from_ascii_str(stream.host())
107            .expect("rustls currently only handles hostname-based connections. See https://github.com/briansmith/webpki/issues/54");
108        ConnectAsyncExt {
109            fut: TlsConnector::from(self.connector.clone()).connect(host, io),
110            stream: Some(stream),
111        }
112    }
113}
114
115pub struct ConnectAsyncExt<T, U> {
116    fut: Connect<U>,
117    stream: Option<Connection<T, ()>>,
118}
119
120impl<T: Address, U> Future for ConnectAsyncExt<T, U>
121where
122    U: AsyncRead + AsyncWrite + Unpin + fmt::Debug,
123{
124    type Output = Result<Connection<T, TlsStream<U>>, std::io::Error>;
125
126    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
127        let this = self.get_mut();
128        Poll::Ready(
129            futures::ready!(Pin::new(&mut this.fut).poll(cx)).map(|stream| {
130                let s = this.stream.take().unwrap();
131                trace!("SSL Handshake success: {:?}", s.host());
132                s.replace(stream).1
133            }),
134        )
135    }
136}