1use std::{
2 io,
3 net::{SocketAddr, TcpListener as StdTcpListener},
4 ops::ControlFlow,
5 pin::{pin, Pin},
6 task::{ready, Context, Poll},
7 time::Duration,
8};
9
10use tokio::{
11 io::{AsyncRead, AsyncWrite},
12 net::{TcpListener, TcpStream},
13};
14use tokio_stream::wrappers::TcpListenerStream;
15use tokio_stream::{Stream, StreamExt};
16use tracing::warn;
17
18use super::service::ServerIo;
19#[cfg(feature = "tls")]
20use super::service::TlsAcceptor;
21
22#[cfg(not(feature = "tls"))]
23pub(crate) fn tcp_incoming<IO, IE>(
24 incoming: impl Stream<Item = Result<IO, IE>>,
25) -> impl Stream<Item = Result<ServerIo<IO>, crate::BoxError>>
26where
27 IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
28 IE: Into<crate::BoxError>,
29{
30 async_stream::try_stream! {
31 let mut incoming = pin!(incoming);
32
33 while let Some(item) = incoming.next().await {
34 yield match item {
35 Ok(_) => item.map(ServerIo::new_io)?,
36 Err(e) => match handle_tcp_accept_error(e) {
37 ControlFlow::Continue(()) => continue,
38 ControlFlow::Break(e) => Err(e)?,
39 }
40 }
41 }
42 }
43}
44
45#[cfg(feature = "tls")]
46pub(crate) fn tcp_incoming<IO, IE>(
47 incoming: impl Stream<Item = Result<IO, IE>>,
48 tls: Option<TlsAcceptor>,
49) -> impl Stream<Item = Result<ServerIo<IO>, crate::BoxError>>
50where
51 IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
52 IE: Into<crate::BoxError>,
53{
54 async_stream::try_stream! {
55 let mut incoming = pin!(incoming);
56
57 let mut tasks = tokio::task::JoinSet::new();
58
59 loop {
60 match select(&mut incoming, &mut tasks).await {
61 SelectOutput::Incoming(stream) => {
62 if let Some(tls) = &tls {
63 let tls = tls.clone();
64 tasks.spawn(async move {
65 let io = tls.accept(stream).await?;
66 Ok(ServerIo::new_tls_io(io))
67 });
68 } else {
69 yield ServerIo::new_io(stream);
70 }
71 }
72
73 SelectOutput::Io(io) => {
74 yield io;
75 }
76
77 SelectOutput::TcpErr(e) => match handle_tcp_accept_error(e) {
78 ControlFlow::Continue(()) => continue,
79 ControlFlow::Break(e) => Err(e)?,
80 }
81
82 SelectOutput::TlsErr(e) => {
83 tracing::debug!(error = %e, "tls accept error");
84 continue;
85 }
86
87 SelectOutput::Done => {
88 break;
89 }
90 }
91 }
92 }
93}
94
95fn handle_tcp_accept_error(e: impl Into<crate::error::BoxError>) -> ControlFlow<crate::error::BoxError> {
96 let e = e.into();
97 tracing::debug!(error = %e, "accept loop error");
98 if let Some(e) = e.downcast_ref::<io::Error>() {
99 if matches!(
100 e.kind(),
101 io::ErrorKind::ConnectionAborted
102 | io::ErrorKind::ConnectionReset
103 | io::ErrorKind::BrokenPipe
104 | io::ErrorKind::Interrupted
105 | io::ErrorKind::WouldBlock
106 | io::ErrorKind::TimedOut
107 ) {
108 return ControlFlow::Continue(());
109 }
110 }
111
112 ControlFlow::Break(e)
113}
114
115#[cfg(feature = "tls")]
116async fn select<IO: 'static, IE>(
117 incoming: &mut (impl Stream<Item = Result<IO, IE>> + Unpin),
118 tasks: &mut tokio::task::JoinSet<Result<ServerIo<IO>, crate::BoxError>>,
119) -> SelectOutput<IO>
120where
121 IE: Into<crate::BoxError>,
122{
123 if tasks.is_empty() {
124 return match incoming.try_next().await {
125 Ok(Some(stream)) => SelectOutput::Incoming(stream),
126 Ok(None) => SelectOutput::Done,
127 Err(e) => SelectOutput::TcpErr(e.into()),
128 };
129 }
130
131 tokio::select! {
132 stream = incoming.try_next() => {
133 match stream {
134 Ok(Some(stream)) => SelectOutput::Incoming(stream),
135 Ok(None) => SelectOutput::Done,
136 Err(e) => SelectOutput::TcpErr(e.into()),
137 }
138 }
139
140 accept = tasks.join_next() => {
141 match accept.expect("JoinSet should never end") {
142 Ok(Ok(io)) => SelectOutput::Io(io),
143 Ok(Err(e)) => SelectOutput::TlsErr(e),
144 Err(e) => SelectOutput::TlsErr(e.into()),
145 }
146 }
147 }
148}
149
150#[cfg(feature = "tls")]
151enum SelectOutput<A> {
152 Incoming(A),
153 Io(ServerIo<A>),
154 TcpErr(crate::BoxError),
155 TlsErr(crate::BoxError),
156 Done,
157}
158
159#[derive(Debug)]
164pub struct TcpIncoming {
165 inner: TcpListenerStream,
166 nodelay: bool,
167 keepalive: Option<Duration>,
168}
169
170impl TcpIncoming {
171 pub fn new(
204 addr: SocketAddr,
205 nodelay: bool,
206 keepalive: Option<Duration>,
207 ) -> Result<Self, crate::BoxError> {
208 let std_listener = StdTcpListener::bind(addr)?;
209 std_listener.set_nonblocking(true)?;
210
211 let inner = TcpListenerStream::new(TcpListener::from_std(std_listener)?);
212 Ok(Self {
213 inner,
214 nodelay,
215 keepalive,
216 })
217 }
218
219 pub fn from_listener(
221 listener: TcpListener,
222 nodelay: bool,
223 keepalive: Option<Duration>,
224 ) -> Result<Self, crate::BoxError> {
225 Ok(Self {
226 inner: TcpListenerStream::new(listener),
227 nodelay,
228 keepalive,
229 })
230 }
231}
232
233impl Stream for TcpIncoming {
234 type Item = Result<TcpStream, std::io::Error>;
235
236 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
237 match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
238 Some(Ok(stream)) => {
239 set_accepted_socket_options(&stream, self.nodelay, self.keepalive);
240 Some(Ok(stream)).into()
241 }
242 other => Poll::Ready(other),
243 }
244 }
245}
246
247fn set_accepted_socket_options(stream: &TcpStream, nodelay: bool, keepalive: Option<Duration>) {
249 if nodelay {
250 if let Err(e) = stream.set_nodelay(true) {
251 warn!("error trying to set TCP nodelay: {}", e);
252 }
253 }
254
255 if let Some(timeout) = keepalive {
256 let sock_ref = socket2::SockRef::from(&stream);
257 let sock_keepalive = socket2::TcpKeepalive::new().with_time(timeout);
258
259 if let Err(e) = sock_ref.set_tcp_keepalive(&sock_keepalive) {
260 warn!("error trying to set TCP keepalive: {}", e);
261 }
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use crate::server::TcpIncoming;
268 #[tokio::test]
269 async fn one_tcpincoming_at_a_time() {
270 let addr = "127.0.0.1:1322".parse().unwrap();
271 {
272 let _t1 = TcpIncoming::new(addr, true, None).unwrap();
273 let _t2 = TcpIncoming::new(addr, true, None).unwrap_err();
274 }
275 let _t3 = TcpIncoming::new(addr, true, None).unwrap();
276 }
277}