use crate::core::io::ChannelInfo;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;
use crate::core::utils::UniqueId;
use crate::protocol::{Frame, MaybeVersioned};
#[cfg_attr(feature = "specta", derive(specta::Type))]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
pub struct ConnectionId(UniqueId);
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
pub struct ChannelId {
connection: ConnectionId,
channel: UniqueId,
}
#[derive(Clone, Debug)]
pub struct IncomingFrame<V: MaybeVersioned> {
frame: Frame<V>,
channel: ChannelInfo,
}
#[derive(Clone, Debug)]
pub struct OutgoingFrame<V: MaybeVersioned> {
frame: Arc<Frame<V>>,
scope: BroadcastScope,
}
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
pub enum BroadcastScope {
#[default]
All,
ExactChannel(ChannelId),
ExceptChannel(ChannelId),
ExceptChannelWithin(ChannelId),
ExactConnection(ConnectionId),
ExceptConnection(ConnectionId),
}
impl ConnectionId {
pub(crate) fn new() -> Self {
Self(UniqueId::new())
}
#[inline(always)]
pub fn contains(&self, channel_id: ChannelId) -> bool {
&channel_id.connection == self
}
}
impl ChannelId {
pub(crate) fn new(connection_id: ConnectionId) -> Self {
Self {
connection: connection_id,
channel: UniqueId::new(),
}
}
pub fn connection_id(&self) -> ConnectionId {
self.connection
}
#[inline(always)]
pub fn belongs_to(&self, connection_id: ConnectionId) -> bool {
connection_id.contains(*self)
}
}
impl Debug for ConnectionId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("ConnectionId").finish()
}
}
impl Debug for ChannelId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("ChannelId").finish()
}
}
impl From<ConnectionId> for ChannelId {
fn from(value: ConnectionId) -> Self {
Self::new(value)
}
}
impl<V: MaybeVersioned> IncomingFrame<V> {
pub fn new(frame: Frame<V>, channel: ChannelInfo) -> Self {
Self { frame, channel }
}
}
impl<V: MaybeVersioned> From<IncomingFrame<V>> for (Frame<V>, ChannelInfo) {
fn from(value: IncomingFrame<V>) -> Self {
(value.frame, value.channel)
}
}
impl<V: MaybeVersioned> OutgoingFrame<V> {
pub fn new(frame: Frame<V>) -> Self {
Self {
frame: Arc::new(frame),
scope: BroadcastScope::All,
}
}
pub(crate) fn scoped(frame: Frame<V>, scope: BroadcastScope) -> Self {
Self {
frame: Arc::new(frame),
scope,
}
}
#[inline]
pub fn frame(&self) -> &Frame<V> {
self.frame.as_ref()
}
#[inline]
pub fn scope(&self) -> BroadcastScope {
self.scope
}
#[inline]
pub(crate) fn set_scope(&mut self, scope: BroadcastScope) {
self.scope = scope;
}
pub(crate) fn matches_connection_reroute(&mut self, connection_id: ConnectionId) -> bool {
match self.scope() {
BroadcastScope::ExceptConnection(conn_id) if connection_id == conn_id => false,
BroadcastScope::ExactConnection(conn_id) if connection_id == conn_id => {
self.set_scope(BroadcastScope::All);
true
}
_ => true,
}
}
pub(crate) fn should_send_to(&self, channel_id: ChannelId) -> bool {
match self.scope {
BroadcastScope::All => true,
BroadcastScope::ExactChannel(sender_id) => sender_id == channel_id,
BroadcastScope::ExceptChannel(sender_id) => sender_id != channel_id,
BroadcastScope::ExceptChannelWithin(sender_id) => {
channel_id.connection_id().contains(sender_id) && sender_id != channel_id
}
BroadcastScope::ExactConnection(conn_id) => conn_id.contains(channel_id),
BroadcastScope::ExceptConnection(conn_id) => !conn_id.contains(channel_id),
}
}
}
impl<V: MaybeVersioned> From<OutgoingFrame<V>> for Frame<V> {
fn from(value: OutgoingFrame<V>) -> Self {
value.frame.as_ref().clone()
}
}