use crate::*;
impl std::error::Error for RequestError {}
impl Default for RequestError {
#[inline(always)]
fn default() -> Self {
RequestError::Unknown(HttpStatus::InternalServerError)
}
}
impl From<std::io::Error> for RequestError {
#[inline(always)]
fn from(error: std::io::Error) -> Self {
let kind: ErrorKind = error.kind();
if kind == ErrorKind::ConnectionReset || kind == ErrorKind::ConnectionAborted {
return RequestError::ClientDisconnected(HttpStatus::BadRequest);
}
RequestError::ReadConnection(HttpStatus::BadRequest)
}
}
impl From<Elapsed> for RequestError {
#[inline(always)]
fn from(_: Elapsed) -> Self {
RequestError::ReadTimeout(HttpStatus::RequestTimeout)
}
}
impl From<ParseIntError> for RequestError {
#[inline(always)]
fn from(_: ParseIntError) -> Self {
RequestError::InvalidContentLength(HttpStatus::BadRequest)
}
}
impl From<ResponseError> for RequestError {
#[inline(always)]
fn from(_: ResponseError) -> Self {
RequestError::WriteTimeout(HttpStatus::InternalServerError)
}
}
impl RequestError {
pub fn get_http_status(&self) -> HttpStatus {
match self {
Self::HttpRead(status) => *status,
Self::GetTcpStream(status) => *status,
Self::GetTlsStream(status) => *status,
Self::ReadConnection(status) => *status,
Self::RequestAborted(status) => *status,
Self::TlsStreamConnect(status) => *status,
Self::NeedOpenRedirect(status) => *status,
Self::MaxRedirectTimes(status) => *status,
Self::MethodsNotSupport(status) => *status,
Self::RedirectInvalidUrl(status) => *status,
Self::ClientDisconnected(status) => *status,
Self::RedirectUrlDeadLoop(status) => *status,
Self::ClientClosedConnection(status) => *status,
Self::IncompleteWebSocketFrame(status) => *status,
Self::RequestTooLong(status) => *status,
Self::PathTooLong(status) => *status,
Self::QueryTooLong(status) => *status,
Self::HeaderLineTooLong(status) => *status,
Self::TooManyHeaders(status) => *status,
Self::HeaderKeyTooLong(status) => *status,
Self::HeaderValueTooLong(status) => *status,
Self::ContentLengthTooLarge(status) => *status,
Self::InvalidContentLength(status) => *status,
Self::InvalidUrlScheme(status) => *status,
Self::InvalidUrlHost(status) => *status,
Self::InvalidUrlPort(status) => *status,
Self::InvalidUrlPath(status) => *status,
Self::InvalidUrlQuery(status) => *status,
Self::InvalidUrlFragment(status) => *status,
Self::ReadTimeout(status) => *status,
Self::WriteTimeout(status) => *status,
Self::TcpConnectionFailed(status) => *status,
Self::TlsHandshakeFailed(status) => *status,
Self::TlsCertificateInvalid(status) => *status,
Self::WebSocketFrameTooLarge(status) => *status,
Self::WebSocketOpcodeUnsupported(status) => *status,
Self::WebSocketMaskMissing(status) => *status,
Self::WebSocketPayloadCorrupted(status) => *status,
Self::WebSocketInvalidUtf8(status) => *status,
Self::WebSocketInvalidCloseCode(status) => *status,
Self::WebSocketInvalidExtension(status) => *status,
Self::HttpRequestPartsInsufficient(status) => *status,
Self::TcpStreamConnect(status) => *status,
Self::TlsConnectorBuild(status) => *status,
Self::InvalidUrl(status) => *status,
Self::ConfigReadError(status) => *status,
Self::TcpStreamConnectString(status) => *status,
Self::TlsConnectorBuildString(status) => *status,
Self::Request(_) => HttpStatus::BadRequest,
Self::NotFoundStream(status) => *status,
Self::Unknown(status) => *status,
}
}
pub fn get_http_status_code(&self) -> ResponseStatusCode {
self.get_http_status().code()
}
}
impl Default for RequestConfig {
#[inline(always)]
fn default() -> Self {
Self {
buffer_size: DEFAULT_BUFFER_SIZE,
max_path_size: DEFAULT_MAX_PATH_SIZE,
max_header_count: DEFAULT_MAX_HEADER_COUNT,
max_header_key_size: DEFAULT_MAX_HEADER_KEY_SIZE,
max_header_value_size: DEFAULT_MAX_HEADER_VALUE_SIZE,
max_body_size: DEFAULT_MAX_BODY_SIZE,
read_timeout_ms: DEFAULT_READ_TIMEOUT_MS,
}
}
}
impl RequestConfig {
pub fn from_json<C>(json: C) -> Result<RequestConfig, serde_json::Error>
where
C: AsRef<str>,
{
serde_json::from_str(json.as_ref())
}
#[inline(always)]
pub fn low_security() -> Self {
Self {
buffer_size: DEFAULT_LOW_SECURITY_BUFFER_SIZE,
max_path_size: DEFAULT_LOW_SECURITY_MAX_PATH_SIZE,
max_header_count: DEFAULT_LOW_SECURITY_MAX_HEADER_COUNT,
max_header_key_size: DEFAULT_LOW_SECURITY_MAX_HEADER_KEY_SIZE,
max_header_value_size: DEFAULT_LOW_SECURITY_MAX_HEADER_VALUE_SIZE,
max_body_size: DEFAULT_LOW_SECURITY_MAX_BODY_SIZE,
read_timeout_ms: DEFAULT_LOW_SECURITY_READ_TIMEOUT_MS,
}
}
#[inline(always)]
pub fn high_security() -> Self {
Self {
buffer_size: DEFAULT_HIGH_SECURITY_BUFFER_SIZE,
max_path_size: DEFAULT_HIGH_SECURITY_MAX_PATH_SIZE,
max_header_count: DEFAULT_HIGH_SECURITY_MAX_HEADER_COUNT,
max_header_key_size: DEFAULT_HIGH_SECURITY_MAX_HEADER_KEY_SIZE,
max_header_value_size: DEFAULT_HIGH_SECURITY_MAX_HEADER_VALUE_SIZE,
max_body_size: DEFAULT_HIGH_SECURITY_MAX_BODY_SIZE,
read_timeout_ms: DEFAULT_HIGH_SECURITY_READ_TIMEOUT_MS,
}
}
}
impl Default for Request {
#[inline(always)]
fn default() -> Self {
Self {
method: Method::default(),
host: String::new(),
version: HttpVersion::default(),
path: String::new(),
querys: hash_map_xx_hash3_64(),
headers: hash_map_xx_hash3_64(),
body: Vec::new(),
}
}
}
impl Request {
#[inline(always)]
pub(crate) fn get_http_first_line(
line: &str,
) -> Result<(RequestMethod, &str, RequestVersion), RequestError> {
let mut parts: SplitWhitespace<'_> = line.split_whitespace();
let method_str: &str = parts
.next()
.ok_or(RequestError::HttpRequestPartsInsufficient(
HttpStatus::BadRequest,
))?;
let full_path: &str = parts
.next()
.ok_or(RequestError::HttpRequestPartsInsufficient(
HttpStatus::BadRequest,
))?;
let version_str: &str = parts
.next()
.ok_or(RequestError::HttpRequestPartsInsufficient(
HttpStatus::BadRequest,
))?;
let method: RequestMethod = method_str
.parse::<RequestMethod>()
.unwrap_or_else(|_| Method::Unknown(method_str.to_string()));
let version: RequestVersion = version_str
.parse::<RequestVersion>()
.unwrap_or_else(|_| RequestVersion::Unknown(version_str.to_string()));
Ok((method, full_path, version))
}
#[inline(always)]
pub(crate) fn check_http_path_size(path: &str, max_size: usize) -> Result<(), RequestError> {
if path.len() > max_size && max_size != DEFAULT_LOW_SECURITY_MAX_PATH_SIZE {
return Err(RequestError::PathTooLong(HttpStatus::URITooLong));
}
Ok(())
}
#[inline(always)]
pub(crate) fn get_http_query(
path: &str,
query_index: Option<usize>,
hash_index: Option<usize>,
) -> &str {
query_index.map_or(EMPTY_STR, |i: usize| {
let temp: &str = &path[i + 1..];
match hash_index {
None => temp,
Some(hash_idx) if hash_idx <= i => temp,
Some(hash_idx) => &temp[..hash_idx - i - 1],
}
})
}
#[inline(always)]
pub(crate) fn get_http_path(
path: &str,
query_index: Option<usize>,
hash_index: Option<usize>,
) -> RequestPath {
match query_index.or(hash_index) {
Some(i) => path[..i].to_owned(),
None => path.to_owned(),
}
}
#[inline(always)]
pub(crate) fn get_http_querys(query: &str) -> RequestQuerys {
let estimated_capacity: usize = query.matches(AND).count() + 1;
let mut query_map: RequestQuerys = HashMapXxHash3_64::with_capacity_and_hasher(
estimated_capacity,
BuildHasherDefault::default(),
);
for pair in query.split(AND) {
if let Some((key, value)) = pair.split_once(EQUAL) {
if !key.is_empty() {
query_map.insert(key.to_string(), value.to_string());
}
} else if !pair.is_empty() {
query_map.insert(pair.to_string(), String::new());
}
}
query_map
}
#[inline(always)]
pub(crate) fn check_http_header_count(
count: usize,
max_count: usize,
) -> Result<(), RequestError> {
if count > max_count && max_count != DEFAULT_LOW_SECURITY_MAX_HEADER_COUNT {
return Err(RequestError::TooManyHeaders(
HttpStatus::RequestHeaderFieldsTooLarge,
));
}
Ok(())
}
#[inline(always)]
pub(crate) fn check_http_header_key_size(
key: &str,
max_size: usize,
) -> Result<(), RequestError> {
if key.len() > max_size && max_size != DEFAULT_LOW_SECURITY_MAX_HEADER_KEY_SIZE {
return Err(RequestError::HeaderKeyTooLong(
HttpStatus::RequestHeaderFieldsTooLarge,
));
}
Ok(())
}
#[inline(always)]
pub(crate) fn check_http_header_value_size(
value: &str,
max_size: usize,
) -> Result<(), RequestError> {
if value.len() > max_size && max_size != DEFAULT_LOW_SECURITY_MAX_HEADER_VALUE_SIZE {
return Err(RequestError::HeaderValueTooLong(
HttpStatus::RequestHeaderFieldsTooLarge,
));
}
Ok(())
}
#[inline(always)]
pub(crate) fn check_http_body_size(
value: &str,
max_size: usize,
) -> Result<usize, RequestError> {
let length: usize = value.parse::<usize>()?;
if length > max_size && max_size != DEFAULT_LOW_SECURITY_MAX_BODY_SIZE {
return Err(RequestError::ContentLengthTooLarge(
HttpStatus::PayloadTooLarge,
));
}
Ok(length)
}
pub(crate) async fn get_http_headers<R>(
reader: &mut R,
config: &RequestConfig,
) -> Result<(RequestHeaders, RequestHost, usize), RequestError>
where
R: AsyncBufReadExt + Unpin,
{
let buffer_size: usize = config.get_buffer_size();
let max_header_count: usize = config.get_max_header_count();
let max_header_key_size: usize = config.get_max_header_key_size();
let max_header_value_size: usize = config.get_max_header_value_size();
let max_body_size: usize = config.get_max_body_size();
let mut headers: RequestHeaders =
HashMapXxHash3_64::with_capacity_and_hasher(B_16, BuildHasherDefault::default());
let mut host: RequestHost = String::new();
let mut content_size: usize = 0;
let mut header_count: usize = 0;
let mut header_line_buffer: String = String::with_capacity(buffer_size);
loop {
header_line_buffer.clear();
AsyncBufReadExt::read_line(reader, &mut header_line_buffer).await?;
let header_line: &str = header_line_buffer.trim();
if header_line.is_empty() {
break;
}
header_count += 1;
Self::check_http_header_count(header_count, max_header_count)?;
let (key_part, value_part): (&str, &str) = match header_line.split_once(COLON) {
Some(parts) => parts,
None => continue,
};
let key_trimmed: &str = key_part.trim();
if key_trimmed.is_empty() {
continue;
}
let key: String = key_trimmed.to_ascii_lowercase();
Self::check_http_header_key_size(&key, max_header_key_size)?;
let value: String = value_part.trim().to_string();
Self::check_http_header_value_size(&value, max_header_value_size)?;
match key.as_str() {
HOST => host = value.clone(),
CONTENT_LENGTH => {
content_size = Self::check_http_body_size(&value, max_body_size)?;
}
_ => {}
}
headers.entry(key).or_default().push_back(value);
}
Ok((headers, host, content_size))
}
#[inline(always)]
pub(crate) async fn get_http_body(
reader: &mut BufReader<&mut TcpStream>,
content_size: usize,
) -> Result<RequestBody, RequestError> {
let mut body: RequestBody = Vec::with_capacity(content_size);
if content_size > 0 {
body.resize(content_size, 0);
AsyncReadExt::read_exact(reader, &mut body).await?;
}
Ok(body)
}
#[inline(always)]
pub fn try_get_query<K>(&self, key: K) -> Option<RequestQuerysValue>
where
K: AsRef<str>,
{
self.querys.get(key.as_ref()).cloned()
}
#[inline(always)]
pub fn get_query<K>(&self, key: K) -> RequestQuerysValue
where
K: AsRef<str>,
{
self.try_get_query(key).unwrap()
}
#[inline(always)]
pub fn try_get_header<K>(&self, key: K) -> Option<RequestHeadersValue>
where
K: AsRef<str>,
{
self.headers.get(key.as_ref()).cloned()
}
#[inline(always)]
pub fn get_header<K>(&self, key: K) -> RequestHeadersValue
where
K: AsRef<str>,
{
self.try_get_header(key).unwrap()
}
#[inline(always)]
pub fn try_get_header_front<K>(&self, key: K) -> Option<RequestHeadersValueItem>
where
K: AsRef<str>,
{
self.headers
.get(key.as_ref())
.and_then(|values| values.front().cloned())
}
#[inline(always)]
pub fn get_header_front<K>(&self, key: K) -> RequestHeadersValueItem
where
K: AsRef<str>,
{
self.try_get_header_front(key).unwrap()
}
#[inline(always)]
pub fn try_get_header_back<K>(&self, key: K) -> Option<RequestHeadersValueItem>
where
K: AsRef<str>,
{
self.headers
.get(key.as_ref())
.and_then(|values| values.back().cloned())
}
#[inline(always)]
pub fn get_header_back<K>(&self, key: K) -> RequestHeadersValueItem
where
K: AsRef<str>,
{
self.try_get_header_back(key).unwrap()
}
#[inline(always)]
pub fn try_get_header_size<K>(&self, key: K) -> Option<usize>
where
K: AsRef<str>,
{
self.headers.get(key.as_ref()).map(|values| values.len())
}
#[inline(always)]
pub fn get_header_size<K>(&self, key: K) -> usize
where
K: AsRef<str>,
{
self.try_get_header_size(key).unwrap()
}
#[inline(always)]
pub fn get_headers_values_size(&self) -> usize {
self.headers.values().map(|values| values.len()).sum()
}
#[inline(always)]
pub fn get_headers_size(&self) -> usize {
self.headers.len()
}
#[inline(always)]
pub fn has_header<K>(&self, key: K) -> bool
where
K: AsRef<str>,
{
self.headers.contains_key(key.as_ref())
}
#[inline(always)]
pub fn has_header_value<K, V>(&self, key: K, value: V) -> bool
where
K: AsRef<str>,
V: AsRef<str>,
{
if let Some(values) = self.headers.get(key.as_ref()) {
values.iter().any(|v| v == value.as_ref())
} else {
false
}
}
#[inline(always)]
pub fn try_get_cookies(&self) -> Option<Cookies> {
self.try_get_header_back(COOKIE)
.map(|cookie_header: String| Cookie::parse(cookie_header))
}
#[inline(always)]
pub fn get_cookies(&self) -> Cookies {
self.try_get_cookies().unwrap()
}
#[inline(always)]
pub fn try_get_cookie<K>(&self, key: K) -> Option<CookieValue>
where
K: AsRef<str>,
{
self.try_get_cookies()
.and_then(|cookies: Cookies| cookies.get(key.as_ref()).cloned())
}
#[inline(always)]
pub fn get_cookie<K>(&self, key: K) -> CookieValue
where
K: AsRef<str>,
{
self.try_get_cookie(key).unwrap()
}
#[inline(always)]
pub fn get_upgrade_type(&self) -> UpgradeType {
self.try_get_header_back(UPGRADE)
.and_then(|data| data.parse::<UpgradeType>().ok())
.unwrap_or_default()
}
#[inline(always)]
pub fn get_body_string(&self) -> String {
String::from_utf8_lossy(self.get_body()).into_owned()
}
#[inline(always)]
pub fn try_get_body_json<T>(&self) -> Result<T, serde_json::Error>
where
T: DeserializeOwned,
{
serde_json::from_slice(self.get_body())
}
#[inline(always)]
pub fn get_body_json<T>(&self) -> T
where
T: DeserializeOwned,
{
self.try_get_body_json().unwrap()
}
#[inline(always)]
pub fn is_ws_upgrade_type(&self) -> bool {
self.get_upgrade_type().is_ws()
}
#[inline(always)]
pub fn is_h2c_upgrade_type(&self) -> bool {
self.get_upgrade_type().is_h2c()
}
#[inline(always)]
pub fn is_tls_upgrade_type(&self) -> bool {
self.get_upgrade_type().is_tls()
}
#[inline(always)]
pub fn is_unknown_upgrade_type(&self) -> bool {
self.get_upgrade_type().is_unknown()
}
#[inline(always)]
pub fn is_enable_keep_alive(&self) -> bool {
if let Some(connection_value) = self.try_get_header_back(CONNECTION) {
if connection_value.eq_ignore_ascii_case(KEEP_ALIVE) {
return true;
} else if connection_value.eq_ignore_ascii_case(CLOSE) {
return self.is_ws_upgrade_type();
}
}
self.get_version().is_http1_1_or_higher() || self.is_ws_upgrade_type()
}
#[inline(always)]
pub fn is_disable_keep_alive(&self) -> bool {
!self.is_enable_keep_alive()
}
}