axum_server/
server.rs

1use crate::{
2    accept::{Accept, DefaultAcceptor},
3    handle::Handle,
4    service::{MakeService, SendService},
5};
6use http::Request;
7use hyper::body::Incoming;
8use hyper_util::{
9    rt::{TokioExecutor, TokioIo},
10    server::conn::auto::Builder,
11    service::TowerToHyperService,
12};
13use std::{
14    fmt,
15    future::poll_fn,
16    io::{self, ErrorKind},
17    net::SocketAddr,
18    time::Duration,
19};
20use tokio::{
21    io::{AsyncRead, AsyncWrite},
22    net::{TcpListener, TcpStream},
23};
24
25/// HTTP server.
26pub struct Server<A = DefaultAcceptor> {
27    acceptor: A,
28    builder: Builder<TokioExecutor>,
29    listener: Listener,
30    handle: Handle,
31}
32
33// Builder doesn't implement Debug or Clone right now
34impl<A> fmt::Debug for Server<A>
35where
36    A: fmt::Debug,
37{
38    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
39        f.debug_struct("Server")
40            .field("acceptor", &self.acceptor)
41            .field("listener", &self.listener)
42            .field("handle", &self.handle)
43            .finish_non_exhaustive()
44    }
45}
46
47#[derive(Debug)]
48enum Listener {
49    Bind(SocketAddr),
50    Std(std::net::TcpListener),
51}
52
53/// Create a [`Server`] that will bind to provided address.
54pub fn bind(addr: SocketAddr) -> Server {
55    Server::bind(addr)
56}
57
58/// Create a [`Server`] from existing `std::net::TcpListener`.
59pub fn from_tcp(listener: std::net::TcpListener) -> Server {
60    Server::from_tcp(listener)
61}
62
63impl Server {
64    /// Create a server that will bind to provided address.
65    pub fn bind(addr: SocketAddr) -> Self {
66        let acceptor = DefaultAcceptor::new();
67        let builder = Builder::new(TokioExecutor::new());
68        let handle = Handle::new();
69
70        Self {
71            acceptor,
72            builder,
73            listener: Listener::Bind(addr),
74            handle,
75        }
76    }
77
78    /// Create a server from existing `std::net::TcpListener`.
79    pub fn from_tcp(listener: std::net::TcpListener) -> Self {
80        let acceptor = DefaultAcceptor::new();
81        let builder = Builder::new(TokioExecutor::new());
82        let handle = Handle::new();
83
84        Self {
85            acceptor,
86            builder,
87            listener: Listener::Std(listener),
88            handle,
89        }
90    }
91}
92
93impl<A> Server<A> {
94    /// Overwrite acceptor.
95    pub fn acceptor<Acceptor>(self, acceptor: Acceptor) -> Server<Acceptor> {
96        Server {
97            acceptor,
98            builder: self.builder,
99            listener: self.listener,
100            handle: self.handle,
101        }
102    }
103
104    /// Map acceptor.
105    pub fn map<Acceptor, F>(self, acceptor: F) -> Server<Acceptor>
106    where
107        F: FnOnce(A) -> Acceptor,
108    {
109        Server {
110            acceptor: acceptor(self.acceptor),
111            builder: self.builder,
112            listener: self.listener,
113            handle: self.handle,
114        }
115    }
116
117    /// Returns a reference to the acceptor.
118    pub fn get_ref(&self) -> &A {
119        &self.acceptor
120    }
121
122    /// Returns a mutable reference to the acceptor.
123    pub fn get_mut(&mut self) -> &mut A {
124        &mut self.acceptor
125    }
126
127    /// Returns a mutable reference to the Http builder.
128    pub fn http_builder(&mut self) -> &mut Builder<TokioExecutor> {
129        &mut self.builder
130    }
131
132    /// Provide a handle for additional utilities.
133    pub fn handle(mut self, handle: Handle) -> Self {
134        self.handle = handle;
135        self
136    }
137
138    /// Serve provided [`MakeService`].
139    ///
140    /// To create [`MakeService`] easily, `Shared` from [`tower`] can be used.
141    ///
142    /// # Errors
143    ///
144    /// An error will be returned when:
145    ///
146    /// - Binding to an address fails.
147    /// - `make_service` returns an error when `poll_ready` is called. This never happens on
148    ///   [`axum`] make services.
149    ///
150    /// [`axum`]: https://docs.rs/axum/0.3
151    /// [`tower`]: https://docs.rs/tower
152    /// [`MakeService`]: https://docs.rs/tower/0.4/tower/make/trait.MakeService.html
153    pub async fn serve<M>(self, mut make_service: M) -> io::Result<()>
154    where
155        M: MakeService<SocketAddr, Request<Incoming>>,
156        A: Accept<TcpStream, M::Service> + Clone + Send + Sync + 'static,
157        A::Stream: AsyncRead + AsyncWrite + Unpin + Send,
158        A::Service: SendService<Request<Incoming>> + Send,
159        A::Future: Send,
160    {
161        let acceptor = self.acceptor;
162        let handle = self.handle;
163        let builder = std::sync::Arc::new(self.builder);
164
165        let mut incoming = match bind_incoming(self.listener).await {
166            Ok(v) => v,
167            Err(e) => {
168                handle.notify_listening(None);
169                return Err(e);
170            }
171        };
172
173        handle.notify_listening(incoming.local_addr().ok());
174
175        let accept_loop_future = async {
176            loop {
177                let (tcp_stream, socket_addr) = tokio::select! {
178                    biased;
179                    result = accept(&mut incoming) => result,
180                    _ = handle.wait_graceful_shutdown() => return Ok(()),
181                };
182
183                poll_fn(|cx| make_service.poll_ready(cx))
184                    .await
185                    .map_err(io_other)?;
186
187                let service = match make_service.make_service(socket_addr).await {
188                    Ok(service) => service,
189                    Err(_) => continue,
190                };
191
192                let acceptor = acceptor.clone();
193                let watcher = handle.watcher();
194                let builder = builder.clone();
195
196                tokio::spawn(async move {
197                    if let Ok((stream, send_service)) = acceptor.accept(tcp_stream, service).await {
198                        let io = TokioIo::new(stream);
199                        let service = send_service.into_service();
200                        let service = TowerToHyperService::new(service);
201
202                        let serve_future = builder.serve_connection_with_upgrades(io, service);
203                        tokio::pin!(serve_future);
204
205                        tokio::select! {
206                            biased;
207                            _ = watcher.wait_graceful_shutdown() => {
208                                serve_future.as_mut().graceful_shutdown();
209                                tokio::select! {
210                                    biased;
211                                    _ = watcher.wait_shutdown() => (),
212                                    _ = &mut serve_future => (),
213                                }
214                            }
215                            _ = watcher.wait_shutdown() => (),
216                            _ = &mut serve_future => (),
217                        }
218                    }
219                });
220            }
221        };
222
223        let result = tokio::select! {
224            biased;
225            _ = handle.wait_shutdown() => return Ok(()),
226            result = accept_loop_future => result,
227        };
228
229        // Tokio internally accepts TCP connections while the TCPListener is active;
230        // drop the listener to immediately refuse connections rather than letting
231        // them hang.
232        drop(incoming);
233
234        // attempting to do a "result?;" requires us to specify the type of result which is annoying
235        #[allow(clippy::question_mark)]
236        if let Err(e) = result {
237            return Err(e);
238        }
239
240        handle.wait_connections_end().await;
241
242        Ok(())
243    }
244}
245
246async fn bind_incoming(listener: Listener) -> io::Result<TcpListener> {
247    match listener {
248        Listener::Bind(addr) => TcpListener::bind(addr).await,
249        Listener::Std(std_listener) => {
250            std_listener.set_nonblocking(true)?;
251            TcpListener::from_std(std_listener)
252        }
253    }
254}
255
256pub(crate) async fn accept(listener: &mut TcpListener) -> (TcpStream, SocketAddr) {
257    loop {
258        match listener.accept().await {
259            Ok(value) => return value,
260            Err(_) => tokio::time::sleep(Duration::from_millis(50)).await,
261        }
262    }
263}
264
265type BoxError = Box<dyn std::error::Error + Send + Sync>;
266
267pub(crate) fn io_other<E: Into<BoxError>>(error: E) -> io::Error {
268    io::Error::new(ErrorKind::Other, error)
269}
270
271#[cfg(test)]
272mod tests {
273    use crate::{handle::Handle, server::Server};
274    use axum::body::Body;
275    use axum::response::Response;
276    use axum::routing::post;
277    use axum::{routing::get, Router};
278    use bytes::Bytes;
279    use futures_util::{stream, StreamExt};
280    use http::{Method, Request, Uri};
281    use http_body::Frame;
282    use http_body_util::{BodyExt, StreamBody};
283    use hyper::client::conn::http1::handshake;
284    use hyper::client::conn::http1::SendRequest;
285    use hyper_util::rt::TokioIo;
286    use std::{io, net::SocketAddr, time::Duration};
287    use tokio::sync::oneshot;
288    use tokio::{net::TcpStream, task::JoinHandle, time::timeout};
289
290    #[tokio::test]
291    async fn start_and_request() {
292        let (_handle, _server_task, addr) = start_server().await;
293
294        let (mut client, _conn) = connect(addr).await;
295
296        // Client can send requests
297
298        do_empty_request(&mut client).await.unwrap();
299
300        do_slow_request(&mut client, Duration::from_millis(50))
301            .await
302            .unwrap();
303    }
304
305    #[tokio::test]
306    async fn test_shutdown() {
307        let (handle, _server_task, addr) = start_server().await;
308
309        let (mut client, conn) = connect(addr).await;
310
311        // Client can send request before shutdown.
312        do_empty_request(&mut client).await.unwrap();
313
314        handle.shutdown();
315
316        // After shutdown, all client requests should fail.
317        do_empty_request(&mut client).await.unwrap_err();
318
319        // Connection should finish soon.
320        let _ = timeout(Duration::from_secs(1), conn).await.unwrap();
321    }
322
323    // Test graceful shutdown with no timeout.
324    #[tokio::test]
325    async fn test_graceful_shutdown_no_timeout() {
326        let (handle, server_task, addr) = start_server().await;
327
328        let (mut client1, _conn1) = connect(addr).await;
329        let (mut client2, _conn2) = connect(addr).await;
330
331        // Clients can send request before graceful shutdown.
332        do_empty_request(&mut client1).await.unwrap();
333        do_empty_request(&mut client2).await.unwrap();
334
335        let start = tokio::time::Instant::now();
336
337        let (hdr1_tx, hdr1_rx) = oneshot::channel::<()>();
338
339        let fut1 = async {
340            // A slow request made before graceful shutdown is handled.
341            // Since there's no request timeout, this can take as long as it
342            // needs.
343            let hdr1 = send_slow_request(&mut client1, Duration::from_millis(500))
344                .await
345                .unwrap();
346            hdr1_tx.send(()).unwrap();
347            recv_slow_response_body(hdr1).await.unwrap();
348
349            assert!(start.elapsed() >= Duration::from_millis(500));
350        };
351        let fut2 = async {
352            // Graceful shutdown partway through
353            tokio::time::sleep(Duration::from_millis(250)).await;
354            hdr1_rx.await.unwrap();
355            handle.graceful_shutdown(None);
356
357            // Any new requests after graceful shutdown begins will fail
358            do_empty_request(&mut client2).await.unwrap_err();
359            do_empty_request(&mut client2).await.unwrap_err();
360            do_empty_request(&mut client2).await.unwrap_err();
361        };
362
363        tokio::join!(fut1, fut2);
364
365        // At this point, graceful shutdown must have occured, and the slow
366        // request must have finished. Since there was no timeout, the elapsed
367        // time should be at least 500 ms (slow request duration).
368        assert!(start.elapsed() >= Duration::from_millis(500 + 100));
369
370        // Server task should finish soon.
371        timeout(Duration::from_secs(1), server_task)
372            .await
373            .unwrap()
374            .unwrap()
375            .unwrap();
376    }
377
378    // Test graceful shutdown with a timeout.
379    #[tokio::test]
380    async fn test_graceful_shutdown_timeout() {
381        let (handle, server_task, addr) = start_server().await;
382
383        let (mut client1, _conn1) = connect(addr).await;
384        let (mut client2, _conn2) = connect(addr).await;
385
386        // Clients can send request before graceful shutdown.
387        do_empty_request(&mut client1).await.unwrap();
388        do_empty_request(&mut client2).await.unwrap();
389
390        let start = tokio::time::Instant::now();
391
392        let (hdr1_tx, hdr1_rx) = oneshot::channel::<()>();
393
394        let task1 = async {
395            // A slow request made before graceful shutdown is handled.
396            // This one is shorter than the timeout, so it should succeed.
397            let hdr1 = send_slow_request(&mut client1, Duration::from_millis(222)).await;
398            hdr1_tx.send(()).unwrap();
399
400            let res1 = recv_slow_response_body(hdr1.unwrap()).await;
401            res1.unwrap();
402        };
403        let task2 = async {
404            // A slow request made before graceful shutdown is handled.
405            // This one is much longer than the timeout; it should fail sometime
406            // after the graceful shutdown timeout.
407
408            let hdr2 = send_slow_request(&mut client2, Duration::from_millis(5_555)).await;
409            hdr2.unwrap_err();
410        };
411        let task3 = async {
412            // Begin graceful shutdown after we receive response headers for (1).
413            hdr1_rx.await.unwrap();
414
415            // Set a timeout on requests to finish before we drop them.
416            handle.graceful_shutdown(Some(Duration::from_millis(333)));
417
418            // Server task should finish soon.
419            timeout(Duration::from_secs(1), server_task)
420                .await
421                .unwrap()
422                .unwrap()
423                .unwrap();
424
425            // At this point, graceful shutdown must have occured.
426            assert!(start.elapsed() >= Duration::from_millis(222 + 333));
427            assert!(start.elapsed() <= Duration::from_millis(5_555));
428        };
429
430        tokio::join!(task1, task2, task3);
431    }
432
433    async fn start_server() -> (Handle, JoinHandle<io::Result<()>>, SocketAddr) {
434        let handle = Handle::new();
435
436        let server_handle = handle.clone();
437        let server_task = tokio::spawn(async move {
438            let app = Router::new()
439                .route("/", get(|| async { "Hello, world!" }))
440                .route(
441                    "/echo_slowly",
442                    post(|body: Bytes| async move {
443                        // Stream a response slowly, byte-by-byte, over 100ms
444                        Response::new(slow_body(body.len(), Duration::from_millis(100)))
445                    }),
446                );
447
448            let addr = SocketAddr::from(([127, 0, 0, 1], 0));
449
450            Server::bind(addr)
451                .handle(server_handle)
452                .serve(app.into_make_service())
453                .await
454        });
455
456        let addr = handle.listening().await.unwrap();
457
458        (handle, server_task, addr)
459    }
460
461    async fn connect(addr: SocketAddr) -> (SendRequest<Body>, JoinHandle<()>) {
462        let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap());
463        let (send_request, connection) = handshake(stream).await.unwrap();
464
465        let task = tokio::spawn(async move {
466            let _ = connection.await;
467        });
468
469        (send_request, task)
470    }
471
472    // Send a basic `GET /` request.
473    async fn do_empty_request(client: &mut SendRequest<Body>) -> hyper::Result<()> {
474        client.ready().await?;
475
476        let body = client
477            .send_request(Request::new(Body::empty()))
478            .await?
479            .into_body();
480
481        let body = body.collect().await?.to_bytes();
482        assert_eq!(body.as_ref(), b"Hello, world!");
483        Ok(())
484    }
485
486    // Send a request with a body streamed byte-by-byte, over a given duration,
487    // then wait for the full response.
488    async fn do_slow_request(
489        client: &mut SendRequest<Body>,
490        duration: Duration,
491    ) -> hyper::Result<()> {
492        let response = send_slow_request(client, duration).await?;
493        recv_slow_response_body(response).await
494    }
495
496    async fn send_slow_request(
497        client: &mut SendRequest<Body>,
498        duration: Duration,
499    ) -> hyper::Result<http::Response<hyper::body::Incoming>> {
500        let req_body_len: usize = 10;
501        let mut req = Request::new(slow_body(req_body_len, duration));
502        *req.method_mut() = Method::POST;
503        *req.uri_mut() = Uri::from_static("/echo_slowly");
504
505        client.ready().await?;
506        client.send_request(req).await
507    }
508
509    async fn recv_slow_response_body(
510        response: http::Response<hyper::body::Incoming>,
511    ) -> hyper::Result<()> {
512        let resp_body = response.into_body();
513        let resp_body_bytes = resp_body.collect().await?.to_bytes();
514        assert_eq!(10, resp_body_bytes.len());
515        Ok(())
516    }
517
518    // A stream of n response data `Frame`s, where n = `length`, and each frame
519    // consists of a single byte. The whole response is smeared out over
520    // a `duration` length of time.
521    fn slow_body(length: usize, duration: Duration) -> axum::body::Body {
522        let frames =
523            (0..length).map(move |_| Ok::<_, hyper::Error>(Frame::data(Bytes::from_static(b"X"))));
524
525        let stream = stream::iter(frames).then(move |frame| async move {
526            tokio::time::sleep(duration / (length as u32)).await;
527            frame
528        });
529
530        axum::body::Body::new(StreamBody::new(stream))
531    }
532}