use bevy_ecs::prelude::*;
use tracing::debug;
use tracing::error;
use tracing::warn;
use tubes::prelude::*;
use zeroize::Zeroize;
use crate::ClientId;
use crate::Fault;
use crate::TerminateEvent;
use crate::events::ClientLeft;
use crate::{
Message,
events::{ClientJoined, MessageReceivedEvent},
};
use std::any::type_name;
use std::collections::HashSet;
use std::error::Error;
use std::{marker::PhantomData, net::IpAddr};
use tracing::trace;
#[derive(Clone, Debug, Default)]
pub struct SessionOptions {
pub host_promotion_from_client_allowed: bool,
}
#[derive(Clone, Debug)]
pub enum SessionConfig {
Direct {
addr: Option<IpAddr>,
port: u16,
host: bool,
options: SessionOptions,
compress: bool,
key: Option<Vec<u8>>,
},
}
impl Drop for SessionConfig {
fn drop(&mut self) {
let SessionConfig::Direct {
addr: _,
port: _,
host: _,
options: _,
compress: _,
key,
} = self;
let Some(key) = key else {
return;
};
key.zeroize();
}
}
impl From<&Config> for SessionConfig {
fn from(value: &Config) -> Self {
Self::Direct {
addr: value.address,
port: value.port,
host: true,
options: SessionOptions::default(),
compress: value.compress,
key: value.key.clone(),
}
}
}
impl From<&SessionConfig> for Config {
fn from(value: &SessionConfig) -> Self {
let SessionConfig::Direct {
addr,
port,
compress,
key,
..
} = value;
Self {
address: *addr,
port: *port,
compress: *compress,
key: key.clone(),
..Default::default()
}
}
}
#[allow(private_interfaces)]
#[derive(Resource)]
pub struct Channel<T: Message> {
session: Session,
destinations: HashSet<ClientId>,
_t: PhantomData<T>,
}
impl<T: Message> Channel<T> {
pub(crate) fn new_host(config: &SessionConfig) -> Result<Self, Fault> {
let config: Config = config.into();
let mut session = Session::new_server(config);
let uuid = session.uuid();
session
.start()
.map_err(|_| Fault::Terminate(std::io::Error::other("")))?;
Ok(Self {
session,
destinations: [uuid].into(),
_t: PhantomData,
})
}
pub(crate) fn try_new_client(config: &SessionConfig) -> Result<Self, Fault> {
let config: Config = config.into();
let mut session = Session::new_client(config);
let uuid = session.uuid();
session
.start()
.map_err(|_| Fault::Terminate(std::io::Error::other("")))?;
Ok(Self {
session,
destinations: [uuid].into(),
_t: PhantomData,
})
}
pub(crate) fn poll(
&mut self,
events: &mut MessageWriter<MessageReceivedEvent<T>>,
joinevent: &mut MessageWriter<ClientJoined<T>>,
leaveevent: &mut MessageWriter<ClientLeft<T>>,
termevent: &mut MessageWriter<TerminateEvent<T>>,
) -> Result<(), Box<dyn Error>> {
let self_uuid = self.uuid();
let type_name = type_name::<T>();
let res = self.session.read();
match res {
Ok(Some(m)) => {
match m {
MessageData::Broadcast { from: _, data: m } => {
events.write(MessageReceivedEvent::<T> {
message: to_message(m.as_slice()).map_err(|_| String::new())?,
_t: PhantomData,
});
}
MessageData::Send {
from: _,
to,
data: m,
} => {
if self_uuid != to {
warn!(
"Received a message that was not intended for this end. {} != {}",
self_uuid, to
);
return Ok(());
}
events.write(MessageReceivedEvent::<T> {
message: to_message(m.as_slice()).map_err(|_| String::new())?,
_t: PhantomData,
});
}
MessageData::ClientJoined(uuid) => {
debug!("[{}] Client joined uuid = {uuid}", self.uuid());
self.destinations.insert(uuid);
joinevent.write(ClientJoined::<T> {
client: uuid,
_t: PhantomData,
});
}
MessageData::ClientLeft(uuid) => {
debug!("[{}] Client left uuid = {uuid}", self.uuid());
self.destinations.remove(&uuid);
leaveevent.write(ClientLeft::<T> {
client: uuid,
_t: PhantomData,
});
}
}
}
Ok(_) => {}
Err(e) => {
warn!(
"[{}:{}/{}] Terminating connection for socket error: {e}.",
self.is_host_c(),
self_uuid,
type_name
);
termevent.write(TerminateEvent::<T> { _t: PhantomData });
}
}
Ok(())
}
pub(crate) fn promote_new_host(&mut self, new_host: ClientId, port: Option<u16>) {
trace!("Sending promotion...");
self.session.promote_to_host(new_host, port);
}
pub fn broadcast(&mut self, m: T) {
let Ok(m) = from_message(m) else {
error!("Error creating message.");
return;
};
if let Err(r) = self.session.broadcast(m) {
error!("Error sending message {r:?}.");
}
}
pub fn send_to(&mut self, to: ClientId, m: T) {
let Ok(m) = from_message(m) else {
error!("Error creating message.");
return;
};
if let Err(r) = self.session.send_to(to, m) {
error!("Error sending message {r:?}.");
}
}
#[must_use]
pub fn is_host(&self) -> bool {
self.session.is_server()
}
pub(crate) fn is_host_c(&self) -> char {
if self.session.is_server() { 'H' } else { 'C' }
}
#[must_use]
pub fn uuid(&self) -> ClientId {
self.session.uuid()
}
#[must_use]
pub fn host_uuid(&self) -> Option<ClientId> {
self.session.server_uuid()
}
#[must_use]
pub fn destinations(&self) -> HashSet<ClientId> {
if self.session.is_server() {
let mut res = self.session.clients().iter().copied().collect::<HashSet<_>>();
res.insert(self.session.uuid());
res
} else {
let mut res = self.destinations.iter().copied().collect::<HashSet<_>>();
if let Some(uuid) = self.session.server_uuid() {
res.insert(uuid);
}
res
}
}
#[must_use]
pub fn total_read(&self) -> usize {
self.session.total_read()
}
#[must_use]
pub fn total_sent(&self) -> usize {
self.session.total_sent()
}
}
const BINCODE_OPTIONS: bincode::config::Configuration<
bincode::config::BigEndian,
bincode::config::Fixint,
> = bincode::config::standard()
.with_big_endian()
.with_fixed_int_encoding();
fn to_message<T: Message>(m: &[u8]) -> Result<Box<T>, bincode::error::DecodeError> {
let (x, _) = bincode::serde::decode_from_slice::<T, _>(m, BINCODE_OPTIONS)?;
Ok(Box::new(x))
}
fn from_message<T: Message>(m: T) -> Result<Vec<u8>, bincode::error::EncodeError> {
let x = bincode::serde::encode_to_vec::<T, _>(m, BINCODE_OPTIONS)?;
Ok(x)
}