use crate::{
errors::ErrorKind,
http::{
request::Request,
response::{Handled, Response},
},
limits::{ConnLimits, Http09Limits, ReqLimits, RespLimits, ServerLimits, WaitStrategy},
server::connection::{writer, ConnectionData, HttpConnection},
Version,
};
use crossbeam::queue::{ArrayQueue, SegQueue};
use std::{
future::Future,
marker::{PhantomData, Send, Sync},
sync::{atomic::Ordering, Arc},
};
use tokio::{
net::{TcpListener, TcpStream},
task::yield_now,
time::sleep as tokio_sleep,
};
pub trait Handler<S>
where
Self: Sync + Send + 'static,
S: ConnectionData,
{
fn handle(
&self,
connection_data: &mut S,
req: &Request,
resp: &mut Response,
) -> impl Future<Output = Handled> + Send;
}
pub struct Server<H: Handler<S>, S: ConnectionData> {
listener: TcpListener,
_marker: PhantomData<S>,
worker_pool: Arc<ArrayQueue<HttpConnection<H, S>>>,
incoming_streams: Arc<SegQueue<TcpStream>>,
server_limits: ServerLimits,
}
macro_rules! impl_get_value_queue {
($name:ident, $type:ident) => {
#[inline]
async fn $name<V>(pool: &Arc<$type<V>>, limits: &ServerLimits) -> V {
loop {
if let Some(value) = pool.pop() {
return value;
}
match &limits.wait_strategy {
WaitStrategy::Yield => yield_now().await,
WaitStrategy::Sleep(time) => tokio_sleep(*time).await,
}
}
}
};
}
impl<H: Handler<S>, S: ConnectionData> Server<H, S> {
#[inline(always)]
pub fn builder() -> ServerBuilder<H, S> {
ServerBuilder {
listener: None,
handler: None,
_marker: PhantomData,
server_limits: None,
request_limits: None,
response_limits: None,
connection_limits: None,
http_09_limits: None,
}
}
#[inline]
pub async fn launch(self) {
let streams_rx = self.incoming_streams.clone();
let worker_pool_0 = self.worker_pool.clone();
let server_limits_0 = self.server_limits.clone();
tokio::spawn(async move {
let streams_tx = self.incoming_streams.clone();
loop {
let Ok((stream, _)) = self.listener.accept().await else {
continue;
};
Self::push_incoming_streams(stream, &streams_tx, &server_limits_0);
}
});
loop {
let mut worker = Self::get_value_array(&worker_pool_0, &self.server_limits).await;
let mut stream = Self::get_value_seq(&streams_rx, &self.server_limits).await;
let worker_pool = self.worker_pool.clone();
tokio::spawn(async move {
let _ = worker.run(&mut stream).await;
let _ = worker_pool.push(worker);
});
}
}
#[inline]
fn push_incoming_streams(
mut stream: TcpStream,
streams_tx: &Arc<SegQueue<TcpStream>>,
server_limits: &ServerLimits,
) {
if streams_tx.len() < server_limits.max_pending_connections {
streams_tx.push(stream);
} else {
tokio::spawn(async move {
let _ =
writer::send_error(&mut stream, Version::Http11, ErrorKind::ServiceUnavailable)
.await;
});
}
}
impl_get_value_queue! { get_value_array, ArrayQueue }
impl_get_value_queue! { get_value_seq, SegQueue }
}
pub struct ServerBuilder<H: Handler<S>, S: ConnectionData> {
listener: Option<TcpListener>,
handler: Option<H>,
_marker: PhantomData<S>,
server_limits: Option<ServerLimits>,
request_limits: Option<ReqLimits>,
response_limits: Option<RespLimits>,
connection_limits: Option<ConnLimits>,
http_09_limits: Option<Http09Limits>,
}
impl<H: Handler<S>, S: ConnectionData> ServerBuilder<H, S> {
#[inline(always)]
pub fn listener(mut self, listener: TcpListener) -> Self {
self.listener = Some(listener);
self
}
#[inline(always)]
pub fn handler(mut self, handler: H) -> Self {
self.handler = Some(handler);
self
}
#[inline(always)]
pub fn server_limits(mut self, limits: ServerLimits) -> Self {
self.server_limits = Some(limits);
self
}
#[inline(always)]
pub fn connection_limits(mut self, limits: ConnLimits) -> Self {
self.connection_limits = Some(limits);
self
}
#[inline(always)]
pub fn http_09_limits(mut self, limits: Http09Limits) -> Self {
self.http_09_limits = Some(limits);
self
}
#[inline(always)]
pub fn request_limits(mut self, limits: ReqLimits) -> Self {
self.request_limits = Some(limits);
self
}
#[inline(always)]
pub fn response_limits(mut self, limits: RespLimits) -> Self {
self.response_limits = Some(limits);
self
}
#[inline(always)]
#[track_caller]
pub fn build(self) -> Server<H, S> {
let (listener, handler, limits) = self.get_all_limits();
Self::store_atomic(&limits);
let handler = Arc::new(handler);
let worker_pool = ArrayQueue::new(limits.0.max_connections);
for _ in 0..limits.0.max_connections {
let value = HttpConnection::new(handler.clone(), limits.clone());
let _ = worker_pool.push(value);
}
Server {
listener,
_marker: PhantomData,
worker_pool: Arc::new(worker_pool),
incoming_streams: Arc::new(SegQueue::new()),
server_limits: limits.0,
}
}
#[inline(always)]
fn store_atomic(limits: &AllTypesLimits) {
writer::SOCKET_WRITE_TIMEOUT.store(
limits
.1
.socket_write_timeout
.as_micros()
.try_into()
.unwrap_or(u64::MAX),
Ordering::Relaxed,
);
writer::JSON_ERRORS.store(limits.0.json_errors, Ordering::Relaxed);
}
#[inline(always)]
#[track_caller]
fn get_all_limits(self) -> (TcpListener, H, AllTypesLimits) {
(
self.listener
.expect("The `listener` method must be called to create"),
self.handler
.expect("The `handler` method must be called to create"),
(
self.server_limits.clone().unwrap_or_default(),
self.connection_limits.clone().unwrap_or_default(),
self.http_09_limits.clone(),
self.request_limits
.clone()
.unwrap_or_default()
.precalculate(),
self.response_limits.clone().unwrap_or_default(),
),
)
}
}
pub(crate) type AllTypesLimits = (
ServerLimits,
ConnLimits,
Option<Http09Limits>,
ReqLimits,
RespLimits,
);