actix_connect/ssl/
openssl.rs

1use std::future::Future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::{fmt, io};
6
7pub use open_ssl::ssl::{Error as SslError, SslConnector, SslMethod};
8pub use tokio_openssl::{HandshakeError, SslStream};
9
10use actix_codec::{AsyncRead, AsyncWrite};
11use actix_rt::net::TcpStream;
12use actix_service::{Service, ServiceFactory};
13use futures_util::future::{err, ok, Either, FutureExt, LocalBoxFuture, Ready};
14use trust_dns_resolver::TokioAsyncResolver as AsyncResolver;
15
16use crate::{
17    Address, Connect, ConnectError, ConnectService, ConnectServiceFactory, Connection,
18};
19
20/// OpenSSL connector factory
21pub struct OpensslConnector<T, U> {
22    connector: SslConnector,
23    _t: PhantomData<(T, U)>,
24}
25
26impl<T, U> OpensslConnector<T, U> {
27    pub fn new(connector: SslConnector) -> Self {
28        OpensslConnector {
29            connector,
30            _t: PhantomData,
31        }
32    }
33}
34
35impl<T, U> OpensslConnector<T, U>
36where
37    T: Address + 'static,
38    U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
39{
40    pub fn service(connector: SslConnector) -> OpensslConnectorService<T, U> {
41        OpensslConnectorService {
42            connector,
43            _t: PhantomData,
44        }
45    }
46}
47
48impl<T, U> Clone for OpensslConnector<T, U> {
49    fn clone(&self) -> Self {
50        Self {
51            connector: self.connector.clone(),
52            _t: PhantomData,
53        }
54    }
55}
56
57impl<T, U> ServiceFactory for OpensslConnector<T, U>
58where
59    T: Address + 'static,
60    U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
61{
62    type Request = Connection<T, U>;
63    type Response = Connection<T, SslStream<U>>;
64    type Error = io::Error;
65    type Config = ();
66    type Service = OpensslConnectorService<T, U>;
67    type InitError = ();
68    type Future = Ready<Result<Self::Service, Self::InitError>>;
69
70    fn new_service(&self, _: ()) -> Self::Future {
71        ok(OpensslConnectorService {
72            connector: self.connector.clone(),
73            _t: PhantomData,
74        })
75    }
76}
77
78pub struct OpensslConnectorService<T, U> {
79    connector: SslConnector,
80    _t: PhantomData<(T, U)>,
81}
82
83impl<T, U> Clone for OpensslConnectorService<T, U> {
84    fn clone(&self) -> Self {
85        Self {
86            connector: self.connector.clone(),
87            _t: PhantomData,
88        }
89    }
90}
91
92impl<T, U> Service for OpensslConnectorService<T, U>
93where
94    T: Address + 'static,
95    U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
96{
97    type Request = Connection<T, U>;
98    type Response = Connection<T, SslStream<U>>;
99    type Error = io::Error;
100    #[allow(clippy::type_complexity)]
101    type Future = Either<ConnectAsyncExt<T, U>, Ready<Result<Self::Response, Self::Error>>>;
102
103    fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
104        Poll::Ready(Ok(()))
105    }
106
107    fn call(&mut self, stream: Connection<T, U>) -> Self::Future {
108        trace!("SSL Handshake start for: {:?}", stream.host());
109        let (io, stream) = stream.replace(());
110        let host = stream.host().to_string();
111
112        match self.connector.configure() {
113            Err(e) => Either::Right(err(io::Error::new(io::ErrorKind::Other, e))),
114            Ok(config) => Either::Left(ConnectAsyncExt {
115                fut: async move { tokio_openssl::connect(config, &host, io).await }
116                    .boxed_local(),
117                stream: Some(stream),
118                _t: PhantomData,
119            }),
120        }
121    }
122}
123
124pub struct ConnectAsyncExt<T, U> {
125    fut: LocalBoxFuture<'static, Result<SslStream<U>, HandshakeError<U>>>,
126    stream: Option<Connection<T, ()>>,
127    _t: PhantomData<U>,
128}
129
130impl<T: Address, U> Future for ConnectAsyncExt<T, U>
131where
132    U: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
133{
134    type Output = Result<Connection<T, SslStream<U>>, io::Error>;
135
136    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
137        let this = self.get_mut();
138
139        match Pin::new(&mut this.fut).poll(cx) {
140            Poll::Ready(Ok(stream)) => {
141                let s = this.stream.take().unwrap();
142                trace!("SSL Handshake success: {:?}", s.host());
143                Poll::Ready(Ok(s.replace(stream).1))
144            }
145            Poll::Ready(Err(e)) => {
146                trace!("SSL Handshake error: {:?}", e);
147                Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))))
148            }
149            Poll::Pending => Poll::Pending,
150        }
151    }
152}
153
154pub struct OpensslConnectServiceFactory<T> {
155    tcp: ConnectServiceFactory<T>,
156    openssl: OpensslConnector<T, TcpStream>,
157}
158
159impl<T> OpensslConnectServiceFactory<T> {
160    /// Construct new OpensslConnectService factory
161    pub fn new(connector: SslConnector) -> Self {
162        OpensslConnectServiceFactory {
163            tcp: ConnectServiceFactory::default(),
164            openssl: OpensslConnector::new(connector),
165        }
166    }
167
168    /// Construct new connect service with custom DNS resolver
169    pub fn with_resolver(connector: SslConnector, resolver: AsyncResolver) -> Self {
170        OpensslConnectServiceFactory {
171            tcp: ConnectServiceFactory::with_resolver(resolver),
172            openssl: OpensslConnector::new(connector),
173        }
174    }
175
176    /// Construct OpenSSL connect service
177    pub fn service(&self) -> OpensslConnectService<T> {
178        OpensslConnectService {
179            tcp: self.tcp.service(),
180            openssl: OpensslConnectorService {
181                connector: self.openssl.connector.clone(),
182                _t: PhantomData,
183            },
184        }
185    }
186}
187
188impl<T> Clone for OpensslConnectServiceFactory<T> {
189    fn clone(&self) -> Self {
190        OpensslConnectServiceFactory {
191            tcp: self.tcp.clone(),
192            openssl: self.openssl.clone(),
193        }
194    }
195}
196
197impl<T: Address + 'static> ServiceFactory for OpensslConnectServiceFactory<T> {
198    type Request = Connect<T>;
199    type Response = SslStream<TcpStream>;
200    type Error = ConnectError;
201    type Config = ();
202    type Service = OpensslConnectService<T>;
203    type InitError = ();
204    type Future = Ready<Result<Self::Service, Self::InitError>>;
205
206    fn new_service(&self, _: ()) -> Self::Future {
207        ok(self.service())
208    }
209}
210
211#[derive(Clone)]
212pub struct OpensslConnectService<T> {
213    tcp: ConnectService<T>,
214    openssl: OpensslConnectorService<T, TcpStream>,
215}
216
217impl<T: Address + 'static> Service for OpensslConnectService<T> {
218    type Request = Connect<T>;
219    type Response = SslStream<TcpStream>;
220    type Error = ConnectError;
221    type Future = OpensslConnectServiceResponse<T>;
222
223    fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
224        Poll::Ready(Ok(()))
225    }
226
227    fn call(&mut self, req: Connect<T>) -> Self::Future {
228        OpensslConnectServiceResponse {
229            fut1: Some(self.tcp.call(req)),
230            fut2: None,
231            openssl: self.openssl.clone(),
232        }
233    }
234}
235
236pub struct OpensslConnectServiceResponse<T: Address + 'static> {
237    fut1: Option<<ConnectService<T> as Service>::Future>,
238    fut2: Option<<OpensslConnectorService<T, TcpStream> as Service>::Future>,
239    openssl: OpensslConnectorService<T, TcpStream>,
240}
241
242impl<T: Address> Future for OpensslConnectServiceResponse<T> {
243    type Output = Result<SslStream<TcpStream>, ConnectError>;
244
245    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
246        if let Some(ref mut fut) = self.fut1 {
247            match futures_util::ready!(Pin::new(fut).poll(cx)) {
248                Ok(res) => {
249                    let _ = self.fut1.take();
250                    self.fut2 = Some(self.openssl.call(res));
251                }
252                Err(e) => return Poll::Ready(Err(e)),
253            }
254        }
255
256        if let Some(ref mut fut) = self.fut2 {
257            match futures_util::ready!(Pin::new(fut).poll(cx)) {
258                Ok(connect) => Poll::Ready(Ok(connect.into_parts().0)),
259                Err(e) => Poll::Ready(Err(ConnectError::Io(io::Error::new(
260                    io::ErrorKind::Other,
261                    e,
262                )))),
263            }
264        } else {
265            Poll::Pending
266        }
267    }
268}