use std::{net::SocketAddr, sync::Arc, time::Duration};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::{
self,
net::TcpListener,
sync::mpsc::Receiver,
time::{timeout, Instant, MissedTickBehavior},
};
use tokio_tungstenite::{tungstenite::Error, WebSocketStream};
#[cfg(feature = "bebop")]
use crate::helpers::common::get_data_schema;
#[cfg(feature = "bebop")]
use crate::schema::{Category, Ping};
use crate::{
helpers::{
client_sender::ClientSendersTrait,
common::{make_disconnect_message, make_pong_message},
},
log_debug, log_error,
};
#[cfg(feature = "bebop")]
use bebop::Record;
use futures_util::{stream::SplitStream, SinkExt, StreamExt};
use tokio::sync::mpsc::{self, Sender};
use tokio_tungstenite::{
accept_async,
tungstenite::{self, Message},
};
use tokio_util::sync::CancellationToken;
use super::{
client_sender::ClientSenders,
metrics::Metrics,
middleware::{MessageMiddleware, MiddlewareResult},
types::RwClientSenders,
};
pub struct AtomicServer {
pub client_senders: RwClientSenders,
cancel_token: CancellationToken,
}
#[derive(Clone)]
pub struct ServerOptions {
pub use_ping: bool,
pub proxy_ping: i16,
pub client_timeout_seconds: u64,
pub client_check_interval_secs: u64,
pub per_connection_buffer_size: usize,
pub handler_buffer_size: usize,
pub spillover_buffer_size: usize,
pub middlewares: Vec<Arc<dyn MessageMiddleware>>,
#[cfg(feature = "rustls")]
pub tls_config: Option<Arc<rustls::ServerConfig>>,
}
impl Default for ServerOptions {
fn default() -> Self {
Self {
use_ping: true,
proxy_ping: -1,
client_timeout_seconds: 30,
client_check_interval_secs: 15,
per_connection_buffer_size: 8,
handler_buffer_size: 1024,
spillover_buffer_size: 1024,
middlewares: Vec::new(),
#[cfg(feature = "rustls")]
tls_config: None,
}
}
}
impl AtomicServer {
pub async fn new(
addr: &str,
option: ServerOptions,
client_senders: Option<RwClientSenders>,
) -> std::io::Result<Self> {
let listener = TcpListener::bind(&addr).await?;
let check_interval = option.client_check_interval_secs;
let cancel_token = CancellationToken::new();
let client_senders = match client_senders {
Some(client_senders) => client_senders,
None => Arc::new(ClientSenders::new_with_buffer_size(
option.handler_buffer_size,
option.spillover_buffer_size,
)),
};
#[cfg(feature = "rustls")]
let use_tls = option.tls_config.is_some();
#[cfg(not(feature = "rustls"))]
let use_tls = false;
client_senders.set_options(option.clone());
if use_tls {
#[cfg(feature = "rustls")]
if let Some(tls_config) = option.tls_config {
let acceptor = tokio_rustls::TlsAcceptor::from(tls_config);
tokio::spawn(handle_accept_tls(
listener,
client_senders.clone(),
cancel_token.clone(),
acceptor,
));
}
} else {
tokio::spawn(handle_accept(
listener,
client_senders.clone(),
cancel_token.clone(),
));
}
tokio::spawn(loop_client_checker(
client_senders.clone(),
check_interval,
cancel_token.clone(),
));
Ok(Self {
client_senders,
cancel_token,
})
}
pub async fn get_handle_message_receiver(&self) -> Option<Receiver<(Vec<u8>, String)>> {
self.client_senders.get_handle_message_receiver().await
}
pub async fn shutdown(&self) {
self.cancel_token.cancel();
}
pub fn metrics(&self) -> Arc<Metrics> {
self.client_senders.metrics.clone()
}
pub fn accept_upgraded<S>(&self, peer: SocketAddr, ws_stream: WebSocketStream<S>)
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let cs = self.client_senders.clone();
tokio::spawn(async move {
if let Err(e) = handle_upgraded_connection(cs, peer, ws_stream).await {
match e {
Error::ConnectionClosed | Error::Protocol(_) | Error::Utf8(_) => (),
err => log_error!("Error processing upgraded connection: {}", err),
}
}
});
}
}
pub async fn loop_client_checker(
server_sender: RwClientSenders,
check_interval_secs: u64,
cancel_token: CancellationToken,
) {
let secs = check_interval_secs.max(1);
let mut interval = tokio::time::interval_at(
Instant::now() + Duration::from_secs(secs),
Duration::from_secs(secs),
);
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
loop {
tokio::select! {
_ = cancel_token.cancelled() => {
log_debug!("Client checker shutting down");
break;
}
_ = interval.tick() => {
server_sender.check_client_send_time();
log_debug!("loop client cheker finish");
}
}
}
}
pub async fn handle_accept(
listener: TcpListener,
client_senders: RwClientSenders,
cancel_token: CancellationToken,
) {
loop {
tokio::select! {
_ = cancel_token.cancelled() => {
log_debug!("Accept loop shutting down");
break;
}
result = listener.accept() => {
match result {
Ok((stream, _)) => {
let peer = match stream.peer_addr() {
Ok(addr) => addr,
Err(e) => {
log_error!("Failed to get peer address: {:?}", e);
continue;
}
};
log_debug!("Peer address: {}", peer);
tokio::spawn(accept_connection(client_senders.clone(), peer, stream));
}
Err(e) => {
log_error!("Error accepting connection: {:?}", e);
}
}
}
}
}
}
#[cfg(feature = "rustls")]
pub async fn handle_accept_tls(
listener: TcpListener,
client_senders: RwClientSenders,
cancel_token: CancellationToken,
tls_acceptor: tokio_rustls::TlsAcceptor,
) {
loop {
tokio::select! {
_ = cancel_token.cancelled() => {
log_debug!("TLS accept loop shutting down");
break;
}
result = listener.accept() => {
match result {
Ok((stream, _)) => {
let peer = match stream.peer_addr() {
Ok(addr) => addr,
Err(e) => {
log_error!("Failed to get peer address: {:?}", e);
continue;
}
};
log_debug!("Peer address (TLS): {}", peer);
let acceptor = tls_acceptor.clone();
let cs = client_senders.clone();
tokio::spawn(async move {
match acceptor.accept(stream).await {
Ok(tls_stream) => {
accept_connection(cs, peer, tls_stream).await;
}
Err(e) => {
log_error!("TLS handshake failed for {}: {:?}", peer, e);
}
}
});
}
Err(e) => {
log_error!("Error accepting connection: {:?}", e);
}
}
}
}
}
}
pub async fn accept_connection<S>(client_senders: RwClientSenders, peer: SocketAddr, stream: S)
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
if let Err(e) = handle_connection(client_senders, peer, stream).await {
match e {
Error::ConnectionClosed | Error::Protocol(_) | Error::Utf8(_) => (),
err => log_error!("Error processing connection: {}", err),
}
}
}
#[cfg(feature = "bebop")]
pub async fn handle_connection<S>(
client_senders: RwClientSenders,
peer: SocketAddr,
stream: S,
) -> tungstenite::Result<()>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
match accept_async(stream).await {
Ok(ws_stream) => {
inner_handle_ws(client_senders, peer, ws_stream).await?;
}
Err(e) => {
log_debug!("Error accepting WebSocket connection: {:?}", e);
}
}
Ok(())
}
#[cfg(feature = "bebop")]
pub async fn handle_upgraded_connection<S>(
client_senders: RwClientSenders,
peer: SocketAddr,
ws_stream: WebSocketStream<S>,
) -> tungstenite::Result<()>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
inner_handle_ws(client_senders, peer, ws_stream).await
}
#[cfg(feature = "bebop")]
async fn inner_handle_ws<S>(
client_senders: RwClientSenders,
peer: SocketAddr,
ws_stream: WebSocketStream<S>,
) -> tungstenite::Result<()>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
log_debug!("New WebSocket connection: {}", peer);
let (mut ostream, mut istream) = ws_stream.split();
let options = client_senders.options();
let buf_size = options.per_connection_buffer_size;
let (sx, mut rx) = mpsc::channel(buf_size);
tokio::spawn(async move {
let use_ping = options.use_ping;
let middlewares = options.middlewares;
let id = get_id_from_first_message(&mut istream, client_senders.clone(), sx.clone()).await;
if let Some(id) = id {
let mut connected = true;
for mw in &middlewares {
if !mw.on_connect(&id).await {
connected = false;
break;
}
}
if connected {
while let Some(Ok(message)) = istream.next().await {
let value = message.into_data();
let data = match get_data_schema(&value) {
Ok(data) => data,
Err(e) => {
log_error!("Error getting data schema: {:?}", e);
continue;
}
};
if data.category == Category::Ping as u16 && use_ping {
if let Ok(data) = Ping::deserialize(&data.datas) {
client_senders.send(data.peer, make_pong_message()).await;
continue;
}
}
if data.category == Category::Disconnect as u16 {
break;
}
let mut should_forward = true;
for mw in &middlewares {
if mw.on_message(&id, &value).await == MiddlewareResult::Stop {
should_forward = false;
break;
}
}
if should_forward {
client_senders.send_handle_message(data, &id).await;
}
}
for mw in &middlewares {
mw.on_disconnect(&id).await;
}
}
}
let _ = sx.send(make_disconnect_message(&peer.to_string())).await;
});
while let Some(message) = rx.recv().await {
ostream.send(message.clone()).await?;
let data = message.into_data();
let data = match get_data_schema(&data) {
Ok(data) => data,
Err(e) => {
log_error!("Error getting data schema: {:?}", e);
rx.close();
break;
}
};
log_debug!("Server sending message: {:?}", data);
if data.category == Category::Disconnect as u16 {
rx.close();
break;
}
}
log_debug!("client: {} disconnected", peer);
let _ = timeout(Duration::from_secs(1), ostream.flush()).await;
Ok(())
}
#[cfg(not(feature = "bebop"))]
pub async fn handle_connection<S>(
client_senders: RwClientSenders,
peer: SocketAddr,
stream: S,
) -> tungstenite::Result<()>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
match accept_async(stream).await {
Ok(ws_stream) => {
inner_handle_ws(client_senders, peer, ws_stream).await?;
}
Err(e) => {
log_debug!("Error accepting WebSocket connection: {:?}", e);
}
}
Ok(())
}
#[cfg(not(feature = "bebop"))]
pub async fn handle_upgraded_connection<S>(
client_senders: RwClientSenders,
peer: SocketAddr,
ws_stream: WebSocketStream<S>,
) -> tungstenite::Result<()>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
inner_handle_ws(client_senders, peer, ws_stream).await
}
#[cfg(not(feature = "bebop"))]
async fn inner_handle_ws<S>(
client_senders: RwClientSenders,
peer: SocketAddr,
ws_stream: WebSocketStream<S>,
) -> tungstenite::Result<()>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
log_debug!("New WebSocket connection: {}", peer);
let (mut ostream, mut istream) = ws_stream.split();
let options = client_senders.options();
let buf_size = options.per_connection_buffer_size;
let (sx, mut rx) = mpsc::channel(buf_size);
let peer_str = peer.to_string();
client_senders.add(&peer_str, sx.clone()).await;
tokio::spawn(async move {
let middlewares = options.middlewares;
let mut connected = true;
for mw in &middlewares {
if !mw.on_connect(&peer_str).await {
connected = false;
break;
}
}
if connected {
while let Some(Ok(message)) = istream.next().await {
let value = message.into_data();
let mut should_forward = true;
for mw in &middlewares {
if mw.on_message(&peer_str, &value).await == MiddlewareResult::Stop {
should_forward = false;
break;
}
}
if should_forward {
client_senders
.send_handle_message(value.to_vec(), &peer_str)
.await;
}
}
for mw in &middlewares {
mw.on_disconnect(&peer_str).await;
}
}
let _ = sx.send(make_disconnect_message(&peer_str)).await;
});
while let Some(message) = rx.recv().await {
ostream.send(message).await?;
}
log_debug!("client: {} disconnected", peer);
let _ = timeout(Duration::from_secs(1), ostream.flush()).await;
Ok(())
}
#[cfg(feature = "bebop")]
async fn get_id_from_first_message<S>(
istream: &mut SplitStream<WebSocketStream<S>>,
client_senders: RwClientSenders,
sx: Sender<Message>,
) -> Option<String>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let mut _id: Option<String> = None;
if let Some(Ok(message)) = istream.next().await {
log_debug!("receive first message from client: {:?}", message);
let value = message.into_data();
let mut data = match get_data_schema(&value) {
Ok(data) => data,
Err(e) => {
log_error!("Error getting data schema: {:?}", e);
return None;
}
};
let options = client_senders.options();
if data.category == Category::Ping as u16 {
log_debug!("receive ping from client: {:?}", data);
if let Ok(ping) = Ping::deserialize(&data.datas) {
let peer_id: String = ping.peer.into();
_id = Some(peer_id.clone());
client_senders.add(&peer_id, sx).await;
if options.use_ping {
client_senders.send(&peer_id, make_pong_message()).await;
} else {
if options.proxy_ping > 0 {
data.category = options.proxy_ping as u16;
}
client_senders.send_handle_message(data, &peer_id).await;
}
}
} else if options.proxy_ping > 0 && data.category == options.proxy_ping as u16 {
if let Ok(ping) = Ping::deserialize(&data.datas) {
let peer_id: String = ping.peer.into();
_id = Some(peer_id.clone());
client_senders.add(&peer_id, sx).await;
data.category = options.proxy_ping as u16;
client_senders.send_handle_message(data, &peer_id).await;
}
}
}
_id
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_server_options_default() {
let options = ServerOptions::default();
assert!(options.use_ping);
assert_eq!(options.proxy_ping, -1);
}
#[test]
fn test_server_options_custom() {
let options = ServerOptions {
use_ping: false,
proxy_ping: 100,
..Default::default()
};
assert!(!options.use_ping);
assert_eq!(options.proxy_ping, 100);
}
#[test]
fn test_server_options_clone() {
let options = ServerOptions {
use_ping: false,
proxy_ping: 50,
..Default::default()
};
let cloned = options.clone();
assert!(!cloned.use_ping);
assert_eq!(cloned.proxy_ping, 50);
}
#[test]
fn test_server_options_proxy_ping_disabled() {
let options = ServerOptions {
use_ping: true,
proxy_ping: -1,
..Default::default()
};
assert!(options.proxy_ping < 0);
}
#[test]
fn test_server_options_proxy_ping_enabled() {
let options = ServerOptions {
use_ping: false,
proxy_ping: 200,
..Default::default()
};
assert!(options.proxy_ping > 0);
assert_eq!(options.proxy_ping, 200);
}
}