1#[cfg(feature = "proxy-protocol")]
2use crate::proxy_protocol::ProxyProtocolAcceptor;
3use crate::{
4 accept::{Accept, DefaultAcceptor},
5 addr_incoming_config::AddrIncomingConfig,
6 handle::Handle,
7 http_config::HttpConfig,
8 service::{MakeServiceRef, SendService},
9};
10use futures_util::future::poll_fn;
11use http::Request;
12use hyper::server::{
13 accept::Accept as HyperAccept,
14 conn::{AddrIncoming, AddrStream},
15};
16#[cfg(feature = "proxy-protocol")]
17use std::time::Duration;
18use std::{
19 io::{self, ErrorKind},
20 net::SocketAddr,
21 pin::Pin,
22};
23use tokio::{
24 io::{AsyncRead, AsyncWrite},
25 net::TcpListener,
26};
27
28#[derive(Debug)]
30pub struct Server<A = DefaultAcceptor> {
31 acceptor: A,
32 listener: Listener,
33 addr_incoming_conf: AddrIncomingConfig,
34 handle: Handle,
35 http_conf: HttpConfig,
36 #[cfg(feature = "proxy-protocol")]
37 proxy_acceptor_set: bool,
38}
39
40#[derive(Debug)]
42enum Listener {
43 Bind(SocketAddr),
44 Std(std::net::TcpListener),
45}
46
47pub fn bind(addr: SocketAddr) -> Server {
49 Server::bind(addr)
50}
51
52pub fn from_tcp(listener: std::net::TcpListener) -> Server {
54 Server::from_tcp(listener)
55}
56
57impl Server {
58 pub fn bind(addr: SocketAddr) -> Self {
60 let acceptor = DefaultAcceptor::new();
61 let handle = Handle::new();
62
63 Self {
64 acceptor,
65 listener: Listener::Bind(addr),
66 addr_incoming_conf: AddrIncomingConfig::default(),
67 handle,
68 http_conf: HttpConfig::default(),
69 #[cfg(feature = "proxy-protocol")]
70 proxy_acceptor_set: false,
71 }
72 }
73
74 pub fn from_tcp(listener: std::net::TcpListener) -> Self {
76 let acceptor = DefaultAcceptor::new();
77 let handle = Handle::new();
78
79 Self {
80 acceptor,
81 listener: Listener::Std(listener),
82 addr_incoming_conf: AddrIncomingConfig::default(),
83 handle,
84 http_conf: HttpConfig::default(),
85 #[cfg(feature = "proxy-protocol")]
86 proxy_acceptor_set: false,
87 }
88 }
89}
90
91impl<A> Server<A> {
92 pub fn acceptor<Acceptor>(self, acceptor: Acceptor) -> Server<Acceptor> {
94 #[cfg(feature = "proxy-protocol")]
95 if self.proxy_acceptor_set {
96 panic!("Overwriting the acceptor after proxy protocol is enabled is not supported. Configure the acceptor first in the builder, then enable proxy protocol.");
97 }
98
99 Server {
100 acceptor,
101 listener: self.listener,
102 addr_incoming_conf: self.addr_incoming_conf,
103 handle: self.handle,
104 http_conf: self.http_conf,
105 #[cfg(feature = "proxy-protocol")]
106 proxy_acceptor_set: self.proxy_acceptor_set,
107 }
108 }
109
110 #[cfg(feature = "proxy-protocol")]
111 pub fn enable_proxy_protocol(
114 self,
115 parsing_timeout: Option<Duration>,
116 ) -> Server<ProxyProtocolAcceptor<A>> {
117 let initial_acceptor = self.acceptor;
118 let mut acceptor = ProxyProtocolAcceptor::new(initial_acceptor);
119
120 if let Some(val) = parsing_timeout {
121 acceptor = acceptor.parsing_timeout(val);
122 }
123
124 Server {
125 acceptor,
126 listener: self.listener,
127 addr_incoming_conf: self.addr_incoming_conf,
128 handle: self.handle,
129 http_conf: self.http_conf,
130 proxy_acceptor_set: true,
131 }
132 }
133
134 pub fn map<Acceptor, F>(self, acceptor: F) -> Server<Acceptor>
136 where
137 F: FnOnce(A) -> Acceptor,
138 {
139 Server {
140 acceptor: acceptor(self.acceptor),
141 listener: self.listener,
142 addr_incoming_conf: self.addr_incoming_conf,
143 handle: self.handle,
144 http_conf: self.http_conf,
145 #[cfg(feature = "proxy-protocol")]
146 proxy_acceptor_set: self.proxy_acceptor_set,
147 }
148 }
149
150 pub fn get_ref(&self) -> &A {
152 &self.acceptor
153 }
154
155 pub fn get_mut(&mut self) -> &mut A {
157 &mut self.acceptor
158 }
159
160 pub fn handle(mut self, handle: Handle) -> Self {
162 self.handle = handle;
163 self
164 }
165
166 pub fn http_config(mut self, config: HttpConfig) -> Self {
168 self.http_conf = config;
169 self
170 }
171
172 pub fn addr_incoming_config(mut self, config: AddrIncomingConfig) -> Self {
174 self.addr_incoming_conf = config;
175 self
176 }
177
178 pub async fn serve<M>(self, mut make_service: M) -> io::Result<()>
195 where
196 M: MakeServiceRef<AddrStream, Request<hyper::Body>>,
197 A: Accept<AddrStream, M::Service> + Clone + Send + Sync + 'static,
198 A::Stream: AsyncRead + AsyncWrite + Unpin + Send,
199 A::Service: SendService<Request<hyper::Body>> + Send,
200 A::Future: Send,
201 {
202 let acceptor = self.acceptor;
204 let addr_incoming_conf = self.addr_incoming_conf;
205 let handle = self.handle;
206 let http_conf = self.http_conf;
207
208 let mut incoming = match bind_incoming(self.listener, addr_incoming_conf).await {
210 Ok(v) => v,
211 Err(e) => {
212 handle.notify_listening(None);
213 return Err(e);
214 }
215 };
216
217 handle.notify_listening(Some(incoming.local_addr()));
219
220 let accept_loop_future = async {
222 loop {
223 let addr_stream = tokio::select! {
225 biased;
226 result = accept(&mut incoming) => result?,
227 _ = handle.wait_graceful_shutdown() => return Ok(()),
228 };
229
230 poll_fn(|cx| make_service.poll_ready(cx))
232 .await
233 .map_err(io_other)?;
234
235 let service = match make_service.make_service(&addr_stream).await {
237 Ok(service) => service,
238 Err(_) => continue, };
240
241 let acceptor = acceptor.clone();
243 let watcher = handle.watcher();
244 let http_conf = http_conf.clone();
245
246 tokio::spawn(async move {
248 if let Ok((stream, send_service)) = acceptor.accept(addr_stream, service).await
249 {
250 let service = send_service.into_service();
251
252 let mut serve_future = http_conf
253 .inner
254 .serve_connection(stream, service)
255 .with_upgrades();
256
257 tokio::select! {
259 biased;
260 _ = watcher.wait_graceful_shutdown() => {
261 Pin::new(&mut serve_future).graceful_shutdown();
263 tokio::select! {
264 biased;
265 _ = watcher.wait_shutdown() => (),
266 _ = &mut serve_future => (),
267 }
268 }
269 _ = watcher.wait_shutdown() => (),
270 _ = &mut serve_future => (),
271 }
272 }
273 });
275 }
276 };
277
278 let result = tokio::select! {
280 biased;
281 _ = handle.wait_shutdown() => return Ok(()),
282 result = accept_loop_future => result,
283 };
284
285 #[allow(clippy::question_mark)]
288 if let Err(e) = result {
289 return Err(e);
290 }
291
292 handle.wait_connections_end().await;
294
295 Ok(())
296 }
297}
298
299async fn bind_incoming(
319 listener: Listener,
320 addr_incoming_conf: AddrIncomingConfig,
321) -> io::Result<AddrIncoming> {
322 let listener = match listener {
323 Listener::Bind(addr) => TcpListener::bind(addr).await?,
324 Listener::Std(std_listener) => {
325 std_listener.set_nonblocking(true)?;
326 TcpListener::from_std(std_listener)?
327 }
328 };
329 let mut incoming = AddrIncoming::from_listener(listener).map_err(io_other)?;
330
331 incoming.set_sleep_on_errors(addr_incoming_conf.tcp_sleep_on_accept_errors);
333 incoming.set_keepalive(addr_incoming_conf.tcp_keepalive);
334 incoming.set_keepalive_interval(addr_incoming_conf.tcp_keepalive_interval);
335 incoming.set_keepalive_retries(addr_incoming_conf.tcp_keepalive_retries);
336 incoming.set_nodelay(addr_incoming_conf.tcp_nodelay);
337
338 Ok(incoming)
339}
340
341pub(crate) async fn accept(incoming: &mut AddrIncoming) -> io::Result<AddrStream> {
358 let mut incoming = Pin::new(incoming);
359
360 poll_fn(|cx| incoming.as_mut().poll_accept(cx))
363 .await
364 .unwrap()
365}
366
367type BoxError = Box<dyn std::error::Error + Send + Sync>;
369
370pub(crate) fn io_other<E: Into<BoxError>>(error: E) -> io::Error {
383 io::Error::new(ErrorKind::Other, error)
384}
385
386#[cfg(test)]
387mod tests {
388 use crate::{handle::Handle, server::Server};
389 use axum::{routing::get, Router};
390 use bytes::Bytes;
391 use http::{response, Request};
392 use hyper::{
393 client::conn::{handshake, SendRequest},
394 Body,
395 };
396 use std::{io, net::SocketAddr, time::Duration};
397 use tokio::{net::TcpStream, task::JoinHandle, time::timeout};
398 use tower::{Service, ServiceExt};
399
400 #[tokio::test]
401 async fn start_and_request() {
402 let (_handle, _server_task, addr) = start_server().await;
403
404 let (mut client, _conn) = connect(addr).await;
405
406 let (_parts, body) = send_empty_request(&mut client).await;
407
408 assert_eq!(body.as_ref(), b"Hello, world!");
409 }
410
411 #[tokio::test]
412 async fn test_shutdown() {
413 let (handle, _server_task, addr) = start_server().await;
414
415 let (mut client, conn) = connect(addr).await;
416
417 handle.shutdown();
418
419 let response_future_result = client
420 .ready()
421 .await
422 .unwrap()
423 .call(Request::new(Body::empty()))
424 .await;
425
426 assert!(response_future_result.is_err());
427
428 let _ = timeout(Duration::from_secs(1), conn).await.unwrap();
430 }
431
432 #[tokio::test]
433 async fn test_graceful_shutdown() {
434 let (handle, server_task, addr) = start_server().await;
435
436 let (mut client, conn) = connect(addr).await;
437
438 handle.graceful_shutdown(None);
439
440 let (_parts, body) = send_empty_request(&mut client).await;
441
442 assert_eq!(body.as_ref(), b"Hello, world!");
443
444 conn.abort();
446
447 let server_result = timeout(Duration::from_secs(1), server_task)
450 .await
451 .unwrap()
452 .unwrap();
453
454 assert!(server_result.is_ok());
455 }
456
457 #[tokio::test]
458 async fn test_graceful_shutdown_timed() {
459 let (handle, server_task, addr) = start_server().await;
460
461 let (mut client, _conn) = connect(addr).await;
462
463 handle.graceful_shutdown(Some(Duration::from_millis(250)));
464
465 let (_parts, body) = send_empty_request(&mut client).await;
466
467 assert_eq!(body.as_ref(), b"Hello, world!");
468
469 let server_result = timeout(Duration::from_secs(1), server_task)
471 .await
472 .unwrap()
473 .unwrap();
474
475 assert!(server_result.is_ok());
476 }
477
478 async fn start_server() -> (Handle, JoinHandle<io::Result<()>>, SocketAddr) {
479 let handle = Handle::new();
480
481 let server_handle = handle.clone();
482 let server_task = tokio::spawn(async move {
483 let app = Router::new().route("/", get(|| async { "Hello, world!" }));
484
485 let addr = SocketAddr::from(([127, 0, 0, 1], 0));
486
487 Server::bind(addr)
488 .handle(server_handle)
489 .serve(app.into_make_service())
490 .await
491 });
492
493 let addr = handle.listening().await.unwrap();
494
495 (handle, server_task, addr)
496 }
497
498 async fn connect(addr: SocketAddr) -> (SendRequest<Body>, JoinHandle<()>) {
499 let stream = TcpStream::connect(addr).await.unwrap();
500
501 let (send_request, connection) = handshake(stream).await.unwrap();
502
503 let task = tokio::spawn(async move {
504 let _ = connection.await;
505 });
506
507 (send_request, task)
508 }
509
510 async fn send_empty_request(client: &mut SendRequest<Body>) -> (response::Parts, Bytes) {
511 let (parts, body) = client
512 .ready()
513 .await
514 .unwrap()
515 .call(Request::new(Body::empty()))
516 .await
517 .unwrap()
518 .into_parts();
519 let body = hyper::body::to_bytes(body).await.unwrap();
520
521 (parts, body)
522 }
523}