use std::{
collections::HashSet,
net::{AddrParseError, IpAddr, SocketAddr, UdpSocket},
sync::Arc,
};
use bevy::prelude::*;
use bytes::Bytes;
use quinn::{default_runtime, Endpoint as QuinnEndpoint, EndpointConfig, ServerConfig};
use tokio::{
runtime,
sync::{
broadcast::{self},
mpsc::{self},
},
};
use crate::{
server::{
certificate::{retrieve_certificate, CertificateRetrievalMode, ServerCertificate},
connection::ServerConnection,
endpoint::Endpoint,
},
shared::{
channels::{
tasks::{spawn_recv_channels_tasks, spawn_send_channels_tasks_spawner},
ChannelAsyncMessage, ChannelId, ChannelSyncMessage, SendChannelsConfiguration,
},
peer_connection::PeerConnection,
AsyncRuntime, ClientId, QuinnetSyncPreUpdate, DEFAULT_INTERNAL_MESSAGES_CHANNEL_SIZE,
DEFAULT_KEEP_ALIVE_INTERVAL_S, DEFAULT_KILL_MESSAGE_QUEUE_SIZE, DEFAULT_MESSAGE_QUEUE_SIZE,
DEFAULT_QCHANNEL_MESSAGES_CHANNEL_SIZE,
},
};
#[cfg(feature = "shared-client-id")]
mod client_id;
#[cfg(feature = "bincode-messages")]
pub mod messages;
pub mod connection;
pub mod endpoint;
pub mod error;
pub use error::*;
pub mod certificate;
#[derive(bevy::ecs::message::Message, Debug, Copy, Clone)]
pub struct ConnectionEvent {
pub id: ClientId,
}
#[derive(bevy::ecs::message::Message, Debug, Copy, Clone)]
pub struct ConnectionLostEvent {
pub id: ClientId,
}
#[derive(Debug, Clone)]
pub struct EndpointAddrConfiguration {
pub local_bind_addr: SocketAddr,
}
impl EndpointAddrConfiguration {
pub fn from_string(local_bind_addr_str: &str) -> Result<Self, AddrParseError> {
let local_bind_addr = local_bind_addr_str.parse()?;
Ok(Self::from_addr(local_bind_addr))
}
pub fn from_ip(local_bind_ip: impl Into<IpAddr>, local_bind_port: u16) -> Self {
Self::from_addr(SocketAddr::new(local_bind_ip.into(), local_bind_port))
}
pub fn from_addr(local_bind_addr: SocketAddr) -> Self {
Self { local_bind_addr }
}
}
pub(crate) enum ServerAsyncMessage {
ClientConnected(PeerConnection<ServerConnection>),
ClientConnectionClosed(ClientId), }
#[derive(Debug, Clone)]
pub(crate) enum ServerSyncMessage {
ClientConnectedAck(ClientId),
}
#[derive(Resource)]
pub struct QuinnetServer {
runtime: runtime::Handle,
endpoint: Option<Endpoint>,
}
impl FromWorld for QuinnetServer {
fn from_world(world: &mut World) -> Self {
if world.get_resource::<AsyncRuntime>().is_none() {
let async_runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
world.insert_resource(AsyncRuntime(async_runtime));
};
let runtime = world.resource::<AsyncRuntime>();
QuinnetServer::new(runtime.handle().clone())
}
}
#[derive(Debug, Clone)]
pub struct ServerEndpointConfiguration {
pub addr_config: EndpointAddrConfiguration,
pub cert_mode: CertificateRetrievalMode,
pub defaultables: ServerEndpointConfigurationDefaultables,
}
#[derive(Debug, Default, Clone)]
pub struct ServerEndpointConfigurationDefaultables {
pub send_channels_cfg: SendChannelsConfiguration,
#[cfg(feature = "recv_channels")]
pub recv_channels_cfg: crate::shared::peer_connection::RecvChannelsConfiguration,
}
impl QuinnetServer {
fn new(runtime: tokio::runtime::Handle) -> Self {
Self {
endpoint: None,
runtime,
}
}
pub fn endpoint(&self) -> &Endpoint {
self.endpoint.as_ref().unwrap()
}
pub fn endpoint_mut(&mut self) -> &mut Endpoint {
self.endpoint.as_mut().unwrap()
}
pub fn get_endpoint(&self) -> Option<&Endpoint> {
self.endpoint.as_ref()
}
pub fn get_endpoint_mut(&mut self) -> Option<&mut Endpoint> {
self.endpoint.as_mut()
}
pub fn start_endpoint(
&mut self,
config: ServerEndpointConfiguration,
) -> Result<ServerCertificate, EndpointStartError> {
let server_cert = retrieve_certificate(config.cert_mode)?;
let mut quinn_endpoint_config = ServerConfig::with_single_cert(
server_cert.cert_chain.clone(),
server_cert.priv_key.clone_key(),
)?;
Arc::get_mut(&mut quinn_endpoint_config.transport)
.ok_or(EndpointStartError::LockAcquisitionFailure)?
.keep_alive_interval(Some(DEFAULT_KEEP_ALIVE_INTERVAL_S));
let (to_sync_endpoint_send, from_async_endpoint_recv) =
mpsc::channel::<ServerAsyncMessage>(DEFAULT_INTERNAL_MESSAGES_CHANNEL_SIZE);
let (endpoint_close_send, endpoint_close_recv) =
broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE);
let socket = std::net::UdpSocket::bind(config.addr_config.local_bind_addr)?;
info!(
"Starting endpoint on: {} ...",
config.addr_config.local_bind_addr
);
#[cfg(feature = "recv_channels")]
let recv_channels_cfg_clone = config.defaultables.recv_channels_cfg.clone();
self.runtime.spawn(async move {
endpoint_task(
socket,
quinn_endpoint_config,
to_sync_endpoint_send.clone(),
endpoint_close_recv,
#[cfg(feature = "recv_channels")]
recv_channels_cfg_clone,
)
.await;
});
let mut endpoint = Endpoint::new(
endpoint_close_send,
from_async_endpoint_recv,
config.addr_config,
#[cfg(feature = "recv_channels")]
config.defaultables.recv_channels_cfg,
);
for channel_type in config.defaultables.send_channels_cfg.configs() {
endpoint.unchecked_open_channel(*channel_type)?;
}
self.endpoint = Some(endpoint);
Ok(server_cert)
}
pub fn stop_endpoint(&mut self) -> Result<(), EndpointAlreadyClosed> {
match self.endpoint.take() {
Some(mut endpoint) => {
endpoint.disconnect_all_clients();
match endpoint.close_incoming_connections_handler() {
Ok(_) => Ok(()),
Err(_) => Err(EndpointAlreadyClosed),
}
}
None => Err(EndpointAlreadyClosed),
}
}
pub fn is_listening(&self) -> bool {
self.endpoint.is_some()
}
}
async fn endpoint_task(
socket: UdpSocket,
endpoint_config: ServerConfig,
to_sync_endpoint_send: mpsc::Sender<ServerAsyncMessage>,
mut endpoint_close_recv: broadcast::Receiver<()>,
#[cfg(feature = "recv_channels")]
recv_channels_cfg: crate::shared::peer_connection::RecvChannelsConfiguration,
) {
let endpoint = QuinnEndpoint::new(
EndpointConfig::default(),
Some(endpoint_config),
socket,
default_runtime().expect("async runtime should be valid"),
)
.expect("should create quinn endpoint");
tokio::select! {
_ = endpoint_close_recv.recv() => {
trace!("Endpoint incoming connection handler received a request to close")
}
_ = async {
while let Some(connecting) = endpoint.accept().await {
match connecting.await {
Err(err) => error!("An incoming connection failed: {}", err),
Ok(connection) => {
let to_sync_endpoint_send = to_sync_endpoint_send.clone();
#[cfg(feature = "recv_channels")]
let recv_channels_cfg = recv_channels_cfg.clone();
tokio::spawn(async move {
client_connection_task(
connection,
to_sync_endpoint_send,
#[cfg(feature = "recv_channels")]
recv_channels_cfg
)
.await
});
},
}
}
} => {}
}
}
async fn client_connection_task(
connection_handle: quinn::Connection,
to_sync_endpoint_send: mpsc::Sender<ServerAsyncMessage>,
#[cfg(feature = "recv_channels")]
recv_channels_cfg: crate::shared::peer_connection::RecvChannelsConfiguration,
) {
let (client_close_send, client_close_recv) =
broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE);
let (bytes_from_client_send, bytes_from_client_recv) =
mpsc::channel::<(ChannelId, Bytes)>(DEFAULT_MESSAGE_QUEUE_SIZE);
let (to_connection_send, mut from_sync_server_recv) =
mpsc::channel::<ServerSyncMessage>(DEFAULT_INTERNAL_MESSAGES_CHANNEL_SIZE);
let (from_channels_send, from_channels_recv) =
mpsc::channel::<ChannelAsyncMessage>(DEFAULT_INTERNAL_MESSAGES_CHANNEL_SIZE);
let (to_channels_send, to_channels_recv) =
mpsc::channel::<ChannelSyncMessage>(DEFAULT_QCHANNEL_MESSAGES_CHANNEL_SIZE);
to_sync_endpoint_send
.send(ServerAsyncMessage::ClientConnected(PeerConnection::new(
ServerConnection::new(connection_handle.clone(), to_connection_send),
bytes_from_client_recv,
client_close_send.clone(),
from_channels_recv,
to_channels_send,
#[cfg(feature = "recv_channels")]
recv_channels_cfg,
)))
.await
.expect("Failed to signal connection to sync client");
match from_sync_server_recv.recv().await {
Some(ServerSyncMessage::ClientConnectedAck(client_id)) => {
info!(
"New connection from {}, client_id: {}",
connection_handle.remote_address(),
client_id
);
#[cfg(feature = "shared-client-id")]
client_id::spawn_client_id_sender(
connection_handle.clone(),
client_id,
from_channels_send.clone(),
);
{
let conn = connection_handle.clone();
let to_sync_server = to_sync_endpoint_send.clone();
tokio::spawn(async move {
let _conn_err = conn.closed().await;
info!("Connection {} closed: {}", client_id, _conn_err);
if !to_sync_server.is_closed() {
to_sync_server
.send(ServerAsyncMessage::ClientConnectionClosed(client_id))
.await
.expect("Failed to signal connection lost in async connection");
}
});
};
spawn_recv_channels_tasks(
connection_handle.clone(),
client_id,
client_close_recv.resubscribe(),
bytes_from_client_send,
);
spawn_send_channels_tasks_spawner(
connection_handle,
client_close_recv,
to_channels_recv,
from_channels_send,
);
}
_ => info!(
"Connection from {} refused",
connection_handle.remote_address()
),
}
}
pub fn handle_server_events(
mut server: ResMut<QuinnetServer>,
mut connection_events: MessageWriter<ConnectionEvent>,
mut connection_lost_events: MessageWriter<ConnectionLostEvent>,
mut lost_clients: Local<HashSet<ClientId>>,
) {
let Some(endpoint) = server.get_endpoint_mut() else {
return;
};
while let Ok(endpoint_message) = endpoint.try_recv_from_async() {
match endpoint_message {
ServerAsyncMessage::ClientConnected(new_connection) => {
match endpoint.handle_new_connection(new_connection) {
Ok(client_id) => {
connection_events.write(ConnectionEvent { id: client_id });
}
Err(_) => {
error!("Failed to handle connection of a client, already disconnected");
}
};
}
ServerAsyncMessage::ClientConnectionClosed(client_id) => {
if endpoint.clients.contains_key(&client_id) {
endpoint.try_disconnect_closed_client(client_id);
connection_lost_events.write(ConnectionLostEvent { id: client_id });
}
}
}
}
for (client_id, connection) in endpoint.clients.iter_mut() {
while let Ok(message) = connection.try_recv_from_channels() {
match message {
ChannelAsyncMessage::LostConnection => {
if !lost_clients.contains(client_id) {
lost_clients.insert(*client_id);
connection_lost_events.write(ConnectionLostEvent { id: *client_id });
}
}
}
}
}
for client_id in lost_clients.drain() {
endpoint.try_disconnect_client(client_id);
}
}
#[cfg(feature = "recv_channels")]
pub type ServerRecvChannelError = crate::shared::error::RecvChannelErrorEvent<ClientId>;
#[cfg(feature = "recv_channels")]
pub fn dispatch_received_payloads(
mut server: ResMut<QuinnetServer>,
mut recv_error_events: MessageWriter<ServerRecvChannelError>,
) {
let Some(endpoint) = server.get_endpoint_mut() else {
return;
};
endpoint.dispatch_received_payloads(&mut recv_error_events);
}
#[cfg(feature = "recv_channels")]
pub fn clear_stale_received_payloads(mut server: ResMut<QuinnetServer>) {
let Some(endpoint) = server.get_endpoint_mut() else {
return;
};
if endpoint.recv_channels_cfg().clear_stale_received_payloads {
endpoint.clear_payloads_from_clients();
}
}
#[derive(Default)]
pub struct QuinnetServerPlugin {
pub initialize_later: bool,
}
impl Plugin for QuinnetServerPlugin {
fn build(&self, app: &mut App) {
app.add_message::<ConnectionEvent>()
.add_message::<ConnectionLostEvent>();
if !self.initialize_later {
app.init_resource::<QuinnetServer>();
}
app.add_systems(
PreUpdate,
handle_server_events
.in_set(QuinnetSyncPreUpdate)
.run_if(resource_exists::<QuinnetServer>),
);
#[cfg(feature = "recv_channels")]
{
app.add_message::<ServerRecvChannelError>();
app.add_systems(
PreUpdate,
dispatch_received_payloads
.in_set(QuinnetSyncPreUpdate)
.run_if(resource_exists::<QuinnetServer>),
);
app.add_systems(
Last,
clear_stale_received_payloads
.in_set(crate::shared::QuinnetSyncLast)
.run_if(resource_exists::<QuinnetServer>),
);
}
}
}
pub fn server_listening(server: Option<Res<QuinnetServer>>) -> bool {
match server {
Some(server) => server.is_listening(),
None => false,
}
}
pub fn server_just_opened(
mut was_listening: Local<bool>,
server: Option<Res<QuinnetServer>>,
) -> bool {
let listening = server.map(|server| server.is_listening()).unwrap_or(false);
let just_opened = !*was_listening && listening;
*was_listening = listening;
just_opened
}
pub fn server_just_closed(
mut was_listening: Local<bool>,
server: Option<Res<QuinnetServer>>,
) -> bool {
let closed = server.map(|server| !server.is_listening()).unwrap_or(true);
let just_closed = *was_listening && closed;
*was_listening = !closed;
just_closed
}