1use std::{
2 collections::HashMap,
3 io,
4};
5use log::{info, warn};
6
7use futures::{
8 prelude::*,
9 AsyncBufRead,
10 AsyncWrite,
11};
12
13pub mod websocket;
14pub mod cookies;
15
16const NEWLINE: &[u8] = b"\r\n";
17const MAX_HEADER_LENGTH: usize = 1024;
18const MAX_HEADERS: usize = 128;
19
20#[derive(Debug)]
21pub struct Request {
22 pub method: String,
23 pub path: String,
24 pub headers: HashMap<String, Vec<u8>>,
25 pub handshake_len: usize,
26}
27
28#[derive(Debug)]
29pub struct Response {
30 pub code: usize,
31 pub reason: &'static str,
32 pub headers: Vec<(String, Vec<u8>)>,
33}
34
35impl Default for Response {
36 fn default() -> Self {
37 Response{
38 code: 200,
39 reason: "OK",
40 headers: vec!(),
41 }
42 }
43}
44
45pub async fn http<S>(stream: &mut S) -> std::io::Result<Request>
52where S: AsyncBufRead + Unpin
53{
54 let mut buff = Vec::new();
55 let mut offset = 0;
56 let mut lines = 0;
57 while let Ok(count) = stream.take(MAX_HEADER_LENGTH as u64).read_until(b'\n', &mut buff).await {
58 if count == 0 {
59 break;
60 }
61 if count < 3 && (&buff[offset..offset+count] == b"\r\n" || &buff[offset..offset+count] == b"\n") {
62 offset += count;
63 break;
64 }
65 lines += 1;
66 if lines > MAX_HEADERS {
67 warn!("Request had more than {} headers; rejected", MAX_HEADERS);
68 return Err(io::ErrorKind::InvalidInput.into());
69 }
70 offset += count;
71 }
72 let mut headers = vec![httparse::EMPTY_HEADER; lines - 1];
75 let mut req = httparse::Request::new(&mut headers);
76 let res = req.parse(&buff).or(Err(io::ErrorKind::InvalidInput))?;
77 match res {
78 httparse::Status::Complete(_) => {
79 },
81 httparse::Status::Partial => {
82 return Err(io::ErrorKind::InvalidInput.into());
84 }
85 }
86 if req.version.unwrap_or(1) > 2 {
88 warn!("HTTP/1.{} request rejected; don't support that", &req.version.unwrap_or(1));
90 return Err(io::ErrorKind::InvalidInput.into());
91 }
92 let mut req_headers = HashMap::default();
94 for header in req.headers {
95 if !header.name.is_empty() {
96 req_headers.insert(String::from(header.name), Vec::from(header.value));
97 }
98 }
99 let request = Request{
101 method: String::from(req.method.unwrap_or("GET")),
102 path: String::from(req.path.unwrap_or("/")),
103 headers: req_headers,
104 handshake_len: offset,
105 };
106 info!("HTTP/1.1 {method} {path}", method=request.method, path=request.path);
107 Ok(request)
108}
109
110pub async fn respond<S>(stream: &mut S, response: Response) -> io::Result<()>
114where S: AsyncWrite + Unpin
115{
116 let buf = format!("HTTP/1.1 {code} {reason}",
117 code=format!("{}", response.code),
118 reason=response.reason,
119 );
120 stream.write_all(&buf.as_bytes()).await?;
121 for (name, value) in &response.headers {
122 stream.write_all(NEWLINE).await?;
123 stream.write_all(name.as_bytes()).await?;
124 stream.write_all(b": ").await?;
125 stream.write_all(&value).await?;
126 }
127 stream.write_all(NEWLINE).await?;
129 stream.write_all(NEWLINE).await?;
130 Ok(())
131}
132
133#[cfg(test)]
134mod tests {
135 use std::error::Error;
136 use async_std::{
137 task,
138 net::{
139 TcpListener,
140 },
141 io::{
142 BufReader,
143 BufWriter,
144 }
145 };
146 use super::*;
147
148 #[async_std::test]
149 async fn test_hello_world() -> Result<(), Box<dyn Error>> {
150 let listener = TcpListener::bind("127.0.0.1:0").await?;
151 let local_addr = listener.local_addr().unwrap();
152 let handle = task::spawn(async move {
153 let mut incoming = listener.incoming();
154 while let Some(stream) = incoming.next().await {
155 let stream = stream.unwrap();
156 let mut reader = BufReader::new(stream.clone());
157 let mut writer = BufWriter::new(stream);
158 let req = http(&mut reader).await.unwrap();
159 assert_eq!(req.method, "GET");
160 assert_eq!(req.path, "/");
161 let mut headers = vec!();
163 headers.push(("Content-Type".into(), Vec::from("text/html; charset=utf-8".as_bytes())));
164 respond(&mut writer, Response{
165 code: 200,
166 reason: "OK",
167 headers,
168 }).await.unwrap();
169 writer.write_all(b"<h1>Hello world!</h1>").await.unwrap();
170 writer.flush().await.unwrap();
171 break;
172 }
173 });
174 let path = format!("http://localhost:{}", local_addr.port());
176 let res = ureq::get(&path).call();
177 handle.await;
178 assert_eq!(res.status(), 200);
179 assert_eq!(res.header("Content-Type").unwrap(), "text/html; charset=utf-8");
180 assert_eq!(res.into_string().unwrap(), "<h1>Hello world!</h1>");
181 Ok(())
182 }
183
184 }