tonic_rustls/server/
incoming.rs

1use std::{
2    io,
3    net::{SocketAddr, TcpListener as StdTcpListener},
4    ops::ControlFlow,
5    pin::{pin, Pin},
6    task::{ready, Context, Poll},
7    time::Duration,
8};
9
10use tokio::{
11    io::{AsyncRead, AsyncWrite},
12    net::{TcpListener, TcpStream},
13};
14use tokio_stream::wrappers::TcpListenerStream;
15use tokio_stream::{Stream, StreamExt};
16use tracing::warn;
17
18use super::service::ServerIo;
19#[cfg(feature = "tls")]
20use super::service::TlsAcceptor;
21
22#[cfg(not(feature = "tls"))]
23pub(crate) fn tcp_incoming<IO, IE>(
24    incoming: impl Stream<Item = Result<IO, IE>>,
25) -> impl Stream<Item = Result<ServerIo<IO>, crate::BoxError>>
26where
27    IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
28    IE: Into<crate::BoxError>,
29{
30    async_stream::try_stream! {
31        let mut incoming = pin!(incoming);
32
33        while let Some(item) = incoming.next().await {
34            yield match item {
35                Ok(_) => item.map(ServerIo::new_io)?,
36                Err(e) => match handle_tcp_accept_error(e) {
37                    ControlFlow::Continue(()) => continue,
38                    ControlFlow::Break(e) => Err(e)?,
39                }
40            }
41        }
42    }
43}
44
45#[cfg(feature = "tls")]
46pub(crate) fn tcp_incoming<IO, IE>(
47    incoming: impl Stream<Item = Result<IO, IE>>,
48    tls: Option<TlsAcceptor>,
49) -> impl Stream<Item = Result<ServerIo<IO>, crate::BoxError>>
50where
51    IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
52    IE: Into<crate::BoxError>,
53{
54    async_stream::try_stream! {
55        let mut incoming = pin!(incoming);
56
57        let mut tasks = tokio::task::JoinSet::new();
58
59        loop {
60            match select(&mut incoming, &mut tasks).await {
61                SelectOutput::Incoming(stream) => {
62                    if let Some(tls) = &tls {
63                        let tls = tls.clone();
64                        tasks.spawn(async move {
65                            let io = tls.accept(stream).await?;
66                            Ok(ServerIo::new_tls_io(io))
67                        });
68                    } else {
69                        yield ServerIo::new_io(stream);
70                    }
71                }
72
73                SelectOutput::Io(io) => {
74                    yield io;
75                }
76
77                SelectOutput::TcpErr(e) => match handle_tcp_accept_error(e) {
78                    ControlFlow::Continue(()) => continue,
79                    ControlFlow::Break(e) => Err(e)?,
80                }
81
82                SelectOutput::TlsErr(e) => {
83                    tracing::debug!(error = %e, "tls accept error");
84                    continue;
85                }
86
87                SelectOutput::Done => {
88                    break;
89                }
90            }
91        }
92    }
93}
94
95fn handle_tcp_accept_error(e: impl Into<crate::error::BoxError>) -> ControlFlow<crate::error::BoxError> {
96    let e = e.into();
97    tracing::debug!(error = %e, "accept loop error");
98    if let Some(e) = e.downcast_ref::<io::Error>() {
99        if matches!(
100            e.kind(),
101            io::ErrorKind::ConnectionAborted
102                | io::ErrorKind::ConnectionReset
103                | io::ErrorKind::BrokenPipe
104                | io::ErrorKind::Interrupted
105                | io::ErrorKind::WouldBlock
106                | io::ErrorKind::TimedOut
107        ) {
108            return ControlFlow::Continue(());
109        }
110    }
111
112    ControlFlow::Break(e)
113}
114
115#[cfg(feature = "tls")]
116async fn select<IO: 'static, IE>(
117    incoming: &mut (impl Stream<Item = Result<IO, IE>> + Unpin),
118    tasks: &mut tokio::task::JoinSet<Result<ServerIo<IO>, crate::BoxError>>,
119) -> SelectOutput<IO>
120where
121    IE: Into<crate::BoxError>,
122{
123    if tasks.is_empty() {
124        return match incoming.try_next().await {
125            Ok(Some(stream)) => SelectOutput::Incoming(stream),
126            Ok(None) => SelectOutput::Done,
127            Err(e) => SelectOutput::TcpErr(e.into()),
128        };
129    }
130
131    tokio::select! {
132        stream = incoming.try_next() => {
133            match stream {
134                Ok(Some(stream)) => SelectOutput::Incoming(stream),
135                Ok(None) => SelectOutput::Done,
136                Err(e) => SelectOutput::TcpErr(e.into()),
137            }
138        }
139
140        accept = tasks.join_next() => {
141            match accept.expect("JoinSet should never end") {
142                Ok(Ok(io)) => SelectOutput::Io(io),
143                Ok(Err(e)) => SelectOutput::TlsErr(e),
144                Err(e) => SelectOutput::TlsErr(e.into()),
145            }
146        }
147    }
148}
149
150#[cfg(feature = "tls")]
151enum SelectOutput<A> {
152    Incoming(A),
153    Io(ServerIo<A>),
154    TcpErr(crate::BoxError),
155    TlsErr(crate::BoxError),
156    Done,
157}
158
159/// Binds a socket address for a [Router](super::Router)
160///
161/// An incoming stream, usable with [Router::serve_with_incoming](super::Router::serve_with_incoming),
162/// of `AsyncRead + AsyncWrite` that communicate with clients that connect to a socket address.
163#[derive(Debug)]
164pub struct TcpIncoming {
165    inner: TcpListenerStream,
166    nodelay: bool,
167    keepalive: Option<Duration>,
168}
169
170impl TcpIncoming {
171    /// Creates an instance by binding (opening) the specified socket address
172    /// to which the specified TCP 'nodelay' and 'keepalive' parameters are applied.
173    /// Returns a TcpIncoming if the socket address was successfully bound.
174    ///
175    /// # Examples
176    /// ```no_run
177    /// # use tower_service::Service;
178    /// # use http::{request::Request, response::Response};
179    /// # use tonic::{body::Body, server::NamedService};
180    /// # use tonic_rustls::{Server, server::TcpIncoming};
181    /// # use core::convert::Infallible;
182    /// # use std::error::Error;
183    /// # fn main() { }  // Cannot have type parameters, hence instead define:
184    /// # fn run<S>(some_service: S) -> Result<(), Box<dyn Error + Send + Sync>>
185    /// # where
186    /// #   S: Service<Request<Body>, Response = Response<Body>, Error = Infallible> + NamedService + Clone + Send + Sync + 'static,
187    /// #   S::Future: Send + 'static,
188    /// # {
189    /// // Find a free port
190    /// let mut port = 1322;
191    /// let tinc = loop {
192    ///    let addr = format!("127.0.0.1:{}", port).parse().unwrap();
193    ///    match TcpIncoming::new(addr, true, None) {
194    ///       Ok(t) => break t,
195    ///       Err(_) => port += 1
196    ///    }
197    /// };
198    /// Server::builder()
199    ///    .add_service(some_service)
200    ///    .serve_with_incoming(tinc);
201    /// # Ok(())
202    /// # }
203    pub fn new(
204        addr: SocketAddr,
205        nodelay: bool,
206        keepalive: Option<Duration>,
207    ) -> Result<Self, crate::BoxError> {
208        let std_listener = StdTcpListener::bind(addr)?;
209        std_listener.set_nonblocking(true)?;
210
211        let inner = TcpListenerStream::new(TcpListener::from_std(std_listener)?);
212        Ok(Self {
213            inner,
214            nodelay,
215            keepalive,
216        })
217    }
218
219    /// Creates a new `TcpIncoming` from an existing `tokio::net::TcpListener`.
220    pub fn from_listener(
221        listener: TcpListener,
222        nodelay: bool,
223        keepalive: Option<Duration>,
224    ) -> Result<Self, crate::BoxError> {
225        Ok(Self {
226            inner: TcpListenerStream::new(listener),
227            nodelay,
228            keepalive,
229        })
230    }
231}
232
233impl Stream for TcpIncoming {
234    type Item = Result<TcpStream, std::io::Error>;
235
236    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
237        match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
238            Some(Ok(stream)) => {
239                set_accepted_socket_options(&stream, self.nodelay, self.keepalive);
240                Some(Ok(stream)).into()
241            }
242            other => Poll::Ready(other),
243        }
244    }
245}
246
247// Consistent with hyper-0.14, this function does not return an error.
248fn set_accepted_socket_options(stream: &TcpStream, nodelay: bool, keepalive: Option<Duration>) {
249    if nodelay {
250        if let Err(e) = stream.set_nodelay(true) {
251            warn!("error trying to set TCP nodelay: {}", e);
252        }
253    }
254
255    if let Some(timeout) = keepalive {
256        let sock_ref = socket2::SockRef::from(&stream);
257        let sock_keepalive = socket2::TcpKeepalive::new().with_time(timeout);
258
259        if let Err(e) = sock_ref.set_tcp_keepalive(&sock_keepalive) {
260            warn!("error trying to set TCP keepalive: {}", e);
261        }
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use crate::server::TcpIncoming;
268    #[tokio::test]
269    async fn one_tcpincoming_at_a_time() {
270        let addr = "127.0.0.1:1322".parse().unwrap();
271        {
272            let _t1 = TcpIncoming::new(addr, true, None).unwrap();
273            let _t2 = TcpIncoming::new(addr, true, None).unwrap_err();
274        }
275        let _t3 = TcpIncoming::new(addr, true, None).unwrap();
276    }
277}