use crate::{
errors::ErrorKind,
http::{
request::{Parser, Request},
response::Response,
types::Version,
},
limits::{ConnLimits, Http09Limits, ReqLimits, RespLimits, ServerLimits},
server::server_impl::{AllLimits, Handler},
Handled,
};
use std::{io, net::SocketAddr, sync::Arc, time::Instant};
use tokio::{io::AsyncWriteExt, net::TcpStream, time::sleep};
pub(crate) struct HttpConnection<H: Handler<S>, S: ConnectionData> {
handler: Arc<H>,
connection_data: S,
connection: Connection,
pub(crate) parser: Parser,
pub(crate) request: Request,
pub(crate) response: Response,
pub(crate) server_limits: ServerLimits,
pub(crate) conn_limits: ConnLimits,
pub(crate) http_09_limits: Option<Http09Limits>,
pub(crate) req_limits: ReqLimits,
pub(crate) resp_limits: RespLimits,
}
impl<H: Handler<S>, S: ConnectionData> HttpConnection<H, S> {
#[inline]
pub(crate) fn new(handler: Arc<H>, limits: AllLimits) -> Self {
Self {
handler,
connection_data: S::new(),
connection: Connection::new(),
parser: Parser::new(&limits.3),
request: Request::new(&limits.3),
response: Response::new(&limits.4),
server_limits: limits.0,
conn_limits: limits.1,
http_09_limits: limits.2,
req_limits: limits.3,
resp_limits: limits.4,
}
}
#[inline]
fn reset_request_response(&mut self) {
self.parser.reset();
self.request.reset();
self.response.reset(&self.resp_limits);
}
}
impl<H: Handler<S>, S: ConnectionData> HttpConnection<H, S> {
#[inline]
pub(crate) async fn run(&mut self, stream: &mut TcpStream) -> Result<(), io::Error> {
match self.impl_run(stream).await {
Ok(()) => Ok(()),
Err(ErrorKind::Io(e)) => Err(e.0),
Err(error) => {
self.conn_limits
.send_error(
stream,
error,
self.request.version(),
self.server_limits.json_errors,
)
.await
}
}
}
#[inline]
pub(crate) async fn impl_run(&mut self, stream: &mut TcpStream) -> Result<(), ErrorKind> {
self.connection.reset();
self.connection_data.reset();
while !self.is_expired()? {
self.reset_request_response();
if self
.parser
.fill_buffer(stream, self.conn_limits.socket_read_timeout)
.await?
== 0
{
break;
}
self.response.version = self.parse()?;
self.handler
.handle(&mut self.connection_data, &self.request, &mut self.response)
.await;
self.conn_limits
.write_bytes(stream, self.response.buffer())
.await?;
if !self.response.keep_alive {
break;
}
self.connection.request_count += 1;
}
Ok(())
}
}
impl ConnLimits {
#[inline]
pub(crate) async fn send_error(
&self,
stream: &mut TcpStream,
error: ErrorKind,
version: Version,
json_errors: bool,
) -> Result<(), io::Error> {
self.write_bytes(stream, error.as_http(version, json_errors))
.await
}
#[inline]
pub(crate) async fn write_bytes(
&self,
stream: &mut TcpStream,
response: &[u8],
) -> Result<(), io::Error> {
tokio::select! {
biased;
result = stream.write_all(response) => result,
_ = sleep(self.socket_write_timeout) => {
Err(io::Error::new(io::ErrorKind::TimedOut, "write timeout"))
},
}
}
}
macro_rules! is_expired {
($self:expr, $limits:expr) => {
Ok(!$self.response.keep_alive
|| $self.connection.request_count >= $limits.max_requests_per_connection
|| $self.connection.created.elapsed() > $limits.connection_lifetime)
};
}
impl<H: Handler<S>, S: ConnectionData> HttpConnection<H, S> {
#[inline]
fn is_expired(&self) -> Result<bool, ErrorKind> {
match (self.response.version, &self.http_09_limits) {
(Version::Http09, Some(limits)) => is_expired!(self, limits),
(Version::Http09, None) => Err(ErrorKind::UnsupportedVersion),
_ => is_expired!(self, self.conn_limits),
}
}
}
#[derive(Debug)]
pub(crate) struct Connection {
created: Instant,
request_count: usize,
}
impl Connection {
#[inline]
pub(crate) fn new() -> Self {
Self {
created: Instant::now(),
request_count: 0,
}
}
#[inline]
pub(crate) fn reset(&mut self) {
self.created = Instant::now();
self.request_count = 0;
}
}
pub trait ConnectionData: Sync + Send + 'static {
fn new() -> Self;
fn reset(&mut self);
}
impl ConnectionData for () {
#[inline(always)]
fn new() -> Self {}
#[inline(always)]
fn reset(&mut self) {}
}
pub trait ConnectionFilter: Sync + Send + 'static {
fn filter(
&self,
client_addr: SocketAddr,
server_addr: SocketAddr,
error_response: &mut Response,
) -> Result<(), Handled>;
}
impl ConnectionFilter for () {
fn filter(&self, _: SocketAddr, _: SocketAddr, _: &mut Response) -> Result<(), Handled> {
Ok(())
}
}
#[cfg(test)]
mod def_handler {
use super::*;
use crate::{Handled, StatusCode};
pub(crate) struct DefHandler;
impl Handler<()> for DefHandler {
async fn handle(&self, _: &mut (), _: &Request, r: &mut Response) -> Handled {
r.status(StatusCode::Ok).body("test")
}
}
impl HttpConnection<DefHandler, ()> {
#[inline]
pub(crate) fn from_req<V: AsRef<[u8]>>(value: V) -> Self {
let req_limits = ReqLimits::default().precalculate();
let resp_limits = RespLimits::default();
Self {
handler: Arc::new(DefHandler),
connection_data: (),
connection: Connection::new(),
parser: Parser::from(&req_limits, value),
request: Request::new(&req_limits),
response: Response::new(&resp_limits),
server_limits: ServerLimits::default(),
conn_limits: ConnLimits::default(),
http_09_limits: None,
req_limits,
resp_limits,
}
}
}
}