1use std::future::Future;
2use std::io;
3use std::marker::PhantomData;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use requiem_codec::{AsyncRead, AsyncWrite};
9use requiem_service::{Service, ServiceFactory};
10use requiem_utils::counter::{Counter, CounterGuard};
11use futures::future::{ok, Ready};
12use tokio_rustls::{Accept, TlsAcceptor};
13
14pub use rust_tls::{ServerConfig, Session};
15pub use tokio_rustls::server::TlsStream;
16pub use webpki_roots::TLS_SERVER_ROOTS;
17
18use crate::MAX_CONN_COUNTER;
19
20pub struct Acceptor<T> {
24 config: Arc<ServerConfig>,
25 io: PhantomData<T>,
26}
27
28impl<T: AsyncRead + AsyncWrite> Acceptor<T> {
29 pub fn new(config: ServerConfig) -> Self {
31 Acceptor {
32 config: Arc::new(config),
33 io: PhantomData,
34 }
35 }
36}
37
38impl<T> Clone for Acceptor<T> {
39 fn clone(&self) -> Self {
40 Self {
41 config: self.config.clone(),
42 io: PhantomData,
43 }
44 }
45}
46
47impl<T: AsyncRead + AsyncWrite + Unpin> ServiceFactory for Acceptor<T> {
48 type Request = T;
49 type Response = TlsStream<T>;
50 type Error = io::Error;
51 type Service = AcceptorService<T>;
52
53 type Config = ();
54 type InitError = ();
55 type Future = Ready<Result<Self::Service, Self::InitError>>;
56
57 fn new_service(&self, _: ()) -> Self::Future {
58 MAX_CONN_COUNTER.with(|conns| {
59 ok(AcceptorService {
60 acceptor: self.config.clone().into(),
61 conns: conns.clone(),
62 io: PhantomData,
63 })
64 })
65 }
66}
67
68pub struct AcceptorService<T> {
70 acceptor: TlsAcceptor,
71 io: PhantomData<T>,
72 conns: Counter,
73}
74
75impl<T: AsyncRead + AsyncWrite + Unpin> Service for AcceptorService<T> {
76 type Request = T;
77 type Response = TlsStream<T>;
78 type Error = io::Error;
79 type Future = AcceptorServiceFut<T>;
80
81 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
82 if self.conns.available(cx) {
83 Poll::Ready(Ok(()))
84 } else {
85 Poll::Pending
86 }
87 }
88
89 fn call(&mut self, req: Self::Request) -> Self::Future {
90 AcceptorServiceFut {
91 _guard: self.conns.get(),
92 fut: self.acceptor.accept(req),
93 }
94 }
95}
96
97pub struct AcceptorServiceFut<T>
98where
99 T: AsyncRead + AsyncWrite + Unpin,
100{
101 fut: Accept<T>,
102 _guard: CounterGuard,
103}
104
105impl<T: AsyncRead + AsyncWrite + Unpin> Future for AcceptorServiceFut<T> {
106 type Output = Result<TlsStream<T>, io::Error>;
107
108 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
109 let this = self.get_mut();
110
111 let res = futures::ready!(Pin::new(&mut this.fut).poll(cx));
112 match res {
113 Ok(io) => Poll::Ready(Ok(io)),
114 Err(e) => Poll::Ready(Err(e)),
115 }
116 }
117}