hyper_server/
server.rs

1#[cfg(feature = "proxy-protocol")]
2use crate::proxy_protocol::ProxyProtocolAcceptor;
3use crate::{
4    accept::{Accept, DefaultAcceptor},
5    addr_incoming_config::AddrIncomingConfig,
6    handle::Handle,
7    http_config::HttpConfig,
8    service::{MakeServiceRef, SendService},
9};
10use futures_util::future::poll_fn;
11use http::Request;
12use hyper::server::{
13    accept::Accept as HyperAccept,
14    conn::{AddrIncoming, AddrStream},
15};
16#[cfg(feature = "proxy-protocol")]
17use std::time::Duration;
18use std::{
19    io::{self, ErrorKind},
20    net::SocketAddr,
21    pin::Pin,
22};
23use tokio::{
24    io::{AsyncRead, AsyncWrite},
25    net::TcpListener,
26};
27
28/// Represents an HTTP server with customization capabilities for handling incoming requests.
29#[derive(Debug)]
30pub struct Server<A = DefaultAcceptor> {
31    acceptor: A,
32    listener: Listener,
33    addr_incoming_conf: AddrIncomingConfig,
34    handle: Handle,
35    http_conf: HttpConfig,
36    #[cfg(feature = "proxy-protocol")]
37    proxy_acceptor_set: bool,
38}
39
40/// Enum representing the ways the server can be initialized - either by binding to an address or from a standard TCP listener.
41#[derive(Debug)]
42enum Listener {
43    Bind(SocketAddr),
44    Std(std::net::TcpListener),
45}
46
47/// Creates a new [`Server`] instance that binds to the provided address.
48pub fn bind(addr: SocketAddr) -> Server {
49    Server::bind(addr)
50}
51
52/// Creates a new [`Server`] instance using an existing `std::net::TcpListener`.
53pub fn from_tcp(listener: std::net::TcpListener) -> Server {
54    Server::from_tcp(listener)
55}
56
57impl Server {
58    /// Constructs a server bound to the provided address.
59    pub fn bind(addr: SocketAddr) -> Self {
60        let acceptor = DefaultAcceptor::new();
61        let handle = Handle::new();
62
63        Self {
64            acceptor,
65            listener: Listener::Bind(addr),
66            addr_incoming_conf: AddrIncomingConfig::default(),
67            handle,
68            http_conf: HttpConfig::default(),
69            #[cfg(feature = "proxy-protocol")]
70            proxy_acceptor_set: false,
71        }
72    }
73
74    /// Constructs a server from an existing `std::net::TcpListener`.
75    pub fn from_tcp(listener: std::net::TcpListener) -> Self {
76        let acceptor = DefaultAcceptor::new();
77        let handle = Handle::new();
78
79        Self {
80            acceptor,
81            listener: Listener::Std(listener),
82            addr_incoming_conf: AddrIncomingConfig::default(),
83            handle,
84            http_conf: HttpConfig::default(),
85            #[cfg(feature = "proxy-protocol")]
86            proxy_acceptor_set: false,
87        }
88    }
89}
90
91impl<A> Server<A> {
92    /// Replace the current acceptor with a new one.
93    pub fn acceptor<Acceptor>(self, acceptor: Acceptor) -> Server<Acceptor> {
94        #[cfg(feature = "proxy-protocol")]
95        if self.proxy_acceptor_set {
96            panic!("Overwriting the acceptor after proxy protocol is enabled is not supported. Configure the acceptor first in the builder, then enable proxy protocol.");
97        }
98
99        Server {
100            acceptor,
101            listener: self.listener,
102            addr_incoming_conf: self.addr_incoming_conf,
103            handle: self.handle,
104            http_conf: self.http_conf,
105            #[cfg(feature = "proxy-protocol")]
106            proxy_acceptor_set: self.proxy_acceptor_set,
107        }
108    }
109
110    #[cfg(feature = "proxy-protocol")]
111    /// Enable proxy protocol header parsing.
112    /// Note has to be called after initial acceptor is set.
113    pub fn enable_proxy_protocol(
114        self,
115        parsing_timeout: Option<Duration>,
116    ) -> Server<ProxyProtocolAcceptor<A>> {
117        let initial_acceptor = self.acceptor;
118        let mut acceptor = ProxyProtocolAcceptor::new(initial_acceptor);
119
120        if let Some(val) = parsing_timeout {
121            acceptor = acceptor.parsing_timeout(val);
122        }
123
124        Server {
125            acceptor,
126            listener: self.listener,
127            addr_incoming_conf: self.addr_incoming_conf,
128            handle: self.handle,
129            http_conf: self.http_conf,
130            proxy_acceptor_set: true,
131        }
132    }
133
134    /// Maps the current acceptor to a new type.
135    pub fn map<Acceptor, F>(self, acceptor: F) -> Server<Acceptor>
136    where
137        F: FnOnce(A) -> Acceptor,
138    {
139        Server {
140            acceptor: acceptor(self.acceptor),
141            listener: self.listener,
142            addr_incoming_conf: self.addr_incoming_conf,
143            handle: self.handle,
144            http_conf: self.http_conf,
145            #[cfg(feature = "proxy-protocol")]
146            proxy_acceptor_set: self.proxy_acceptor_set,
147        }
148    }
149
150    /// Retrieves a reference to the server's acceptor.
151    pub fn get_ref(&self) -> &A {
152        &self.acceptor
153    }
154
155    /// Retrieves a mutable reference to the server's acceptor.
156    pub fn get_mut(&mut self) -> &mut A {
157        &mut self.acceptor
158    }
159
160    /// Provides the server with a handle for extra utilities.
161    pub fn handle(mut self, handle: Handle) -> Self {
162        self.handle = handle;
163        self
164    }
165
166    /// Replaces the current HTTP configuration.
167    pub fn http_config(mut self, config: HttpConfig) -> Self {
168        self.http_conf = config;
169        self
170    }
171
172    /// Replaces the current incoming address configuration.
173    pub fn addr_incoming_config(mut self, config: AddrIncomingConfig) -> Self {
174        self.addr_incoming_conf = config;
175        self
176    }
177
178    /// Serves the provided `MakeService`.
179    ///
180    /// The `MakeService` is responsible for constructing services for each incoming connection.
181    /// Each service is then used to handle requests from that specific connection.
182    ///
183    /// # Arguments
184    /// - `make_service`: A mutable reference to a type implementing the `MakeServiceRef` trait.
185    ///   This will be used to produce a service for each incoming connection.
186    ///
187    /// # Errors
188    ///
189    /// This method can return errors in the following scenarios:
190    /// - When binding to an address fails.
191    /// - If the `make_service` function encounters an error during its `poll_ready` call.
192    ///   It's worth noting that this error scenario doesn't typically occur with `axum` make services.
193    ///
194    pub async fn serve<M>(self, mut make_service: M) -> io::Result<()>
195    where
196        M: MakeServiceRef<AddrStream, Request<hyper::Body>>,
197        A: Accept<AddrStream, M::Service> + Clone + Send + Sync + 'static,
198        A::Stream: AsyncRead + AsyncWrite + Unpin + Send,
199        A::Service: SendService<Request<hyper::Body>> + Send,
200        A::Future: Send,
201    {
202        // Extract relevant fields from `self` for easier access.
203        let acceptor = self.acceptor;
204        let addr_incoming_conf = self.addr_incoming_conf;
205        let handle = self.handle;
206        let http_conf = self.http_conf;
207
208        // Bind the incoming connections. Notify the handle if an error occurs during binding.
209        let mut incoming = match bind_incoming(self.listener, addr_incoming_conf).await {
210            Ok(v) => v,
211            Err(e) => {
212                handle.notify_listening(None);
213                return Err(e);
214            }
215        };
216
217        // Notify the handle about the server's listening state.
218        handle.notify_listening(Some(incoming.local_addr()));
219
220        // This is the main loop that accepts incoming connections and spawns tasks to handle them.
221        let accept_loop_future = async {
222            loop {
223                // Wait for a new connection or for the server to be signaled to shut down.
224                let addr_stream = tokio::select! {
225                    biased;
226                    result = accept(&mut incoming) => result?,
227                    _ = handle.wait_graceful_shutdown() => return Ok(()),
228                };
229
230                // Ensure the `make_service` is ready to produce another service.
231                poll_fn(|cx| make_service.poll_ready(cx))
232                    .await
233                    .map_err(io_other)?;
234
235                // Create a service for this connection.
236                let service = match make_service.make_service(&addr_stream).await {
237                    Ok(service) => service,
238                    Err(_) => continue, // TODO: Consider logging or handling this error in a more detailed manner.
239                };
240
241                // Clone necessary objects for the spawned task.
242                let acceptor = acceptor.clone();
243                let watcher = handle.watcher();
244                let http_conf = http_conf.clone();
245
246                // Spawn a new task to handle the connection.
247                tokio::spawn(async move {
248                    if let Ok((stream, send_service)) = acceptor.accept(addr_stream, service).await
249                    {
250                        let service = send_service.into_service();
251
252                        let mut serve_future = http_conf
253                            .inner
254                            .serve_connection(stream, service)
255                            .with_upgrades();
256
257                        // Wait for either the server to be shut down or the connection to finish.
258                        tokio::select! {
259                            biased;
260                            _ = watcher.wait_graceful_shutdown() => {
261                                // Initiate a graceful shutdown.
262                                Pin::new(&mut serve_future).graceful_shutdown();
263                                tokio::select! {
264                                    biased;
265                                    _ = watcher.wait_shutdown() => (),
266                                    _ = &mut serve_future => (),
267                                }
268                            }
269                            _ = watcher.wait_shutdown() => (),
270                            _ = &mut serve_future => (),
271                        }
272                    }
273                    // TODO: Consider logging or handling any errors that occur during acceptance.
274                });
275            }
276        };
277
278        // Wait for either the server to be fully shut down or an error to occur.
279        let result = tokio::select! {
280            biased;
281            _ = handle.wait_shutdown() => return Ok(()),
282            result = accept_loop_future => result,
283        };
284
285        // Handle potential errors.
286        // TODO: Consider removing the Clippy annotation by restructuring this error handling.
287        #[allow(clippy::question_mark)]
288        if let Err(e) = result {
289            return Err(e);
290        }
291
292        // Wait for all connections to end.
293        handle.wait_connections_end().await;
294
295        Ok(())
296    }
297}
298
299/// Binds the listener based on the provided configuration and returns an [`AddrIncoming`]
300/// which will produce [`AddrStream`]s for incoming connections.
301///
302/// The function takes into account different ways the listener might be set up,
303/// either by binding to a provided address or by using an existing standard listener.
304///
305/// # Arguments
306///
307/// - `listener`: The listener configuration. Can be either a direct bind address or an existing standard listener.
308/// - `addr_incoming_conf`: Configuration for the incoming connections, such as TCP keepalive settings.
309///
310/// # Errors
311///
312/// Returns an `io::Error` if:
313/// - Binding the listener fails.
314/// - Setting the listener to non-blocking mode fails.
315/// - The listener cannot be converted to a [`TcpListener`].
316/// - An error occurs when creating the [`AddrIncoming`].
317///
318async fn bind_incoming(
319    listener: Listener,
320    addr_incoming_conf: AddrIncomingConfig,
321) -> io::Result<AddrIncoming> {
322    let listener = match listener {
323        Listener::Bind(addr) => TcpListener::bind(addr).await?,
324        Listener::Std(std_listener) => {
325            std_listener.set_nonblocking(true)?;
326            TcpListener::from_std(std_listener)?
327        }
328    };
329    let mut incoming = AddrIncoming::from_listener(listener).map_err(io_other)?;
330
331    // Apply configuration settings to the incoming connection handler.
332    incoming.set_sleep_on_errors(addr_incoming_conf.tcp_sleep_on_accept_errors);
333    incoming.set_keepalive(addr_incoming_conf.tcp_keepalive);
334    incoming.set_keepalive_interval(addr_incoming_conf.tcp_keepalive_interval);
335    incoming.set_keepalive_retries(addr_incoming_conf.tcp_keepalive_retries);
336    incoming.set_nodelay(addr_incoming_conf.tcp_nodelay);
337
338    Ok(incoming)
339}
340
341/// Awaits and accepts a new incoming connection.
342///
343/// This function will poll the given `incoming` object until a new connection is ready to be accepted.
344///
345/// # Arguments
346///
347/// - `incoming`: The incoming connection handler from which new connections will be accepted.
348///
349/// # Returns
350///
351/// Returns the accepted [`AddrStream`] which represents a specific incoming connection.
352///
353/// # Panics
354///
355/// This function will panic if the `poll_accept` method returns `None`, which should never happen as per the Hyper documentation.
356///
357pub(crate) async fn accept(incoming: &mut AddrIncoming) -> io::Result<AddrStream> {
358    let mut incoming = Pin::new(incoming);
359
360    // Always [`Option::Some`].
361    // According to: https://docs.rs/hyper/0.14.14/src/hyper/server/tcp.rs.html#165
362    poll_fn(|cx| incoming.as_mut().poll_accept(cx))
363        .await
364        .unwrap()
365}
366
367/// Type definition for a boxed error which can be sent between threads and is Sync.
368type BoxError = Box<dyn std::error::Error + Send + Sync>;
369
370/// Converts any error into an `io::Error` of kind `Other`.
371///
372/// This function can be used to create a uniform `io::Error` response for various error types.
373///
374/// # Arguments
375///
376/// - `error`: The error to be converted.
377///
378/// # Returns
379///
380/// Returns an `io::Error` with the kind set to `Other` and the provided error as its cause.
381///
382pub(crate) fn io_other<E: Into<BoxError>>(error: E) -> io::Error {
383    io::Error::new(ErrorKind::Other, error)
384}
385
386#[cfg(test)]
387mod tests {
388    use crate::{handle::Handle, server::Server};
389    use axum::{routing::get, Router};
390    use bytes::Bytes;
391    use http::{response, Request};
392    use hyper::{
393        client::conn::{handshake, SendRequest},
394        Body,
395    };
396    use std::{io, net::SocketAddr, time::Duration};
397    use tokio::{net::TcpStream, task::JoinHandle, time::timeout};
398    use tower::{Service, ServiceExt};
399
400    #[tokio::test]
401    async fn start_and_request() {
402        let (_handle, _server_task, addr) = start_server().await;
403
404        let (mut client, _conn) = connect(addr).await;
405
406        let (_parts, body) = send_empty_request(&mut client).await;
407
408        assert_eq!(body.as_ref(), b"Hello, world!");
409    }
410
411    #[tokio::test]
412    async fn test_shutdown() {
413        let (handle, _server_task, addr) = start_server().await;
414
415        let (mut client, conn) = connect(addr).await;
416
417        handle.shutdown();
418
419        let response_future_result = client
420            .ready()
421            .await
422            .unwrap()
423            .call(Request::new(Body::empty()))
424            .await;
425
426        assert!(response_future_result.is_err());
427
428        // Connection task should finish soon.
429        let _ = timeout(Duration::from_secs(1), conn).await.unwrap();
430    }
431
432    #[tokio::test]
433    async fn test_graceful_shutdown() {
434        let (handle, server_task, addr) = start_server().await;
435
436        let (mut client, conn) = connect(addr).await;
437
438        handle.graceful_shutdown(None);
439
440        let (_parts, body) = send_empty_request(&mut client).await;
441
442        assert_eq!(body.as_ref(), b"Hello, world!");
443
444        // Disconnect client.
445        conn.abort();
446
447        // TODO(This does not shut down gracefully)
448        // Server task should finish soon.
449        let server_result = timeout(Duration::from_secs(1), server_task)
450            .await
451            .unwrap()
452            .unwrap();
453
454        assert!(server_result.is_ok());
455    }
456
457    #[tokio::test]
458    async fn test_graceful_shutdown_timed() {
459        let (handle, server_task, addr) = start_server().await;
460
461        let (mut client, _conn) = connect(addr).await;
462
463        handle.graceful_shutdown(Some(Duration::from_millis(250)));
464
465        let (_parts, body) = send_empty_request(&mut client).await;
466
467        assert_eq!(body.as_ref(), b"Hello, world!");
468
469        // Server task should finish soon.
470        let server_result = timeout(Duration::from_secs(1), server_task)
471            .await
472            .unwrap()
473            .unwrap();
474
475        assert!(server_result.is_ok());
476    }
477
478    async fn start_server() -> (Handle, JoinHandle<io::Result<()>>, SocketAddr) {
479        let handle = Handle::new();
480
481        let server_handle = handle.clone();
482        let server_task = tokio::spawn(async move {
483            let app = Router::new().route("/", get(|| async { "Hello, world!" }));
484
485            let addr = SocketAddr::from(([127, 0, 0, 1], 0));
486
487            Server::bind(addr)
488                .handle(server_handle)
489                .serve(app.into_make_service())
490                .await
491        });
492
493        let addr = handle.listening().await.unwrap();
494
495        (handle, server_task, addr)
496    }
497
498    async fn connect(addr: SocketAddr) -> (SendRequest<Body>, JoinHandle<()>) {
499        let stream = TcpStream::connect(addr).await.unwrap();
500
501        let (send_request, connection) = handshake(stream).await.unwrap();
502
503        let task = tokio::spawn(async move {
504            let _ = connection.await;
505        });
506
507        (send_request, task)
508    }
509
510    async fn send_empty_request(client: &mut SendRequest<Body>) -> (response::Parts, Bytes) {
511        let (parts, body) = client
512            .ready()
513            .await
514            .unwrap()
515            .call(Request::new(Body::empty()))
516            .await
517            .unwrap()
518            .into_parts();
519        let body = hyper::body::to_bytes(body).await.unwrap();
520
521        (parts, body)
522    }
523}