use std::net::{Ipv4Addr, SocketAddrV4};
use std::sync::Arc;
use async_compat::Compat;
use futures_rustls::TlsAcceptor;
use rustls::ServerConfig;
use tokio::net::TcpListener;
use tokio::sync::mpsc;
use crate::*;
pub struct Server {
active_sessions: Sessions,
listening_socket: SocketAddrV4,
tls_config: Option<ServerConfig>,
#[cfg(any(test, feature = "__tests"))]
receiver: Option<mpsc::Receiver<SessionMessage>>,
#[cfg(any(test, feature = "__tests"))]
sender: Option<mpsc::Sender<SessionMessage>>,
}
impl Default for Server {
#[inline]
fn default() -> Self {
Self::builder().build()
}
}
impl Server {
#[inline]
pub fn builder() -> ServerBuilder {
ServerBuilder::default()
}
#[inline]
fn handle_message(&mut self, msg: SessionMessage) {
use SessionMessage::*;
match msg {
JoinSession(peer_id, session_id, sender) => {
let res = self.active_sessions.join(session_id, peer_id);
let _ = sender.send(res);
},
PeerLeft(peer_id, session_id) => {
self.active_sessions.remove_peer(session_id, peer_id);
},
StartSession(peer_id, sender) => {
let infos = self.active_sessions.start(peer_id);
let _ = sender.send(infos);
},
#[cfg(any(test, feature = "__tests"))]
Ping(sender) => {
let _ = sender.send(());
},
#[cfg(any(test, feature = "__tests"))]
Sessions(sender) => {
let _ = sender.send(self.active_sessions.clone());
},
}
}
#[inline]
pub fn new() -> Self {
Self::default()
}
#[inline]
pub async fn run(mut self) -> Result<(), anyhow::Error> {
let tcp_listener = TcpListener::bind(self.listening_socket).await?;
let tls_acceptor =
self.tls_config.take().map(Arc::new).map(TlsAcceptor::from);
#[cfg(not(any(test, feature = "__tests")))]
let (sender, mut receiver) = mpsc::channel(1024);
#[cfg(any(test, feature = "__tests"))]
let (sender, mut receiver) =
(self.sender.take().unwrap(), self.receiver.take().unwrap());
tokio::spawn(async move {
while let Some(msg) = receiver.recv().await {
self.handle_message(msg);
}
});
loop {
let (tcp_stream, client_addr) = tcp_listener.accept().await?;
let tcp_stream = Compat::new(tcp_stream);
let sender = sender.clone();
if let Some(tls_acceptor) = &tls_acceptor {
let tls_stream = tls_acceptor.accept(tcp_stream).await?;
let run = Connection::run(tls_stream, client_addr, sender);
tokio::spawn(run);
} else {
let run = Connection::run(tcp_stream, client_addr, sender);
tokio::spawn(run);
}
}
}
#[cfg(any(test, feature = "__tests"))]
pub fn test_sender(&self) -> TestSender {
TestSender::new(self.sender.as_ref().unwrap().clone())
}
}
#[derive(Clone)]
pub struct ServerBuilder {
ip: Ipv4Addr,
port: u16,
tls_config: Option<ServerConfig>,
}
impl Default for ServerBuilder {
#[inline]
fn default() -> Self {
Self {
ip: Ipv4Addr::new(0, 0, 0, 0),
port: common::SERVER_LISTENING_PORT,
tls_config: None,
}
}
}
impl ServerBuilder {
#[inline]
pub fn build(self) -> Server {
#[cfg(any(test, feature = "__tests"))]
let (sender, receiver) = mpsc::channel(1024);
Server {
active_sessions: Sessions::new(),
listening_socket: SocketAddrV4::new(self.ip, self.port),
tls_config: self.tls_config,
#[cfg(any(test, feature = "__tests"))]
sender: Some(sender),
#[cfg(any(test, feature = "__tests"))]
receiver: Some(receiver),
}
}
#[inline]
pub fn with_ip(mut self, ip: Ipv4Addr) -> Self {
self.ip = ip;
self
}
#[inline]
pub fn with_port(mut self, port: u16) -> Self {
self.port = port;
self
}
#[inline]
pub fn with_tls_config(mut self, config: ServerConfig) -> Self {
self.tls_config = Some(config);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn server_start_twice() -> Result<(), anyhow::Error> {
let server = Server::new();
let sender = server.test_sender();
tokio::spawn(async move { server.run().await });
sender.start().await?;
let server = Server::new();
assert!(
server.run().await.is_err(),
"running the server should fail if the socket is being used"
);
Ok(())
}
}