1use crate::{
2 accept::{Accept, DefaultAcceptor},
3 handle::Handle,
4 service::{MakeService, SendService},
5};
6use http::Request;
7use hyper::body::Incoming;
8use hyper_util::{
9 rt::{TokioExecutor, TokioIo},
10 server::conn::auto::Builder,
11 service::TowerToHyperService,
12};
13use std::{
14 fmt,
15 future::poll_fn,
16 io::{self, ErrorKind},
17 net::SocketAddr,
18 time::Duration,
19};
20use tokio::{
21 io::{AsyncRead, AsyncWrite},
22 net::{TcpListener, TcpStream},
23};
24
25pub struct Server<A = DefaultAcceptor> {
27 acceptor: A,
28 builder: Builder<TokioExecutor>,
29 listener: Listener,
30 handle: Handle,
31}
32
33impl<A> fmt::Debug for Server<A>
35where
36 A: fmt::Debug,
37{
38 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
39 f.debug_struct("Server")
40 .field("acceptor", &self.acceptor)
41 .field("listener", &self.listener)
42 .field("handle", &self.handle)
43 .finish_non_exhaustive()
44 }
45}
46
47#[derive(Debug)]
48enum Listener {
49 Bind(SocketAddr),
50 Std(std::net::TcpListener),
51}
52
53pub fn bind(addr: SocketAddr) -> Server {
55 Server::bind(addr)
56}
57
58pub fn from_tcp(listener: std::net::TcpListener) -> Server {
60 Server::from_tcp(listener)
61}
62
63impl Server {
64 pub fn bind(addr: SocketAddr) -> Self {
66 let acceptor = DefaultAcceptor::new();
67 let builder = Builder::new(TokioExecutor::new());
68 let handle = Handle::new();
69
70 Self {
71 acceptor,
72 builder,
73 listener: Listener::Bind(addr),
74 handle,
75 }
76 }
77
78 pub fn from_tcp(listener: std::net::TcpListener) -> Self {
80 let acceptor = DefaultAcceptor::new();
81 let builder = Builder::new(TokioExecutor::new());
82 let handle = Handle::new();
83
84 Self {
85 acceptor,
86 builder,
87 listener: Listener::Std(listener),
88 handle,
89 }
90 }
91}
92
93impl<A> Server<A> {
94 pub fn acceptor<Acceptor>(self, acceptor: Acceptor) -> Server<Acceptor> {
96 Server {
97 acceptor,
98 builder: self.builder,
99 listener: self.listener,
100 handle: self.handle,
101 }
102 }
103
104 pub fn map<Acceptor, F>(self, acceptor: F) -> Server<Acceptor>
106 where
107 F: FnOnce(A) -> Acceptor,
108 {
109 Server {
110 acceptor: acceptor(self.acceptor),
111 builder: self.builder,
112 listener: self.listener,
113 handle: self.handle,
114 }
115 }
116
117 pub fn get_ref(&self) -> &A {
119 &self.acceptor
120 }
121
122 pub fn get_mut(&mut self) -> &mut A {
124 &mut self.acceptor
125 }
126
127 pub fn http_builder(&mut self) -> &mut Builder<TokioExecutor> {
129 &mut self.builder
130 }
131
132 pub fn handle(mut self, handle: Handle) -> Self {
134 self.handle = handle;
135 self
136 }
137
138 pub async fn serve<M>(self, mut make_service: M) -> io::Result<()>
154 where
155 M: MakeService<SocketAddr, Request<Incoming>>,
156 A: Accept<TcpStream, M::Service> + Clone + Send + Sync + 'static,
157 A::Stream: AsyncRead + AsyncWrite + Unpin + Send,
158 A::Service: SendService<Request<Incoming>> + Send,
159 A::Future: Send,
160 {
161 let acceptor = self.acceptor;
162 let handle = self.handle;
163 let builder = std::sync::Arc::new(self.builder);
164
165 let mut incoming = match bind_incoming(self.listener).await {
166 Ok(v) => v,
167 Err(e) => {
168 handle.notify_listening(None);
169 return Err(e);
170 }
171 };
172
173 handle.notify_listening(incoming.local_addr().ok());
174
175 let accept_loop_future = async {
176 loop {
177 let (tcp_stream, socket_addr) = tokio::select! {
178 biased;
179 result = accept(&mut incoming) => result,
180 _ = handle.wait_graceful_shutdown() => return Ok(()),
181 };
182
183 poll_fn(|cx| make_service.poll_ready(cx))
184 .await
185 .map_err(io_other)?;
186
187 let service = match make_service.make_service(socket_addr).await {
188 Ok(service) => service,
189 Err(_) => continue,
190 };
191
192 let acceptor = acceptor.clone();
193 let watcher = handle.watcher();
194 let builder = builder.clone();
195
196 tokio::spawn(async move {
197 if let Ok((stream, send_service)) = acceptor.accept(tcp_stream, service).await {
198 let io = TokioIo::new(stream);
199 let service = send_service.into_service();
200 let service = TowerToHyperService::new(service);
201
202 let serve_future = builder.serve_connection_with_upgrades(io, service);
203 tokio::pin!(serve_future);
204
205 tokio::select! {
206 biased;
207 _ = watcher.wait_graceful_shutdown() => {
208 serve_future.as_mut().graceful_shutdown();
209 tokio::select! {
210 biased;
211 _ = watcher.wait_shutdown() => (),
212 _ = &mut serve_future => (),
213 }
214 }
215 _ = watcher.wait_shutdown() => (),
216 _ = &mut serve_future => (),
217 }
218 }
219 });
220 }
221 };
222
223 let result = tokio::select! {
224 biased;
225 _ = handle.wait_shutdown() => return Ok(()),
226 result = accept_loop_future => result,
227 };
228
229 drop(incoming);
233
234 #[allow(clippy::question_mark)]
236 if let Err(e) = result {
237 return Err(e);
238 }
239
240 handle.wait_connections_end().await;
241
242 Ok(())
243 }
244}
245
246async fn bind_incoming(listener: Listener) -> io::Result<TcpListener> {
247 match listener {
248 Listener::Bind(addr) => TcpListener::bind(addr).await,
249 Listener::Std(std_listener) => {
250 std_listener.set_nonblocking(true)?;
251 TcpListener::from_std(std_listener)
252 }
253 }
254}
255
256pub(crate) async fn accept(listener: &mut TcpListener) -> (TcpStream, SocketAddr) {
257 loop {
258 match listener.accept().await {
259 Ok(value) => return value,
260 Err(_) => tokio::time::sleep(Duration::from_millis(50)).await,
261 }
262 }
263}
264
265type BoxError = Box<dyn std::error::Error + Send + Sync>;
266
267pub(crate) fn io_other<E: Into<BoxError>>(error: E) -> io::Error {
268 io::Error::new(ErrorKind::Other, error)
269}
270
271#[cfg(test)]
272mod tests {
273 use crate::{handle::Handle, server::Server};
274 use axum::body::Body;
275 use axum::response::Response;
276 use axum::routing::post;
277 use axum::{routing::get, Router};
278 use bytes::Bytes;
279 use futures_util::{stream, StreamExt};
280 use http::{Method, Request, Uri};
281 use http_body::Frame;
282 use http_body_util::{BodyExt, StreamBody};
283 use hyper::client::conn::http1::handshake;
284 use hyper::client::conn::http1::SendRequest;
285 use hyper_util::rt::TokioIo;
286 use std::{io, net::SocketAddr, time::Duration};
287 use tokio::sync::oneshot;
288 use tokio::{net::TcpStream, task::JoinHandle, time::timeout};
289
290 #[tokio::test]
291 async fn start_and_request() {
292 let (_handle, _server_task, addr) = start_server().await;
293
294 let (mut client, _conn) = connect(addr).await;
295
296 do_empty_request(&mut client).await.unwrap();
299
300 do_slow_request(&mut client, Duration::from_millis(50))
301 .await
302 .unwrap();
303 }
304
305 #[tokio::test]
306 async fn test_shutdown() {
307 let (handle, _server_task, addr) = start_server().await;
308
309 let (mut client, conn) = connect(addr).await;
310
311 do_empty_request(&mut client).await.unwrap();
313
314 handle.shutdown();
315
316 do_empty_request(&mut client).await.unwrap_err();
318
319 let _ = timeout(Duration::from_secs(1), conn).await.unwrap();
321 }
322
323 #[tokio::test]
325 async fn test_graceful_shutdown_no_timeout() {
326 let (handle, server_task, addr) = start_server().await;
327
328 let (mut client1, _conn1) = connect(addr).await;
329 let (mut client2, _conn2) = connect(addr).await;
330
331 do_empty_request(&mut client1).await.unwrap();
333 do_empty_request(&mut client2).await.unwrap();
334
335 let start = tokio::time::Instant::now();
336
337 let (hdr1_tx, hdr1_rx) = oneshot::channel::<()>();
338
339 let fut1 = async {
340 let hdr1 = send_slow_request(&mut client1, Duration::from_millis(500))
344 .await
345 .unwrap();
346 hdr1_tx.send(()).unwrap();
347 recv_slow_response_body(hdr1).await.unwrap();
348
349 assert!(start.elapsed() >= Duration::from_millis(500));
350 };
351 let fut2 = async {
352 tokio::time::sleep(Duration::from_millis(250)).await;
354 hdr1_rx.await.unwrap();
355 handle.graceful_shutdown(None);
356
357 do_empty_request(&mut client2).await.unwrap_err();
359 do_empty_request(&mut client2).await.unwrap_err();
360 do_empty_request(&mut client2).await.unwrap_err();
361 };
362
363 tokio::join!(fut1, fut2);
364
365 assert!(start.elapsed() >= Duration::from_millis(500 + 100));
369
370 timeout(Duration::from_secs(1), server_task)
372 .await
373 .unwrap()
374 .unwrap()
375 .unwrap();
376 }
377
378 #[tokio::test]
380 async fn test_graceful_shutdown_timeout() {
381 let (handle, server_task, addr) = start_server().await;
382
383 let (mut client1, _conn1) = connect(addr).await;
384 let (mut client2, _conn2) = connect(addr).await;
385
386 do_empty_request(&mut client1).await.unwrap();
388 do_empty_request(&mut client2).await.unwrap();
389
390 let start = tokio::time::Instant::now();
391
392 let (hdr1_tx, hdr1_rx) = oneshot::channel::<()>();
393
394 let task1 = async {
395 let hdr1 = send_slow_request(&mut client1, Duration::from_millis(222)).await;
398 hdr1_tx.send(()).unwrap();
399
400 let res1 = recv_slow_response_body(hdr1.unwrap()).await;
401 res1.unwrap();
402 };
403 let task2 = async {
404 let hdr2 = send_slow_request(&mut client2, Duration::from_millis(5_555)).await;
409 hdr2.unwrap_err();
410 };
411 let task3 = async {
412 hdr1_rx.await.unwrap();
414
415 handle.graceful_shutdown(Some(Duration::from_millis(333)));
417
418 timeout(Duration::from_secs(1), server_task)
420 .await
421 .unwrap()
422 .unwrap()
423 .unwrap();
424
425 assert!(start.elapsed() >= Duration::from_millis(222 + 333));
427 assert!(start.elapsed() <= Duration::from_millis(5_555));
428 };
429
430 tokio::join!(task1, task2, task3);
431 }
432
433 async fn start_server() -> (Handle, JoinHandle<io::Result<()>>, SocketAddr) {
434 let handle = Handle::new();
435
436 let server_handle = handle.clone();
437 let server_task = tokio::spawn(async move {
438 let app = Router::new()
439 .route("/", get(|| async { "Hello, world!" }))
440 .route(
441 "/echo_slowly",
442 post(|body: Bytes| async move {
443 Response::new(slow_body(body.len(), Duration::from_millis(100)))
445 }),
446 );
447
448 let addr = SocketAddr::from(([127, 0, 0, 1], 0));
449
450 Server::bind(addr)
451 .handle(server_handle)
452 .serve(app.into_make_service())
453 .await
454 });
455
456 let addr = handle.listening().await.unwrap();
457
458 (handle, server_task, addr)
459 }
460
461 async fn connect(addr: SocketAddr) -> (SendRequest<Body>, JoinHandle<()>) {
462 let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap());
463 let (send_request, connection) = handshake(stream).await.unwrap();
464
465 let task = tokio::spawn(async move {
466 let _ = connection.await;
467 });
468
469 (send_request, task)
470 }
471
472 async fn do_empty_request(client: &mut SendRequest<Body>) -> hyper::Result<()> {
474 client.ready().await?;
475
476 let body = client
477 .send_request(Request::new(Body::empty()))
478 .await?
479 .into_body();
480
481 let body = body.collect().await?.to_bytes();
482 assert_eq!(body.as_ref(), b"Hello, world!");
483 Ok(())
484 }
485
486 async fn do_slow_request(
489 client: &mut SendRequest<Body>,
490 duration: Duration,
491 ) -> hyper::Result<()> {
492 let response = send_slow_request(client, duration).await?;
493 recv_slow_response_body(response).await
494 }
495
496 async fn send_slow_request(
497 client: &mut SendRequest<Body>,
498 duration: Duration,
499 ) -> hyper::Result<http::Response<hyper::body::Incoming>> {
500 let req_body_len: usize = 10;
501 let mut req = Request::new(slow_body(req_body_len, duration));
502 *req.method_mut() = Method::POST;
503 *req.uri_mut() = Uri::from_static("/echo_slowly");
504
505 client.ready().await?;
506 client.send_request(req).await
507 }
508
509 async fn recv_slow_response_body(
510 response: http::Response<hyper::body::Incoming>,
511 ) -> hyper::Result<()> {
512 let resp_body = response.into_body();
513 let resp_body_bytes = resp_body.collect().await?.to_bytes();
514 assert_eq!(10, resp_body_bytes.len());
515 Ok(())
516 }
517
518 fn slow_body(length: usize, duration: Duration) -> axum::body::Body {
522 let frames =
523 (0..length).map(move |_| Ok::<_, hyper::Error>(Frame::data(Bytes::from_static(b"X"))));
524
525 let stream = stream::iter(frames).then(move |frame| async move {
526 tokio::time::sleep(duration / (length as u32)).await;
527 frame
528 });
529
530 axum::body::Body::new(StreamBody::new(stream))
531 }
532}