use crate::*;
impl Default for ServerData {
fn default() -> Self {
Self {
server_config: ServerConfigData::default(),
hook: vec![],
task_panic: vec![],
read_error: vec![],
}
}
}
impl Default for ServerControlHook {
fn default() -> Self {
Self {
wait_hook: Arc::new(|| Box::pin(async {})),
shutdown_hook: Arc::new(|| Box::pin(async {})),
}
}
}
impl ServerData {
pub(crate) fn get_config(&self) -> &ServerConfigData {
&self.server_config
}
pub(crate) fn get_mut_server_config(&mut self) -> &mut ServerConfigData {
&mut self.server_config
}
pub(crate) fn get_hook(&self) -> &ServerHookList {
&self.hook
}
pub(crate) fn get_mut_hook(&mut self) -> &mut ServerHookList {
&mut self.hook
}
pub(crate) fn get_task_panic(&self) -> &ServerHookList {
&self.task_panic
}
pub(crate) fn get_mut_task_panic(&mut self) -> &mut ServerHookList {
&mut self.task_panic
}
pub(crate) fn get_read_error(&self) -> &ServerHookList {
&self.read_error
}
pub(crate) fn get_mut_read_error(&mut self) -> &mut ServerHookList {
&mut self.read_error
}
}
impl Default for Server {
fn default() -> Self {
Self(Arc::new(RwLock::new(ServerData::default())))
}
}
impl Server {
pub fn new() -> Self {
Self::default()
}
pub(crate) async fn read(&self) -> ArcRwLockReadGuard<'_, ServerData> {
self.0.read().await
}
pub(crate) async fn write(&self) -> ArcRwLockWriteGuard<'_, ServerData> {
self.0.write().await
}
pub async fn server_config(&self, config: ServerConfig) -> &Self {
*self.write().await.get_mut_server_config() = config.get_data().await;
self
}
#[inline(always)]
pub fn get_bind_addr<H>(host: H, port: u16) -> String
where
H: AsRef<str>,
{
format!("{}{}{}", host.as_ref(), COLON, port)
}
pub async fn hook<H>(&self) -> &Self
where
H: ServerHook,
{
self.write()
.await
.get_mut_hook()
.push(server_hook_factory::<H>());
self
}
pub async fn task_panic<H>(&self) -> &Self
where
H: ServerHook,
{
self.write()
.await
.get_mut_task_panic()
.push(server_hook_factory::<H>());
self
}
pub async fn read_error<H>(&self) -> &Self
where
H: ServerHook,
{
self.write()
.await
.get_mut_read_error()
.push(server_hook_factory::<H>());
self
}
async fn create_tcp_listener(&self) -> Result<TcpListener, ServerError> {
let config: ServerConfigData = self.read().await.get_config().clone();
let host: String = config.host;
let port: u16 = config.port;
let addr: String = Self::get_bind_addr(&host, port);
TcpListener::bind(&addr)
.await
.map_err(|e| ServerError::TcpBind(e.to_string()))
}
async fn spawn_connection_handler(&self, stream: ArcRwLockStream) {
let server: Server = self.clone();
let hook: ServerHookList = self.read().await.get_hook().clone();
let task_panic: ServerHookList = self.read().await.get_task_panic().clone();
let buffer_size: usize = self.read().await.get_config().buffer_size;
spawn(async move {
server
.handle_connection(stream, hook, task_panic, buffer_size)
.await;
});
}
async fn handle_connection(
&self,
stream: ArcRwLockStream,
hook: ServerHookList,
task_panic: ServerHookList,
buffer_size: usize,
) {
let request: Request = match self.read_stream(&stream, buffer_size).await {
Ok(data) => data,
Err(e) => {
self.read_error_handle(e.to_string()).await;
return;
}
};
let ctx: Context = self.create_context(stream, request).await;
for h in hook.iter() {
let ctx_clone: Context = ctx.clone();
let h_clone: ServerHookHandler = Arc::clone(h);
let join_handle: JoinHandle<()> = spawn(async move {
h_clone(ctx_clone).await;
});
match join_handle.await {
Ok(()) => {}
Err(e) if e.is_panic() => {
for panic_handler in task_panic.iter() {
panic_handler(ctx.clone()).await;
}
break;
}
Err(_) => break,
}
}
}
async fn read_stream(
&self,
stream: &ArcRwLockStream,
buffer_size: usize,
) -> Result<Request, ServerError> {
let mut buffer: Vec<u8> = Vec::new();
let mut tmp_buf: Vec<u8> = vec![0u8; buffer_size];
let mut stream_guard: ArcRwLockWriteGuard<'_, TcpStream> = stream.write().await;
loop {
match stream_guard.read(&mut tmp_buf).await {
Ok(0) => break,
Ok(n) => {
buffer.extend_from_slice(&tmp_buf[..n]);
if tmp_buf[..n].ends_with(SPLIT_REQUEST_BYTES) {
let end_pos: usize = buffer.len().saturating_sub(SPLIT_REQUEST_BYTES.len());
buffer.truncate(end_pos);
break;
}
if n < tmp_buf.len() {
break;
}
}
Err(e) => {
return Err(ServerError::TcpRead(e.to_string()));
}
}
}
Ok(buffer)
}
async fn create_context(&self, stream: ArcRwLockStream, request: Request) -> Context {
let mut data: ContextData = ContextData::new();
data.stream = Some(stream);
data.request = request;
Context::from(data)
}
async fn read_error_handle(&self, error: String) {
let error_handlers: ServerHookList = self.read().await.get_read_error().clone();
let ctx: Context = Context::new();
ctx.set_data("error", error).await;
for handler in error_handlers.iter() {
handler(ctx.clone()).await;
}
}
pub async fn run(&self) -> Result<ServerControlHook, ServerError> {
let tcp_listener: TcpListener = self.create_tcp_listener().await?;
let server: Server = self.clone();
let (wait_sender, wait_receiver) = channel(());
let (shutdown_sender, mut shutdown_receiver) = channel(());
let accept_connections: JoinHandle<()> = spawn(async move {
loop {
tokio::select! {
result = tcp_listener.accept() => {
match result {
Ok((stream, _)) => {
let stream: ArcRwLockStream = ArcRwLockStream::from_stream(stream);
server.spawn_connection_handler(stream).await;
}
Err(_) => break,
}
}
_ = shutdown_receiver.changed() => {
break;
}
}
}
let _ = wait_sender.send(());
});
let wait_hook = Arc::new(move || {
let mut wait_receiver_clone = wait_receiver.clone();
Box::pin(async move {
let _ = wait_receiver_clone.changed().await;
}) as Pin<Box<dyn Future<Output = ()> + Send + 'static>>
});
let shutdown_hook = Arc::new(move || {
let shutdown_sender_clone: Sender<()> = shutdown_sender.clone();
Box::pin(async move {
let _ = shutdown_sender_clone.send(());
}) as Pin<Box<dyn Future<Output = ()> + Send + 'static>>
});
spawn(async move {
let _ = accept_connections.await;
});
Ok(ServerControlHook {
wait_hook,
shutdown_hook,
})
}
}
impl ServerControlHook {
pub async fn wait(&self) {
(self.wait_hook)().await;
}
pub async fn shutdown(&self) {
(self.shutdown_hook)().await;
}
}