cubecl-core 0.2.0

CubeCL core create
Documentation
use cubecl_core as cubecl;
use cubecl_core::prelude::*;

#[derive(Clone)]
pub struct State {
    cond: bool,
    bound: u32,
}

impl Init for State {
    fn init(self, _context: &mut CubeContext) -> Self {
        self
    }
}

#[cube]
pub fn comptime_if_else<T: Numeric>(lhs: T, cond: Comptime<bool>) {
    if Comptime::get(cond) {
        let _ = lhs + T::from_int(4);
    } else {
        let _ = lhs - T::from_int(5);
    }
}

#[cube]
#[allow(clippy::collapsible_else_if)]
pub fn comptime_else_then_if<T: Numeric>(lhs: T, cond1: Comptime<bool>, cond2: Comptime<bool>) {
    if Comptime::get(cond1) {
        let _ = lhs + T::from_int(4);
    } else {
        if Comptime::get(cond2) {
            let _ = lhs + T::from_int(5);
        } else {
            let _ = lhs - T::from_int(6);
        }
    }
}

#[cube]
pub fn comptime_float() {
    let comptime_float = Comptime::new(F32::new(0.0));
    let _runtime_float = Comptime::runtime(comptime_float);
}

#[cube]
pub fn comptime_elsif<T: Numeric>(lhs: T, cond1: Comptime<bool>, cond2: Comptime<bool>) {
    if Comptime::get(cond1) {
        let _ = lhs + T::from_int(4);
    } else if Comptime::get(cond2) {
        let _ = lhs + T::from_int(5);
    } else {
        let _ = lhs - T::from_int(6);
    }
}

#[cube]
pub fn comptime_elsif_with_runtime1<T: Numeric>(lhs: T, comptime_cond: Comptime<bool>) {
    let runtime_cond = lhs >= T::from_int(2);
    if Comptime::get(comptime_cond) {
        let _ = lhs + T::from_int(4);
    } else if runtime_cond {
        let _ = lhs + T::from_int(5);
    } else {
        let _ = lhs - T::from_int(6);
    }
}

#[cube]
pub fn comptime_elsif_with_runtime2<T: Numeric>(lhs: T, comptime_cond: Comptime<bool>) {
    let runtime_cond = lhs >= T::from_int(2);
    if runtime_cond {
        let _ = lhs + T::from_int(4);
    } else if Comptime::get(comptime_cond) {
        let _ = lhs + T::from_int(5);
    } else {
        let _ = lhs - T::from_int(6);
    }
}

#[cube]
pub fn comptime_if_expr<T: Numeric>(lhs: T, x: Comptime<UInt>, y: Comptime<UInt>) {
    let y2 = x + y;

    if x < y2 {
        let _ = lhs + T::from_int(4);
    } else {
        let _ = lhs - T::from_int(5);
    }
}

#[cube]
pub fn comptime_with_map_bool<T: Numeric>(state: Comptime<State>) -> T {
    let cond = Comptime::map(state, |s: State| s.cond);

    let mut x = T::from_int(3);
    if Comptime::get(cond) {
        x += T::from_int(4);
    } else {
        x -= T::from_int(4);
    }
    x
}

#[cube]
pub fn comptime_with_map_uint<T: Numeric>(state: Comptime<State>) -> T {
    let bound = Comptime::map(state, |s: State| s.bound);

    let mut x = T::from_int(3);
    for _ in range(0u32, Comptime::get(bound), Comptime::new(true)) {
        x += T::from_int(4);
    }

    x
}

mod tests {
    use super::*;
    use cubecl_core::{
        cpa,
        frontend::{CubeContext, CubePrimitive, F32},
        ir::{Elem, Item, Variable},
    };

    type ElemType = F32;

    #[test]
    fn cube_comptime_if_test() {
        let mut context = CubeContext::root();

        let lhs = context.create_local(Item::new(ElemType::as_elem()));

        comptime_if_else::__expand::<ElemType>(&mut context, lhs.into(), true);
        let scope = context.into_scope();

        assert_eq!(
            format!("{:?}", scope.operations),
            inline_macro_ref_comptime(true)
        );
    }

    #[test]
    fn cube_comptime_if_numeric_test() {
        let mut context = CubeContext::root();

        let lhs = context.create_local(Item::new(ElemType::as_elem()));

        comptime_if_expr::__expand::<ElemType>(
            &mut context,
            lhs.into(),
            UInt::new(4),
            UInt::new(5),
        );
        let scope = context.into_scope();

        assert_eq!(
            format!("{:?}", scope.operations),
            inline_macro_ref_comptime(true)
        );
    }

    #[test]
    fn cube_comptime_else_test() {
        let mut context = CubeContext::root();

        let lhs = context.create_local(Item::new(ElemType::as_elem()));

        comptime_if_else::__expand::<ElemType>(&mut context, lhs.into(), false);
        let scope = context.into_scope();

        assert_eq!(
            format!("{:?}", scope.operations),
            inline_macro_ref_comptime(false)
        );
    }

    #[test]
    fn cube_comptime_elsif_test() {
        for cond1 in [false, true] {
            for cond2 in [false, true] {
                let mut context1 = CubeContext::root();
                let lhs = context1.create_local(Item::new(ElemType::as_elem()));
                comptime_else_then_if::__expand::<ElemType>(
                    &mut context1,
                    lhs.into(),
                    cond1,
                    cond2,
                );
                let scope1 = context1.into_scope();

                let mut context2 = CubeContext::root();
                let lhs = context2.create_local(Item::new(ElemType::as_elem()));
                comptime_elsif::__expand::<ElemType>(&mut context2, lhs.into(), cond1, cond2);
                let scope2 = context2.into_scope();

                assert_eq!(
                    format!("{:?}", scope1.operations),
                    format!("{:?}", scope2.operations),
                );
            }
        }
    }

    #[test]
    fn cube_comptime_elsif_runtime1_test() {
        for cond in [false, true] {
            let mut context = CubeContext::root();
            let lhs = context.create_local(Item::new(ElemType::as_elem()));
            comptime_elsif_with_runtime1::__expand::<ElemType>(&mut context, lhs.into(), cond);
            let scope = context.into_scope();

            assert_eq!(
                format!("{:?}", scope.operations),
                inline_macro_ref_elsif_runtime1(cond)
            );
        }
    }

    #[test]
    fn cube_comptime_elsif_runtime2_test() {
        for cond in [false, true] {
            let mut context = CubeContext::root();
            let lhs = context.create_local(Item::new(ElemType::as_elem()));

            comptime_elsif_with_runtime2::__expand::<ElemType>(&mut context, lhs.into(), cond);
            let scope = context.into_scope();

            assert_eq!(
                format!("{:?}", scope.operations),
                inline_macro_ref_elsif_runtime2(cond)
            );
        }
    }

    #[test]
    fn cube_comptime_map_bool_test() {
        let mut context1 = CubeContext::root();
        let mut context2 = CubeContext::root();

        let comptime_state_true = State {
            cond: true,
            bound: 4,
        };
        let comptime_state_false = State {
            cond: false,
            bound: 4,
        };

        comptime_with_map_bool::__expand::<ElemType>(&mut context1, comptime_state_true);
        comptime_with_map_bool::__expand::<ElemType>(&mut context2, comptime_state_false);

        let scope1 = context1.into_scope();
        let scope2 = context2.into_scope();

        assert_ne!(
            format!("{:?}", scope1.operations),
            format!("{:?}", scope2.operations)
        );
    }

    #[test]
    fn cube_comptime_map_uint_test() {
        let mut context = CubeContext::root();

        let comptime_state = State {
            cond: true,
            bound: 4,
        };

        comptime_with_map_uint::__expand::<ElemType>(&mut context, comptime_state);

        let scope = context.into_scope();

        assert!(!format!("{:?}", scope.operations).contains("RangeLoop"));
    }

    fn inline_macro_ref_comptime(cond: bool) -> String {
        let mut context = CubeContext::root();
        let item = Item::new(ElemType::as_elem());
        let x = context.create_local(item);

        let mut scope = context.into_scope();
        let x: Variable = x.into();
        let y = scope.create_local(item);

        if cond {
            cpa!(scope, y = x + 4.0f32);
        } else {
            cpa!(scope, y = x - 5.0f32);
        };

        format!("{:?}", scope.operations)
    }

    fn inline_macro_ref_elsif_runtime1(comptime_cond: bool) -> String {
        let mut context = CubeContext::root();
        let item = Item::new(ElemType::as_elem());
        let x = context.create_local(item);

        let mut scope = context.into_scope();
        let x: Variable = x.into();
        let runtime_cond = scope.create_local(Item::new(Elem::Bool));
        let y = scope.create_local(item);
        cpa!(scope, runtime_cond = x >= 2.0f32);

        if comptime_cond {
            cpa!(scope, y = x + 4.0f32);
        } else {
            cpa!(&mut scope, if(runtime_cond).then(|scope| {
                cpa!(scope, y = x + 5.0f32);
            }).else(|scope| {
                cpa!(scope, y = x - 6.0f32);
            }));
        };

        format!("{:?}", scope.operations)
    }

    fn inline_macro_ref_elsif_runtime2(comptime_cond: bool) -> String {
        let mut context = CubeContext::root();
        let item = Item::new(ElemType::as_elem());
        let x = context.create_local(item);

        let mut scope = context.into_scope();
        let x: Variable = x.into();
        let runtime_cond = scope.create_local(Item::new(Elem::Bool));
        let y = scope.create_local(item);
        cpa!(scope, runtime_cond = x >= 2.0f32);

        cpa!(&mut scope, if(runtime_cond).then(|scope| {
            cpa!(scope, y = x + 4.0f32);
        }).else(|scope| {
            if comptime_cond {
                cpa!(scope, y = x + 5.0f32);
            } else {
                cpa!(scope, y = x - 6.0f32);
            }
        }));

        format!("{:?}", scope.operations)
    }
}