blocking_http_server/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::ops::Deref;
4use std::ops::DerefMut;
5
6use bytes::BytesMut;
7pub use http::*;
8use io::Read;
9use io::Write;
10use std::io;
11use std::net::SocketAddr;
12use std::net::TcpListener;
13use std::net::TcpStream;
14use std::net::ToSocketAddrs;
15
16pub struct Server {
17    listener: TcpListener,
18    req_size_limit: usize,
19
20    buf: BytesMut,
21}
22
23impl Server {
24    const DEFAULT_REQ_SIZE_LIMIT: usize = 4096;
25    const HEADER_COUNT_LIMIT: usize = 64;
26
27    pub fn bind(addr: impl ToSocketAddrs) -> io::Result<Self> {
28        let listener = TcpListener::bind(addr)?;
29        Ok(Self {
30            listener,
31            req_size_limit: Self::DEFAULT_REQ_SIZE_LIMIT,
32            buf: BytesMut::with_capacity(Self::DEFAULT_REQ_SIZE_LIMIT),
33        })
34    }
35
36    pub fn set_request_size_limit(&mut self, limit: usize) {
37        self.buf = BytesMut::with_capacity(limit);
38        self.req_size_limit = limit;
39    }
40
41    pub fn incoming(&mut self) -> Incoming {
42        Incoming { server: self }
43    }
44
45    pub fn recv(&mut self) -> io::Result<HttpRequest> {
46        self.incoming().next().unwrap()
47    }
48}
49
50#[derive(Debug)]
51pub struct HttpRequest {
52    pub peer_addr: SocketAddr,
53
54    header_buf: BytesMut,
55    request: Request<BytesMut>,
56    stream: TcpStream,
57}
58
59impl HttpRequest {
60    pub fn header_bytes(&self) -> &[u8] {
61        &self.header_buf
62    }
63
64    pub fn respond<T: AsRef<[u8]>>(
65        &self,
66        response: impl std::borrow::Borrow<Response<T>>,
67    ) -> io::Result<()> {
68        let version = self.version();
69        let mut stream = &self.stream;
70
71        let response: &Response<T> = response.borrow();
72        // let version = response.version();
73        let status = response.status();
74        let headers = response.headers();
75        let body = response.body().as_ref();
76
77        write!(
78            stream,
79            "{:?} {} {}\r\n",
80            version,
81            status.as_str(),
82            status.canonical_reason().unwrap_or("Unknown"),
83        )?;
84
85        // println!("write_response: {}", text);
86
87        // if !headers.contains_key(header::DATE) {
88        //     let date = time::strftime("%a, %d %b %Y %H:%M:%S GMT", &time::now_utc()).unwrap();
89        //     write!(stream, "date: {}\r\n", date)?;
90        // }
91        if !headers.contains_key(header::CONNECTION) {
92            write!(stream, "connection: close\r\n")?;
93        }
94        if !headers.contains_key(header::CONTENT_LENGTH) {
95            write!(stream, "content-length: {}\r\n", body.len())?;
96        }
97        for (k, v) in headers.iter() {
98            write!(
99                stream,
100                "{}: {}\r\n",
101                k.as_str(),
102                v.to_str().unwrap_or("unknown")
103            )?;
104        }
105
106        stream.write_all(b"\r\n")?;
107        stream.write_all(body)?;
108        stream.flush()?;
109
110        Ok(())
111    }
112}
113
114impl Deref for HttpRequest {
115    type Target = Request<BytesMut>;
116    fn deref(&self) -> &Self::Target {
117        &self.request
118    }
119}
120
121impl DerefMut for HttpRequest {
122    fn deref_mut(&mut self) -> &mut Self::Target {
123        &mut self.request
124    }
125}
126
127pub struct Incoming<'a> {
128    server: &'a mut Server,
129}
130
131impl Iterator for Incoming<'_> {
132    type Item = io::Result<HttpRequest>;
133    fn next(&mut self) -> Option<Self::Item> {
134        let (mut stream, addr) = match self.server.listener.accept() {
135            Ok((stream, addr)) => {
136                let _ = stream.set_nodelay(true);
137                (stream, addr)
138            }
139            Err(e) => return Some(Err(e)),
140        };
141
142        {
143            // prepare the buffer
144            let buf = &mut self.server.buf;
145            buf.clear();
146            if self.server.req_size_limit > buf.capacity() {
147                // This will not cause reallocation, because the `split_off`ed header_buf and body_buf are dropped at this point.
148                buf.reserve(self.server.req_size_limit - buf.capacity());
149            }
150        }
151
152        let mut header_buf = self.server.buf.split_off(0);
153
154        loop {
155            let mut tmp = header_buf.split_off(header_buf.len());
156            unsafe { tmp.set_len(tmp.capacity()) };
157
158            match stream.read(&mut tmp) {
159                Ok(0) => {
160                    tmp.clear();
161                    header_buf.unsplit(tmp);
162                    return Some(Err(io::Error::new(
163                        io::ErrorKind::Other,
164                        "uncomplete request header",
165                    )));
166                }
167                Ok(n) => {
168                    unsafe { tmp.set_len(n) };
169                    header_buf.unsplit(tmp);
170
171                    let mut headers = [httparse::EMPTY_HEADER; Server::HEADER_COUNT_LIMIT];
172                    let mut req = httparse::Request::new(&mut headers);
173
174                    let offset = match req.parse(&header_buf) {
175                        Ok(httparse::Status::Complete(offset)) => offset,
176                        Ok(httparse::Status::Partial) => continue,
177                        Err(e) => {
178                            // eprintln!("error: {e}");
179                            return Some(Err(io::Error::new(io::ErrorKind::Other, e)));
180                        }
181                    };
182
183                    let version = match req.version {
184                        Some(0) => Version::HTTP_10,
185                        Some(1) => Version::HTTP_11,
186                        Some(_) => Version::HTTP_11,
187                        None => Version::HTTP_11,
188                    };
189
190                    let mut uri = Uri::builder()
191                        .scheme(uri::Scheme::HTTP)
192                        .path_and_query(req.path.unwrap_or("/"));
193
194                    let mut builder = Request::builder()
195                        .method(req.method.unwrap_or("GET"))
196                        .version(version);
197
198                    let mut content_len = 0;
199                    for header in req.headers {
200                        builder = builder.header(header.name, header.value);
201                        if header.name.eq_ignore_ascii_case("host") {
202                            let host = header.value;
203                            uri = uri.authority(host);
204                        }
205
206                        if header.name.eq_ignore_ascii_case(header::CONTENT_LENGTH.as_str()) {
207                            content_len = std::str::from_utf8(header.value).unwrap_or("0").parse::<usize>().unwrap_or(0);
208                            if content_len > header_buf.capacity() - offset {
209                                return Some(Err(io::Error::new(
210                                    io::ErrorKind::Other,
211                                    "body too large",
212                                )));
213                            }
214                        }
215                    }
216
217                    let mut body_buf = header_buf.split_off(offset);
218                    if body_buf.capacity() < content_len {
219                        return Some(Err(io::Error::new(io::ErrorKind::Other, "body too large")));
220                    }
221
222                    if body_buf.len() >= content_len {
223                        body_buf.truncate(content_len);
224                    } else {
225                        let size = content_len - body_buf.len();
226    
227                        let mut tmp = body_buf.split_off(body_buf.len());
228                        unsafe { tmp.set_len(size) };
229    
230                        if let Err(e) = stream.read_exact(&mut tmp) {
231                            return Some(Err(e));
232                        }
233                        body_buf.unsplit(tmp);
234                    }
235
236                    builder = builder.uri(uri.build().unwrap_or_default());
237
238                    let request = match builder.body(body_buf) {
239                        Ok(req) => req,
240                        Err(e) => return Some(Err(io::Error::new(io::ErrorKind::Other, e))),
241                    };
242
243                    return Some(Ok(HttpRequest {
244                        peer_addr: addr,
245                        header_buf,
246                        request,
247                        stream,
248                    }));
249                }
250                Err(e) => {
251                    if e.kind() == io::ErrorKind::Interrupted
252                        || e.kind() == io::ErrorKind::WouldBlock
253                    {
254                        tmp.clear();
255                        header_buf.unsplit(tmp);
256                        continue;
257                    }
258                    // eprintln!("error: {e}");
259                    return Some(Err(e));
260                }
261            };
262        }
263    }
264}
265
266