1use crate::{
2 common::tokio_stream::TokioListenerStream, ConnectionError, LocalAddress, ResolvedTarget,
3 RewindStream, Ssl, SslError, StreamUpgrade, TlsDriver, TlsServerParameterProvider,
4 UpgradableStream,
5};
6use futures::{FutureExt, StreamExt};
7use std::{
8 future::Future,
9 pin::Pin,
10 task::{ready, Poll},
11};
12use std::{net::SocketAddr, path::Path};
13
14use super::Connection;
15
16pub struct Acceptor {
17 resolved_target: ResolvedTarget,
18 tls_provider: Option<TlsServerParameterProvider>,
19 should_upgrade: bool,
20 ignore_missing_tls_close_notify: bool,
21}
22
23impl Acceptor {
24 pub fn new(target: ResolvedTarget) -> Self {
25 Self {
26 resolved_target: target,
27 tls_provider: None,
28 should_upgrade: false,
29 ignore_missing_tls_close_notify: false,
30 }
31 }
32
33 pub fn new_tls(target: ResolvedTarget, provider: TlsServerParameterProvider) -> Self {
34 Self {
35 resolved_target: target,
36 tls_provider: Some(provider),
37 should_upgrade: true,
38 ignore_missing_tls_close_notify: false,
39 }
40 }
41
42 pub fn new_starttls(target: ResolvedTarget, provider: TlsServerParameterProvider) -> Self {
43 Self {
44 resolved_target: target,
45 tls_provider: Some(provider),
46 should_upgrade: false,
47 ignore_missing_tls_close_notify: false,
48 }
49 }
50
51 pub fn new_tcp(addr: SocketAddr) -> Self {
52 Self {
53 resolved_target: ResolvedTarget::SocketAddr(addr),
54 tls_provider: None,
55 should_upgrade: false,
56 ignore_missing_tls_close_notify: false,
57 }
58 }
59
60 pub fn new_tcp_tls(addr: SocketAddr, provider: TlsServerParameterProvider) -> Self {
61 Self {
62 resolved_target: ResolvedTarget::SocketAddr(addr),
63 tls_provider: Some(provider),
64 should_upgrade: true,
65 ignore_missing_tls_close_notify: false,
66 }
67 }
68
69 pub fn new_tcp_starttls(addr: SocketAddr, provider: TlsServerParameterProvider) -> Self {
70 Self {
71 resolved_target: ResolvedTarget::SocketAddr(addr),
72 tls_provider: Some(provider),
73 should_upgrade: false,
74 ignore_missing_tls_close_notify: false,
75 }
76 }
77
78 pub fn new_unix_path(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
79 #[cfg(unix)]
80 {
81 Ok(Self {
82 resolved_target: ResolvedTarget::from(
83 std::os::unix::net::SocketAddr::from_pathname(path)?,
84 ),
85 tls_provider: None,
86 should_upgrade: false,
87 ignore_missing_tls_close_notify: false,
88 })
89 }
90 #[cfg(not(unix))]
91 {
92 Err(std::io::Error::new(
93 std::io::ErrorKind::Unsupported,
94 "Unix domain sockets are not supported on this platform",
95 ))
96 }
97 }
98
99 pub fn new_unix_domain(domain: impl AsRef<[u8]>) -> Result<Self, std::io::Error> {
100 #[cfg(any(target_os = "linux", target_os = "android"))]
101 {
102 use std::os::linux::net::SocketAddrExt;
103 Ok(Self {
104 resolved_target: ResolvedTarget::from(
105 std::os::unix::net::SocketAddr::from_abstract_name(domain)?,
106 ),
107 tls_provider: None,
108 should_upgrade: false,
109 ignore_missing_tls_close_notify: false,
110 })
111 }
112 #[cfg(not(any(target_os = "linux", target_os = "android")))]
113 {
114 Err(std::io::Error::new(
115 std::io::ErrorKind::Unsupported,
116 "Unix domain sockets are not supported on this platform",
117 ))
118 }
119 }
120
121 pub async fn bind(
122 self,
123 ) -> Result<
124 impl ::futures::Stream<Item = Result<Connection, ConnectionError>> + LocalAddress,
125 ConnectionError,
126 > {
127 let stream = self.resolved_target.listen_raw().await?;
128 Ok(AcceptedStream {
129 stream,
130 should_upgrade: self.should_upgrade,
131 ignore_missing_tls_close_notify: self.ignore_missing_tls_close_notify,
132 upgrade_future: None,
133 tls_provider: self.tls_provider,
134 _phantom: None,
135 })
136 }
137
138 #[allow(private_bounds)]
139 pub async fn bind_explicit<D: TlsDriver>(
140 self,
141 ) -> Result<
142 impl ::futures::Stream<Item = Result<Connection<D>, ConnectionError>> + LocalAddress,
143 ConnectionError,
144 > {
145 let stream = self.resolved_target.listen_raw().await?;
146 Ok(AcceptedStream {
147 stream,
148 ignore_missing_tls_close_notify: self.ignore_missing_tls_close_notify,
149 should_upgrade: self.should_upgrade,
150 upgrade_future: None,
151 tls_provider: self.tls_provider,
152 _phantom: None,
153 })
154 }
155
156 pub async fn accept_one(self) -> Result<Connection, std::io::Error> {
157 let mut stream = self.resolved_target.listen().await?;
158 let (stream, _target) = stream.next().await.unwrap()?;
159 let mut stm = UpgradableStream::new_server(
160 RewindStream::new(stream),
161 None::<TlsServerParameterProvider>,
162 );
163 if self.ignore_missing_tls_close_notify {
164 stm.ignore_missing_close_notify();
165 }
166 Ok(stm)
167 }
168}
169
170struct AcceptedStream<D: TlsDriver = Ssl> {
171 stream: TokioListenerStream,
172 should_upgrade: bool,
173 ignore_missing_tls_close_notify: bool,
174 tls_provider: Option<TlsServerParameterProvider>,
175 #[allow(clippy::type_complexity)]
176 upgrade_future:
177 Option<Pin<Box<dyn Future<Output = Result<Connection<D>, SslError>> + Send + 'static>>>,
178 _phantom: Option<&'static D>,
180}
181
182impl<D: TlsDriver> LocalAddress for AcceptedStream<D> {
183 fn local_address(&self) -> std::io::Result<ResolvedTarget> {
184 self.stream.local_address()
185 }
186}
187
188impl<D: TlsDriver> futures::Stream for AcceptedStream<D> {
189 type Item = Result<Connection<D>, ConnectionError>;
190
191 fn poll_next(
192 mut self: std::pin::Pin<&mut Self>,
193 cx: &mut std::task::Context<'_>,
194 ) -> Poll<Option<Self::Item>> {
195 if let Some(mut upgrade_future) = self.upgrade_future.take() {
196 match upgrade_future.poll_unpin(cx) {
197 Poll::Ready(Ok(conn)) => {
198 return Poll::Ready(Some(Ok(conn)));
199 }
200 Poll::Ready(Err(e)) => {
201 return Poll::Ready(Some(Err(e.into())));
202 }
203 Poll::Pending => {
204 self.upgrade_future = Some(upgrade_future);
205 return Poll::Pending;
206 }
207 }
208 }
209 let r = ready!(self.stream.poll_next_unpin(cx));
210 let Some(r) = r else {
211 return Poll::Ready(None);
212 };
213 let (stream, _target) = r?;
214 let mut stream =
215 UpgradableStream::new_server(RewindStream::new(stream), self.tls_provider.clone());
216 if self.ignore_missing_tls_close_notify {
217 stream.ignore_missing_close_notify();
218 }
219 if self.should_upgrade {
220 let mut upgrade_future = Box::pin(async move {
221 stream.secure_upgrade().await?;
222 Ok::<_, SslError>(stream)
223 });
224 match upgrade_future.poll_unpin(cx) {
225 Poll::Ready(Ok(stream)) => {
226 return Poll::Ready(Some(Ok(stream)));
227 }
228 Poll::Ready(Err(e)) => {
229 return Poll::Ready(Some(Err(e.into())));
230 }
231 Poll::Pending => {
232 self.upgrade_future = Some(upgrade_future);
233 return Poll::Pending;
234 }
235 }
236 }
237 Poll::Ready(Some(Ok(stream)))
238 }
239}