1use crate::{
2 accept::{Accept, DefaultAcceptor},
3 handle::Handle,
4 service::{MakeService, SendService},
5};
6use either::Either;
7use http::Request;
8use hyper::body::Incoming;
9use hyper_util::{
10 rt::{TokioExecutor, TokioIo},
11 server::conn::auto::Builder,
12 service::TowerToHyperService,
13};
14use std::{
15 fmt,
16 future::poll_fn,
17 io::{self, ErrorKind},
18 net::SocketAddr as IpSocketAddr,
19 time::Duration,
20};
21use tokio::{
22 io::{AsyncRead, AsyncWrite},
23 net::{TcpListener, TcpStream},
24};
25
26#[cfg(unix)]
27use {
28 std::os::unix::net::SocketAddr as UnixSocketAddr,
29 tokio::net::{UnixListener, UnixStream},
30};
31
32pub struct Server<Addr: Address, A = DefaultAcceptor> {
34 acceptor: A,
35 builder: Builder<TokioExecutor>,
36 listener: Listener<Addr>,
37 handle: Handle<Addr>,
38 http_version: Option<HttpVersion>,
39}
40
41impl<A: Address, B> fmt::Debug for Server<A, B>
43where
44 Listener<A>: fmt::Debug,
45 Handle<A>: fmt::Debug,
46 B: fmt::Debug,
47{
48 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
49 f.debug_struct("Server")
50 .field("acceptor", &self.acceptor)
51 .field("listener", &self.listener)
52 .field("handle", &self.handle)
53 .finish_non_exhaustive()
54 }
55}
56
57#[derive(Debug)]
58enum Listener<A: Address> {
59 Bind(A),
60 Ready(A::Listener),
61}
62
63pub fn bind<A: Address>(addr: A) -> Server<A> {
65 Server::bind(addr)
66}
67
68pub fn from_tcp(listener: std::net::TcpListener) -> io::Result<Server<IpSocketAddr>> {
70 Ok(Server::from_listener(TcpListener::from_std(listener)?))
71}
72
73#[cfg(unix)]
75pub fn from_unix(listener: std::os::unix::net::UnixListener) -> io::Result<Server<UnixSocketAddr>> {
76 Ok(Server::from_listener(UnixListener::from_std(listener)?))
77}
78
79pub trait AddrListener<Stream, Addr: Address>: std::marker::Sized {
82 fn bind_to(addr: Addr) -> impl std::future::Future<Output = io::Result<Self>> + Send;
84
85 fn accept_stream(&self)
87 -> impl std::future::Future<Output = io::Result<(Stream, Addr)>> + Send;
88
89 fn get_local_addr(&self) -> io::Result<Addr>;
91}
92
93#[cfg(unix)]
94impl AddrListener<UnixStream, UnixSocketAddr> for UnixListener {
95 async fn bind_to(addr: UnixSocketAddr) -> io::Result<Self> {
96 UnixListener::bind(addr.as_pathname().ok_or_else(|| {
97 io::Error::new(
98 ErrorKind::InvalidInput,
99 "A UnixListener can only be bound to a path address!",
100 )
101 })?)
102 }
103
104 async fn accept_stream(&self) -> io::Result<(UnixStream, UnixSocketAddr)> {
105 let (stream, tokio_addr) = self.accept().await?;
106 Ok((stream, tokio_addr.into()))
107 }
108
109 fn get_local_addr(&self) -> io::Result<UnixSocketAddr> {
110 self.local_addr().map(tokio::net::unix::SocketAddr::into)
111 }
112}
113
114impl AddrListener<TcpStream, IpSocketAddr> for TcpListener {
115 async fn bind_to(addr: IpSocketAddr) -> io::Result<Self> {
116 TcpListener::bind(addr).await
117 }
118
119 async fn accept_stream(&self) -> io::Result<(TcpStream, IpSocketAddr)> {
120 self.accept().await
121 }
122
123 fn get_local_addr(&self) -> io::Result<IpSocketAddr> {
124 self.local_addr()
125 }
126}
127
128pub trait Address: std::marker::Sized + Clone {
130 type Stream;
132
133 type Listener: AddrListener<Self::Stream, Self>;
135}
136
137#[cfg(unix)]
138impl Address for UnixSocketAddr {
139 type Stream = UnixStream;
140 type Listener = UnixListener;
141}
142
143impl Address for IpSocketAddr {
144 type Stream = TcpStream;
145 type Listener = TcpListener;
146}
147
148impl<A: Address> Server<A> {
149 pub fn bind(addr: A) -> Self {
151 let acceptor = DefaultAcceptor::new();
152 let builder = Builder::new(TokioExecutor::new());
153 let handle = Handle::new();
154
155 Self {
156 acceptor,
157 builder,
158 listener: Listener::Bind(addr),
159 handle,
160 http_version: None,
161 }
162 }
163
164 pub fn from_listener(listener: A::Listener) -> Self {
166 let acceptor = DefaultAcceptor::new();
167 let builder = Builder::new(TokioExecutor::new());
168 let handle = Handle::new();
169
170 Self {
171 acceptor,
172 builder,
173 listener: Listener::Ready(listener),
174 handle,
175 http_version: None,
176 }
177 }
178}
179
180#[derive(Clone, Copy, Eq, PartialEq)]
181enum HttpVersion {
182 Http1,
183 Http2,
184}
185
186impl<A: Address, Acc> Server<A, Acc> {
187 pub fn acceptor<Acceptor>(self, acceptor: Acceptor) -> Server<A, Acceptor> {
189 Server {
190 acceptor,
191 builder: self.builder,
192 listener: self.listener,
193 handle: self.handle,
194 http_version: None,
195 }
196 }
197
198 pub fn map<Acceptor, F>(self, acceptor: F) -> Server<A, Acceptor>
200 where
201 F: FnOnce(Acc) -> Acceptor,
202 {
203 Server {
204 acceptor: acceptor(self.acceptor),
205 builder: self.builder,
206 listener: self.listener,
207 handle: self.handle,
208 http_version: None,
209 }
210 }
211
212 pub fn get_ref(&self) -> &Acc {
214 &self.acceptor
215 }
216
217 pub fn get_mut(&mut self) -> &mut Acc {
219 &mut self.acceptor
220 }
221
222 pub fn http_builder(&mut self) -> &mut Builder<TokioExecutor> {
224 &mut self.builder
225 }
226
227 pub fn http1_only(mut self) -> Self {
229 self.http_version = Some(HttpVersion::Http1);
230 self.builder = self.builder.http1_only();
231 self
232 }
233
234 pub fn http2_only(mut self) -> Self {
236 self.http_version = Some(HttpVersion::Http2);
237 self.builder = self.builder.http2_only();
238 self
239 }
240
241 pub fn handle(mut self, handle: Handle<A>) -> Self {
243 self.handle = handle;
244 self
245 }
246
247 pub async fn serve<M>(self, mut make_service: M) -> io::Result<()>
263 where
264 M: MakeService<A, Request<Incoming>>,
265 A: Send + 'static,
266 A::Stream: Send,
267 Acc: Accept<A::Stream, M::Service> + Clone + Send + Sync + 'static,
268 Acc::Stream: AsyncRead + AsyncWrite + Unpin + Send,
269 Acc::Service: SendService<Request<Incoming>> + Send,
270 Acc::Future: Send,
271 {
272 let acceptor = self.acceptor;
273 let handle = self.handle;
274 let builder = std::sync::Arc::new(self.builder);
275
276 let mut incoming = match bind_incoming(self.listener).await {
277 Ok(v) => v,
278 Err(e) => {
279 handle.notify_listening(None);
280 return Err(e);
281 }
282 };
283
284 handle.notify_listening(incoming.get_local_addr().ok());
285
286 let accept_loop_future = async {
287 loop {
288 let (tcp_stream, socket_addr) = tokio::select! {
289 biased;
290 result = accept(&mut incoming) => result,
291 _ = handle.wait_graceful_shutdown() => return Ok(()),
292 };
293
294 poll_fn(|cx| make_service.poll_ready(cx))
295 .await
296 .map_err(io_other)?;
297
298 let service = match make_service.make_service(socket_addr).await {
299 Ok(service) => service,
300 Err(_) => continue,
301 };
302
303 let acceptor = acceptor.clone();
304 let watcher = handle.watcher();
305 let builder = builder.clone();
306 let http_version = self.http_version;
307
308 tokio::spawn(async move {
309 if let Ok((stream, send_service)) = acceptor.accept(tcp_stream, service).await {
310 let io = TokioIo::new(stream);
311 let service = send_service.into_service();
312 let service = TowerToHyperService::new(service);
313 let serve_future = match http_version {
314 Some(_) => Either::Left(builder.serve_connection(io, service)),
315 _ => Either::Right(builder.serve_connection_with_upgrades(io, service)),
316 };
317 tokio::pin!(serve_future);
318 let mut serve_future = serve_future.as_pin_mut();
319 tokio::select! {
320 biased;
321 _ = watcher.wait_graceful_shutdown() => {
322 match &mut serve_future {
323 Either::Left(serve_future) => serve_future.as_mut().graceful_shutdown(),
324 Either::Right(serve_future) => serve_future.as_mut().graceful_shutdown(),
325 }
326 tokio::select! {
327 biased;
328 _ = watcher.wait_shutdown() => (),
329 _ = &mut serve_future => (),
330 }
331 }
332 _ = watcher.wait_shutdown() => (),
333 _ = &mut serve_future => (),
334 }
335 }
336 });
337 }
338 };
339
340 let result = tokio::select! {
341 biased;
342 _ = handle.wait_shutdown() => return Ok(()),
343 result = accept_loop_future => result,
344 };
345
346 drop(incoming);
350
351 #[allow(clippy::question_mark)]
353 if let Err(e) = result {
354 return Err(e);
355 }
356
357 handle.wait_connections_end().await;
358
359 Ok(())
360 }
361}
362
363async fn bind_incoming<A: Address>(listener: Listener<A>) -> io::Result<A::Listener> {
364 match listener {
365 Listener::Bind(addr) => A::Listener::bind_to(addr).await,
366 Listener::Ready(listener) => Ok(listener),
367 }
368}
369
370pub(crate) async fn accept<L: AddrListener<S, A>, S, A: Address>(listener: &mut L) -> (S, A) {
371 loop {
372 match listener.accept_stream().await {
373 Ok(value) => return value,
374 Err(_) => tokio::time::sleep(Duration::from_millis(50)).await,
375 }
376 }
377}
378
379type BoxError = Box<dyn std::error::Error + Send + Sync>;
380
381pub(crate) fn io_other<E: Into<BoxError>>(error: E) -> io::Error {
382 io::Error::other(error)
383}
384
385#[cfg(test)]
386mod tests {
387 use crate::{
388 handle::Handle,
389 server::{HttpVersion, Server},
390 };
391 use axum::body::Body;
392 use axum::response::Response;
393 use axum::routing::post;
394 use axum::{routing::get, Router};
395 use bytes::Bytes;
396 use futures_util::{stream, StreamExt};
397 use http::{Method, Request, Uri};
398 use http_body::Frame;
399 use http_body_util::{BodyExt, StreamBody};
400 use hyper::client;
401 use hyper::client::conn::{http1, http2};
402 use hyper_util::rt::{TokioExecutor, TokioIo};
403 use std::{io, net::SocketAddr as IpSocketAddr, time::Duration};
404 use tokio::sync::oneshot;
405 use tokio::{net::TcpStream, task::JoinHandle, time::timeout};
406
407 #[tokio::test]
408 async fn start_and_request() {
409 let (_handle, _server_task, addr) = start_server().await;
410
411 let (mut client, _conn) = connect(addr).await;
412
413 do_empty_request_h1(&mut client).await.unwrap();
416
417 do_slow_request(&mut client, Duration::from_millis(50))
418 .await
419 .unwrap();
420 }
421
422 #[tokio::test]
423 async fn test_shutdown() {
424 let (handle, _server_task, addr) = start_server().await;
425
426 let (mut client, conn) = connect(addr).await;
427
428 do_empty_request_h1(&mut client).await.unwrap();
430
431 handle.shutdown();
432
433 do_empty_request_h1(&mut client).await.unwrap_err();
435
436 let _ = timeout(Duration::from_secs(1), conn).await.unwrap();
438 }
439
440 #[tokio::test]
442 async fn test_graceful_shutdown_no_timeout() {
443 let (handle, server_task, addr) = start_server().await;
444
445 let (mut client1, _conn1) = connect(addr).await;
446 let (mut client2, _conn2) = connect(addr).await;
447
448 do_empty_request_h1(&mut client1).await.unwrap();
450 do_empty_request_h1(&mut client2).await.unwrap();
451
452 let start = tokio::time::Instant::now();
453
454 let (hdr1_tx, hdr1_rx) = oneshot::channel::<()>();
455
456 let fut1 = async {
457 let hdr1 = send_slow_request(&mut client1, Duration::from_millis(500))
461 .await
462 .unwrap();
463 hdr1_tx.send(()).unwrap();
464 recv_slow_response_body(hdr1).await.unwrap();
465
466 assert!(start.elapsed() >= Duration::from_millis(500));
467 };
468 let fut2 = async {
469 tokio::time::sleep(Duration::from_millis(250)).await;
471 hdr1_rx.await.unwrap();
472 handle.graceful_shutdown(None);
473
474 do_empty_request_h1(&mut client2).await.unwrap_err();
476 do_empty_request_h1(&mut client2).await.unwrap_err();
477 do_empty_request_h1(&mut client2).await.unwrap_err();
478 };
479
480 tokio::join!(fut1, fut2);
481
482 assert!(start.elapsed() >= Duration::from_millis(500 + 100));
486
487 timeout(Duration::from_secs(1), server_task)
489 .await
490 .unwrap()
491 .unwrap()
492 .unwrap();
493 }
494
495 #[tokio::test]
497 async fn test_graceful_shutdown_timeout() {
498 let (handle, server_task, addr) = start_server().await;
499
500 let (mut client1, _conn1) = connect(addr).await;
501 let (mut client2, _conn2) = connect(addr).await;
502
503 do_empty_request_h1(&mut client1).await.unwrap();
505 do_empty_request_h1(&mut client2).await.unwrap();
506
507 let start = tokio::time::Instant::now();
508
509 let (hdr1_tx, hdr1_rx) = oneshot::channel::<()>();
510
511 let task1 = async {
512 let hdr1 = send_slow_request(&mut client1, Duration::from_millis(222)).await;
515 hdr1_tx.send(()).unwrap();
516
517 let res1 = recv_slow_response_body(hdr1.unwrap()).await;
518 res1.unwrap();
519 };
520 let task2 = async {
521 let hdr2 = send_slow_request(&mut client2, Duration::from_millis(5_555)).await;
526 hdr2.unwrap_err();
527 };
528 let task3 = async {
529 hdr1_rx.await.unwrap();
531
532 handle.graceful_shutdown(Some(Duration::from_millis(333)));
534
535 timeout(Duration::from_secs(1), server_task)
537 .await
538 .unwrap()
539 .unwrap()
540 .unwrap();
541
542 assert!(start.elapsed() >= Duration::from_millis(222 + 333));
544 assert!(start.elapsed() <= Duration::from_millis(5_555));
545 };
546
547 tokio::join!(task1, task2, task3);
548 }
549
550 #[tokio::test]
551 async fn test_http1_only() {
552 let (_handle, _server_task, addr) =
553 start_server_with_http_version(Some(HttpVersion::Http1)).await;
554
555 let (mut client, _conn) = connect_h1(addr).await;
556
557 do_empty_request_h1(&mut client).await.unwrap();
558
559 do_slow_request(&mut client, Duration::from_millis(50))
560 .await
561 .unwrap();
562
563 let (mut client, _conn) = connect_h2(addr).await;
564 do_empty_request_h2(&mut client).await.unwrap_err();
565 }
566
567 #[tokio::test]
568 async fn test_http2_only() {
569 let (_handle, _server_task, addr) =
570 start_server_with_http_version(Some(HttpVersion::Http2)).await;
571
572 let (mut client, _conn) = connect_h2(addr).await;
573
574 do_empty_request_h2(&mut client).await.unwrap();
575
576 do_slow_request_h2(&mut client, Duration::from_millis(50))
577 .await
578 .unwrap();
579
580 let (mut client, _conn) = connect_h1(addr).await;
581 do_empty_request_h1(&mut client).await.unwrap_err();
582 }
583
584 async fn start_server_with_http_version(
585 http_version: Option<HttpVersion>,
586 ) -> (
587 Handle<IpSocketAddr>,
588 JoinHandle<io::Result<()>>,
589 IpSocketAddr,
590 ) {
591 let handle = Handle::new();
592
593 let server_handle = handle.clone();
594 let server_task = tokio::spawn(async move {
595 let app = Router::new()
596 .route("/", get(|| async { "Hello, world!" }))
597 .route(
598 "/echo_slowly",
599 post(|body: Bytes| async move {
600 Response::new(slow_body(body.len(), Duration::from_millis(100)))
602 }),
603 );
604
605 let addr = IpSocketAddr::from(([127, 0, 0, 1], 0));
606 let server = Server::bind(addr);
607 let server = match http_version {
608 Some(HttpVersion::Http1) => server.http1_only(),
609 Some(HttpVersion::Http2) => server.http2_only(),
610 None => server,
611 };
612
613 server
614 .handle(server_handle)
615 .serve(app.into_make_service())
616 .await
617 });
618
619 let addr = handle.listening().await.unwrap();
620
621 (handle, server_task, addr)
622 }
623
624 async fn start_server() -> (
625 Handle<IpSocketAddr>,
626 JoinHandle<io::Result<()>>,
627 IpSocketAddr,
628 ) {
629 start_server_with_http_version(None).await
630 }
631
632 async fn connect(addr: IpSocketAddr) -> (http1::SendRequest<Body>, JoinHandle<()>) {
633 connect_h1(addr).await
634 }
635
636 async fn connect_h1(addr: IpSocketAddr) -> (http1::SendRequest<Body>, JoinHandle<()>) {
637 let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap());
638 let (send_request, connection) = client::conn::http1::handshake(stream).await.unwrap();
639
640 let task = tokio::spawn(async move {
641 let _ = connection.await;
642 });
643
644 (send_request, task)
645 }
646
647 async fn connect_h2(addr: IpSocketAddr) -> (http2::SendRequest<Body>, JoinHandle<()>) {
648 let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap());
649 let (send_request, connection) =
650 client::conn::http2::handshake(TokioExecutor::new(), stream)
651 .await
652 .unwrap();
653
654 let task = tokio::spawn(async move {
655 let _ = connection.await;
656 });
657
658 (send_request, task)
659 }
660
661 async fn do_empty_request_h1(client: &mut http1::SendRequest<Body>) -> hyper::Result<()> {
663 client.ready().await?;
664
665 let body = client
666 .send_request(Request::new(Body::empty()))
667 .await?
668 .into_body();
669
670 let body = body.collect().await?.to_bytes();
671 assert_eq!(body.as_ref(), b"Hello, world!");
672 Ok(())
673 }
674
675 async fn do_empty_request_h2(client: &mut http2::SendRequest<Body>) -> hyper::Result<()> {
677 client.ready().await?;
678
679 let body = client
680 .send_request(Request::new(Body::empty()))
681 .await?
682 .into_body();
683
684 let body = body.collect().await?.to_bytes();
685 assert_eq!(body.as_ref(), b"Hello, world!");
686 Ok(())
687 }
688
689 async fn do_slow_request(
692 client: &mut http1::SendRequest<Body>,
693 duration: Duration,
694 ) -> hyper::Result<()> {
695 let response = send_slow_request(client, duration).await?;
696 recv_slow_response_body(response).await
697 }
698
699 async fn do_slow_request_h2(
700 client: &mut http2::SendRequest<Body>,
701 duration: Duration,
702 ) -> hyper::Result<()> {
703 let response = send_slow_request_h2(client, duration).await?;
704 recv_slow_response_body(response).await
705 }
706
707 async fn send_slow_request(
708 client: &mut http1::SendRequest<Body>,
709 duration: Duration,
710 ) -> hyper::Result<http::Response<hyper::body::Incoming>> {
711 let req_body_len: usize = 10;
712 let mut req = Request::new(slow_body(req_body_len, duration));
713 *req.method_mut() = Method::POST;
714 *req.uri_mut() = Uri::from_static("/echo_slowly");
715
716 client.ready().await?;
717 client.send_request(req).await
718 }
719
720 async fn send_slow_request_h2(
721 client: &mut http2::SendRequest<Body>,
722 duration: Duration,
723 ) -> hyper::Result<http::Response<hyper::body::Incoming>> {
724 let req_body_len: usize = 10;
725 let mut req = Request::new(slow_body(req_body_len, duration));
726 *req.method_mut() = Method::POST;
727 *req.uri_mut() = Uri::from_static("/echo_slowly");
728
729 client.ready().await?;
730 client.send_request(req).await
731 }
732
733 async fn recv_slow_response_body(
734 response: http::Response<hyper::body::Incoming>,
735 ) -> hyper::Result<()> {
736 let resp_body = response.into_body();
737 let resp_body_bytes = resp_body.collect().await?.to_bytes();
738 assert_eq!(10, resp_body_bytes.len());
739 Ok(())
740 }
741
742 fn slow_body(length: usize, duration: Duration) -> axum::body::Body {
746 let frames =
747 (0..length).map(move |_| Ok::<_, hyper::Error>(Frame::data(Bytes::from_static(b"X"))));
748
749 let stream = stream::iter(frames).then(move |frame| async move {
750 tokio::time::sleep(duration / (length as u32)).await;
751 frame
752 });
753
754 axum::body::Body::new(StreamBody::new(stream))
755 }
756}