1use std::future::Future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::{fmt, io};
6
7pub use open_ssl::ssl::{Error as SslError, SslConnector, SslMethod};
8pub use tokio_openssl::{HandshakeError, SslStream};
9
10use actix_codec::{AsyncRead, AsyncWrite};
11use actix_rt::net::TcpStream;
12use actix_service::{Service, ServiceFactory};
13use futures_util::future::{err, ok, Either, FutureExt, LocalBoxFuture, Ready};
14use trust_dns_resolver::TokioAsyncResolver as AsyncResolver;
15
16use crate::{
17 Address, Connect, ConnectError, ConnectService, ConnectServiceFactory, Connection,
18};
19
20pub struct OpensslConnector<T, U> {
22 connector: SslConnector,
23 _t: PhantomData<(T, U)>,
24}
25
26impl<T, U> OpensslConnector<T, U> {
27 pub fn new(connector: SslConnector) -> Self {
28 OpensslConnector {
29 connector,
30 _t: PhantomData,
31 }
32 }
33}
34
35impl<T, U> OpensslConnector<T, U>
36where
37 T: Address + 'static,
38 U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
39{
40 pub fn service(connector: SslConnector) -> OpensslConnectorService<T, U> {
41 OpensslConnectorService {
42 connector,
43 _t: PhantomData,
44 }
45 }
46}
47
48impl<T, U> Clone for OpensslConnector<T, U> {
49 fn clone(&self) -> Self {
50 Self {
51 connector: self.connector.clone(),
52 _t: PhantomData,
53 }
54 }
55}
56
57impl<T, U> ServiceFactory for OpensslConnector<T, U>
58where
59 T: Address + 'static,
60 U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
61{
62 type Request = Connection<T, U>;
63 type Response = Connection<T, SslStream<U>>;
64 type Error = io::Error;
65 type Config = ();
66 type Service = OpensslConnectorService<T, U>;
67 type InitError = ();
68 type Future = Ready<Result<Self::Service, Self::InitError>>;
69
70 fn new_service(&self, _: ()) -> Self::Future {
71 ok(OpensslConnectorService {
72 connector: self.connector.clone(),
73 _t: PhantomData,
74 })
75 }
76}
77
78pub struct OpensslConnectorService<T, U> {
79 connector: SslConnector,
80 _t: PhantomData<(T, U)>,
81}
82
83impl<T, U> Clone for OpensslConnectorService<T, U> {
84 fn clone(&self) -> Self {
85 Self {
86 connector: self.connector.clone(),
87 _t: PhantomData,
88 }
89 }
90}
91
92impl<T, U> Service for OpensslConnectorService<T, U>
93where
94 T: Address + 'static,
95 U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
96{
97 type Request = Connection<T, U>;
98 type Response = Connection<T, SslStream<U>>;
99 type Error = io::Error;
100 #[allow(clippy::type_complexity)]
101 type Future = Either<ConnectAsyncExt<T, U>, Ready<Result<Self::Response, Self::Error>>>;
102
103 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
104 Poll::Ready(Ok(()))
105 }
106
107 fn call(&mut self, stream: Connection<T, U>) -> Self::Future {
108 trace!("SSL Handshake start for: {:?}", stream.host());
109 let (io, stream) = stream.replace(());
110 let host = stream.host().to_string();
111
112 match self.connector.configure() {
113 Err(e) => Either::Right(err(io::Error::new(io::ErrorKind::Other, e))),
114 Ok(config) => Either::Left(ConnectAsyncExt {
115 fut: async move { tokio_openssl::connect(config, &host, io).await }
116 .boxed_local(),
117 stream: Some(stream),
118 _t: PhantomData,
119 }),
120 }
121 }
122}
123
124pub struct ConnectAsyncExt<T, U> {
125 fut: LocalBoxFuture<'static, Result<SslStream<U>, HandshakeError<U>>>,
126 stream: Option<Connection<T, ()>>,
127 _t: PhantomData<U>,
128}
129
130impl<T: Address, U> Future for ConnectAsyncExt<T, U>
131where
132 U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
133{
134 type Output = Result<Connection<T, SslStream<U>>, io::Error>;
135
136 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
137 let this = self.get_mut();
138
139 match Pin::new(&mut this.fut).poll(cx) {
140 Poll::Ready(Ok(stream)) => {
141 let s = this.stream.take().unwrap();
142 trace!("SSL Handshake success: {:?}", s.host());
143 Poll::Ready(Ok(s.replace(stream).1))
144 }
145 Poll::Ready(Err(e)) => {
146 trace!("SSL Handshake error: {:?}", e);
147 Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))))
148 }
149 Poll::Pending => Poll::Pending,
150 }
151 }
152}
153
154pub struct OpensslConnectServiceFactory<T> {
155 tcp: ConnectServiceFactory<T>,
156 openssl: OpensslConnector<T, TcpStream>,
157}
158
159impl<T> OpensslConnectServiceFactory<T> {
160 pub fn new(connector: SslConnector) -> Self {
162 OpensslConnectServiceFactory {
163 tcp: ConnectServiceFactory::default(),
164 openssl: OpensslConnector::new(connector),
165 }
166 }
167
168 pub fn with_resolver(connector: SslConnector, resolver: AsyncResolver) -> Self {
170 OpensslConnectServiceFactory {
171 tcp: ConnectServiceFactory::with_resolver(resolver),
172 openssl: OpensslConnector::new(connector),
173 }
174 }
175
176 pub fn service(&self) -> OpensslConnectService<T> {
178 OpensslConnectService {
179 tcp: self.tcp.service(),
180 openssl: OpensslConnectorService {
181 connector: self.openssl.connector.clone(),
182 _t: PhantomData,
183 },
184 }
185 }
186}
187
188impl<T> Clone for OpensslConnectServiceFactory<T> {
189 fn clone(&self) -> Self {
190 OpensslConnectServiceFactory {
191 tcp: self.tcp.clone(),
192 openssl: self.openssl.clone(),
193 }
194 }
195}
196
197impl<T: Address + 'static> ServiceFactory for OpensslConnectServiceFactory<T> {
198 type Request = Connect<T>;
199 type Response = SslStream<TcpStream>;
200 type Error = ConnectError;
201 type Config = ();
202 type Service = OpensslConnectService<T>;
203 type InitError = ();
204 type Future = Ready<Result<Self::Service, Self::InitError>>;
205
206 fn new_service(&self, _: ()) -> Self::Future {
207 ok(self.service())
208 }
209}
210
211#[derive(Clone)]
212pub struct OpensslConnectService<T> {
213 tcp: ConnectService<T>,
214 openssl: OpensslConnectorService<T, TcpStream>,
215}
216
217impl<T: Address + 'static> Service for OpensslConnectService<T> {
218 type Request = Connect<T>;
219 type Response = SslStream<TcpStream>;
220 type Error = ConnectError;
221 type Future = OpensslConnectServiceResponse<T>;
222
223 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
224 Poll::Ready(Ok(()))
225 }
226
227 fn call(&mut self, req: Connect<T>) -> Self::Future {
228 OpensslConnectServiceResponse {
229 fut1: Some(self.tcp.call(req)),
230 fut2: None,
231 openssl: self.openssl.clone(),
232 }
233 }
234}
235
236pub struct OpensslConnectServiceResponse<T: Address + 'static> {
237 fut1: Option<<ConnectService<T> as Service>::Future>,
238 fut2: Option<<OpensslConnectorService<T, TcpStream> as Service>::Future>,
239 openssl: OpensslConnectorService<T, TcpStream>,
240}
241
242impl<T: Address> Future for OpensslConnectServiceResponse<T> {
243 type Output = Result<SslStream<TcpStream>, ConnectError>;
244
245 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
246 if let Some(ref mut fut) = self.fut1 {
247 match futures_util::ready!(Pin::new(fut).poll(cx)) {
248 Ok(res) => {
249 let _ = self.fut1.take();
250 self.fut2 = Some(self.openssl.call(res));
251 }
252 Err(e) => return Poll::Ready(Err(e)),
253 }
254 }
255
256 if let Some(ref mut fut) = self.fut2 {
257 match futures_util::ready!(Pin::new(fut).poll(cx)) {
258 Ok(connect) => Poll::Ready(Ok(connect.into_parts().0)),
259 Err(e) => Poll::Ready(Err(ConnectError::Io(io::Error::new(
260 io::ErrorKind::Other,
261 e,
262 )))),
263 }
264 } else {
265 Poll::Pending
266 }
267 }
268}