cyfs_async_h1/server/
mod.rs

1//! Process HTTP connections on the server.
2
3use async_std::future::{timeout, Future, TimeoutError};
4use async_std::io::{self, Read, Write};
5use http_types::headers::{CONNECTION, UPGRADE};
6use http_types::upgrade::Connection;
7use http_types::{Request, Response, StatusCode};
8use std::{marker::PhantomData, time::Duration};
9mod body_reader;
10mod decode;
11mod encode;
12
13pub use decode::decode;
14pub use encode::Encoder;
15
16/// Configure the server.
17#[derive(Debug, Clone)]
18pub struct ServerOptions {
19    /// Timeout to handle headers. Defaults to 60s.
20    headers_timeout: Option<Duration>,
21}
22
23impl Default for ServerOptions {
24    fn default() -> Self {
25        Self {
26            headers_timeout: Some(Duration::from_secs(60)),
27        }
28    }
29}
30
31/// Accept a new incoming HTTP/1.1 connection.
32///
33/// Supports `KeepAlive` requests by default.
34pub async fn accept<RW, F, Fut>(io: RW, endpoint: F) -> http_types::Result<()>
35where
36    RW: Read + Write + Clone + Send + Sync + Unpin + 'static,
37    F: Fn(Request) -> Fut,
38    Fut: Future<Output = http_types::Result<Response>>,
39{
40    Server::new(io, endpoint).accept().await
41}
42
43/// Accept a new incoming HTTP/1.1 connection.
44///
45/// Supports `KeepAlive` requests by default.
46pub async fn accept_with_opts<RW, F, Fut>(
47    io: RW,
48    endpoint: F,
49    opts: ServerOptions,
50) -> http_types::Result<()>
51where
52    RW: Read + Write + Clone + Send + Sync + Unpin + 'static,
53    F: Fn(Request) -> Fut,
54    Fut: Future<Output = http_types::Result<Response>>,
55{
56    Server::new(io, endpoint).with_opts(opts).accept().await
57}
58
59/// struct for server
60#[derive(Debug)]
61pub struct Server<RW, F, Fut> {
62    io: RW,
63    endpoint: F,
64    opts: ServerOptions,
65    _phantom: PhantomData<Fut>,
66}
67
68/// An enum that represents whether the server should accept a subsequent request
69#[derive(Debug, Copy, Clone, Eq, PartialEq)]
70pub enum ConnectionStatus {
71    /// The server should not accept another request
72    Close,
73
74    /// The server may accept another request
75    KeepAlive,
76}
77
78impl<RW, F, Fut> Server<RW, F, Fut>
79where
80    RW: Read + Write + Clone + Send + Sync + Unpin + 'static,
81    F: Fn(Request) -> Fut,
82    Fut: Future<Output = http_types::Result<Response>>,
83{
84    /// builds a new server
85    pub fn new(io: RW, endpoint: F) -> Self {
86        Self {
87            io,
88            endpoint,
89            opts: Default::default(),
90            _phantom: PhantomData,
91        }
92    }
93
94    /// with opts
95    pub fn with_opts(mut self, opts: ServerOptions) -> Self {
96        self.opts = opts;
97        self
98    }
99
100    /// accept in a loop
101    pub async fn accept(&mut self) -> http_types::Result<()> {
102        while ConnectionStatus::KeepAlive == self.accept_one().await? {}
103        Ok(())
104    }
105
106    /// accept one request
107    pub async fn accept_one(&mut self) -> http_types::Result<ConnectionStatus>
108    where
109        RW: Read + Write + Clone + Send + Sync + Unpin + 'static,
110        F: Fn(Request) -> Fut,
111        Fut: Future<Output = http_types::Result<Response>>,
112    {
113        // Decode a new request, timing out if this takes longer than the timeout duration.
114        let fut = decode(self.io.clone());
115
116        let (req, mut body) = if let Some(timeout_duration) = self.opts.headers_timeout {
117            match timeout(timeout_duration, fut).await {
118                Ok(Ok(Some(r))) => r,
119                Ok(Ok(None)) | Err(TimeoutError { .. }) => return Ok(ConnectionStatus::Close), /* EOF or timeout */
120                Ok(Err(e)) => return Err(e),
121            }
122        } else {
123            match fut.await? {
124                Some(r) => r,
125                None => return Ok(ConnectionStatus::Close), /* EOF */
126            }
127        };
128
129        let has_upgrade_header = req.header(UPGRADE).is_some();
130        let connection_header_as_str = req
131            .header(CONNECTION)
132            .map(|connection| connection.as_str())
133            .unwrap_or("");
134
135        let connection_header_is_upgrade = connection_header_as_str
136            .split(',')
137            .any(|s| s.trim().eq_ignore_ascii_case("upgrade"));
138        let mut close_connection = connection_header_as_str.eq_ignore_ascii_case("close");
139
140        let upgrade_requested = has_upgrade_header && connection_header_is_upgrade;
141
142        let method = req.method();
143
144        // Pass the request to the endpoint and encode the response.
145        let mut res = (self.endpoint)(req).await?;
146
147        close_connection |= res
148            .header(CONNECTION)
149            .map(|c| c.as_str().eq_ignore_ascii_case("close"))
150            .unwrap_or(false);
151
152        let upgrade_provided = res.status() == StatusCode::SwitchingProtocols && res.has_upgrade();
153
154        let upgrade_sender = if upgrade_requested && upgrade_provided {
155            Some(res.send_upgrade())
156        } else {
157            None
158        };
159
160        let mut encoder = Encoder::new(res, method);
161
162        let bytes_written = io::copy(&mut encoder, &mut self.io).await?;
163        log::trace!("wrote {} response bytes", bytes_written);
164
165        let body_bytes_discarded = io::copy(&mut body, &mut io::sink()).await?;
166        log::trace!(
167            "discarded {} unread request body bytes",
168            body_bytes_discarded
169        );
170
171        if let Some(upgrade_sender) = upgrade_sender {
172            upgrade_sender.send(Connection::new(self.io.clone())).await;
173            Ok(ConnectionStatus::Close)
174        } else if close_connection {
175            Ok(ConnectionStatus::Close)
176        } else {
177            Ok(ConnectionStatus::KeepAlive)
178        }
179    }
180}