use std::{
any::{Any, TypeId},
cell::RefCell,
collections::HashMap,
rc::Rc,
};
use cubecl::prelude::*;
use cubecl_core::{self as cubecl, intrinsic};
#[derive(CubeType, Clone)]
pub struct ComptimeEventBus {
#[allow(unused)]
#[cube(comptime)]
listener_family: Rc<RefCell<HashMap<TypeId, Vec<EventItem>>>>,
}
type EventItem = Box<dyn Any>;
type Call<E> =
Box<dyn Fn(&mut Scope, &Box<dyn Any>, <E as CubeType>::ExpandType, ComptimeEventBusExpand)>;
struct Payload<E: CubeType> {
listener: Box<dyn Any>,
call: Call<E>,
}
impl Default for ComptimeEventBus {
fn default() -> Self {
Self::new()
}
}
#[cube]
impl ComptimeEventBus {
pub fn new() -> Self {
intrinsic!(|_| {
ComptimeEventBusExpand {
listener_family: Rc::new(RefCell::new(HashMap::new())),
}
})
}
#[allow(unused_variables)]
pub fn listener<L: EventListener>(&mut self, listener: L) {
intrinsic!(|_| {
let type_id = TypeId::of::<L::Event>();
let mut listeners = self.listener_family.borrow_mut();
let call =
|scope: &mut cubecl::prelude::Scope,
listener: &Box<dyn Any>,
event: <L::Event as cubecl::prelude::CubeType>::ExpandType,
bus: <ComptimeEventBus as cubecl::prelude::CubeType>::ExpandType| {
let listener: &L::ExpandType = listener.downcast_ref().unwrap();
listener.clone().__expand_on_event_method(scope, event, bus)
};
let call: Call<L::Event> = Box::new(call);
let listener: Box<dyn Any> = Box::new(listener);
let payload = Payload::<L::Event> { listener, call };
let listener_dyn: Box<dyn Any> = Box::new(payload);
match listeners.get_mut(&type_id) {
Some(list) => list.push(listener_dyn),
None => {
listeners.insert(type_id, vec![listener_dyn]);
}
}
})
}
#[allow(unused_variables)]
pub fn event<E: CubeType + 'static>(&mut self, event: E) {
intrinsic!(|scope| {
let type_id = TypeId::of::<E>();
let family = self.listener_family.borrow();
let listeners = match family.get(&type_id) {
Some(val) => val,
None => return,
};
for listener in listeners.iter() {
let payload = listener.downcast_ref::<Payload<E>>().unwrap();
let call = &payload.call;
call(scope, &payload.listener, event.clone(), self.clone());
}
})
}
}
#[cube]
pub trait EventListener: 'static {
type Event: CubeType + 'static;
fn on_event(&mut self, event: Self::Event, bus: &mut ComptimeEventBus);
}