1use std::{
2 any::{Any, TypeId},
3 cell::RefCell,
4 collections::HashMap,
5 rc::Rc,
6};
7
8use cubecl::prelude::*;
9use cubecl_core::{self as cubecl, intrinsic};
10
11#[derive(CubeType, Clone)]
12pub struct ComptimeEventBus {
19 #[allow(unused)]
20 #[cube(comptime)]
21 listener_family: Rc<RefCell<HashMap<TypeId, Vec<EventItem>>>>,
22}
23
24type EventItem = Box<dyn Any>;
25type Call<E> =
26 Box<dyn Fn(&mut Scope, &Box<dyn Any>, <E as CubeType>::ExpandType, ComptimeEventBusExpand)>;
27
28struct Payload<E: CubeType> {
29 listener: Box<dyn Any>,
30 call: Call<E>,
31}
32
33impl Default for ComptimeEventBus {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39#[cube]
40impl ComptimeEventBus {
41 pub fn new() -> Self {
43 intrinsic!(|_| {
44 ComptimeEventBusExpand {
45 listener_family: Rc::new(RefCell::new(HashMap::new())),
46 }
47 })
48 }
49
50 #[allow(unused_variables)]
51 pub fn listener<L: EventListener>(&mut self, listener: L) {
58 intrinsic!(|_| {
59 let type_id = TypeId::of::<L::Event>();
60 let mut listeners = self.listener_family.borrow_mut();
61
62 let call =
67 |scope: &mut cubecl::prelude::Scope,
68 listener: &Box<dyn Any>,
69 event: <L::Event as cubecl::prelude::CubeType>::ExpandType,
70 bus: <ComptimeEventBus as cubecl::prelude::CubeType>::ExpandType| {
71 let listener: &L::ExpandType = listener.downcast_ref().unwrap();
72 listener.clone().__expand_on_event_method(scope, event, bus)
73 };
74 let call: Call<L::Event> = Box::new(call);
75
76 let listener: Box<dyn Any> = Box::new(listener);
77 let payload = Payload::<L::Event> { listener, call };
78
79 let listener_dyn: Box<dyn Any> = Box::new(payload);
82
83 match listeners.get_mut(&type_id) {
84 Some(list) => list.push(listener_dyn),
85 None => {
86 listeners.insert(type_id, vec![listener_dyn]);
87 }
88 }
89 })
90 }
91
92 #[allow(unused_variables)]
93 pub fn event<E: CubeType + 'static>(&mut self, event: E) {
95 intrinsic!(|scope| {
96 let type_id = TypeId::of::<E>();
97 let family = self.listener_family.borrow();
98 let listeners = match family.get(&type_id) {
99 Some(val) => val,
100 None => return,
101 };
102
103 for listener in listeners.iter() {
104 let payload = listener.downcast_ref::<Payload<E>>().unwrap();
105 let call = &payload.call;
106 call(scope, &payload.listener, event.clone(), self.clone());
107 }
108 })
109 }
110}
111
112#[cube]
113pub trait EventListener: 'static {
116 type Event: CubeType + 'static;
118
119 fn on_event(&mut self, event: Self::Event, bus: &mut ComptimeEventBus);
122}