cubecl-core 0.2.0

CubeCL core create
Documentation
use cubecl_core as cubecl;
use cubecl_core::prelude::*;

#[cube]
trait FunctionGeneric {
    #[allow(unused)]
    fn test<C: Float>(lhs: C, rhs: C) -> C;
}

#[cube]
trait TraitGeneric<C: Float> {
    #[allow(unused)]
    fn test(lhs: C, rhs: C) -> C;
}

#[cube]
trait CombinedTraitFunctionGeneric<C: Float> {
    #[allow(unused)]
    fn test<O: Numeric>(lhs: C, rhs: C) -> O;
}

struct Test;

#[cube]
impl FunctionGeneric for Test {
    fn test<C: Float>(lhs: C, rhs: C) -> C {
        lhs + rhs
    }
}

#[cube]
impl<C: Float> TraitGeneric<C> for Test {
    fn test(lhs: C, rhs: C) -> C {
        lhs + rhs
    }
}

#[cube]
impl<C: Float> CombinedTraitFunctionGeneric<C> for Test {
    fn test<O: Numeric>(lhs: C, rhs: C) -> O {
        O::cast_from(lhs + rhs)
    }
}

#[cube]
pub fn simple<C: Float>(lhs: C, rhs: C) -> C {
    lhs + rhs
}

#[cube]
pub fn with_cast<C: Float, O: Numeric>(lhs: C, rhs: C) -> O {
    O::cast_from(lhs + rhs)
}

mod tests {
    use cubecl_core::ir::{Item, Scope};

    use super::*;

    #[test]
    fn test_function_generic() {
        let mut context = CubeContext::root();
        let lhs = context.create_local(Item::new(F32::as_elem()));
        let rhs = context.create_local(Item::new(F32::as_elem()));

        <Test as FunctionGeneric>::__expand_test::<F32>(&mut context, lhs.into(), rhs.into());

        assert_eq!(simple_scope(), context.into_scope());
    }

    #[test]
    fn test_trait_generic() {
        let mut context = CubeContext::root();
        let lhs = context.create_local(Item::new(F32::as_elem()));
        let rhs = context.create_local(Item::new(F32::as_elem()));

        <Test as TraitGeneric<F32>>::__expand_test(&mut context, lhs.into(), rhs.into());

        assert_eq!(simple_scope(), context.into_scope());
    }

    #[test]
    fn test_combined_function_generic() {
        let mut context = CubeContext::root();
        let lhs = context.create_local(Item::new(F32::as_elem()));
        let rhs = context.create_local(Item::new(F32::as_elem()));

        <Test as CombinedTraitFunctionGeneric<F32>>::__expand_test::<UInt>(
            &mut context,
            lhs.into(),
            rhs.into(),
        );

        assert_eq!(with_cast_scope(), context.into_scope());
    }

    fn simple_scope() -> Scope {
        let mut context_ref = CubeContext::root();
        let lhs = context_ref.create_local(Item::new(F32::as_elem()));
        let rhs = context_ref.create_local(Item::new(F32::as_elem()));

        simple::__expand::<F32>(&mut context_ref, lhs.into(), rhs.into());
        context_ref.into_scope()
    }

    fn with_cast_scope() -> Scope {
        let mut context_ref = CubeContext::root();
        let lhs = context_ref.create_local(Item::new(F32::as_elem()));
        let rhs = context_ref.create_local(Item::new(F32::as_elem()));

        with_cast::__expand::<F32, UInt>(&mut context_ref, lhs.into(), rhs.into());
        context_ref.into_scope()
    }
}