use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tokio::net::{TcpListener, TcpStream};
use crate::config::Config;
use crate::platform::Platform;
use std::future::Future;
use std::sync::atomic::{AtomicBool, Ordering};
const CONNECT_WAIT_TIMEOUT: Duration = Duration::from_millis(500);
pub struct Connection<P: Default + Send + Sync> {
peer_address: String,
active: AtomicBool,
payload: P,
}
impl<P: Default + Send + Sync> PartialEq for Connection<P> {
fn eq(&self, other: &Self) -> bool {
self.peer_address == other.peer_address
}
}
impl<P: Default + Send + Sync> Connection<P> {
pub fn is_active(&self) -> bool {
self.active.load(Ordering::Acquire)
}
pub fn quit(&self) {
self.active.store(false, Ordering::Release);
}
pub fn payload(&self) -> &P {
&self.payload
}
}
pub struct ConnectionInfo<P: Default + Send + Sync> {
pub peer_address: String,
pub payload: P,
}
pub struct Server<P: Default + Send + Sync> {
running: AtomicBool,
current_address: Mutex<Option<String>>,
platform: Arc<Platform>,
connections: Mutex<Vec<Arc<Connection<P>>>>,
}
impl<P: 'static + Default + Send + Sync + Clone> Server<P> {
pub fn install(platform: &Arc<Platform>) -> Arc<Self> {
let server = Arc::new(Server {
running: AtomicBool::new(false),
current_address: Mutex::new(None),
platform: platform.clone(),
connections: Mutex::new(Vec::new()),
});
platform.register::<Server<P>>(server.clone());
server
}
pub fn connections(&self) -> Vec<ConnectionInfo<P>> {
let mut result = Vec::new();
for connection in self.connections.lock().unwrap().iter() {
result.push(ConnectionInfo {
peer_address: connection.peer_address.clone(),
payload: connection.payload.clone(),
});
}
result
}
pub fn kill(&self, peer_address: &str) -> bool {
self.connections
.lock()
.unwrap()
.iter()
.find(|c| c.peer_address == peer_address)
.map(|c| c.active.store(false, Ordering::Release))
.is_some()
}
fn add_connection(&self, connection: Arc<Connection<P>>) {
self.connections.lock().unwrap().push(connection);
}
fn remove_connection(&self, connection: Arc<Connection<P>>) {
let mut mut_connections = self.connections.lock().unwrap();
if let Some(index) = mut_connections
.iter()
.position(|other| *other == connection)
{
let _ = mut_connections.remove(index);
}
}
fn is_running(&self) -> bool {
self.running.load(Ordering::Acquire)
}
fn address(&self) -> String {
self.platform
.find::<Config>()
.map(|config| {
let handle = config.current();
format!(
"{}:{}",
handle.config()["server"]["host"]
.as_str()
.unwrap_or("0.0.0.0"),
handle.config()["server"]["port"]
.as_i64()
.filter(|port| port > &0 && port <= &(u16::MAX as i64))
.unwrap_or(2410)
)
})
.unwrap_or_else(|| "0.0.0.0:2410".to_owned())
}
pub fn fork<F>(
server: &Arc<Server<P>>,
client_loop: &'static (impl Fn(Arc<Platform>, Arc<Connection<P>>, TcpStream) -> F + Send + Sync),
) where
F: Future<Output = anyhow::Result<()>> + Send + Sync,
{
let cloned_server = server.clone();
let _ = tokio::spawn(async move {
cloned_server.event_loop(client_loop).await;
});
}
pub async fn fork_and_await<F>(
server: &Arc<Server<P>>,
client_loop: &'static (impl Fn(Arc<Platform>, Arc<Connection<P>>, TcpStream) -> F + Send + Sync),
) where
F: Future<Output = anyhow::Result<()>> + Send + Sync,
{
Server::fork(server, client_loop);
while !server.is_running() {
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
pub async fn event_loop<F>(
&self,
client_loop: impl Fn(Arc<Platform>, Arc<Connection<P>>, TcpStream) -> F
+ Send
+ Sync
+ Copy
+ 'static,
) where
F: Future<Output = anyhow::Result<()>> + Send,
{
let mut address = String::new();
let mut last_bind_error_reported = Instant::now();
while self.platform.is_running() {
if !self.is_running() {
address = self.address();
self.running.store(true, Ordering::Release);
}
if let Ok(mut listener) = TcpListener::bind(&address).await {
log::info!("Opened server socket on {}...", &address);
*self.current_address.lock().unwrap() = Some(address.clone());
self.server_loop(&mut listener, client_loop).await;
log::info!("Closing server socket on {}.", &address);
} else {
if Instant::now()
.duration_since(last_bind_error_reported)
.as_secs()
> 5
{
log::error!(
"Cannot open server address: {}. Retrying every 500ms...",
&address
);
last_bind_error_reported = Instant::now();
}
tokio::time::sleep(Duration::from_millis(500)).await;
}
}
}
async fn server_loop<F>(
&self,
listener: &mut TcpListener,
client_loop: impl Fn(Arc<Platform>, Arc<Connection<P>>, TcpStream) -> F
+ Copy
+ Send
+ Sync
+ 'static,
) where
F: Future<Output = anyhow::Result<()>> + Send,
{
let mut config_changed_flag = self.platform.require::<Config>().notifier();
while self.platform.is_running() && self.is_running() {
tokio::select! {
timeout_stream = tokio::time::timeout(CONNECT_WAIT_TIMEOUT, listener.accept()) => {
if let Ok(stream) = timeout_stream {
if let Ok((stream, _)) = stream {
self.handle_new_connection(stream, client_loop);
} else {
return;
}
}
}
_ = config_changed_flag.recv() => {
let new_address = self.address();
if let Some(current_address) = &*self.current_address.lock().unwrap() {
if current_address != &new_address {
log::info!("Server address has changed. Restarting server socket...");
self.running.store(false, Ordering::Release);
return;
}
}
}
}
}
}
fn handle_new_connection<F>(
&self,
stream: TcpStream,
client_loop: impl FnOnce(Arc<Platform>, Arc<Connection<P>>, TcpStream) -> F
+ 'static
+ Send
+ Sync
+ Copy,
) where
F: Future<Output = anyhow::Result<()>> + Send,
{
let platform = self.platform.clone();
let _ = tokio::spawn(async move {
let _ = stream.set_nodelay(true);
let server = platform.require::<Server<P>>();
let connection = Arc::new(Connection {
peer_address: stream
.peer_addr()
.map(|addr| addr.to_string())
.unwrap_or_else(|_| "<unknown>".to_owned()),
active: AtomicBool::new(true),
payload: P::default(),
});
log::debug!("Opened connection from {}...", connection.peer_address);
server.add_connection(connection.clone());
if let Err(error) = client_loop(platform, connection.clone(), stream).await {
log::debug!(
"An IO error occurred in connection {}: {}",
connection.peer_address,
error
);
}
log::debug!("Closing connection to {}...", connection.peer_address);
server.remove_connection(connection);
});
}
}