blocking_http_server/
lib.rs1#![doc = include_str!("../README.md")]
2
3use std::ops::Deref;
4use std::ops::DerefMut;
5
6use bytes::BytesMut;
7pub use http::*;
8use io::Read;
9use io::Write;
10use std::io;
11use std::net::SocketAddr;
12use std::net::TcpListener;
13use std::net::TcpStream;
14use std::net::ToSocketAddrs;
15
16pub struct Server {
17 listener: TcpListener,
18 req_size_limit: usize,
19
20 buf: BytesMut,
21}
22
23impl Server {
24 const DEFAULT_REQ_SIZE_LIMIT: usize = 4096;
25 const HEADER_COUNT_LIMIT: usize = 64;
26
27 pub fn bind(addr: impl ToSocketAddrs) -> io::Result<Self> {
28 let listener = TcpListener::bind(addr)?;
29 Ok(Self {
30 listener,
31 req_size_limit: Self::DEFAULT_REQ_SIZE_LIMIT,
32 buf: BytesMut::with_capacity(Self::DEFAULT_REQ_SIZE_LIMIT),
33 })
34 }
35
36 pub fn set_request_size_limit(&mut self, limit: usize) {
37 self.buf = BytesMut::with_capacity(limit);
38 self.req_size_limit = limit;
39 }
40
41 pub fn incoming(&mut self) -> Incoming {
42 Incoming { server: self }
43 }
44
45 pub fn recv(&mut self) -> io::Result<HttpRequest> {
46 self.incoming().next().unwrap()
47 }
48}
49
50#[derive(Debug)]
51pub struct HttpRequest {
52 pub peer_addr: SocketAddr,
53
54 header_buf: BytesMut,
55 request: Request<BytesMut>,
56 stream: TcpStream,
57}
58
59impl HttpRequest {
60 pub fn header_bytes(&self) -> &[u8] {
61 &self.header_buf
62 }
63
64 pub fn respond<T: AsRef<[u8]>>(
65 &self,
66 response: impl std::borrow::Borrow<Response<T>>,
67 ) -> io::Result<()> {
68 let version = self.version();
69 let mut stream = &self.stream;
70
71 let response: &Response<T> = response.borrow();
72 let status = response.status();
74 let headers = response.headers();
75 let body = response.body().as_ref();
76
77 write!(
78 stream,
79 "{:?} {} {}\r\n",
80 version,
81 status.as_str(),
82 status.canonical_reason().unwrap_or("Unknown"),
83 )?;
84
85 if !headers.contains_key(header::CONNECTION) {
92 write!(stream, "connection: close\r\n")?;
93 }
94 if !headers.contains_key(header::CONTENT_LENGTH) {
95 write!(stream, "content-length: {}\r\n", body.len())?;
96 }
97 for (k, v) in headers.iter() {
98 write!(
99 stream,
100 "{}: {}\r\n",
101 k.as_str(),
102 v.to_str().unwrap_or("unknown")
103 )?;
104 }
105
106 stream.write_all(b"\r\n")?;
107 stream.write_all(body)?;
108 stream.flush()?;
109
110 Ok(())
111 }
112}
113
114impl Deref for HttpRequest {
115 type Target = Request<BytesMut>;
116 fn deref(&self) -> &Self::Target {
117 &self.request
118 }
119}
120
121impl DerefMut for HttpRequest {
122 fn deref_mut(&mut self) -> &mut Self::Target {
123 &mut self.request
124 }
125}
126
127pub struct Incoming<'a> {
128 server: &'a mut Server,
129}
130
131impl Iterator for Incoming<'_> {
132 type Item = io::Result<HttpRequest>;
133 fn next(&mut self) -> Option<Self::Item> {
134 let (mut stream, addr) = match self.server.listener.accept() {
135 Ok((stream, addr)) => {
136 let _ = stream.set_nodelay(true);
137 (stream, addr)
138 }
139 Err(e) => return Some(Err(e)),
140 };
141
142 {
143 let buf = &mut self.server.buf;
145 buf.clear();
146 if self.server.req_size_limit > buf.capacity() {
147 buf.reserve(self.server.req_size_limit - buf.capacity());
149 }
150 }
151
152 let mut header_buf = self.server.buf.split_off(0);
153
154 loop {
155 let mut tmp = header_buf.split_off(header_buf.len());
156 unsafe { tmp.set_len(tmp.capacity()) };
157
158 match stream.read(&mut tmp) {
159 Ok(0) => {
160 tmp.clear();
161 header_buf.unsplit(tmp);
162 return Some(Err(io::Error::other("uncomplete request header")));
163 }
164 Ok(n) => {
165 unsafe { tmp.set_len(n) };
166 header_buf.unsplit(tmp);
167
168 let mut headers = [httparse::EMPTY_HEADER; Server::HEADER_COUNT_LIMIT];
169 let mut req = httparse::Request::new(&mut headers);
170
171 let offset = match req.parse(&header_buf) {
172 Ok(httparse::Status::Complete(offset)) => offset,
173 Ok(httparse::Status::Partial) => continue,
174 Err(e) => {
175 return Some(Err(io::Error::other(e)));
177 }
178 };
179
180 let version = match req.version {
181 Some(0) => Version::HTTP_10,
182 Some(1) => Version::HTTP_11,
183 Some(_) => Version::HTTP_11,
184 None => Version::HTTP_11,
185 };
186
187 let mut uri = Uri::builder()
188 .scheme(uri::Scheme::HTTP)
189 .path_and_query(req.path.unwrap_or("/"));
190
191 let mut builder = Request::builder()
192 .method(req.method.unwrap_or("GET"))
193 .version(version);
194
195 let mut content_len = 0;
196 for header in req.headers {
197 builder = builder.header(header.name, header.value);
198 if header.name.eq_ignore_ascii_case("host") {
199 let host = header.value;
200 uri = uri.authority(host);
201 }
202
203 if header.name.eq_ignore_ascii_case(header::CONTENT_LENGTH.as_str()) {
204 content_len = std::str::from_utf8(header.value).unwrap_or("0").parse::<usize>().unwrap_or(0);
205 if content_len > header_buf.capacity() - offset {
206 return Some(Err(io::Error::other("body too large")));
207 }
208 }
209 }
210
211 let mut body_buf = header_buf.split_off(offset);
212 if body_buf.capacity() < content_len {
213 return Some(Err(io::Error::other("body too large")));
214 }
215
216 if body_buf.len() >= content_len {
217 body_buf.truncate(content_len);
218 } else {
219 let size = content_len - body_buf.len();
220
221 let mut tmp = body_buf.split_off(body_buf.len());
222 unsafe { tmp.set_len(size) };
223
224 if let Err(e) = stream.read_exact(&mut tmp) {
225 return Some(Err(e));
226 }
227 body_buf.unsplit(tmp);
228 }
229
230 builder = builder.uri(uri.build().unwrap_or_default());
231
232 let request = match builder.body(body_buf) {
233 Ok(req) => req,
234 Err(e) => return Some(Err(io::Error::other(e))),
235 };
236
237 return Some(Ok(HttpRequest {
238 peer_addr: addr,
239 header_buf,
240 request,
241 stream,
242 }));
243 }
244 Err(e) => {
245 if e.kind() == io::ErrorKind::Interrupted
246 || e.kind() == io::ErrorKind::WouldBlock
247 {
248 tmp.clear();
249 header_buf.unsplit(tmp);
250 continue;
251 }
252 return Some(Err(e));
254 }
255 };
256 }
257 }
258}
259
260