use cubecl_core as cubecl;
use cubecl_core::prelude::*;
#[cube]
pub trait Strategy<T: Numeric> {
fn operation(input_1: T, input_2: T) -> T;
}
struct AddStrategy;
#[cube]
pub fn add_strategy_operation<T: Numeric>(input_1: T, input_2: T) -> T {
input_1 + input_2
}
#[cube]
impl<T: Numeric> Strategy<T> for AddStrategy {
fn operation(input_1: T, input_2: T) -> T {
add_strategy_operation::<T>(input_1, input_2)
}
}
struct SubStrategy;
#[cube]
impl<T: Numeric> Strategy<T> for SubStrategy {
fn operation(input_1: T, input_2: T) -> T {
input_1 - input_2
}
}
#[cube]
pub fn with_strategy_trait<S: Strategy<T>, T: Numeric>(x: T, y: T) -> T {
S::operation(x, y)
}
#[cube]
pub fn two_strategy_traits<S1: Strategy<F>, S2: Strategy<F>, F: Float>(x: F, y: F) -> F {
let z = S1::operation(x, y);
S2::operation(z, y)
}
pub trait MethodTypedStrategy {
fn operation<T: Numeric>(input_1: T, input_2: T) -> T;
fn __expand_operation<T: Numeric>(
_context: &mut CubeContext,
input_1: <T as CubeType>::ExpandType,
input_2: <T as CubeType>::ExpandType,
) -> <T as CubeType>::ExpandType;
}
impl MethodTypedStrategy for AddStrategy {
fn operation<T: Numeric>(input_1: T, input_2: T) -> T {
add_strategy_operation(input_1, input_2)
}
fn __expand_operation<T: Numeric>(
context: &mut CubeContext,
input_1: <T as CubeType>::ExpandType,
input_2: <T as CubeType>::ExpandType,
) -> <T as CubeType>::ExpandType {
add_strategy_operation::__expand::<T>(context, input_1, input_2)
}
}
#[cube]
pub fn with_trait_generic_method<S: MethodTypedStrategy, T: Numeric>(x: T, y: T) -> T {
S::operation::<T>(x, y)
}
mod tests {
use super::*;
use cubecl_core::{
cpa,
ir::{Item, Variable},
};
type ElemType = F32;
#[test]
fn cube_strategy_trait_add_test() {
let mut context = CubeContext::root();
let x = context.create_local(Item::new(ElemType::as_elem()));
let y = context.create_local(Item::new(ElemType::as_elem()));
with_strategy_trait::__expand::<AddStrategy, ElemType>(&mut context, x.into(), y.into());
let scope = context.into_scope();
assert_eq!(
format!("{:?}", scope.operations),
inline_macro_ref_one(true)
);
}
#[test]
fn cube_strategy_trait_sub_test() {
let mut context = CubeContext::root();
let x = context.create_local(Item::new(ElemType::as_elem()));
let y = context.create_local(Item::new(ElemType::as_elem()));
with_strategy_trait::__expand::<SubStrategy, ElemType>(&mut context, x.into(), y.into());
let scope = context.into_scope();
assert_eq!(
format!("{:?}", scope.operations),
inline_macro_ref_one(false)
);
}
#[test]
fn cube_two_strategy_traits_test() {
let mut context = CubeContext::root();
let x = context.create_local(Item::new(ElemType::as_elem()));
let y = context.create_local(Item::new(ElemType::as_elem()));
two_strategy_traits::__expand::<SubStrategy, AddStrategy, ElemType>(
&mut context,
x.into(),
y.into(),
);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_two());
}
#[test]
fn cube_trait_generic_method_test() {
let mut context = CubeContext::root();
let x = context.create_local(Item::new(ElemType::as_elem()));
let y = context.create_local(Item::new(ElemType::as_elem()));
with_trait_generic_method::__expand::<AddStrategy, ElemType>(
&mut context,
x.into(),
y.into(),
);
let scope = context.into_scope();
assert_eq!(
format!("{:?}", scope.operations),
inline_macro_ref_one(true)
);
}
fn inline_macro_ref_one(is_add_strategy: bool) -> String {
let mut context = CubeContext::root();
let item = Item::new(ElemType::as_elem());
let x = context.create_local(item);
let y = context.create_local(item);
let mut scope = context.into_scope();
let x: Variable = x.into();
let y: Variable = y.into();
match is_add_strategy {
true => cpa!(scope, x = x + y),
false => cpa!(scope, x = x - y),
}
format!("{:?}", scope.operations)
}
fn inline_macro_ref_two() -> String {
let mut context = CubeContext::root();
let item = Item::new(ElemType::as_elem());
let x = context.create_local(item);
let y = context.create_local(item);
let mut scope = context.into_scope();
let x: Variable = x.into();
let y: Variable = y.into();
cpa!(scope, x = x - y);
cpa!(scope, x = x + y);
format!("{:?}", scope.operations)
}
}