cubecl_std/event/
mod.rs

1use std::{
2    any::{Any, TypeId},
3    cell::RefCell,
4    collections::HashMap,
5    rc::Rc,
6};
7
8use cubecl::prelude::*;
9use cubecl_core::{self as cubecl, intrinsic};
10
11#[derive(CubeType, Clone)]
12/// This event bus allows users to trigger events at compilation time to modify the generated code.
13///
14/// # Warning
15///
16/// Recursion isn't supported with a runtime end condition, the compilation will fail with a
17/// strange error.
18pub struct ComptimeEventBus {
19    #[allow(unused)]
20    #[cube(comptime)]
21    listener_family: Rc<RefCell<HashMap<TypeId, Vec<EventItem>>>>,
22}
23
24type EventItem = Box<dyn Any>;
25type Call<E> =
26    Box<dyn Fn(&mut Scope, &Box<dyn Any>, <E as CubeType>::ExpandType, ComptimeEventBusExpand)>;
27
28struct Payload<E: CubeType> {
29    listener: Box<dyn Any>,
30    call: Call<E>,
31}
32
33impl Default for ComptimeEventBus {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39#[cube]
40impl ComptimeEventBus {
41    /// Creates a new event bus.
42    pub fn new() -> Self {
43        intrinsic!(|_| {
44            ComptimeEventBusExpand {
45                listener_family: Rc::new(RefCell::new(HashMap::new())),
46            }
47        })
48    }
49
50    #[allow(unused_variables)]
51    /// Registers a new callback to be called when its event is launched.
52    ///
53    /// # Notes
54    ///
55    /// Multiple listeners for a single event type is supported. All the listeners will be called
56    /// for each event in the same order they were registered.
57    pub fn listener<L: EventListener>(&mut self, listener: L) {
58        intrinsic!(|_| {
59            let type_id = TypeId::of::<L::Event>();
60            let mut listeners = self.listener_family.borrow_mut();
61
62            // The call dynamic function erases the [EventListener] type.
63            //
64            // This is necessary since we need to clone the expand type when calling the expand
65            // method. The listener is passed as a dynamic type and casted during the function call.
66            let call =
67                |scope: &mut cubecl::prelude::Scope,
68                 listener: &Box<dyn Any>,
69                 event: <L::Event as cubecl::prelude::CubeType>::ExpandType,
70                 bus: <ComptimeEventBus as cubecl::prelude::CubeType>::ExpandType| {
71                    let listener: &L::ExpandType = listener.downcast_ref().unwrap();
72                    listener.clone().__expand_on_event_method(scope, event, bus)
73                };
74            let call: Call<L::Event> = Box::new(call);
75
76            let listener: Box<dyn Any> = Box::new(listener);
77            let payload = Payload::<L::Event> { listener, call };
78
79            // Here we erase the event type, so that all listeners can be stored in the same event
80            // bus which support multiple event types.
81            let listener_dyn: Box<dyn Any> = Box::new(payload);
82
83            match listeners.get_mut(&type_id) {
84                Some(list) => list.push(listener_dyn),
85                None => {
86                    listeners.insert(type_id, vec![listener_dyn]);
87                }
88            }
89        })
90    }
91
92    #[allow(unused_variables)]
93    /// Registers a new event to be processed by [registered listeners](EventListener).
94    pub fn event<E: CubeType + 'static>(&mut self, event: E) {
95        intrinsic!(|scope| {
96            let type_id = TypeId::of::<E>();
97            let family = self.listener_family.borrow();
98            let listeners = match family.get(&type_id) {
99                Some(val) => val,
100                None => return,
101            };
102
103            for listener in listeners.iter() {
104                let payload = listener.downcast_ref::<Payload<E>>().unwrap();
105                let call = &payload.call;
106                call(scope, &payload.listener, event.clone(), self.clone());
107            }
108        })
109    }
110}
111
112#[cube]
113/// Defines a listener that is called each time an event is triggered on an
114/// [event bus](ComptimeEventBus).
115pub trait EventListener: 'static {
116    /// The event type triggering the [EventListener::on_event] callback.
117    type Event: CubeType + 'static;
118
119    /// The function called when an event of the type [EventListener::Event] is registered on the
120    /// [ComptimeEventBus].
121    fn on_event(&mut self, event: Self::Event, bus: &mut ComptimeEventBus);
122}