requiem_tls/
openssl.rs

1use std::future::Future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6pub use open_ssl::ssl::{AlpnError, SslAcceptor, SslAcceptorBuilder};
7pub use tokio_openssl::{HandshakeError, SslStream};
8
9use requiem_codec::{AsyncRead, AsyncWrite};
10use requiem_service::{Service, ServiceFactory};
11use requiem_utils::counter::{Counter, CounterGuard};
12use futures::future::{ok, FutureExt, LocalBoxFuture, Ready};
13
14use crate::MAX_CONN_COUNTER;
15
16/// Support `TLS` server connections via openssl package
17///
18/// `openssl` feature enables `Acceptor` type
19pub struct Acceptor<T: AsyncRead + AsyncWrite> {
20    acceptor: SslAcceptor,
21    io: PhantomData<T>,
22}
23
24impl<T: AsyncRead + AsyncWrite> Acceptor<T> {
25    /// Create default `OpensslAcceptor`
26    pub fn new(acceptor: SslAcceptor) -> Self {
27        Acceptor {
28            acceptor,
29            io: PhantomData,
30        }
31    }
32}
33
34impl<T: AsyncRead + AsyncWrite> Clone for Acceptor<T> {
35    fn clone(&self) -> Self {
36        Self {
37            acceptor: self.acceptor.clone(),
38            io: PhantomData,
39        }
40    }
41}
42
43impl<T: AsyncRead + AsyncWrite + Unpin + 'static> ServiceFactory for Acceptor<T> {
44    type Request = T;
45    type Response = SslStream<T>;
46    type Error = HandshakeError<T>;
47    type Config = ();
48    type Service = AcceptorService<T>;
49    type InitError = ();
50    type Future = Ready<Result<Self::Service, Self::InitError>>;
51
52    fn new_service(&self, _: ()) -> Self::Future {
53        MAX_CONN_COUNTER.with(|conns| {
54            ok(AcceptorService {
55                acceptor: self.acceptor.clone(),
56                conns: conns.clone(),
57                io: PhantomData,
58            })
59        })
60    }
61}
62
63pub struct AcceptorService<T> {
64    acceptor: SslAcceptor,
65    conns: Counter,
66    io: PhantomData<T>,
67}
68
69impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Service for AcceptorService<T> {
70    type Request = T;
71    type Response = SslStream<T>;
72    type Error = HandshakeError<T>;
73    type Future = AcceptorServiceResponse<T>;
74
75    fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
76        if self.conns.available(ctx) {
77            Poll::Ready(Ok(()))
78        } else {
79            Poll::Pending
80        }
81    }
82
83    fn call(&mut self, req: Self::Request) -> Self::Future {
84        let acc = self.acceptor.clone();
85        AcceptorServiceResponse {
86            _guard: self.conns.get(),
87            fut: async move {
88                let acc = acc;
89                tokio_openssl::accept(&acc, req).await
90            }
91            .boxed_local(),
92        }
93    }
94}
95
96pub struct AcceptorServiceResponse<T>
97where
98    T: AsyncRead + AsyncWrite,
99{
100    fut: LocalBoxFuture<'static, Result<SslStream<T>, HandshakeError<T>>>,
101    _guard: CounterGuard,
102}
103
104impl<T: AsyncRead + AsyncWrite + Unpin> Future for AcceptorServiceResponse<T> {
105    type Output = Result<SslStream<T>, HandshakeError<T>>;
106
107    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
108        let io = futures::ready!(Pin::new(&mut self.fut).poll(cx))?;
109        Poll::Ready(Ok(io))
110    }
111}