cubecl-std 0.10.0-pre.3

CubeCL Standard Library.
Documentation
use crate::event::{ComptimeEventBus, EventListener, EventListenerExpand};
use cubecl::prelude::*;
use cubecl_core as cubecl;

#[derive(CubeType)]
pub struct EventUInt {
    #[cube(comptime)]
    pub value: u32,
}

#[derive(CubeType)]
pub struct EventFloat {
    #[cube(comptime)]
    pub value: f32,
}

#[derive(CubeType, Clone)]
pub struct EventListenerPosZero {
    items: SliceMut<f32>,
}

#[derive(CubeType, Clone)]
pub struct EventListenerPosOne {
    items: SliceMut<f32>,
}

#[derive(CubeType, Clone)]
pub struct EventListenerPosTwo {
    items: SliceMut<f32>,
    times: ComptimeCell<Counter>,
}

#[derive(CubeType, Clone)]
pub struct Counter {
    #[cube(comptime)]
    value: u32,
}

#[cube]
impl EventListener for EventListenerPosZero {
    type Event = EventUInt;

    fn on_event(&mut self, event: Self::Event, bus: &mut ComptimeEventBus) {
        if comptime!(event.value < 10) {
            comment!("On event pos zero < 10");
            bus.event::<EventUInt>(EventUInt {
                value: comptime!(15u32 + event.value),
            });
        } else {
            comment!("On event pos zero >= 10");
            self.items[0] = f32::cast_from(event.value);
        }
    }
}

#[cube]
impl EventListener for EventListenerPosOne {
    type Event = EventUInt;

    fn on_event(&mut self, event: Self::Event, _bus: &mut ComptimeEventBus) {
        comment!("On event pos one");
        self.items[1] = (f32::cast_from(event.value) * 2.0) + self.items[1];
    }
}

#[cube]
impl EventListener for EventListenerPosTwo {
    type Event = EventFloat;

    fn on_event(&mut self, event: Self::Event, bus: &mut ComptimeEventBus) {
        comment!("On event pos two");
        self.items[2] = event.value + self.items[2];

        let times = self.times.read();
        self.times.store(Counter {
            value: comptime!(times.value + 1),
        });

        if comptime!(times.value < 4) {
            bus.event::<EventFloat>(EventFloat {
                value: comptime!(event.value * 2.0),
            });
            bus.event::<EventUInt>(EventUInt {
                value: comptime!((event.value * 2.0) as u32),
            });
        }
    }
}

#[cube]
fn test_1(items: SliceMut<f32>) {
    let mut bus = ComptimeEventBus::new();
    let listener_zero = EventListenerPosZero { items };
    let listener_one = EventListenerPosOne { items };

    bus.listener::<EventListenerPosZero>(listener_zero);
    bus.listener::<EventListenerPosOne>(listener_one);

    bus.event::<EventUInt>(EventUInt { value: 5u32 });
}

#[cube]
fn test_2(items: SliceMut<f32>) {
    let mut bus = ComptimeEventBus::new();
    let listener_zero = EventListenerPosZero { items };
    let listener_one = EventListenerPosOne { items };

    bus.listener::<EventListenerPosZero>(listener_zero);
    bus.listener::<EventListenerPosOne>(listener_one);

    bus.event::<EventUInt>(EventUInt { value: 15u32 });
}

#[cube]
fn test_3(items: SliceMut<f32>) {
    let mut bus = ComptimeEventBus::new();
    let listener_zero = EventListenerPosZero { items };
    let listener_one = EventListenerPosOne { items };
    let listener_two = EventListenerPosTwo {
        items,
        times: ComptimeCell::new(Counter { value: 0u32 }),
    };

    bus.listener::<EventListenerPosZero>(listener_zero);
    bus.listener::<EventListenerPosOne>(listener_one);
    bus.listener::<EventListenerPosTwo>(listener_two);

    bus.event::<EventFloat>(EventFloat { value: 15.0f32 });
}

#[cube(launch_unchecked)]
fn launch_test_1(output: &mut Array<f32>) {
    output[0] = 0.0;
    output[1] = 0.0;
    test_1(output.to_slice_mut());
}

#[cube(launch_unchecked)]
fn launch_test_2(output: &mut Array<f32>) {
    output[0] = 0.0;
    output[1] = 0.0;
    test_2(output.to_slice_mut());
}

#[cube(launch_unchecked)]
fn launch_test_3(output: &mut Array<f32>) {
    output[0] = 0.0;
    output[1] = 0.0;
    output[2] = 0.0;
    test_3(output.to_slice_mut());
}

pub fn event_test_1<R: Runtime>(client: ComputeClient<R>) {
    let output = client.empty(8);

    unsafe {
        launch_test_1::launch_unchecked::<R>(
            &client,
            CubeCount::Static(1, 1, 1),
            CubeDim { x: 1, y: 1, z: 1 },
            ArrayArg::from_raw_parts(output.clone(), 2),
        );
    }

    let bytes = client.read_one_unchecked(output);
    let actual = f32::from_bytes(&bytes);

    assert_eq!(actual, &[20.0, 50.0]);
}

pub fn event_test_2<R: Runtime>(client: ComputeClient<R>) {
    let output = client.empty(8);

    unsafe {
        launch_test_2::launch_unchecked::<R>(
            &client,
            CubeCount::Static(1, 1, 1),
            CubeDim { x: 1, y: 1, z: 1 },
            ArrayArg::from_raw_parts(output.clone(), 2),
        )
    }

    let bytes = client.read_one_unchecked(output);
    let actual = f32::from_bytes(&bytes);

    assert_eq!(actual, &[15.0, 30.0]);
}

pub fn event_test_3<R: Runtime>(client: ComputeClient<R>) {
    let output = client.empty(12);

    unsafe {
        launch_test_3::launch_unchecked::<R>(
            &client,
            CubeCount::Static(1, 1, 1),
            CubeDim { x: 1, y: 1, z: 1 },
            ArrayArg::from_raw_parts(output.clone(), 3),
        )
    }

    let bytes = client.read_one_unchecked(output);
    let actual = f32::from_bytes(&bytes);

    assert_eq!(actual, &[30.0, 900.0, 465.0]);
}

#[macro_export]
macro_rules! testgen_event {
    () => {
        mod event {
            use super::*;

            #[$crate::tests::test_log::test]
            fn test_1() {
                let client = TestRuntime::client(&Default::default());
                cubecl_std::tests::event::event_test_1(client);
            }

            #[$crate::tests::test_log::test]
            fn test_2() {
                let client = TestRuntime::client(&Default::default());
                cubecl_std::tests::event::event_test_2(client);
            }

            #[$crate::tests::test_log::test]
            fn test_3() {
                let client = TestRuntime::client(&Default::default());
                cubecl_std::tests::event::event_test_3(client);
            }
        }
    };
}