use smol::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use std::{borrow::Cow, collections::HashMap, io, net::SocketAddr};
use crate::{Method, Url};
#[cfg(feature = "json")]
use crate::ResponseLike;
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "json", derive(serde::Serialize))]
pub struct Request {
pub ip: SocketAddr,
pub url: String,
pub method: Method,
pub body: Vec<u8>,
pub headers: HashMap<String, String>,
}
impl Request {
pub fn new(bytes: &[u8], ip: SocketAddr) -> Option<Self> {
let mut words = bytes.split(|b| *b == b' ');
let method = Method::from(words.next()?);
let url_bytes = words.next()?;
let url = String::from_utf8(url_bytes.into()).ok()?;
words.next()?;
let mut headers = HashMap::with_capacity(12);
for line in bytes.split(|b| *b == b'\n').skip(1) {
if line == b"\r" || line.is_empty() {
break;
}
let (key, value) = Self::parse_header(line)?;
headers.insert(key, value);
}
let body = if let Some(position) = bytes.windows(4).position(|window| window == b"\r\n\r\n")
{
bytes[position + 4..].into()
} else {
vec![]
};
Some(Self {
ip,
url,
method,
body,
headers,
})
}
fn parse_header(line: &[u8]) -> Option<(String, String)> {
let pos = line.iter().position(|&byte| byte == b':')?;
let (key, rest) = line.split_at(pos);
let value = &rest[1..rest.len() - 1];
Some((
String::from_utf8_lossy(key).trim().to_string(),
String::from_utf8_lossy(value).trim().to_string(),
))
}
pub fn get_header(&self, key: &str) -> Option<&str> {
self.headers.get(key).map(|s| s.as_str())
}
pub fn get_header_or(&self, key: &str, default: &'static str) -> &str {
self.get_header(key).unwrap_or(default)
}
pub fn has_header(&self, key: &str) -> bool {
self.headers.contains_key(key)
}
pub fn set_header<T: ToString, K: ToString>(&mut self, k: T, v: K) {
self.headers.insert(k.to_string(), v.to_string());
}
pub fn len(&self) -> usize {
self.body.len()
}
pub fn is_empty(&self) -> bool {
self.body.is_empty()
}
pub fn text(&self) -> Cow<'_, str> {
String::from_utf8_lossy(&self.body)
}
#[cfg(feature = "json")]
pub fn json<T>(&self) -> serde_json::Result<T>
where
T: for<'a> serde::de::Deserialize<'a>,
{
serde_json::from_slice(&self.body)
}
#[cfg(feature = "json")]
pub fn expect_json<T>(&self) -> Result<T, crate::Response>
where
T: for<'a> serde::de::Deserialize<'a>,
{
self.json().map_err(|e| e.to_response())
}
pub fn parse_url(&self) -> Url<'_> {
self.url.as_str().into()
}
pub fn pretty_ip(&self) -> String {
crate::util::format_addr(self.ip)
}
pub fn keep_alive(&self) -> bool {
self.headers
.get("connection")
.map(|s| s.to_ascii_lowercase())
!= Some("false".to_string())
|| self
.headers
.get("connection")
.map(|s| s.to_ascii_lowercase())
== Some("keep-alive".to_string())
}
pub async fn read_from<T: AsyncRead + Unpin + AsyncWrite>(
mut stream: &mut T,
addr: SocketAddr,
buffer_size: usize,
) -> io::Result<Request> {
let mut buffer: Vec<u8> = Vec::with_capacity(buffer_size);
let mut chunk = vec![0u8; 1024];
loop {
let n = stream.read(&mut chunk).await?;
if n == 0 {
crate::response!(bad_request).send_to(&mut stream).await?;
return Err(io::Error::from(io::ErrorKind::InvalidInput));
}
buffer.extend_from_slice(&chunk[..n]);
if buffer.windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
if buffer.len() > buffer_size {
return Err(io::Error::from(io::ErrorKind::InvalidInput));
}
}
match Request::new(&buffer, addr) {
Some(req) => Ok(req),
None => Err(io::Error::from(io::ErrorKind::InvalidInput)),
}
}
}