acril_http/server/
mod.rs

1//! Process HTTP connections on the server.
2
3use futures::io::{self, AsyncRead as Read, AsyncWrite as Write};
4use http_types::headers::{CONNECTION, UPGRADE};
5use http_types::upgrade::Connection;
6use http_types::{Request, Response, StatusCode};
7use std::{future::Future, time::Duration};
8mod body_reader;
9mod decode;
10mod encode;
11
12pub use decode::decode;
13pub use encode::Encoder;
14
15/// Configure the server.
16#[derive(Debug, Clone)]
17pub struct ServerOptions {
18    /// Timeout to handle headers. Defaults to 60s.
19    headers_timeout: Option<Duration>,
20}
21
22impl Default for ServerOptions {
23    fn default() -> Self {
24        Self {
25            headers_timeout: Some(Duration::from_secs(60)),
26        }
27    }
28}
29
30/// struct for server
31#[derive(Debug)]
32pub struct Server<RW> {
33    io: RW,
34    opts: ServerOptions,
35}
36
37/// An enum that represents whether the server should accept a subsequent request
38#[derive(Debug, Copy, Clone, Eq, PartialEq)]
39pub enum ConnectionStatus {
40    /// The server should not accept another request
41    Close,
42
43    /// The server may accept another request
44    KeepAlive,
45}
46
47impl<RW> Server<RW>
48where
49    RW: Read + Write + Clone + Send + Sync + Unpin + 'static,
50{
51    /// builds a new server
52    pub fn new(io: RW) -> Self {
53        Self {
54            io,
55            opts: Default::default(),
56        }
57    }
58
59    /// with opts
60    pub fn with_opts(mut self, opts: ServerOptions) -> Self {
61        self.opts = opts;
62        self
63    }
64
65    /// accept one request
66    pub async fn accept_one<
67        F: FnOnce(Request) -> Fut,
68        Fut: Future<Output = Result<Response, Error>>,
69        Error: From<http_types::Error> + From<std::io::Error>,
70    >(
71        &mut self,
72        callback: F,
73    ) -> Result<ConnectionStatus, Error> {
74        // Decode a new request, timing out if this takes longer than the timeout duration.
75        let fut = decode(self.io.clone());
76
77        let (req, mut body) = if let Some(timeout_duration) = self.opts.headers_timeout {
78            match async_std::future::timeout(timeout_duration, fut).await {
79                Ok(Ok(Some(r))) => r,
80                Err(_) | Ok(Ok(None)) => return Ok(ConnectionStatus::Close), /* EOF or timeout */
81                Ok(Err(e)) => return Err(e.into()),
82            }
83        } else {
84            match fut.await? {
85                Some(r) => r,
86                None => return Ok(ConnectionStatus::Close), /* EOF */
87            }
88        };
89
90        let has_upgrade_header = req.header(UPGRADE).is_some();
91        let connection_header_as_str = req
92            .header(CONNECTION)
93            .map(|connection| connection.as_str())
94            .unwrap_or("");
95
96        let connection_header_is_upgrade = connection_header_as_str
97            .split(',')
98            .any(|s| s.trim().eq_ignore_ascii_case("upgrade"));
99        let mut close_connection = connection_header_as_str.eq_ignore_ascii_case("close");
100
101        let upgrade_requested = has_upgrade_header && connection_header_is_upgrade;
102
103        let method = req.method();
104
105        // Pass the request to the endpoint and encode the response.
106        let mut res = (callback)(req).await?;
107
108        close_connection |= res
109            .header(CONNECTION)
110            .map(|c| c.as_str().eq_ignore_ascii_case("close"))
111            .unwrap_or(false);
112
113        let upgrade_provided = res.status() == StatusCode::SwitchingProtocols && res.has_upgrade();
114
115        let upgrade_sender = if upgrade_requested && upgrade_provided {
116            Some(res.send_upgrade())
117        } else {
118            None
119        };
120
121        let mut encoder = Encoder::new(res, method);
122
123        let bytes_written = io::copy(&mut encoder, &mut self.io).await?;
124        log::trace!("wrote {} response bytes", bytes_written);
125
126        let body_bytes_discarded = io::copy(&mut body, &mut io::sink()).await?;
127        log::trace!(
128            "discarded {} unread request body bytes",
129            body_bytes_discarded
130        );
131
132        if let Some(upgrade_sender) = upgrade_sender {
133            upgrade_sender.send(Connection::new(self.io.clone())).await;
134            Ok(ConnectionStatus::Close)
135        } else if close_connection {
136            Ok(ConnectionStatus::Close)
137        } else {
138            Ok(ConnectionStatus::KeepAlive)
139        }
140    }
141}