1use async_io::Timer;
4use futures_lite::io::{self, AsyncRead as Read, AsyncWrite as Write};
5use futures_lite::prelude::*;
6use http_types::upgrade::Connection;
12use http_types::{
13 headers::{CONNECTION, UPGRADE},
14 Version,
15};
16use http_types::{Request, Response, StatusCode};
17use std::{future::Future, marker::PhantomData, time::Duration};
18mod body_reader;
19mod decode;
20mod encode;
21
22pub use decode::decode;
23pub use encode::Encoder;
24
25#[derive(Debug, Clone)]
27pub struct ServerOptions {
28 headers_timeout: Option<Duration>,
30 default_host: Option<String>,
31}
32
33impl ServerOptions {
34 pub fn new() -> Self {
36 Self::default()
37 }
38
39 pub fn with_headers_timeout(mut self, headers_timeout: Duration) -> Self {
41 self.headers_timeout = Some(headers_timeout);
42 self
43 }
44
45 pub fn with_default_host(mut self, default_host: &str) -> Self {
56 self.default_host = Some(default_host.into());
57 self
58 }
59}
60
61impl Default for ServerOptions {
62 fn default() -> Self {
63 Self {
64 headers_timeout: Some(Duration::from_secs(60)),
65 default_host: None,
66 }
67 }
68}
69
70pub async fn accept<RW, F, Fut>(io: RW, endpoint: F) -> crate::Result<()>
74where
75 RW: Read + Write + Clone + Send + Sync + Unpin + 'static,
76 F: Fn(Request) -> Fut,
77 Fut: Future<Output = Response>,
78{
79 Server::new(io, endpoint).accept().await
80}
81
82pub async fn accept_with_opts<RW, F, Fut>(
86 io: RW,
87 endpoint: F,
88 opts: ServerOptions,
89) -> crate::Result<()>
90where
91 RW: Read + Write + Clone + Send + Sync + Unpin + 'static,
92 F: Fn(Request) -> Fut,
93 Fut: Future<Output = Response>,
94{
95 Server::new(io, endpoint).with_opts(opts).accept().await
96}
97
98#[derive(Debug)]
100pub struct Server<RW, F, Fut> {
101 io: RW,
102 endpoint: F,
103 opts: ServerOptions,
104 _phantom: PhantomData<Fut>,
105}
106
107#[derive(Debug, Copy, Clone, Eq, PartialEq)]
109pub enum ConnectionStatus {
110 Close,
112
113 KeepAlive,
115}
116
117impl<RW, F, Fut> Server<RW, F, Fut>
118where
119 RW: Read + Write + Clone + Send + Sync + Unpin + 'static,
120 F: Fn(Request) -> Fut,
121 Fut: Future<Output = Response>,
122{
123 pub fn new(io: RW, endpoint: F) -> Self {
125 Self {
126 io,
127 endpoint,
128 opts: Default::default(),
129 _phantom: PhantomData,
130 }
131 }
132
133 pub fn with_opts(mut self, opts: ServerOptions) -> Self {
135 self.opts = opts;
136 self
137 }
138
139 pub async fn accept(&mut self) -> crate::Result<()> {
141 loop {
142 let result = self.accept_one().await;
143 match result {
144 Ok(status) => {
145 if status != ConnectionStatus::KeepAlive {
146 break;
147 }
148 },
149 Err(err) => {
150 log::warn!("async-h1 accept_one returns Err: {err:#?}");
151 return Err(err);
152 }
153 }
154 }
155 Ok(())
156 }
157
158 pub async fn accept_one(&mut self) -> crate::Result<ConnectionStatus>
160 where
161 RW: Read + Write + Clone + Send + Sync + Unpin + 'static,
162 F: Fn(Request) -> Fut,
163 Fut: Future<Output = Response>,
164 {
165 let fut = decode(self.io.clone(), &self.opts);
167
168 let (req, mut body) = if let Some(timeout_duration) = self.opts.headers_timeout {
169 match fut
170 .or(async {
171 Timer::after(timeout_duration).await;
172 Ok(None)
173 })
174 .await
175 {
176 Ok(Some(r)) => r,
177 Ok(None) => return Ok(ConnectionStatus::Close), Err(e) => return Err(e),
179 }
180 } else {
181 match fut.await? {
182 Some(r) => r,
183 None => return Ok(ConnectionStatus::Close), }
185 };
186
187 let req_version = req.version();
188
189
190 let connection_header =
191 req.header(CONNECTION)
192 .map(|connection| connection.as_str())
193 .unwrap_or("")
194 .to_string();
195
196 let res_header_keepalive = {
197 let c = connection_header.to_ascii_lowercase();
198 if c == "keep-alive" || c.contains("keep-alive,") {
199 "keep-alive"
200 } else if c == "close" || c.contains("close") {
201 "close"
202 } else {
203 match req_version {
204 Some(Version::Http1_1) => "keep-alive",
205 Some(Version::Http1_0) => "close",
206 _ => { unreachable!(); }
207 }
208 }
209 };
210
211 let close_connection =
212 match res_header_keepalive {
213 "close" => true,
214 _ => false
215 };
216 let connection_header_is_upgrade = connection_header.split(',').any(|s| s.trim().eq_ignore_ascii_case("upgrade"));
227 let has_upgrade_header = req.header(UPGRADE).is_some();
228 let upgrade_requested = has_upgrade_header && connection_header_is_upgrade;
229
230 let method = req.method();
231
232 let mut response = (self.endpoint)(req).await;
234 response.set_version(req_version);
235
236 let upgrade_provided =
244 response.status() == StatusCode::SwitchingProtocols && response.has_upgrade();
245
246 if ! upgrade_provided {
247 if let Some(hc) = response.header(CONNECTION) {
248 let tmp: Vec<_> = hc.iter().collect();
249 if tmp.len() != 1 {
250 return Err(crate::Error::UnexpectedHeader("should not have multi 'Connection' header"));
253 }
254
255 let mut new_hc = hc.last().as_str().to_string();
256 if new_hc.is_empty() {
257 new_hc = res_header_keepalive.to_string();
258 } else {
259 new_hc.push(',');
260 new_hc.push(' ');
261 new_hc.extend(res_header_keepalive.chars());
262 }
263 response.insert_header(CONNECTION, new_hc);
264 } else {
265 response.insert_header(CONNECTION, res_header_keepalive);
266 }
267 }
268
269 let upgrade_sender = if upgrade_requested && upgrade_provided {
270 Some(response.send_upgrade())
271 } else {
272 None
273 };
274
275 let mut encoder = Encoder::new(response, method);
276
277 let bytes_written = io::copy(&mut encoder, &mut self.io).await?;
278 log::trace!("wrote {} response bytes", bytes_written);
279
280 let body_bytes_discarded = io::copy(&mut body, &mut io::sink()).await?;
281 log::trace!(
282 "discarded {} unread request body bytes",
283 body_bytes_discarded
284 );
285
286 if let Some(upgrade_sender) = upgrade_sender {
287 upgrade_sender.send(Connection::new(self.io.clone())).await;
288 Ok(ConnectionStatus::Close)
289 } else if close_connection {
290 Ok(ConnectionStatus::Close)
291 } else {
292 Ok(ConnectionStatus::KeepAlive)
293 }
294 }
295}