Skip to main content

requiem_connect/ssl/
openssl.rs

1use std::future::Future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::{fmt, io};
6
7pub use open_ssl::ssl::{Error as SslError, SslConnector, SslMethod};
8pub use tokio_openssl::{HandshakeError, SslStream};
9
10use requiem_codec::{AsyncRead, AsyncWrite};
11use requiem_rt::net::TcpStream;
12use requiem_service::{Service, ServiceFactory};
13use futures::future::{err, ok, Either, FutureExt, LocalBoxFuture, Ready};
14use trust_dns_resolver::AsyncResolver;
15
16use crate::{
17    Address, Connect, ConnectError, ConnectService, ConnectServiceFactory, Connection,
18};
19
20/// Openssl connector factory
21pub struct OpensslConnector<T, U> {
22    connector: SslConnector,
23    _t: PhantomData<(T, U)>,
24}
25
26impl<T, U> OpensslConnector<T, U> {
27    pub fn new(connector: SslConnector) -> Self {
28        OpensslConnector {
29            connector,
30            _t: PhantomData,
31        }
32    }
33}
34
35impl<T, U> OpensslConnector<T, U>
36where
37    T: Address + 'static,
38    U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
39{
40    pub fn service(connector: SslConnector) -> OpensslConnectorService<T, U> {
41        OpensslConnectorService {
42            connector,
43            _t: PhantomData,
44        }
45    }
46}
47
48impl<T, U> Clone for OpensslConnector<T, U> {
49    fn clone(&self) -> Self {
50        Self {
51            connector: self.connector.clone(),
52            _t: PhantomData,
53        }
54    }
55}
56
57impl<T, U> ServiceFactory for OpensslConnector<T, U>
58where
59    T: Address + 'static,
60    U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
61{
62    type Request = Connection<T, U>;
63    type Response = Connection<T, SslStream<U>>;
64    type Error = io::Error;
65    type Config = ();
66    type Service = OpensslConnectorService<T, U>;
67    type InitError = ();
68    type Future = Ready<Result<Self::Service, Self::InitError>>;
69
70    fn new_service(&self, _: ()) -> Self::Future {
71        ok(OpensslConnectorService {
72            connector: self.connector.clone(),
73            _t: PhantomData,
74        })
75    }
76}
77
78pub struct OpensslConnectorService<T, U> {
79    connector: SslConnector,
80    _t: PhantomData<(T, U)>,
81}
82
83impl<T, U> Clone for OpensslConnectorService<T, U> {
84    fn clone(&self) -> Self {
85        Self {
86            connector: self.connector.clone(),
87            _t: PhantomData,
88        }
89    }
90}
91
92impl<T, U> Service for OpensslConnectorService<T, U>
93where
94    T: Address + 'static,
95    U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
96{
97    type Request = Connection<T, U>;
98    type Response = Connection<T, SslStream<U>>;
99    type Error = io::Error;
100    type Future = Either<ConnectAsyncExt<T, U>, Ready<Result<Self::Response, Self::Error>>>;
101
102    fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
103        Poll::Ready(Ok(()))
104    }
105
106    fn call(&mut self, stream: Connection<T, U>) -> Self::Future {
107        trace!("SSL Handshake start for: {:?}", stream.host());
108        let (io, stream) = stream.replace(());
109        let host = stream.host().to_string();
110
111        match self.connector.configure() {
112            Err(e) => Either::Right(err(io::Error::new(io::ErrorKind::Other, e))),
113            Ok(config) => Either::Left(ConnectAsyncExt {
114                fut: async move { tokio_openssl::connect(config, &host, io).await }
115                    .boxed_local(),
116                stream: Some(stream),
117                _t: PhantomData,
118            }),
119        }
120    }
121}
122
123pub struct ConnectAsyncExt<T, U> {
124    fut: LocalBoxFuture<'static, Result<SslStream<U>, HandshakeError<U>>>,
125    stream: Option<Connection<T, ()>>,
126    _t: PhantomData<U>,
127}
128
129impl<T: Address, U> Future for ConnectAsyncExt<T, U>
130where
131    U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
132{
133    type Output = Result<Connection<T, SslStream<U>>, io::Error>;
134
135    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
136        let this = self.get_mut();
137
138        match Pin::new(&mut this.fut).poll(cx) {
139            Poll::Ready(Ok(stream)) => {
140                let s = this.stream.take().unwrap();
141                trace!("SSL Handshake success: {:?}", s.host());
142                Poll::Ready(Ok(s.replace(stream).1))
143            }
144            Poll::Ready(Err(e)) => {
145                trace!("SSL Handshake error: {:?}", e);
146                Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))))
147            }
148            Poll::Pending => Poll::Pending,
149        }
150    }
151}
152
153pub struct OpensslConnectServiceFactory<T> {
154    tcp: ConnectServiceFactory<T>,
155    openssl: OpensslConnector<T, TcpStream>,
156}
157
158impl<T> OpensslConnectServiceFactory<T> {
159    /// Construct new OpensslConnectService factory
160    pub fn new(connector: SslConnector) -> Self {
161        OpensslConnectServiceFactory {
162            tcp: ConnectServiceFactory::default(),
163            openssl: OpensslConnector::new(connector),
164        }
165    }
166
167    /// Construct new connect service with custom dns resolver
168    pub fn with_resolver(connector: SslConnector, resolver: AsyncResolver) -> Self {
169        OpensslConnectServiceFactory {
170            tcp: ConnectServiceFactory::with_resolver(resolver),
171            openssl: OpensslConnector::new(connector),
172        }
173    }
174
175    /// Construct openssl connect service
176    pub fn service(&self) -> OpensslConnectService<T> {
177        OpensslConnectService {
178            tcp: self.tcp.service(),
179            openssl: OpensslConnectorService {
180                connector: self.openssl.connector.clone(),
181                _t: PhantomData,
182            },
183        }
184    }
185}
186
187impl<T> Clone for OpensslConnectServiceFactory<T> {
188    fn clone(&self) -> Self {
189        OpensslConnectServiceFactory {
190            tcp: self.tcp.clone(),
191            openssl: self.openssl.clone(),
192        }
193    }
194}
195
196impl<T: Address + 'static> ServiceFactory for OpensslConnectServiceFactory<T> {
197    type Request = Connect<T>;
198    type Response = SslStream<TcpStream>;
199    type Error = ConnectError;
200    type Config = ();
201    type Service = OpensslConnectService<T>;
202    type InitError = ();
203    type Future = Ready<Result<Self::Service, Self::InitError>>;
204
205    fn new_service(&self, _: ()) -> Self::Future {
206        ok(self.service())
207    }
208}
209
210#[derive(Clone)]
211pub struct OpensslConnectService<T> {
212    tcp: ConnectService<T>,
213    openssl: OpensslConnectorService<T, TcpStream>,
214}
215
216impl<T: Address + 'static> Service for OpensslConnectService<T> {
217    type Request = Connect<T>;
218    type Response = SslStream<TcpStream>;
219    type Error = ConnectError;
220    type Future = OpensslConnectServiceResponse<T>;
221
222    fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
223        Poll::Ready(Ok(()))
224    }
225
226    fn call(&mut self, req: Connect<T>) -> Self::Future {
227        OpensslConnectServiceResponse {
228            fut1: Some(self.tcp.call(req)),
229            fut2: None,
230            openssl: self.openssl.clone(),
231        }
232    }
233}
234
235pub struct OpensslConnectServiceResponse<T: Address + 'static> {
236    fut1: Option<<ConnectService<T> as Service>::Future>,
237    fut2: Option<<OpensslConnectorService<T, TcpStream> as Service>::Future>,
238    openssl: OpensslConnectorService<T, TcpStream>,
239}
240
241impl<T: Address> Future for OpensslConnectServiceResponse<T> {
242    type Output = Result<SslStream<TcpStream>, ConnectError>;
243
244    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
245        if let Some(ref mut fut) = self.fut1 {
246            match futures::ready!(Pin::new(fut).poll(cx)) {
247                Ok(res) => {
248                    let _ = self.fut1.take();
249                    self.fut2 = Some(self.openssl.call(res));
250                }
251                Err(e) => return Poll::Ready(Err(e)),
252            }
253        }
254
255        if let Some(ref mut fut) = self.fut2 {
256            match futures::ready!(Pin::new(fut).poll(cx)) {
257                Ok(connect) => Poll::Ready(Ok(connect.into_parts().0)),
258                Err(e) => Poll::Ready(Err(ConnectError::Io(io::Error::new(
259                    io::ErrorKind::Other,
260                    e,
261                )))),
262            }
263        } else {
264            Poll::Pending
265        }
266    }
267}