blocking_http_server/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::ops::Deref;
4use std::ops::DerefMut;
5
6use bytes::BytesMut;
7pub use http::*;
8use io::Read;
9use io::Write;
10use std::io;
11use std::net::SocketAddr;
12use std::net::TcpListener;
13use std::net::TcpStream;
14use std::net::ToSocketAddrs;
15
16pub struct Server {
17    listener: TcpListener,
18    req_size_limit: usize,
19
20    buf: BytesMut,
21}
22
23impl Server {
24    const DEFAULT_REQ_SIZE_LIMIT: usize = 4096;
25    const HEADER_COUNT_LIMIT: usize = 64;
26
27    pub fn bind(addr: impl ToSocketAddrs) -> io::Result<Self> {
28        let listener = TcpListener::bind(addr)?;
29        Ok(Self {
30            listener,
31            req_size_limit: Self::DEFAULT_REQ_SIZE_LIMIT,
32            buf: BytesMut::with_capacity(Self::DEFAULT_REQ_SIZE_LIMIT),
33        })
34    }
35
36    pub fn set_request_size_limit(&mut self, limit: usize) {
37        self.buf = BytesMut::with_capacity(limit);
38        self.req_size_limit = limit;
39    }
40
41    pub fn incoming(&mut self) -> Incoming {
42        Incoming { server: self }
43    }
44
45    pub fn recv(&mut self) -> io::Result<HttpRequest> {
46        self.incoming().next().unwrap()
47    }
48}
49
50#[derive(Debug)]
51pub struct HttpRequest {
52    pub peer_addr: SocketAddr,
53
54    header_buf: BytesMut,
55    request: Request<BytesMut>,
56    stream: TcpStream,
57}
58
59impl HttpRequest {
60    pub fn header_bytes(&self) -> &[u8] {
61        &self.header_buf
62    }
63
64    pub fn respond<T: AsRef<[u8]>>(
65        &self,
66        response: impl std::borrow::Borrow<Response<T>>,
67    ) -> io::Result<()> {
68        let version = self.version();
69        let mut stream = &self.stream;
70
71        let response: &Response<T> = response.borrow();
72        // let version = response.version();
73        let status = response.status();
74        let headers = response.headers();
75        let body = response.body().as_ref();
76
77        write!(
78            stream,
79            "{:?} {} {}\r\n",
80            version,
81            status.as_str(),
82            status.canonical_reason().unwrap_or("Unknown"),
83        )?;
84
85        // println!("write_response: {}", text);
86
87        // if !headers.contains_key(header::DATE) {
88        //     let date = time::strftime("%a, %d %b %Y %H:%M:%S GMT", &time::now_utc()).unwrap();
89        //     write!(stream, "date: {}\r\n", date)?;
90        // }
91        if !headers.contains_key(header::CONNECTION) {
92            write!(stream, "connection: close\r\n")?;
93        }
94        if !headers.contains_key(header::CONTENT_LENGTH) {
95            write!(stream, "content-length: {}\r\n", body.len())?;
96        }
97        for (k, v) in headers.iter() {
98            write!(
99                stream,
100                "{}: {}\r\n",
101                k.as_str(),
102                v.to_str().unwrap_or("unknown")
103            )?;
104        }
105
106        stream.write_all(b"\r\n")?;
107        stream.write_all(body)?;
108        stream.flush()?;
109
110        Ok(())
111    }
112}
113
114impl Deref for HttpRequest {
115    type Target = Request<BytesMut>;
116    fn deref(&self) -> &Self::Target {
117        &self.request
118    }
119}
120
121impl DerefMut for HttpRequest {
122    fn deref_mut(&mut self) -> &mut Self::Target {
123        &mut self.request
124    }
125}
126
127pub struct Incoming<'a> {
128    server: &'a mut Server,
129}
130
131impl Iterator for Incoming<'_> {
132    type Item = io::Result<HttpRequest>;
133    fn next(&mut self) -> Option<Self::Item> {
134        let (mut stream, addr) = match self.server.listener.accept() {
135            Ok((stream, addr)) => {
136                let _ = stream.set_nodelay(true);
137                (stream, addr)
138            }
139            Err(e) => return Some(Err(e)),
140        };
141
142        {
143            // prepare the buffer
144            let buf = &mut self.server.buf;
145            buf.clear();
146            if self.server.req_size_limit > buf.capacity() {
147                // This will not cause reallocation, because the `split_off`ed header_buf and body_buf are dropped at this point.
148                buf.reserve(self.server.req_size_limit - buf.capacity());
149            }
150        }
151
152        let mut header_buf = self.server.buf.split_off(0);
153
154        loop {
155            let mut tmp = header_buf.split_off(header_buf.len());
156            unsafe { tmp.set_len(tmp.capacity()) };
157
158            match stream.read(&mut tmp) {
159                Ok(0) => {
160                    tmp.clear();
161                    header_buf.unsplit(tmp);
162                    return Some(Err(io::Error::other("uncomplete request header")));
163                }
164                Ok(n) => {
165                    unsafe { tmp.set_len(n) };
166                    header_buf.unsplit(tmp);
167
168                    let mut headers = [httparse::EMPTY_HEADER; Server::HEADER_COUNT_LIMIT];
169                    let mut req = httparse::Request::new(&mut headers);
170
171                    let offset = match req.parse(&header_buf) {
172                        Ok(httparse::Status::Complete(offset)) => offset,
173                        Ok(httparse::Status::Partial) => continue,
174                        Err(e) => {
175                            // eprintln!("error: {e}");
176                            return Some(Err(io::Error::other(e)));
177                        }
178                    };
179
180                    let version = match req.version {
181                        Some(0) => Version::HTTP_10,
182                        Some(1) => Version::HTTP_11,
183                        Some(_) => Version::HTTP_11,
184                        None => Version::HTTP_11,
185                    };
186
187                    let mut uri = Uri::builder()
188                        .scheme(uri::Scheme::HTTP)
189                        .path_and_query(req.path.unwrap_or("/"));
190
191                    let mut builder = Request::builder()
192                        .method(req.method.unwrap_or("GET"))
193                        .version(version);
194
195                    let mut content_len = 0;
196                    for header in req.headers {
197                        builder = builder.header(header.name, header.value);
198                        if header.name.eq_ignore_ascii_case("host") {
199                            let host = header.value;
200                            uri = uri.authority(host);
201                        }
202
203                        if header.name.eq_ignore_ascii_case(header::CONTENT_LENGTH.as_str()) {
204                            content_len = std::str::from_utf8(header.value).unwrap_or("0").parse::<usize>().unwrap_or(0);
205                            if content_len > header_buf.capacity() - offset {
206                                return Some(Err(io::Error::other("body too large")));
207                            }
208                        }
209                    }
210
211                    let mut body_buf = header_buf.split_off(offset);
212                    if body_buf.capacity() < content_len {
213                        return Some(Err(io::Error::other("body too large")));
214                    }
215
216                    if body_buf.len() >= content_len {
217                        body_buf.truncate(content_len);
218                    } else {
219                        let size = content_len - body_buf.len();
220    
221                        let mut tmp = body_buf.split_off(body_buf.len());
222                        unsafe { tmp.set_len(size) };
223    
224                        if let Err(e) = stream.read_exact(&mut tmp) {
225                            return Some(Err(e));
226                        }
227                        body_buf.unsplit(tmp);
228                    }
229
230                    builder = builder.uri(uri.build().unwrap_or_default());
231
232                    let request = match builder.body(body_buf) {
233                        Ok(req) => req,
234                        Err(e) => return Some(Err(io::Error::other(e))),
235                    };
236
237                    return Some(Ok(HttpRequest {
238                        peer_addr: addr,
239                        header_buf,
240                        request,
241                        stream,
242                    }));
243                }
244                Err(e) => {
245                    if e.kind() == io::ErrorKind::Interrupted
246                        || e.kind() == io::ErrorKind::WouldBlock
247                    {
248                        tmp.clear();
249                        header_buf.unsplit(tmp);
250                        continue;
251                    }
252                    // eprintln!("error: {e}");
253                    return Some(Err(e));
254                }
255            };
256        }
257    }
258}
259
260