bevy_connect 0.18.6

Connectivity via TCP sessions
Documentation
use bevy_ecs::prelude::*;
use tracing::debug;
use tracing::error;
use tracing::warn;
use tubes::prelude::*;
use zeroize::Zeroize;

use crate::ClientId;
use crate::Fault;
use crate::TerminateEvent;
use crate::events::ClientLeft;
use crate::{
    Message,
    events::{ClientJoined, MessageReceivedEvent},
};
use std::any::type_name;
use std::collections::HashSet;
use std::error::Error;
use std::{marker::PhantomData, net::IpAddr};
use tracing::trace;

/// These options are meant to configure the session, and they are usually
/// changeable at runtime.
#[derive(Clone, Debug, Default)]
pub struct SessionOptions {
    /// When this is true, a received promotion from a client will be accepted.
    ///
    /// When this is false, only the current server can decide to trigger a
    /// host promotion.
    ///
    /// Relevant for servers: allow to accept a host promotion event from a
    /// client so that new client will be the new server and all other clients
    /// will reconnect to the new server. This server will become also a client
    /// connecting to this new server.
    ///
    /// This promotion will not trigger [`super::events::SessionConnectedEvent`].
    ///
    /// This value is usually changed at runtime and it starts as false, to
    /// let users decide when to allow this behavior.
    pub host_promotion_from_client_allowed: bool,
}

/// This is the configuration needed to start or restart a session. It also
/// contains [`SessionOptions`] which is usually meant to be changed at
/// runtime while a session is open.
/// However all the other properties here outside of options are generally
/// immutable and require a reconnection to be changed.
#[derive(Clone, Debug)]
pub enum SessionConfig {
    Direct {
        /// Address to connect to or to bind for.
        /// Defaults to 127.0.0.1 when None.
        addr: Option<IpAddr>,
        /// Port to host/connect to.
        /// Ideally give a different port for each variation T for `Connection<T>`.
        port: u16,
        /// The host is responsible to reroute messages to all other clients
        /// and is also the listening tcp connection.
        /// The clients just connect to the host and send only once.
        host: bool,
        /// Extra options for this config.
        /// These options should work the same on all connection modes.
        options: SessionOptions,
        /// Enables networo compression
        compress: bool,
        /// Enables encryption with this key
        key: Option<Vec<u8>>,
    },
}

impl Drop for SessionConfig {
    fn drop(&mut self) {
        let SessionConfig::Direct {
            addr: _,
            port: _,
            host: _,
            options: _,
            compress: _,
            key,
        } = self;
        let Some(key) = key else {
            return;
        };
        key.zeroize();
    }
}

impl From<&Config> for SessionConfig {
    fn from(value: &Config) -> Self {
        Self::Direct {
            addr: value.address,
            port: value.port,
            host: true,
            options: SessionOptions::default(),
            compress: value.compress,
            key: value.key.clone(),
        }
    }
}

impl From<&SessionConfig> for Config {
    fn from(value: &SessionConfig) -> Self {
        let SessionConfig::Direct {
            addr,
            port,
            compress,
            key,
            ..
        } = value;
        Self {
            address: *addr,
            port: *port,
            compress: *compress,
            key: key.clone(),
            ..Default::default()
        }
    }
}

/// The presence of this resource (discerned by T) means a connection is
/// currently established
/// This is a channel of communication for a specific message type T.
/// An event will be emitted on message receive of
/// [`super::events::MessageReceivedEvent`].
/// To send a message, use [`Channel::broadcast`].
#[allow(private_interfaces)]
#[derive(Resource)]
pub struct Channel<T: Message> {
    session: Session,
    destinations: HashSet<ClientId>,
    _t: PhantomData<T>,
}

impl<T: Message> Channel<T> {
    pub(crate) fn new_host(config: &SessionConfig) -> Result<Self, Fault> {
        let config: Config = config.into();
        let mut session = Session::new_server(config);
        let uuid = session.uuid();
        session
            .start()
            .map_err(|_| Fault::Terminate(std::io::Error::other("")))?;
        Ok(Self {
            session,
            destinations: [uuid].into(),
            _t: PhantomData,
        })
    }

    pub(crate) fn try_new_client(config: &SessionConfig) -> Result<Self, Fault> {
        let config: Config = config.into();
        let mut session = Session::new_client(config);
        let uuid = session.uuid();
        session
            .start()
            .map_err(|_| Fault::Terminate(std::io::Error::other("")))?;
        Ok(Self {
            session,
            destinations: [uuid].into(),
            _t: PhantomData,
        })
    }

    pub(crate) fn poll(
        &mut self,
        events: &mut MessageWriter<MessageReceivedEvent<T>>,
        joinevent: &mut MessageWriter<ClientJoined<T>>,
        leaveevent: &mut MessageWriter<ClientLeft<T>>,
        termevent: &mut MessageWriter<TerminateEvent<T>>,
    ) -> Result<(), Box<dyn Error>> {
        let self_uuid = self.uuid();
        let type_name = type_name::<T>();
        let res = self.session.read();
        match res {
            Ok(Some(m)) => {
                match m {
                    MessageData::Broadcast { from: _, data: m } => {
                        events.write(MessageReceivedEvent::<T> {
                            message: to_message(m.as_slice()).map_err(|_| String::new())?,
                            _t: PhantomData,
                        });
                    }
                    MessageData::Send {
                        from: _,
                        to,
                        data: m,
                    } => {
                        if self_uuid != to {
                            warn!(
                                "Received a message that was not intended for this end. {} != {}",
                                self_uuid, to
                            );
                            return Ok(());
                        }
                        events.write(MessageReceivedEvent::<T> {
                            message: to_message(m.as_slice()).map_err(|_| String::new())?,
                            _t: PhantomData,
                        });
                    }
                    MessageData::ClientJoined(uuid) => {
                        debug!("[{}] Client joined uuid = {uuid}", self.uuid());
                        self.destinations.insert(uuid);
                        joinevent.write(ClientJoined::<T> {
                            client: uuid,
                            _t: PhantomData,
                        });
                    }
                    MessageData::ClientLeft(uuid) => {
                        debug!("[{}] Client left uuid = {uuid}", self.uuid());
                        self.destinations.remove(&uuid);
                        leaveevent.write(ClientLeft::<T> {
                            client: uuid,
                            _t: PhantomData,
                        });
                    }
                }
            }
            Ok(_) => {}
            Err(e) => {
                warn!(
                    "[{}:{}/{}] Terminating connection for socket error: {e}.",
                    self.is_host_c(),
                    self_uuid,
                    type_name
                );
                termevent.write(TerminateEvent::<T> { _t: PhantomData });
            }
        }
        Ok(())
    }

    pub(crate) fn promote_new_host(&mut self, new_host: ClientId, port: Option<u16>) {
        trace!("Sending promotion...");
        self.session.promote_to_host(new_host, port);
    }

    /// Sends a new message to this channel.
    pub fn broadcast(&mut self, m: T) {
        let Ok(m) = from_message(m) else {
            error!("Error creating message.");
            return;
        };
        if let Err(r) = self.session.broadcast(m) {
            error!("Error sending message {r:?}.");
        }
    }

    /// Sends a new message to the specific target, and nobody else.
    /// Naturally the message will still have to be passed to the host and then
    /// the host will only send it to the relevant client.
    /// If this channel is already the host, then it will only be send directly
    /// to the client.
    pub fn send_to(&mut self, to: ClientId, m: T) {
        let Ok(m) = from_message(m) else {
            error!("Error creating message.");
            return;
        };
        if let Err(r) = self.session.send_to(to, m) {
            error!("Error sending message {r:?}.");
        }
    }

    /// Return if this side of the channel is hosting.
    #[must_use]
    pub fn is_host(&self) -> bool {
        self.session.is_server()
    }

    pub(crate) fn is_host_c(&self) -> char {
        if self.session.is_server() { 'H' } else { 'C' }
    }

    /// Returns the uuid of this end of the channel.
    ///
    /// This represents the 'user' from the perspective of the local accessor,
    /// if this is the hosting side, then it's the host uuid, otherwise it's
    /// the client uuid of this client.
    #[must_use]
    pub fn uuid(&self) -> ClientId {
        self.session.uuid()
    }

    /// Returns the uuid of the host.
    /// If client it is the uuid of the other end, if host it is self.
    /// This can be useful to direct some messages directly to the host,
    /// depending on the situation, by using this a combination of this method
    /// and [`Channel::send_to`]
    #[must_use]
    pub fn host_uuid(&self) -> Option<ClientId> {
        self.session.server_uuid()
    }

    /// Returns the list of all the targetable destinations by uuid.
    /// The host is also included in the list.
    #[must_use]
    pub fn destinations(&self) -> HashSet<ClientId> {
        if self.session.is_server() {
            let mut res = self.session.clients().iter().copied().collect::<HashSet<_>>();
            res.insert(self.session.uuid());
            res
        } else {
            let mut res = self.destinations.iter().copied().collect::<HashSet<_>>();
            if let Some(uuid) = self.session.server_uuid() {
                res.insert(uuid);
            }
            res
        }
    }

    #[must_use]
    pub fn total_read(&self) -> usize {
        self.session.total_read()
    }

    #[must_use]
    pub fn total_sent(&self) -> usize {
        self.session.total_sent()
    }
}

const BINCODE_OPTIONS: bincode::config::Configuration<
    bincode::config::BigEndian,
    bincode::config::Fixint,
> = bincode::config::standard()
    .with_big_endian()
    .with_fixed_int_encoding();

fn to_message<T: Message>(m: &[u8]) -> Result<Box<T>, bincode::error::DecodeError> {
    let (x, _) = bincode::serde::decode_from_slice::<T, _>(m, BINCODE_OPTIONS)?;
    Ok(Box::new(x))
}

fn from_message<T: Message>(m: T) -> Result<Vec<u8>, bincode::error::EncodeError> {
    let x = bincode::serde::encode_to_vec::<T, _>(m, BINCODE_OPTIONS)?;
    Ok(x)
}