1use std::future::Future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::{fmt, io};
6
7pub use open_ssl::ssl::{Error as SslError, SslConnector, SslMethod};
8pub use tokio_openssl::{HandshakeError, SslStream};
9
10use requiem_codec::{AsyncRead, AsyncWrite};
11use requiem_rt::net::TcpStream;
12use requiem_service::{Service, ServiceFactory};
13use futures::future::{err, ok, Either, FutureExt, LocalBoxFuture, Ready};
14use trust_dns_resolver::AsyncResolver;
15
16use crate::{
17 Address, Connect, ConnectError, ConnectService, ConnectServiceFactory, Connection,
18};
19
20pub struct OpensslConnector<T, U> {
22 connector: SslConnector,
23 _t: PhantomData<(T, U)>,
24}
25
26impl<T, U> OpensslConnector<T, U> {
27 pub fn new(connector: SslConnector) -> Self {
28 OpensslConnector {
29 connector,
30 _t: PhantomData,
31 }
32 }
33}
34
35impl<T, U> OpensslConnector<T, U>
36where
37 T: Address + 'static,
38 U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
39{
40 pub fn service(connector: SslConnector) -> OpensslConnectorService<T, U> {
41 OpensslConnectorService {
42 connector,
43 _t: PhantomData,
44 }
45 }
46}
47
48impl<T, U> Clone for OpensslConnector<T, U> {
49 fn clone(&self) -> Self {
50 Self {
51 connector: self.connector.clone(),
52 _t: PhantomData,
53 }
54 }
55}
56
57impl<T, U> ServiceFactory for OpensslConnector<T, U>
58where
59 T: Address + 'static,
60 U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
61{
62 type Request = Connection<T, U>;
63 type Response = Connection<T, SslStream<U>>;
64 type Error = io::Error;
65 type Config = ();
66 type Service = OpensslConnectorService<T, U>;
67 type InitError = ();
68 type Future = Ready<Result<Self::Service, Self::InitError>>;
69
70 fn new_service(&self, _: ()) -> Self::Future {
71 ok(OpensslConnectorService {
72 connector: self.connector.clone(),
73 _t: PhantomData,
74 })
75 }
76}
77
78pub struct OpensslConnectorService<T, U> {
79 connector: SslConnector,
80 _t: PhantomData<(T, U)>,
81}
82
83impl<T, U> Clone for OpensslConnectorService<T, U> {
84 fn clone(&self) -> Self {
85 Self {
86 connector: self.connector.clone(),
87 _t: PhantomData,
88 }
89 }
90}
91
92impl<T, U> Service for OpensslConnectorService<T, U>
93where
94 T: Address + 'static,
95 U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
96{
97 type Request = Connection<T, U>;
98 type Response = Connection<T, SslStream<U>>;
99 type Error = io::Error;
100 type Future = Either<ConnectAsyncExt<T, U>, Ready<Result<Self::Response, Self::Error>>>;
101
102 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
103 Poll::Ready(Ok(()))
104 }
105
106 fn call(&mut self, stream: Connection<T, U>) -> Self::Future {
107 trace!("SSL Handshake start for: {:?}", stream.host());
108 let (io, stream) = stream.replace(());
109 let host = stream.host().to_string();
110
111 match self.connector.configure() {
112 Err(e) => Either::Right(err(io::Error::new(io::ErrorKind::Other, e))),
113 Ok(config) => Either::Left(ConnectAsyncExt {
114 fut: async move { tokio_openssl::connect(config, &host, io).await }
115 .boxed_local(),
116 stream: Some(stream),
117 _t: PhantomData,
118 }),
119 }
120 }
121}
122
123pub struct ConnectAsyncExt<T, U> {
124 fut: LocalBoxFuture<'static, Result<SslStream<U>, HandshakeError<U>>>,
125 stream: Option<Connection<T, ()>>,
126 _t: PhantomData<U>,
127}
128
129impl<T: Address, U> Future for ConnectAsyncExt<T, U>
130where
131 U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
132{
133 type Output = Result<Connection<T, SslStream<U>>, io::Error>;
134
135 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
136 let this = self.get_mut();
137
138 match Pin::new(&mut this.fut).poll(cx) {
139 Poll::Ready(Ok(stream)) => {
140 let s = this.stream.take().unwrap();
141 trace!("SSL Handshake success: {:?}", s.host());
142 Poll::Ready(Ok(s.replace(stream).1))
143 }
144 Poll::Ready(Err(e)) => {
145 trace!("SSL Handshake error: {:?}", e);
146 Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))))
147 }
148 Poll::Pending => Poll::Pending,
149 }
150 }
151}
152
153pub struct OpensslConnectServiceFactory<T> {
154 tcp: ConnectServiceFactory<T>,
155 openssl: OpensslConnector<T, TcpStream>,
156}
157
158impl<T> OpensslConnectServiceFactory<T> {
159 pub fn new(connector: SslConnector) -> Self {
161 OpensslConnectServiceFactory {
162 tcp: ConnectServiceFactory::default(),
163 openssl: OpensslConnector::new(connector),
164 }
165 }
166
167 pub fn with_resolver(connector: SslConnector, resolver: AsyncResolver) -> Self {
169 OpensslConnectServiceFactory {
170 tcp: ConnectServiceFactory::with_resolver(resolver),
171 openssl: OpensslConnector::new(connector),
172 }
173 }
174
175 pub fn service(&self) -> OpensslConnectService<T> {
177 OpensslConnectService {
178 tcp: self.tcp.service(),
179 openssl: OpensslConnectorService {
180 connector: self.openssl.connector.clone(),
181 _t: PhantomData,
182 },
183 }
184 }
185}
186
187impl<T> Clone for OpensslConnectServiceFactory<T> {
188 fn clone(&self) -> Self {
189 OpensslConnectServiceFactory {
190 tcp: self.tcp.clone(),
191 openssl: self.openssl.clone(),
192 }
193 }
194}
195
196impl<T: Address + 'static> ServiceFactory for OpensslConnectServiceFactory<T> {
197 type Request = Connect<T>;
198 type Response = SslStream<TcpStream>;
199 type Error = ConnectError;
200 type Config = ();
201 type Service = OpensslConnectService<T>;
202 type InitError = ();
203 type Future = Ready<Result<Self::Service, Self::InitError>>;
204
205 fn new_service(&self, _: ()) -> Self::Future {
206 ok(self.service())
207 }
208}
209
210#[derive(Clone)]
211pub struct OpensslConnectService<T> {
212 tcp: ConnectService<T>,
213 openssl: OpensslConnectorService<T, TcpStream>,
214}
215
216impl<T: Address + 'static> Service for OpensslConnectService<T> {
217 type Request = Connect<T>;
218 type Response = SslStream<TcpStream>;
219 type Error = ConnectError;
220 type Future = OpensslConnectServiceResponse<T>;
221
222 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
223 Poll::Ready(Ok(()))
224 }
225
226 fn call(&mut self, req: Connect<T>) -> Self::Future {
227 OpensslConnectServiceResponse {
228 fut1: Some(self.tcp.call(req)),
229 fut2: None,
230 openssl: self.openssl.clone(),
231 }
232 }
233}
234
235pub struct OpensslConnectServiceResponse<T: Address + 'static> {
236 fut1: Option<<ConnectService<T> as Service>::Future>,
237 fut2: Option<<OpensslConnectorService<T, TcpStream> as Service>::Future>,
238 openssl: OpensslConnectorService<T, TcpStream>,
239}
240
241impl<T: Address> Future for OpensslConnectServiceResponse<T> {
242 type Output = Result<SslStream<TcpStream>, ConnectError>;
243
244 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
245 if let Some(ref mut fut) = self.fut1 {
246 match futures::ready!(Pin::new(fut).poll(cx)) {
247 Ok(res) => {
248 let _ = self.fut1.take();
249 self.fut2 = Some(self.openssl.call(res));
250 }
251 Err(e) => return Poll::Ready(Err(e)),
252 }
253 }
254
255 if let Some(ref mut fut) = self.fut2 {
256 match futures::ready!(Pin::new(fut).poll(cx)) {
257 Ok(connect) => Poll::Ready(Ok(connect.into_parts().0)),
258 Err(e) => Poll::Ready(Err(ConnectError::Io(io::Error::new(
259 io::ErrorKind::Other,
260 e,
261 )))),
262 }
263 } else {
264 Poll::Pending
265 }
266 }
267}