cubecl_core/frontend/element/
base.rsuse super::{CubePrimitive, Numeric, Vectorized};
use crate::{
ir::{ConstantScalarValue, Elem, FloatKind, Item, Operator, Variable},
prelude::{assign, init_expand, CubeContext, CubeIndex, KernelBuilder, KernelLauncher},
Runtime,
};
use alloc::rc::Rc;
use half::{bf16, f16};
use std::{marker::PhantomData, num::NonZero};
pub trait CubeType {
type ExpandType: Clone + Init;
fn init(context: &mut CubeContext, expand: Self::ExpandType) -> Self::ExpandType {
expand.init(context)
}
}
pub trait IntoRuntime: CubeType + Sized {
fn runtime(self) -> Self {
self
}
fn __expand_runtime_method(self, context: &mut CubeContext) -> Self::ExpandType;
}
pub trait Init: Sized {
fn init(self, context: &mut CubeContext) -> Self;
}
pub trait LaunchArgExpand: CubeType {
type CompilationArg: Clone
+ PartialEq
+ Eq
+ core::hash::Hash
+ core::fmt::Debug
+ Send
+ Sync
+ 'static;
fn expand(
arg: &Self::CompilationArg,
builder: &mut KernelBuilder,
) -> <Self as CubeType>::ExpandType;
fn expand_output(
arg: &Self::CompilationArg,
builder: &mut KernelBuilder,
) -> <Self as CubeType>::ExpandType {
Self::expand(arg, builder)
}
}
pub trait LaunchArg: LaunchArgExpand + Send + Sync + 'static {
type RuntimeArg<'a, R: Runtime>: ArgSettings<R>;
fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg;
}
impl LaunchArg for () {
type RuntimeArg<'a, R: Runtime> = ();
fn compilation_arg<'a, R: Runtime>(
_runtime_arg: &'a Self::RuntimeArg<'a, R>,
) -> Self::CompilationArg {
}
}
impl<R: Runtime> ArgSettings<R> for () {
fn register(&self, _launcher: &mut KernelLauncher<R>) {
}
}
impl LaunchArgExpand for () {
type CompilationArg = ();
fn expand(
_: &Self::CompilationArg,
_builder: &mut KernelBuilder,
) -> <Self as CubeType>::ExpandType {
}
}
impl CubeType for () {
type ExpandType = ();
}
impl Init for () {
fn init(self, _context: &mut CubeContext) -> Self {
self
}
}
pub trait ArgSettings<R: Runtime>: Send + Sync {
fn register(&self, launcher: &mut KernelLauncher<R>);
}
#[derive(Clone, Debug)]
pub enum ExpandElement {
Managed(Rc<Variable>),
Plain(Variable),
}
#[derive(new)]
pub struct ExpandElementTyped<T: CubeType> {
pub(crate) expand: ExpandElement,
pub(crate) _type: PhantomData<T>,
}
macro_rules! from_const {
($lit:ty) => {
impl From<$lit> for ExpandElementTyped<$lit> {
fn from(value: $lit) -> Self {
let variable: Variable = value.into();
ExpandElement::Plain(variable).into()
}
}
};
}
from_const!(u32);
from_const!(i64);
from_const!(i32);
from_const!(f64);
from_const!(f32);
from_const!(bool);
impl From<f16> for ExpandElementTyped<f16> {
fn from(value: f16) -> Self {
let variable =
Variable::ConstantScalar(ConstantScalarValue::Float(value.to_f64(), FloatKind::F16));
ExpandElement::Plain(variable).into()
}
}
impl From<bf16> for ExpandElementTyped<bf16> {
fn from(value: bf16) -> Self {
let variable =
Variable::ConstantScalar(ConstantScalarValue::Float(value.to_f64(), FloatKind::BF16));
ExpandElement::Plain(variable).into()
}
}
macro_rules! tuple_cube_type {
($($P:ident),*) => {
impl<$($P: CubeType),*> CubeType for ($($P,)*) {
type ExpandType = ($($P::ExpandType,)*);
}
}
}
macro_rules! tuple_init {
($($P:ident),*) => {
impl<$($P: Init),*> Init for ($($P,)*) {
#[allow(non_snake_case)]
fn init(self, context: &mut CubeContext) -> Self {
let ($($P,)*) = self;
($(
$P.init(context),
)*)
}
}
}
}
macro_rules! tuple_runtime {
($($P:ident),*) => {
impl<$($P: IntoRuntime),*> IntoRuntime for ($($P,)*) {
#[allow(non_snake_case)]
fn __expand_runtime_method(self, context: &mut CubeContext) -> Self::ExpandType {
let ($($P,)*) = self;
($(
$P.__expand_runtime_method(context),
)*)
}
}
}
}
tuple_cube_type!(P1);
tuple_cube_type!(P1, P2);
tuple_cube_type!(P1, P2, P3);
tuple_cube_type!(P1, P2, P3, P4);
tuple_cube_type!(P1, P2, P3, P4, P5);
tuple_cube_type!(P1, P2, P3, P4, P5, P6);
tuple_init!(P1);
tuple_init!(P1, P2);
tuple_init!(P1, P2, P3);
tuple_init!(P1, P2, P3, P4);
tuple_init!(P1, P2, P3, P4, P5);
tuple_init!(P1, P2, P3, P4, P5, P6);
tuple_runtime!(P1);
tuple_runtime!(P1, P2);
tuple_runtime!(P1, P2, P3);
tuple_runtime!(P1, P2, P3, P4);
tuple_runtime!(P1, P2, P3, P4, P5);
tuple_runtime!(P1, P2, P3, P4, P5, P6);
pub trait ExpandElementBaseInit: CubeType {
fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement;
}
impl<T: ExpandElementBaseInit> Init for ExpandElementTyped<T> {
fn init(self, context: &mut CubeContext) -> Self {
<T as ExpandElementBaseInit>::init_elem(context, self.into()).into()
}
}
impl<T: CubeType> Vectorized for ExpandElementTyped<T> {
fn vectorization_factor(&self) -> u32 {
self.expand.vectorization_factor()
}
fn vectorize(self, factor: u32) -> Self {
Self {
expand: self.expand.vectorize(factor),
_type: PhantomData,
}
}
}
impl<T: CubeType> ExpandElementTyped<T> {
pub fn __expand_vectorization_factor_method(self, _context: &mut CubeContext) -> u32 {
self.expand
.item()
.vectorization
.map(|it| it.get())
.unwrap_or(1) as u32
}
pub fn __expand_vectorize_method(self, _context: &mut CubeContext, factor: u32) -> Self {
Self {
expand: self.expand.vectorize(factor),
_type: PhantomData,
}
}
}
impl<T: CubeType> Clone for ExpandElementTyped<T> {
fn clone(&self) -> Self {
Self {
expand: self.expand.clone(),
_type: PhantomData,
}
}
}
impl<T: CubeType> From<ExpandElement> for ExpandElementTyped<T> {
fn from(expand: ExpandElement) -> Self {
Self {
expand,
_type: PhantomData,
}
}
}
impl<T: CubeType> From<ExpandElementTyped<T>> for ExpandElement {
fn from(value: ExpandElementTyped<T>) -> Self {
value.expand
}
}
impl<T: CubePrimitive> ExpandElementTyped<T> {
pub fn from_lit<L: Into<Variable>>(lit: L) -> Self {
let variable: Variable = lit.into();
let variable = T::as_elem().from_constant(variable);
ExpandElementTyped::new(ExpandElement::Plain(variable))
}
pub fn constant(&self) -> Option<ConstantScalarValue> {
match *self.expand {
Variable::ConstantScalar(val) => Some(val),
_ => None,
}
}
}
impl ExpandElement {
pub fn can_mut(&self) -> bool {
match self {
ExpandElement::Managed(var) => {
if let Variable::Local { .. } = var.as_ref() {
Rc::strong_count(var) <= 2
} else {
false
}
}
ExpandElement::Plain(_) => false,
}
}
pub fn consume(self) -> Variable {
*self
}
}
impl core::ops::Deref for ExpandElement {
type Target = Variable;
fn deref(&self) -> &Self::Target {
match self {
ExpandElement::Managed(var) => var.as_ref(),
ExpandElement::Plain(var) => var,
}
}
}
impl From<ExpandElement> for Variable {
fn from(value: ExpandElement) -> Self {
match value {
ExpandElement::Managed(var) => *var,
ExpandElement::Plain(var) => var,
}
}
}
pub(crate) fn init_expand_element<E: Into<ExpandElement>>(
context: &mut CubeContext,
element: E,
) -> ExpandElement {
let elem = element.into();
if elem.can_mut() {
return elem;
}
let mut init = |elem: ExpandElement| init_expand(context, elem, Operator::Assign);
match *elem {
Variable::GlobalScalar { .. } => init(elem),
Variable::ConstantScalar { .. } => init(elem),
Variable::Local { .. } => init(elem),
Variable::Versioned { .. } => init(elem),
Variable::LocalBinding { .. } => init(elem),
Variable::Rank
| Variable::UnitPos
| Variable::UnitPosX
| Variable::UnitPosY
| Variable::UnitPosZ
| Variable::CubePos
| Variable::CubePosX
| Variable::CubePosY
| Variable::CubePosZ
| Variable::CubeDim
| Variable::CubeDimX
| Variable::CubeDimY
| Variable::CubeDimZ
| Variable::CubeCount
| Variable::CubeCountX
| Variable::CubeCountY
| Variable::CubeCountZ
| Variable::SubcubeDim
| Variable::AbsolutePos
| Variable::AbsolutePosX
| Variable::AbsolutePosY
| Variable::AbsolutePosZ => init(elem),
Variable::SharedMemory { .. }
| Variable::GlobalInputArray { .. }
| Variable::GlobalOutputArray { .. }
| Variable::LocalArray { .. }
| Variable::ConstantArray { .. }
| Variable::Slice { .. }
| Variable::Matrix { .. } => elem,
}
}
impl Init for ExpandElement {
fn init(self, context: &mut CubeContext) -> Self {
init_expand_element(context, self)
}
}
impl<T: Init> Init for Option<T> {
fn init(self, context: &mut CubeContext) -> Self {
self.map(|o| Init::init(o, context))
}
}
impl<T: CubeType> CubeType for Vec<T> {
type ExpandType = Vec<T::ExpandType>;
}
impl<T: CubeType> CubeType for &mut Vec<T> {
type ExpandType = Vec<T::ExpandType>;
}
impl<T: Init> Init for Vec<T> {
fn init(self, context: &mut CubeContext) -> Self {
self.into_iter().map(|e| e.init(context)).collect()
}
}
pub(crate) fn __expand_new<C: Numeric, Out: Numeric>(
_context: &mut CubeContext,
val: C,
) -> ExpandElementTyped<Out> {
let val = Out::from(val).unwrap();
val.into()
}
pub(crate) fn __expand_vectorized<C: Numeric + CubeIndex<u32>, Out: Numeric>(
context: &mut CubeContext,
val: C,
vectorization: u32,
elem: Elem,
) -> ExpandElementTyped<Out> {
let new_var =
context.create_local_binding(Item::vectorized(elem, NonZero::new(vectorization as u8)));
let val = Out::from(val).unwrap();
let val: ExpandElementTyped<Out> = val.into();
assign::expand(context, val, new_var.clone().into());
new_var.into()
}