use std::any::{Any, TypeId};
use std::collections::HashMap;
#[cfg(feature = "network")]
use crate::network::{NetworkMetadata, NetworkScope};
pub trait Event: Clone + Send + Sync + 'static {
#[cfg(feature = "network")]
fn is_networked() -> bool {
false
}
#[cfg(feature = "network")]
fn network_scope() -> NetworkScope {
NetworkScope::default()
}
}
pub struct EventBus {
channels: HashMap<TypeId, Box<dyn EventChannelStorage>>,
#[cfg(feature = "network")]
network: Option<NetworkState>,
tracer: Option<std::sync::Arc<std::sync::Mutex<crate::trace::EventChainTracer>>>,
recorder: Option<std::sync::Arc<std::sync::Mutex<crate::replay::EventRecorder>>>,
current_frame: u64,
}
#[cfg(feature = "network")]
struct NetworkState {
backend: std::sync::Arc<dyn crate::network::NetworkBackend>,
tx: tokio::sync::mpsc::Sender<NetworkTask>,
rx: std::sync::Arc<
std::sync::Mutex<tokio::sync::mpsc::Receiver<crate::network::backend::RawNetworkEvent>>,
>,
sequence: std::sync::atomic::AtomicU64,
current_metadata: Option<NetworkMetadata>,
deserializers: HashMap<String, Box<dyn EventDeserializer>>,
}
#[cfg(feature = "network")]
#[allow(dead_code)]
enum NetworkTask {
Send(Vec<u8>), Shutdown,
}
#[cfg(feature = "network")]
trait EventDeserializer: Send + Sync {
fn deserialize_and_push(
&self,
payload: &[u8],
channels: &mut HashMap<TypeId, Box<dyn EventChannelStorage>>,
);
}
#[cfg(feature = "network")]
struct TypedEventDeserializer<E: Event + serde::de::DeserializeOwned> {
_phantom: std::marker::PhantomData<E>,
}
#[cfg(feature = "network")]
impl<E: Event + serde::de::DeserializeOwned> TypedEventDeserializer<E> {
fn new() -> Self {
Self {
_phantom: std::marker::PhantomData,
}
}
}
#[cfg(feature = "network")]
impl<E: Event + serde::de::DeserializeOwned> EventDeserializer for TypedEventDeserializer<E> {
fn deserialize_and_push(
&self,
payload: &[u8],
channels: &mut HashMap<TypeId, Box<dyn EventChannelStorage>>,
) {
if let Ok(event) = bincode::deserialize::<E>(payload) {
let entry = channels
.entry(TypeId::of::<E>())
.or_insert_with(|| Box::new(EventChannel::<E>::new()));
if let Some(channel) = entry.as_any_mut().downcast_mut::<EventChannel<E>>() {
channel.push(event);
}
}
}
}
impl Default for EventBus {
fn default() -> Self {
Self::new()
}
}
impl EventBus {
pub fn new() -> Self {
Self {
channels: HashMap::new(),
#[cfg(feature = "network")]
network: None,
tracer: None,
recorder: None,
current_frame: 0,
}
}
pub fn set_tracer(
&mut self,
tracer: std::sync::Arc<std::sync::Mutex<crate::trace::EventChainTracer>>,
) {
self.tracer = Some(tracer);
}
pub fn clear_tracer(&mut self) {
self.tracer = None;
}
pub fn set_recorder(
&mut self,
recorder: std::sync::Arc<std::sync::Mutex<crate::replay::EventRecorder>>,
) {
self.recorder = Some(recorder);
}
pub fn clear_recorder(&mut self) {
self.recorder = None;
}
pub fn set_frame(&mut self, frame: u64) {
self.current_frame = frame;
if let Some(ref tracer) = self.tracer {
if let Ok(mut t) = tracer.lock() {
t.set_frame(frame);
}
}
if let Some(ref recorder) = self.recorder {
if let Ok(mut r) = recorder.lock() {
r.set_frame(frame);
}
}
}
pub fn current_frame(&self) -> u64 {
self.current_frame
}
pub fn publish<E>(&mut self, event: E)
where
E: Event + serde::Serialize,
{
if let Some(ref tracer) = self.tracer {
if let Ok(mut t) = tracer.lock() {
t.record_simple(
crate::trace::TraceEntryType::EventPublished {
event_type: std::any::type_name::<E>().to_string(),
event_id: format!("{}@{}", std::any::type_name::<E>(), self.current_frame),
},
"EventBus",
);
}
}
if let Some(ref recorder) = self.recorder {
if let Ok(mut r) = recorder.lock() {
r.record(&event);
}
}
let channel = self.channel_mut::<E>();
channel.push(event.clone());
#[cfg(feature = "network")]
if E::is_networked() {
if let Some(ref mut net) = self.network {
let sequence = net
.sequence
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
let metadata = NetworkMetadata::new(net.backend.node_id(), sequence);
let scope = E::network_scope();
let raw_event = crate::network::backend::RawNetworkEvent {
metadata,
scope,
type_name: std::any::type_name::<E>().to_string(),
payload: bincode::serialize(&event).unwrap_or_default(),
};
if let Ok(serialized) = bincode::serialize(&raw_event) {
let _ = net.tx.try_send(NetworkTask::Send(serialized));
}
}
}
}
pub fn reader<E>(&mut self) -> EventReader<'_, E>
where
E: Event,
{
let channel = self.channel_mut::<E>();
EventReader {
events: channel.read(),
cursor: 0,
}
}
pub fn dispatch(&mut self) {
for channel in self.channels.values_mut() {
channel.swap_buffers();
}
}
fn channel_mut<E>(&mut self) -> &mut EventChannel<E>
where
E: Event,
{
let entry = self
.channels
.entry(TypeId::of::<E>())
.or_insert_with(|| Box::new(EventChannel::<E>::new()));
entry
.as_any_mut()
.downcast_mut::<EventChannel<E>>()
.expect("Stored channel type mismatch")
}
#[cfg(feature = "network")]
pub fn with_network(mut self, backend: impl crate::network::NetworkBackend) -> Self {
use std::sync::{Arc, Mutex};
use tokio::sync::mpsc;
let backend = Arc::new(backend);
let (tx, send_rx) = mpsc::channel(1000);
let recv_rx = backend.receive_stream();
let recv_rx = Arc::new(Mutex::new(recv_rx));
let backend_clone = backend.clone();
tokio::spawn(async move {
network_send_worker(send_rx, backend_clone).await;
});
self.network = Some(NetworkState {
backend,
tx,
rx: recv_rx,
sequence: std::sync::atomic::AtomicU64::new(0),
current_metadata: None,
deserializers: HashMap::new(),
});
self
}
#[cfg(feature = "network")]
pub fn current_metadata(&self) -> Option<&NetworkMetadata> {
self.network
.as_ref()
.and_then(|n| n.current_metadata.as_ref())
}
#[cfg(feature = "network")]
pub fn is_networked(&self) -> bool {
self.network.is_some()
}
#[cfg(feature = "network")]
pub fn register_networked_event<E>(&mut self)
where
E: Event + serde::de::DeserializeOwned + 'static,
{
if let Some(ref mut net) = self.network {
let type_name = std::any::type_name::<E>().to_string();
net.deserializers
.insert(type_name, Box::new(TypedEventDeserializer::<E>::new()));
}
}
#[cfg(feature = "network")]
pub fn poll_network(&mut self) {
use crate::network::backend::RawNetworkEvent;
let events: Vec<RawNetworkEvent> = if let Some(ref net) = self.network {
let rx = net.rx.clone();
let result = if let Ok(mut rx_guard) = rx.try_lock() {
let mut collected = Vec::new();
while let Ok(raw_event) = rx_guard.try_recv() {
collected.push(raw_event);
}
collected
} else {
Vec::new()
};
result
} else {
Vec::new()
};
for raw_event in events {
if let Some(ref mut net) = self.network {
net.current_metadata = Some(raw_event.metadata.clone());
let type_name = &raw_event.type_name;
if let Some(deserializer) = net.deserializers.get(type_name) {
deserializer.deserialize_and_push(&raw_event.payload, &mut self.channels);
}
net.current_metadata = None;
}
}
}
}
#[cfg(feature = "network")]
async fn network_send_worker(
mut rx: tokio::sync::mpsc::Receiver<NetworkTask>,
backend: std::sync::Arc<dyn crate::network::NetworkBackend>,
) {
use crate::network::backend::RawNetworkEvent;
while let Some(task) = rx.recv().await {
match task {
NetworkTask::Send(data) => {
if let Ok(event) = bincode::deserialize::<RawNetworkEvent>(&data) {
if let Err(e) = backend.send(event).await {
eprintln!("Failed to send network event: {:?}", e);
}
}
}
NetworkTask::Shutdown => break,
}
}
}
pub struct EventReader<'a, E>
where
E: Event,
{
events: &'a [E],
cursor: usize,
}
impl<'a, E> EventReader<'a, E>
where
E: Event,
{
pub fn iter(&self) -> impl Iterator<Item = &E> {
self.events[self.cursor..].iter()
}
pub fn len(&self) -> usize {
self.events.len().saturating_sub(self.cursor)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
struct EventChannel<E>
where
E: Event,
{
a: Vec<E>,
b: Vec<E>,
}
impl<E> EventChannel<E>
where
E: Event,
{
fn new() -> Self {
Self {
a: Vec::new(),
b: Vec::new(),
}
}
fn push(&mut self, event: E) {
self.a.push(event);
}
fn read(&self) -> &[E] {
&self.b
}
fn swap_buffers(&mut self) {
std::mem::swap(&mut self.a, &mut self.b);
self.a.clear();
}
}
trait EventChannelStorage: Any + Send + Sync {
fn swap_buffers(&mut self);
fn as_any_mut(&mut self) -> &mut dyn Any;
}
impl<E> EventChannelStorage for EventChannel<E>
where
E: Event,
{
fn swap_buffers(&mut self) {
EventChannel::swap_buffers(self);
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
}
#[macro_export]
macro_rules! collect_events {
($bus:expr, $event_type:ty) => {{
$bus.reader::<$event_type>()
.iter()
.cloned()
.collect::<Vec<$event_type>>()
}};
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
struct Damage(u32);
impl Event for Damage {}
#[test]
fn publish_requires_dispatch() {
let mut bus = EventBus::new();
bus.publish(Damage(10));
let reader = bus.reader::<Damage>();
assert!(reader.is_empty());
bus.dispatch();
let reader = bus.reader::<Damage>();
assert_eq!(reader.len(), 1);
assert_eq!(reader.iter().next(), Some(&Damage(10)));
}
#[test]
fn multiple_dispatch_cycles() {
let mut bus = EventBus::new();
bus.publish(Damage(1));
bus.dispatch();
let reader = bus.reader::<Damage>();
assert_eq!(reader.iter().map(|d| d.0).collect::<Vec<_>>(), vec![1]);
bus.publish(Damage(2));
bus.publish(Damage(3));
bus.dispatch();
let reader = bus.reader::<Damage>();
assert_eq!(reader.iter().map(|d| d.0).collect::<Vec<_>>(), vec![2, 3]);
bus.dispatch();
let reader = bus.reader::<Damage>();
assert!(reader.is_empty());
}
#[cfg(feature = "network")]
#[tokio::test]
async fn network_event_registration_and_polling() {
use crate::network::backend::LocalOnlyBackend;
let backend = LocalOnlyBackend::new();
let mut bus = EventBus::new().with_network(backend);
bus.register_networked_event::<Damage>();
assert!(bus.is_networked());
bus.poll_network();
bus.dispatch();
let reader = bus.reader::<Damage>();
assert!(reader.is_empty());
}
}