use crate::*;
impl Default for Server {
#[inline(always)]
fn default() -> Self {
Self {
server_config: ServerConfig::default(),
request_config: RequestConfig::default(),
task_panic: Vec::new(),
request_error: Vec::new(),
route_matcher: RouteMatcher::new(),
request_middleware: Vec::new(),
response_middleware: Vec::new(),
task: Task::default(),
}
}
}
impl PartialEq for Server {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.get_server_config() == other.get_server_config()
&& self.get_request_config() == other.get_request_config()
&& self.get_route_matcher() == other.get_route_matcher()
&& self.get_task_panic().len() == other.get_task_panic().len()
&& self.get_request_error().len() == other.get_request_error().len()
&& self.get_request_middleware().len() == other.get_request_middleware().len()
&& self.get_response_middleware().len() == other.get_response_middleware().len()
&& self
.get_task_panic()
.iter()
.zip(other.get_task_panic().iter())
.all(|(a, b)| Arc::ptr_eq(a, b))
&& self
.get_request_error()
.iter()
.zip(other.get_request_error().iter())
.all(|(a, b)| Arc::ptr_eq(a, b))
&& self
.get_request_middleware()
.iter()
.zip(other.get_request_middleware().iter())
.all(|(a, b)| Arc::ptr_eq(a, b))
&& self
.get_response_middleware()
.iter()
.zip(other.get_response_middleware().iter())
.all(|(a, b)| Arc::ptr_eq(a, b))
}
}
impl Eq for Server {}
impl From<usize> for Server {
#[inline(always)]
fn from(address: usize) -> Self {
let server: &Server = address.into();
server.clone()
}
}
impl From<usize> for &'static Server {
#[inline(always)]
fn from(address: usize) -> &'static Server {
unsafe { &*(address as *const Server) }
}
}
impl From<usize> for &'static mut Server {
#[inline(always)]
fn from(address: usize) -> &'static mut Server {
unsafe { &mut *(address as *mut Server) }
}
}
impl From<&Server> for usize {
#[inline(always)]
fn from(server: &Server) -> Self {
server as *const Server as usize
}
}
impl From<&mut Server> for usize {
#[inline(always)]
fn from(server: &mut Server) -> Self {
server as *mut Server as usize
}
}
impl AsRef<Server> for Server {
#[inline(always)]
fn as_ref(&self) -> &Self {
let address: usize = self.into();
address.into()
}
}
impl AsMut<Server> for Server {
#[inline(always)]
fn as_mut(&mut self) -> &mut Self {
let address: usize = self.into();
address.into()
}
}
impl From<ServerConfig> for Server {
#[inline(always)]
fn from(server_config: ServerConfig) -> Self {
Self {
server_config,
..Default::default()
}
}
}
impl From<RequestConfig> for Server {
#[inline(always)]
fn from(request_config: RequestConfig) -> Self {
Self {
request_config,
..Default::default()
}
}
}
impl Lifetime for Server {
#[inline(always)]
fn leak(&self) -> &'static Self {
let address: usize = self.into();
address.into()
}
#[inline(always)]
fn leak_mut(&self) -> &'static mut Self {
let address: usize = self.into();
address.into()
}
}
impl Server {
#[inline]
pub fn handle_hook(&mut self, hook: HookType) {
match hook {
HookType::TaskPanic(_, hook) => {
self.get_mut_task_panic().push(hook());
}
HookType::RequestError(_, hook) => {
self.get_mut_request_error().push(hook());
}
HookType::RequestMiddleware(_, hook) => {
self.get_mut_request_middleware().push(hook());
}
HookType::Route(path, hook) => {
self.get_mut_route_matcher().add(path, hook()).unwrap();
}
HookType::ResponseMiddleware(_, hook) => {
self.get_mut_response_middleware().push(hook());
}
};
}
#[inline]
pub fn config_from_json<C>(&mut self, json: C) -> &mut Self
where
C: AsRef<str>,
{
let config: ServerConfig = serde_json::from_str(json.as_ref()).unwrap();
self.set_server_config(config);
self
}
#[inline(always)]
pub fn server_config(&mut self, config: ServerConfig) -> &mut Self {
self.set_server_config(config);
self
}
#[inline(always)]
pub fn request_config(&mut self, config: RequestConfig) -> &mut Self {
self.set_request_config(config);
self
}
#[inline(always)]
pub fn task_panic<S>(&mut self) -> &mut Self
where
S: ServerHook,
{
self.get_mut_task_panic().push(server_hook_factory::<S>());
self
}
#[inline(always)]
pub fn request_error<S>(&mut self) -> &mut Self
where
S: ServerHook,
{
self.get_mut_request_error()
.push(server_hook_factory::<S>());
self
}
#[inline(always)]
pub fn route<S>(&mut self, path: impl AsRef<str>) -> &mut Self
where
S: ServerHook,
{
self.get_mut_route_matcher()
.add(path.as_ref(), server_hook_factory::<S>())
.unwrap();
self
}
#[inline(always)]
pub fn request_middleware<S>(&mut self) -> &mut Self
where
S: ServerHook,
{
self.get_mut_request_middleware()
.push(server_hook_factory::<S>());
self
}
#[inline(always)]
pub fn response_middleware<S>(&mut self) -> &mut Self
where
S: ServerHook,
{
self.get_mut_response_middleware()
.push(server_hook_factory::<S>());
self
}
#[inline(always)]
pub fn format_bind_address<H>(host: H, port: u16) -> String
where
H: AsRef<str>,
{
format!("{}{COLON}{port}", host.as_ref())
}
#[inline(always)]
pub fn try_flush_stdout() -> io::Result<()> {
stdout().flush()
}
#[inline(always)]
pub fn flush_stdout() {
stdout().flush().unwrap();
}
#[inline(always)]
pub fn try_flush_stderr() -> io::Result<()> {
stderr().flush()
}
#[inline(always)]
pub fn flush_stderr() {
stderr().flush().unwrap();
}
#[inline(always)]
pub fn try_flush_stdout_and_stderr() -> io::Result<()> {
Self::try_flush_stdout()?;
Self::try_flush_stderr()
}
#[inline(always)]
pub fn flush_stdout_and_stderr() {
Self::flush_stdout();
Self::flush_stderr();
}
async fn task_handler<F>(&'static self, ctx_address: usize, hook: F)
where
F: Future<Output = ()> + Send + 'static,
{
if let Err(error) = spawn(hook).await
&& error.is_panic()
{
let ctx: &mut Context = ctx_address.into();
let panic: PanicData = PanicData::from_join_error(error);
ctx.set_task_panic(panic)
.get_mut_response()
.set_status_code(HttpStatus::InternalServerError.code());
let panic_hook = async move {
for hook in self.get_task_panic().iter() {
hook(ctx).await;
if ctx.get_aborted() {
return;
}
}
};
if let Err(error) = spawn(panic_hook).await
&& error.is_panic()
{
eprintln!("{}", error);
let _ = Self::try_flush_stdout_and_stderr();
}
let drop_ctx: &mut Context = ctx_address.into();
if !drop_ctx.get_leaked() {
drop_ctx.free();
}
};
}
fn configure_stream(&self, stream: &TcpStream) {
let config: &ServerConfig = self.get_server_config();
if let Some(nodelay) = config.try_get_nodelay() {
let _ = stream.set_nodelay(*nodelay);
}
if let Some(ttl) = config.try_get_ttl() {
let _ = stream.set_ttl(*ttl);
}
}
pub(super) async fn handle_request_middleware(&self, ctx: &mut Context) -> bool {
for hook in self.get_request_middleware().iter() {
hook(ctx).await;
if ctx.get_aborted() {
return true;
}
}
false
}
pub(super) async fn handle_route_matcher(&self, ctx: &mut Context, path: &str) -> bool {
if let Some(hook) = self.get_route_matcher().try_resolve_route(ctx, path) {
hook(ctx).await;
if ctx.get_aborted() {
return true;
}
}
false
}
pub(super) async fn handle_response_middleware(&self, ctx: &mut Context) -> bool {
for hook in self.get_response_middleware().iter() {
hook(ctx).await;
if ctx.get_aborted() {
return true;
}
}
false
}
pub async fn handle_request_error(&self, ctx: &mut Context, error: &RequestError) {
ctx.set_aborted(false)
.set_closed(false)
.set_request_error_data(error.clone());
for hook in self.get_request_error().iter() {
hook(ctx).await;
if ctx.get_aborted() {
return;
}
}
}
async fn request_hook(&self, ctx: &mut Context, request: &Request) -> bool {
let mut response: Response = Response::default();
response.set_version(request.get_version().clone());
ctx.set_aborted(false)
.set_closed(false)
.set_response(response)
.set_route_params(RouteParams::default())
.set_attributes(ThreadSafeAttributeStore::default());
let keep_alive: bool = request.is_enable_keep_alive();
if self.handle_request_middleware(ctx).await {
return ctx.is_keep_alive(keep_alive);
}
let route: &str = request.get_path();
if self.handle_route_matcher(ctx, route).await {
return ctx.is_keep_alive(keep_alive);
}
if self.handle_response_middleware(ctx).await {
return ctx.is_keep_alive(keep_alive);
}
ctx.is_keep_alive(keep_alive)
}
async fn handle_http_requests(&self, ctx: &mut Context, request: &Request) {
if !self.request_hook(ctx, request).await {
return;
}
loop {
match ctx.http_from_stream().await {
Ok(new_request) => {
if !self.request_hook(ctx, &new_request).await {
return;
}
}
Err(error) => {
self.handle_request_error(ctx, &error).await;
return;
}
}
}
}
async fn handle_connection(&self, ctx: &mut Context) {
match ctx.http_from_stream().await {
Ok(request) => {
self.handle_http_requests(ctx, &request).await;
}
Err(error) => {
self.handle_request_error(ctx, &error).await;
}
}
if !ctx.get_leaked() {
ctx.free();
}
}
async fn tcp_accept(&'static self, tcp_listener: &TcpListener) {
loop {
if let Ok((stream, _)) = tcp_listener.accept().await {
self.configure_stream(&stream);
let stream: ArcRwLockStream = ArcRwLockStream::from_stream(stream);
let ctx: &'static mut Context = Box::leak(Box::new(Context::new(&stream, self)));
spawn(self.task_handler(ctx.into(), self.handle_connection(ctx)));
}
}
}
pub async fn run(&self) -> Result<ServerControlHook, ServerError> {
let bind_address: &String = self.get_server_config().get_address();
let tcp_listener: TcpListener = TcpListener::bind(bind_address).await?;
let server: &'static Self = self.leak();
let (wait_sender, wait_receiver) = channel(());
let (shutdown_sender, mut shutdown_receiver) = channel(());
let accept_connections: JoinHandle<()> = spawn(async move {
server.tcp_accept(&tcp_listener).await;
let _ = wait_sender.send(());
});
let wait_hook: ServerControlHookHandler<()> = Arc::new(move || {
let mut wait_receiver_clone: Receiver<()> = wait_receiver.clone();
Box::pin(async move {
let _ = wait_receiver_clone.changed().await;
})
});
let shutdown_hook: ServerControlHookHandler<()> = Arc::new(move || {
let shutdown_sender_clone: Sender<()> = shutdown_sender.clone();
Box::pin(async move {
let _ = shutdown_sender_clone.send(());
})
});
spawn(async move {
let _ = shutdown_receiver.changed().await;
accept_connections.abort();
server.get_task().shutdown();
});
let mut server_control_hook: ServerControlHook = ServerControlHook::default();
server_control_hook.set_shutdown_hook(shutdown_hook);
server_control_hook.set_wait_hook(wait_hook);
Ok(server_control_hook)
}
}