use std::{
collections::HashMap,
io,
};
use log::{info, warn};
use futures::{
prelude::*,
AsyncBufRead,
AsyncWrite,
};
pub mod websocket;
pub mod cookies;
const NEWLINE: &[u8] = b"\r\n";
const MAX_HEADER_LENGTH: usize = 1024;
const MAX_HEADERS: usize = 128;
#[derive(Debug)]
pub struct Request {
pub method: String,
pub path: String,
pub headers: HashMap<String, Vec<u8>>,
pub handshake_len: usize,
}
#[derive(Debug)]
pub struct Response {
pub code: usize,
pub reason: &'static str,
pub headers: Vec<(String, Vec<u8>)>,
}
impl Default for Response {
fn default() -> Self {
Response{
code: 200,
reason: "OK",
headers: vec!(),
}
}
}
pub async fn http<S>(stream: &mut S) -> std::io::Result<Request>
where S: AsyncBufRead + Unpin
{
let mut buff = Vec::new();
let mut offset = 0;
let mut lines = 0;
while let Ok(count) = stream.take(MAX_HEADER_LENGTH as u64).read_until(b'\n', &mut buff).await {
if count == 0 {
break;
}
if count < 3 && (&buff[offset..offset+count] == b"\r\n" || &buff[offset..offset+count] == b"\n") {
offset += count;
break;
}
lines += 1;
if lines > MAX_HEADERS {
warn!("Request had more than {} headers; rejected", MAX_HEADERS);
return Err(io::ErrorKind::InvalidInput.into());
}
offset += count;
}
let mut headers = vec![httparse::EMPTY_HEADER; lines - 1];
let mut req = httparse::Request::new(&mut headers);
let res = req.parse(&buff).or(Err(io::ErrorKind::InvalidInput))?;
match res {
httparse::Status::Complete(_) => {
},
httparse::Status::Partial => {
return Err(io::ErrorKind::InvalidInput.into());
}
}
if req.version.unwrap_or(1) > 2 {
warn!("HTTP/1.{} request rejected; don't support that", &req.version.unwrap_or(1));
return Err(io::ErrorKind::InvalidInput.into());
}
let mut req_headers = HashMap::default();
for header in req.headers {
if !header.name.is_empty() {
req_headers.insert(String::from(header.name), Vec::from(header.value));
}
}
let request = Request{
method: String::from(req.method.unwrap_or("GET")),
path: String::from(req.path.unwrap_or("/")),
headers: req_headers,
handshake_len: offset,
};
info!("HTTP/1.1 {method} {path}", method=request.method, path=request.path);
Ok(request)
}
pub async fn respond<S>(stream: &mut S, response: Response) -> io::Result<()>
where S: AsyncWrite + Unpin
{
let buf = format!("HTTP/1.1 {code} {reason}",
code=format!("{}", response.code),
reason=response.reason,
);
stream.write_all(&buf.as_bytes()).await?;
for (name, value) in &response.headers {
stream.write_all(NEWLINE).await?;
stream.write_all(name.as_bytes()).await?;
stream.write_all(b": ").await?;
stream.write_all(&value).await?;
}
stream.write_all(NEWLINE).await?;
stream.write_all(NEWLINE).await?;
Ok(())
}
#[cfg(test)]
mod tests {
use std::error::Error;
use async_std::{
task,
net::{
TcpListener,
},
io::{
BufReader,
BufWriter,
}
};
use super::*;
#[async_std::test]
async fn test_hello_world() -> Result<(), Box<dyn Error>> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let local_addr = listener.local_addr().unwrap();
let handle = task::spawn(async move {
let mut incoming = listener.incoming();
while let Some(stream) = incoming.next().await {
let stream = stream.unwrap();
let mut reader = BufReader::new(stream.clone());
let mut writer = BufWriter::new(stream);
let req = http(&mut reader).await.unwrap();
assert_eq!(req.method, "GET");
assert_eq!(req.path, "/");
let mut headers = vec!();
headers.push(("Content-Type".into(), Vec::from("text/html; charset=utf-8".as_bytes())));
respond(&mut writer, Response{
code: 200,
reason: "OK",
headers,
}).await.unwrap();
writer.write_all(b"<h1>Hello world!</h1>").await.unwrap();
writer.flush().await.unwrap();
break;
}
});
let path = format!("http://localhost:{}", local_addr.port());
let res = ureq::get(&path).call();
handle.await;
assert_eq!(res.status(), 200);
assert_eq!(res.header("Content-Type").unwrap(), "text/html; charset=utf-8");
assert_eq!(res.into_string().unwrap(), "<h1>Hello world!</h1>");
Ok(())
}
}