actori_connect/ssl/
rustls.rs1use std::fmt;
2use std::future::Future;
3use std::marker::PhantomData;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8pub use rust_tls::Session;
9pub use tokio_rustls::{client::TlsStream, rustls::ClientConfig};
10
11use actori_codec::{AsyncRead, AsyncWrite};
12use actori_service::{Service, ServiceFactory};
13use futures::future::{ok, Ready};
14use tokio_rustls::{Connect, TlsConnector};
15use webpki::DNSNameRef;
16
17use crate::{Address, Connection};
18
19pub struct RustlsConnector<T, U> {
21 connector: Arc<ClientConfig>,
22 _t: PhantomData<(T, U)>,
23}
24
25impl<T, U> RustlsConnector<T, U> {
26 pub fn new(connector: Arc<ClientConfig>) -> Self {
27 RustlsConnector {
28 connector,
29 _t: PhantomData,
30 }
31 }
32}
33
34impl<T, U> RustlsConnector<T, U>
35where
36 T: Address,
37 U: AsyncRead + AsyncWrite + Unpin + fmt::Debug,
38{
39 pub fn service(connector: Arc<ClientConfig>) -> RustlsConnectorService<T, U> {
40 RustlsConnectorService {
41 connector: connector,
42 _t: PhantomData,
43 }
44 }
45}
46
47impl<T, U> Clone for RustlsConnector<T, U> {
48 fn clone(&self) -> Self {
49 Self {
50 connector: self.connector.clone(),
51 _t: PhantomData,
52 }
53 }
54}
55
56impl<T: Address, U> ServiceFactory for RustlsConnector<T, U>
57where
58 U: AsyncRead + AsyncWrite + Unpin + fmt::Debug,
59{
60 type Request = Connection<T, U>;
61 type Response = Connection<T, TlsStream<U>>;
62 type Error = std::io::Error;
63 type Config = ();
64 type Service = RustlsConnectorService<T, U>;
65 type InitError = ();
66 type Future = Ready<Result<Self::Service, Self::InitError>>;
67
68 fn new_service(&self, _: ()) -> Self::Future {
69 ok(RustlsConnectorService {
70 connector: self.connector.clone(),
71 _t: PhantomData,
72 })
73 }
74}
75
76pub struct RustlsConnectorService<T, U> {
77 connector: Arc<ClientConfig>,
78 _t: PhantomData<(T, U)>,
79}
80
81impl<T, U> Clone for RustlsConnectorService<T, U> {
82 fn clone(&self) -> Self {
83 Self {
84 connector: self.connector.clone(),
85 _t: PhantomData,
86 }
87 }
88}
89
90impl<T: Address, U> Service for RustlsConnectorService<T, U>
91where
92 U: AsyncRead + AsyncWrite + Unpin + fmt::Debug,
93{
94 type Request = Connection<T, U>;
95 type Response = Connection<T, TlsStream<U>>;
96 type Error = std::io::Error;
97 type Future = ConnectAsyncExt<T, U>;
98
99 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
100 Poll::Ready(Ok(()))
101 }
102
103 fn call(&mut self, stream: Connection<T, U>) -> Self::Future {
104 trace!("SSL Handshake start for: {:?}", stream.host());
105 let (io, stream) = stream.replace(());
106 let host = DNSNameRef::try_from_ascii_str(stream.host())
107 .expect("rustls currently only handles hostname-based connections. See https://github.com/briansmith/webpki/issues/54");
108 ConnectAsyncExt {
109 fut: TlsConnector::from(self.connector.clone()).connect(host, io),
110 stream: Some(stream),
111 }
112 }
113}
114
115pub struct ConnectAsyncExt<T, U> {
116 fut: Connect<U>,
117 stream: Option<Connection<T, ()>>,
118}
119
120impl<T: Address, U> Future for ConnectAsyncExt<T, U>
121where
122 U: AsyncRead + AsyncWrite + Unpin + fmt::Debug,
123{
124 type Output = Result<Connection<T, TlsStream<U>>, std::io::Error>;
125
126 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
127 let this = self.get_mut();
128 Poll::Ready(
129 futures::ready!(Pin::new(&mut this.fut).poll(cx)).map(|stream| {
130 let s = this.stream.take().unwrap();
131 trace!("SSL Handshake success: {:?}", s.host());
132 s.replace(stream).1
133 }),
134 )
135 }
136}