use crate::http::address::Address;
use crate::http::cookie::Cookie;
use crate::http::headers::{HeaderType, Headers};
use crate::http::method::Method;
use std::error::Error;
use std::net::SocketAddr;
#[cfg(not(feature = "tokio"))]
use crate::stream::Stream;
#[cfg(not(feature = "tokio"))]
use std::io::{BufRead, BufReader, ErrorKind, Read};
#[cfg(not(feature = "tokio"))]
use std::time::Duration;
#[cfg(feature = "tokio")]
use tokio::io::{AsyncBufReadExt, AsyncReadExt, BufReader};
#[derive(Clone, Debug)]
pub struct Request {
pub method: Method,
pub uri: String,
pub query: String,
pub version: String,
pub headers: Headers,
pub content: Option<Vec<u8>>,
pub address: Address,
}
#[derive(Debug, PartialEq, Eq)]
pub enum RequestError {
Request,
Stream,
Disconnected,
Timeout,
}
trait OptionToRequestResult<T> {
fn to_error(self, e: RequestError) -> Result<T, RequestError>;
}
impl<T> OptionToRequestResult<T> for Option<T> {
fn to_error(self, e: RequestError) -> Result<T, RequestError> {
self.map_or(Err(e), Ok)
}
}
impl std::fmt::Display for RequestError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "RequestError")
}
}
impl Error for RequestError {}
impl Request {
#[cfg(not(feature = "tokio"))]
pub fn from_stream<T>(stream: &mut T, address: SocketAddr) -> Result<Self, RequestError>
where
T: Read,
{
let mut first_buf: [u8; 1] = [0; 1];
stream
.read_exact(&mut first_buf)
.map_err(|_| RequestError::Disconnected)?;
Self::from_stream_inner(stream, address, first_buf[0])
}
#[cfg(feature = "tokio")]
pub async fn from_stream<T>(stream: &mut T, address: SocketAddr) -> Result<Self, RequestError>
where
T: AsyncReadExt + Unpin,
{
let mut first_buf: [u8; 1] = [0; 1];
stream
.read_exact(&mut first_buf)
.await
.map_err(|_| RequestError::Disconnected)?;
Self::from_stream_inner(stream, address, first_buf[0]).await
}
#[cfg(not(feature = "tokio"))]
pub fn from_stream_with_timeout(
stream: &mut Stream,
address: SocketAddr,
timeout: Duration,
) -> Result<Self, RequestError> {
stream
.set_timeout(Some(timeout))
.map_err(|_| RequestError::Stream)?;
let mut first_buf: [u8; 1] = [0; 1];
stream
.read_exact(&mut first_buf)
.map_err(|e| match e.kind() {
ErrorKind::TimedOut => RequestError::Timeout,
ErrorKind::WouldBlock => RequestError::Timeout,
_ => RequestError::Disconnected,
})?;
stream.set_timeout(None).map_err(|_| RequestError::Stream)?;
Self::from_stream_inner(stream, address, first_buf[0])
}
pub fn get_cookies(&self) -> Vec<Cookie> {
self.headers
.get(HeaderType::Cookie)
.map(|cookies| {
cookies
.split(';')
.filter_map(|cookie| {
let (k, v) = cookie.split_once('=')?;
Some(Cookie::new(k.trim(), v.trim()))
})
.collect()
})
.unwrap_or_default()
}
pub fn get_cookie(&self, name: impl AsRef<str>) -> Option<Cookie> {
self.get_cookies()
.into_iter()
.find(|cookie| cookie.name == name.as_ref())
}
#[cfg(not(feature = "tokio"))]
fn from_stream_inner<T>(
stream: &mut T,
address: SocketAddr,
first_byte: u8,
) -> Result<Self, RequestError>
where
T: Read,
{
let mut reader = BufReader::new(stream);
let mut start_line_buf: Vec<u8> = Vec::with_capacity(256);
reader
.read_until(0xA, &mut start_line_buf)
.map_err(|_| RequestError::Stream)?;
start_line_buf.insert(0, first_byte);
let start_line_string =
std::str::from_utf8(&start_line_buf).map_err(|_| RequestError::Request)?;
let mut start_line = start_line_string.split(' ');
let method = Method::from_name(start_line.next().to_error(RequestError::Request)?)?;
let mut uri_iter = start_line
.next()
.to_error(RequestError::Request)?
.splitn(2, '?');
let version = start_line
.next()
.to_error(RequestError::Request)?
.strip_suffix("\r\n")
.unwrap_or("")
.to_string();
safe_assert(!version.is_empty())?;
let uri = uri_iter.next().unwrap().to_string();
let query = uri_iter.next().unwrap_or("").to_string();
let mut headers = Headers::new();
loop {
let mut line_buf: Vec<u8> = Vec::with_capacity(256);
reader
.read_until(0xA, &mut line_buf)
.map_err(|_| RequestError::Stream)?;
let line = std::str::from_utf8(&line_buf).map_err(|_| RequestError::Request)?;
if line == "\r\n" {
break;
} else {
safe_assert(line.len() >= 2)?;
let line_without_crlf = &line[0..line.len() - 2];
let mut line_parts = line_without_crlf.splitn(2, ':');
headers.add(
HeaderType::from(line_parts.next().to_error(RequestError::Request)?),
line_parts
.next()
.to_error(RequestError::Request)?
.trim_start(),
);
}
}
let address =
Address::from_headers(&headers, address).map_err(|_| RequestError::Request)?;
if let Some(content_length) = headers.get(&HeaderType::ContentLength) {
let content_length: usize =
content_length.parse().map_err(|_| RequestError::Request)?;
let mut content_buf: Vec<u8> = vec![0u8; content_length];
reader
.read_exact(&mut content_buf)
.map_err(|_| RequestError::Stream)?;
Ok(Self {
method,
uri,
query,
version,
headers,
content: Some(content_buf),
address,
})
} else {
Ok(Self {
method,
uri,
query,
version,
headers,
content: None,
address,
})
}
}
#[cfg(feature = "tokio")]
async fn from_stream_inner<T>(
stream: &mut T,
address: SocketAddr,
first_byte: u8,
) -> Result<Self, RequestError>
where
T: AsyncReadExt + Unpin,
{
let mut reader = BufReader::new(stream);
let mut start_line_buf: Vec<u8> = Vec::with_capacity(256);
reader
.read_until(0xA, &mut start_line_buf)
.await
.map_err(|_| RequestError::Stream)?;
start_line_buf.insert(0, first_byte);
let start_line_string =
std::str::from_utf8(&start_line_buf).map_err(|_| RequestError::Request)?;
let mut start_line = start_line_string.split(' ');
let method = Method::from_name(start_line.next().to_error(RequestError::Request)?)?;
let mut uri_iter = start_line
.next()
.to_error(RequestError::Request)?
.splitn(2, '?');
let version = start_line
.next()
.to_error(RequestError::Request)?
.strip_suffix("\r\n")
.unwrap_or("")
.to_string();
safe_assert(!version.is_empty())?;
let uri = uri_iter.next().unwrap().to_string();
let query = uri_iter.next().unwrap_or("").to_string();
let mut headers = Headers::new();
loop {
let mut line_buf: Vec<u8> = Vec::with_capacity(256);
reader
.read_until(0xA, &mut line_buf)
.await
.map_err(|_| RequestError::Stream)?;
let line = std::str::from_utf8(&line_buf).map_err(|_| RequestError::Request)?;
if line == "\r\n" {
break;
} else {
safe_assert(line.len() >= 2)?;
let line_without_crlf = &line[0..line.len() - 2];
let mut line_parts = line_without_crlf.splitn(2, ':');
headers.add(
HeaderType::from(line_parts.next().to_error(RequestError::Request)?),
line_parts
.next()
.to_error(RequestError::Request)?
.trim_start(),
);
}
}
let address =
Address::from_headers(&headers, address).map_err(|_| RequestError::Request)?;
if let Some(content_length) = headers.get(&HeaderType::ContentLength) {
let content_length: usize =
content_length.parse().map_err(|_| RequestError::Request)?;
let mut content_buf: Vec<u8> = vec![0u8; content_length];
reader
.read_exact(&mut content_buf)
.await
.map_err(|_| RequestError::Stream)?;
Ok(Self {
method,
uri,
query,
version,
headers,
content: Some(content_buf),
address,
})
} else {
Ok(Self {
method,
uri,
query,
version,
headers,
content: None,
address,
})
}
}
}
fn safe_assert(condition: bool) -> Result<(), RequestError> {
match condition {
true => Ok(()),
false => Err(RequestError::Request),
}
}
impl From<Request> for Vec<u8> {
fn from(req: Request) -> Self {
let start_line = if req.query.is_empty() {
format!("{} {} {}", req.method, req.uri, req.version)
} else {
format!("{} {}?{} {}", req.method, req.uri, req.query, req.version)
};
let headers = req
.headers
.iter()
.map(|h| format!("{}: {}", h.name.to_string(), h.value))
.collect::<Vec<String>>()
.join("\r\n");
let mut bytes: Vec<u8> = Vec::new();
bytes.extend(start_line.as_bytes());
bytes.extend(b"\r\n");
bytes.extend(headers.as_bytes());
bytes.extend(b"\r\n\r\n");
if let Some(content) = req.content {
bytes.extend(content);
}
bytes
}
}