use crate::{
errors::ErrorKind,
http::{
request::Request,
response::{Handled, Response},
},
limits::{ConnLimits, Http09Limits, ReqLimits, RespLimits, ServerLimits, WaitStrategy},
server::connection::{ConnectionData, HttpConnection},
ConnectionFilter, Version,
};
use crossbeam::queue::SegQueue;
use std::{
future::Future,
marker::{PhantomData, Send, Sync},
net::SocketAddr,
sync::Arc,
};
use tokio::{
net::{TcpListener, TcpStream},
task::yield_now,
time::sleep as tokio_sleep,
};
pub trait Handler<S = ()>
where
Self: Sync + Send + 'static,
S: ConnectionData,
{
fn handle(
&self,
connection_data: &mut S,
request: &Request,
response: &mut Response,
) -> impl Future<Output = Handled> + Send;
}
pub struct Server {
listener: TcpListener,
stream_queue: TcpQueue,
error_queue: TcpQueue,
server_limits: ServerLimits,
}
impl Server {
#[inline]
pub fn builder<H, S>() -> ServerBuilder<H, S, ()>
where
H: Handler<S>,
S: ConnectionData,
{
ServerBuilder {
listener: None,
handler: None,
connection_filter: Arc::new(()),
_marker: PhantomData,
server_limits: None,
request_limits: None,
response_limits: None,
connection_limits: None,
http_09_limits: None,
}
}
#[inline]
pub async fn launch(self) {
loop {
let Ok(value) = self.listener.accept().await else {
continue;
};
match self.stream_queue.len() < self.server_limits.max_pending_connections {
true => self.stream_queue.push(value),
false => self.error_queue.push(value),
}
}
}
#[inline]
async fn get_stream(queue: &TcpQueue, wait: &WaitStrategy) -> (TcpStream, SocketAddr) {
loop {
if let Some(value) = queue.pop() {
return value;
}
match wait {
WaitStrategy::Yield => yield_now().await,
WaitStrategy::Sleep(time) => tokio_sleep(*time).await,
}
}
}
}
pub struct ServerBuilder<H, S = (), F = ()>
where
H: Handler<S>,
S: ConnectionData,
F: ConnectionFilter,
{
listener: Option<TcpListener>,
handler: Option<Arc<H>>,
connection_filter: Arc<F>,
_marker: PhantomData<S>,
server_limits: Option<ServerLimits>,
request_limits: Option<ReqLimits>,
response_limits: Option<RespLimits>,
connection_limits: Option<ConnLimits>,
http_09_limits: Option<Http09Limits>,
}
impl<H, S, F> ServerBuilder<H, S, F>
where
H: Handler<S>,
S: ConnectionData,
F: ConnectionFilter,
{
#[inline(always)]
pub fn listener(mut self, listener: TcpListener) -> Self {
self.listener = Some(listener);
self
}
#[inline(always)]
pub fn handler(mut self, handler: H) -> Self {
self.handler = Some(Arc::new(handler));
self
}
#[inline(always)]
pub fn conn_filter<NewF>(self, filter: NewF) -> ServerBuilder<H, S, NewF>
where
NewF: ConnectionFilter,
{
ServerBuilder {
listener: self.listener,
handler: self.handler,
connection_filter: Arc::new(filter),
_marker: self._marker,
server_limits: self.server_limits,
request_limits: self.request_limits,
response_limits: self.response_limits,
connection_limits: self.connection_limits,
http_09_limits: self.http_09_limits,
}
}
#[inline(always)]
pub fn server_limits(mut self, limits: ServerLimits) -> Self {
self.server_limits = Some(limits);
self
}
#[inline(always)]
pub fn connection_limits(mut self, limits: ConnLimits) -> Self {
self.connection_limits = Some(limits);
self
}
#[inline(always)]
pub fn http_09_limits(mut self, limits: Http09Limits) -> Self {
self.http_09_limits = Some(limits);
self
}
#[inline(always)]
pub fn request_limits(mut self, limits: ReqLimits) -> Self {
self.request_limits = Some(limits);
self
}
#[inline(always)]
pub fn response_limits(mut self, limits: RespLimits) -> Self {
self.response_limits = Some(limits);
self
}
#[inline]
#[track_caller]
pub fn build(self) -> Server {
let (listener, handler, filter, limits) = self.get_all_parts();
let stream_queue = Arc::new(SegQueue::new());
let error_queue = Arc::new(SegQueue::new());
for _ in 0..limits.0.max_connections {
Self::spawn_worker(&stream_queue, &limits, &filter, &handler);
}
if limits.0.count_503_handlers != 0 {
for _ in 0..limits.0.count_503_handlers {
Self::spawn_alarmist(&error_queue, &limits);
}
} else {
Self::spawn_quiet_alarmist(&error_queue, &limits);
}
Server {
listener,
stream_queue,
error_queue,
server_limits: limits.0,
}
}
#[inline]
fn spawn_worker(queue: &TcpQueue, limits: &AllLimits, filter: &Arc<F>, handler: &Arc<H>) {
let queue = queue.clone();
let filter = filter.clone();
let mut conn = HttpConnection::new(handler.clone(), limits.clone());
tokio::spawn(async move {
loop {
let (mut stream, c_addr) =
Server::get_stream(&queue, &conn.server_limits.wait_strategy).await;
let Ok(s_addr) = stream.local_addr() else {
continue;
};
if filter.filter(c_addr, s_addr, &mut conn.response).is_err()
|| filter
.filter_async(c_addr, s_addr, &mut conn.response)
.await
.is_err()
{
let _ = conn
.conn_limits
.write_bytes(&mut stream, conn.response.buffer())
.await;
conn.response.reset(&conn.resp_limits);
continue;
}
let _ = conn.run(&mut stream, c_addr, s_addr).await;
}
});
}
#[inline]
fn spawn_alarmist(queue: &TcpQueue, limits: &AllLimits) {
let queue = queue.clone();
let (server_limits, conn_limits, ..) = limits.clone();
tokio::spawn(async move {
loop {
let (mut stream, _) =
Server::get_stream(&queue, &server_limits.wait_strategy).await;
let _ = conn_limits
.send_error(
&mut stream,
ErrorKind::ServiceUnavailable,
Version::Http11,
server_limits.json_errors,
)
.await;
}
});
}
#[inline]
fn spawn_quiet_alarmist(queue: &TcpQueue, limits: &AllLimits) {
let queue = queue.clone();
let (server_limits, ..) = limits.clone();
tokio::spawn(async move {
loop {
let (stream, _) = Server::get_stream(&queue, &server_limits.wait_strategy).await;
drop(stream);
}
});
}
#[inline]
#[track_caller]
fn get_all_parts(self) -> (TcpListener, Arc<H>, Arc<F>, AllLimits) {
(
self.listener
.expect("The `listener` method must be called to create"),
self.handler
.expect("The `handler` method must be called to create"),
self.connection_filter,
(
self.server_limits.clone().unwrap_or_default(),
self.connection_limits.clone().unwrap_or_default(),
self.http_09_limits.clone(),
self.request_limits
.clone()
.unwrap_or_default()
.precalculate(),
self.response_limits.clone().unwrap_or_default(),
),
)
}
}
type TcpQueue = Arc<SegQueue<(TcpStream, SocketAddr)>>;
pub(crate) type AllLimits = (
ServerLimits,
ConnLimits,
Option<Http09Limits>,
ReqLimits,
RespLimits,
);