use crate::common::{authentication::Verifier, Listener, Response, Transport};
use async_trait::async_trait;
use log::*;
use serde::{de::DeserializeOwned, Serialize};
use std::{io, sync::Arc, time::Duration};
use tokio::sync::{broadcast, RwLock};
mod builder;
pub use builder::*;
mod config;
pub use config::*;
mod connection;
use connection::*;
mod context;
pub use context::*;
mod r#ref;
pub use r#ref::*;
mod reply;
pub use reply::*;
mod state;
use state::*;
mod shutdown_timer;
use shutdown_timer::*;
pub struct Server<T> {
config: ServerConfig,
handler: T,
verifier: Verifier,
}
#[async_trait]
pub trait ServerHandler: Send {
type Request;
type Response;
type LocalData: Send;
#[allow(unused_variables)]
async fn on_accept(&self, ctx: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> {
Ok(())
}
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>);
}
impl Server<()> {
pub fn new() -> Self {
Self {
config: Default::default(),
handler: (),
verifier: Verifier::empty(),
}
}
pub fn tcp() -> TcpServerBuilder<()> {
TcpServerBuilder::default()
}
#[cfg(unix)]
pub fn unix_socket() -> UnixSocketServerBuilder<()> {
UnixSocketServerBuilder::default()
}
#[cfg(windows)]
pub fn windows_pipe() -> WindowsPipeServerBuilder<()> {
WindowsPipeServerBuilder::default()
}
}
impl Default for Server<()> {
fn default() -> Self {
Self::new()
}
}
impl<T> Server<T> {
pub fn config(self, config: ServerConfig) -> Self {
Self {
config,
handler: self.handler,
verifier: self.verifier,
}
}
pub fn handler<U>(self, handler: U) -> Server<U> {
Server {
config: self.config,
handler,
verifier: self.verifier,
}
}
pub fn verifier(self, verifier: Verifier) -> Self {
Self {
config: self.config,
handler: self.handler,
verifier,
}
}
}
impl<T> Server<T>
where
T: ServerHandler + Sync + 'static,
T::Request: DeserializeOwned + Send + Sync + 'static,
T::Response: Serialize + Send + 'static,
T::LocalData: Default + Send + Sync + 'static,
{
pub fn start<L>(self, listener: L) -> io::Result<Box<dyn ServerRef>>
where
L: Listener + 'static,
L::Output: Transport + 'static,
{
let state = Arc::new(ServerState::new());
let (tx, rx) = broadcast::channel(1);
let task = tokio::spawn(self.task(Arc::clone(&state), listener, tx.clone(), rx));
Ok(Box::new(GenericServerRef { shutdown: tx, task }))
}
async fn task<L>(
self,
state: Arc<ServerState<Response<T::Response>>>,
mut listener: L,
shutdown_tx: broadcast::Sender<()>,
shutdown_rx: broadcast::Receiver<()>,
) where
L: Listener + 'static,
L::Output: Transport + 'static,
{
let Server {
config,
handler,
verifier,
} = self;
let handler = Arc::new(handler);
let timer = ShutdownTimer::start(config.shutdown);
let mut notification = timer.clone_notification();
let timer = Arc::new(RwLock::new(timer));
let verifier = Arc::new(verifier);
let mut connection_tasks = Vec::new();
loop {
let transport = tokio::select! {
result = listener.accept() => {
match result {
Ok(x) => x,
Err(x) => {
error!("Server no longer accepting connections: {x}");
timer.read().await.abort();
break;
}
}
}
_ = notification.wait() => {
info!(
"Server shutdown triggered after {}s",
config.shutdown.duration().unwrap_or_default().as_secs_f32(),
);
let _ = shutdown_tx.send(());
break;
}
};
timer.read().await.stop();
connection_tasks.push(
ConnectionTask::build()
.handler(Arc::downgrade(&handler))
.state(Arc::downgrade(&state))
.keychain(state.keychain.clone())
.transport(transport)
.shutdown(shutdown_rx.resubscribe())
.shutdown_timer(Arc::downgrade(&timer))
.sleep_duration(config.connection_sleep)
.heartbeat_duration(config.connection_heartbeat)
.verifier(Arc::downgrade(&verifier))
.spawn(),
);
}
info!("Server waiting for active connections to terminate");
loop {
connection_tasks.retain(|task| !task.is_finished());
if connection_tasks.is_empty() {
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
info!("Server task terminated");
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::{
authentication::{AuthenticationMethod, DummyAuthHandler, NoneAuthenticationMethod},
Connection, InmemoryTransport, MpscListener, Request, Response,
};
use async_trait::async_trait;
use std::time::Duration;
use test_log::test;
use tokio::sync::mpsc;
pub struct TestServerHandler;
#[async_trait]
impl ServerHandler for TestServerHandler {
type Request = u16;
type Response = String;
type LocalData = ();
async fn on_accept(&self, _: ConnectionCtx<'_, Self::LocalData>) -> io::Result<()> {
Ok(())
}
async fn on_request(&self, ctx: ServerCtx<Self::Request, Self::Response, Self::LocalData>) {
ctx.reply.send("hello".to_string()).await.unwrap();
}
}
#[inline]
fn make_test_server(config: ServerConfig) -> Server<TestServerHandler> {
let methods: Vec<Box<dyn AuthenticationMethod>> =
vec![Box::new(NoneAuthenticationMethod::new())];
Server {
config,
handler: TestServerHandler,
verifier: Verifier::new(methods),
}
}
#[allow(clippy::type_complexity)]
fn make_listener(
buffer: usize,
) -> (
mpsc::Sender<InmemoryTransport>,
MpscListener<InmemoryTransport>,
) {
MpscListener::channel(buffer)
}
#[test(tokio::test)]
async fn should_invoke_handler_upon_receiving_a_request() {
let (tx, listener) = make_listener(100);
let (transport, connection) = InmemoryTransport::pair(100);
tx.send(connection)
.await
.expect("Failed to feed listener a connection");
let _server = make_test_server(ServerConfig::default())
.start(listener)
.expect("Failed to start server");
let mut connection = Connection::client(transport, DummyAuthHandler)
.await
.expect("Failed to connect to server");
connection
.write_frame(Request::new(123).to_vec().unwrap())
.await
.expect("Failed to send request");
let frame = connection.read_frame().await.unwrap().unwrap();
let response: Response<String> = Response::from_slice(frame.as_item()).unwrap();
assert_eq!(response.payload, "hello");
}
#[test(tokio::test)]
async fn should_lonely_shutdown_if_no_connections_received_after_n_secs_when_config_set() {
let (_tx, listener) = make_listener(100);
let server = make_test_server(ServerConfig {
shutdown: Shutdown::Lonely(Duration::from_millis(100)),
..Default::default()
})
.start(listener)
.expect("Failed to start server");
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(server.is_finished(), "Server shutdown not triggered!");
}
#[test(tokio::test)]
async fn should_lonely_shutdown_if_last_connection_terminated_and_then_no_connections_after_n_secs(
) {
let (tx, listener) = make_listener(100);
let (transport, connection) = InmemoryTransport::pair(100);
tx.send(connection)
.await
.expect("Failed to feed listener a connection");
let server = make_test_server(ServerConfig {
shutdown: Shutdown::Lonely(Duration::from_millis(100)),
..Default::default()
})
.start(listener)
.expect("Failed to start server");
drop(transport);
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(server.is_finished(), "Server shutdown not triggered!");
}
#[test(tokio::test)]
async fn should_not_lonely_shutdown_as_long_as_a_connection_exists() {
let (tx, listener) = make_listener(100);
let (_transport, connection) = InmemoryTransport::pair(100);
tx.send(connection)
.await
.expect("Failed to feed listener a connection");
let server = make_test_server(ServerConfig {
shutdown: Shutdown::Lonely(Duration::from_millis(100)),
..Default::default()
})
.start(listener)
.expect("Failed to start server");
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(!server.is_finished(), "Server shutdown when it should not!");
}
#[test(tokio::test)]
async fn should_shutdown_after_n_seconds_even_with_connections_if_config_set_to_after() {
let (tx, listener) = make_listener(100);
let (_transport, connection) = InmemoryTransport::pair(100);
tx.send(connection)
.await
.expect("Failed to feed listener a connection");
let server = make_test_server(ServerConfig {
shutdown: Shutdown::After(Duration::from_millis(100)),
..Default::default()
})
.start(listener)
.expect("Failed to start server");
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(server.is_finished(), "Server shutdown not triggered!");
}
#[test(tokio::test)]
async fn should_shutdown_after_n_seconds_if_config_set_to_after() {
let (_tx, listener) = make_listener(100);
let server = make_test_server(ServerConfig {
shutdown: Shutdown::After(Duration::from_millis(100)),
..Default::default()
})
.start(listener)
.expect("Failed to start server");
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(server.is_finished(), "Server shutdown not triggered!");
}
#[test(tokio::test)]
async fn should_never_shutdown_if_config_set_to_never() {
let (_tx, listener) = make_listener(100);
let server = make_test_server(ServerConfig {
shutdown: Shutdown::Never,
..Default::default()
})
.start(listener)
.expect("Failed to start server");
tokio::time::sleep(Duration::from_millis(300)).await;
assert!(!server.is_finished(), "Server shutdown when it should not!");
}
}