blocking_http_server/
lib.rs1#![doc = include_str!("../README.md")]
2
3use std::ops::Deref;
4use std::ops::DerefMut;
5
6use bytes::BytesMut;
7pub use http::*;
8use io::Read;
9use io::Write;
10use std::io;
11use std::net::SocketAddr;
12use std::net::TcpListener;
13use std::net::TcpStream;
14use std::net::ToSocketAddrs;
15
16pub struct Server {
17 listener: TcpListener,
18 req_size_limit: usize,
19
20 buf: BytesMut,
21}
22
23impl Server {
24 const DEFAULT_REQ_SIZE_LIMIT: usize = 4096;
25 const HEADER_COUNT_LIMIT: usize = 64;
26
27 pub fn bind(addr: impl ToSocketAddrs) -> io::Result<Self> {
28 let listener = TcpListener::bind(addr)?;
29 Ok(Self {
30 listener,
31 req_size_limit: Self::DEFAULT_REQ_SIZE_LIMIT,
32 buf: BytesMut::with_capacity(Self::DEFAULT_REQ_SIZE_LIMIT),
33 })
34 }
35
36 pub fn set_request_size_limit(&mut self, limit: usize) {
37 self.buf = BytesMut::with_capacity(limit);
38 self.req_size_limit = limit;
39 }
40
41 pub fn incoming(&mut self) -> Incoming {
42 Incoming { server: self }
43 }
44
45 pub fn recv(&mut self) -> io::Result<HttpRequest> {
46 self.incoming().next().unwrap()
47 }
48}
49
50#[derive(Debug)]
51pub struct HttpRequest {
52 pub peer_addr: SocketAddr,
53
54 header_buf: BytesMut,
55 request: Request<BytesMut>,
56 stream: TcpStream,
57}
58
59impl HttpRequest {
60 pub fn header_bytes(&self) -> &[u8] {
61 &self.header_buf
62 }
63
64 pub fn respond<T: AsRef<[u8]>>(
65 &self,
66 response: impl std::borrow::Borrow<Response<T>>,
67 ) -> io::Result<()> {
68 let version = self.version();
69 let mut stream = &self.stream;
70
71 let response: &Response<T> = response.borrow();
72 let status = response.status();
74 let headers = response.headers();
75 let body = response.body().as_ref();
76
77 write!(
78 stream,
79 "{:?} {} {}\r\n",
80 version,
81 status.as_str(),
82 status.canonical_reason().unwrap_or("Unknown"),
83 )?;
84
85 if !headers.contains_key(header::CONNECTION) {
92 write!(stream, "connection: close\r\n")?;
93 }
94 if !headers.contains_key(header::CONTENT_LENGTH) {
95 write!(stream, "content-length: {}\r\n", body.len())?;
96 }
97 for (k, v) in headers.iter() {
98 write!(
99 stream,
100 "{}: {}\r\n",
101 k.as_str(),
102 v.to_str().unwrap_or("unknown")
103 )?;
104 }
105
106 stream.write_all(b"\r\n")?;
107 stream.write_all(body)?;
108 stream.flush()?;
109
110 Ok(())
111 }
112}
113
114impl Deref for HttpRequest {
115 type Target = Request<BytesMut>;
116 fn deref(&self) -> &Self::Target {
117 &self.request
118 }
119}
120
121impl DerefMut for HttpRequest {
122 fn deref_mut(&mut self) -> &mut Self::Target {
123 &mut self.request
124 }
125}
126
127pub struct Incoming<'a> {
128 server: &'a mut Server,
129}
130
131impl Iterator for Incoming<'_> {
132 type Item = io::Result<HttpRequest>;
133 fn next(&mut self) -> Option<Self::Item> {
134 let (mut stream, addr) = match self.server.listener.accept() {
135 Ok((stream, addr)) => {
136 let _ = stream.set_nodelay(true);
137 (stream, addr)
138 }
139 Err(e) => return Some(Err(e)),
140 };
141
142 {
143 let buf = &mut self.server.buf;
145 buf.clear();
146 if self.server.req_size_limit > buf.capacity() {
147 buf.reserve(self.server.req_size_limit - buf.capacity());
149 }
150 }
151
152 let mut header_buf = self.server.buf.split_off(0);
153
154 loop {
155 let mut tmp = header_buf.split_off(header_buf.len());
156 unsafe { tmp.set_len(tmp.capacity()) };
157
158 match stream.read(&mut tmp) {
159 Ok(0) => {
160 tmp.clear();
161 header_buf.unsplit(tmp);
162 return Some(Err(io::Error::new(
163 io::ErrorKind::Other,
164 "uncomplete request header",
165 )));
166 }
167 Ok(n) => {
168 unsafe { tmp.set_len(n) };
169 header_buf.unsplit(tmp);
170
171 let mut headers = [httparse::EMPTY_HEADER; Server::HEADER_COUNT_LIMIT];
172 let mut req = httparse::Request::new(&mut headers);
173
174 let offset = match req.parse(&header_buf) {
175 Ok(httparse::Status::Complete(offset)) => offset,
176 Ok(httparse::Status::Partial) => continue,
177 Err(e) => {
178 return Some(Err(io::Error::new(io::ErrorKind::Other, e)));
180 }
181 };
182
183 let version = match req.version {
184 Some(0) => Version::HTTP_10,
185 Some(1) => Version::HTTP_11,
186 Some(_) => Version::HTTP_11,
187 None => Version::HTTP_11,
188 };
189
190 let mut uri = Uri::builder()
191 .scheme(uri::Scheme::HTTP)
192 .path_and_query(req.path.unwrap_or("/"));
193
194 let mut builder = Request::builder()
195 .method(req.method.unwrap_or("GET"))
196 .version(version);
197
198 let mut content_len = 0;
199 for header in req.headers {
200 builder = builder.header(header.name, header.value);
201 if header.name.eq_ignore_ascii_case("host") {
202 let host = header.value;
203 uri = uri.authority(host);
204 }
205
206 if header.name.eq_ignore_ascii_case(header::CONTENT_LENGTH.as_str()) {
207 content_len = std::str::from_utf8(header.value).unwrap_or("0").parse::<usize>().unwrap_or(0);
208 if content_len > header_buf.capacity() - offset {
209 return Some(Err(io::Error::new(
210 io::ErrorKind::Other,
211 "body too large",
212 )));
213 }
214 }
215 }
216
217 let mut body_buf = header_buf.split_off(offset);
218 if body_buf.capacity() < content_len {
219 return Some(Err(io::Error::new(io::ErrorKind::Other, "body too large")));
220 }
221
222 if body_buf.len() >= content_len {
223 body_buf.truncate(content_len);
224 } else {
225 let size = content_len - body_buf.len();
226
227 let mut tmp = body_buf.split_off(body_buf.len());
228 unsafe { tmp.set_len(size) };
229
230 if let Err(e) = stream.read_exact(&mut tmp) {
231 return Some(Err(e));
232 }
233 body_buf.unsplit(tmp);
234 }
235
236 builder = builder.uri(uri.build().unwrap_or_default());
237
238 let request = match builder.body(body_buf) {
239 Ok(req) => req,
240 Err(e) => return Some(Err(io::Error::new(io::ErrorKind::Other, e))),
241 };
242
243 return Some(Ok(HttpRequest {
244 peer_addr: addr,
245 header_buf,
246 request,
247 stream,
248 }));
249 }
250 Err(e) => {
251 if e.kind() == io::ErrorKind::Interrupted
252 || e.kind() == io::ErrorKind::WouldBlock
253 {
254 tmp.clear();
255 header_buf.unsplit(tmp);
256 continue;
257 }
258 return Some(Err(e));
260 }
261 };
262 }
263 }
264}
265
266