use cubecl_ir::{ConstantValue, ManagedVariable};
use cubecl_runtime::runtime::Runtime;
use num_traits::{NumCast, One, Zero};
use crate::compute::KernelLauncher;
use crate::{IntoRuntime, ScalarArgType, compute::KernelBuilder};
use crate::{
frontend::{Abs, Remainder},
unexpanded,
};
use crate::{
frontend::{CubePrimitive, CubeType},
prelude::InputScalar,
};
use crate::{
ir::{Scope, Variable},
prelude::Scalar,
};
use super::{LaunchArg, NativeAssign, NativeExpand};
pub trait Numeric:
Copy
+ Abs
+ Remainder
+ Scalar
+ NativeAssign
+ Into<NativeExpand<Self>>
+ Into<ConstantValue>
+ num_traits::NumCast
+ num_traits::NumAssign
+ core::cmp::PartialOrd
+ core::cmp::PartialEq
+ core::fmt::Debug
+ bytemuck::Zeroable
{
fn min_value() -> Self;
fn max_value() -> Self;
fn __expand_min_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
let elem = Self::as_type(scope).elem_type();
let var = elem.min_variable();
let expand = ManagedVariable::Plain(var);
expand.into()
}
fn __expand_max_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
let elem = Self::as_type(scope).elem_type();
let var = elem.max_variable();
let expand = ManagedVariable::Plain(var);
expand.into()
}
fn from_int(val: i64) -> Self {
<Self as NumCast>::from(val).unwrap()
}
fn from_int_128(val: i128) -> Self {
<Self as NumCast>::from(val).unwrap()
}
fn from_vec<const D: usize>(_vec: [u32; D]) -> Self {
unexpanded!()
}
fn __expand_from_int(
scope: &mut Scope,
val: NativeExpand<i64>,
) -> <Self as CubeType>::ExpandType {
let elem = Self::as_type(scope).elem_type();
let var: Variable = elem.constant(val.constant().unwrap());
ManagedVariable::Plain(var).into()
}
}
pub trait ScalarArgSettings: Send + Sync + CubePrimitive {
fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>);
fn expand_scalar(builder: &mut KernelBuilder) -> NativeExpand<Self> {
builder
.scalar(Self::as_type(&builder.scope).storage_type())
.into()
}
}
impl<E: ScalarArgType> ScalarArgSettings for E {
fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
launcher.register_scalar(*self);
}
}
impl ScalarArgSettings for usize {
fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
let value = InputScalar::new(*self, launcher.settings.address_type.unsigned_type());
InputScalar::register(value, launcher);
}
}
impl ScalarArgSettings for isize {
fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
let value = InputScalar::new(*self, launcher.settings.address_type.signed_type());
InputScalar::register(value, launcher);
}
}
impl<T: ScalarArgSettings> LaunchArg for T {
type RuntimeArg<R: Runtime> = T;
type CompilationArg = ();
fn register<R: Runtime>(arg: Self::RuntimeArg<R>, launcher: &mut KernelLauncher<R>) {
arg.register(launcher);
}
fn expand(_: &(), builder: &mut KernelBuilder) -> NativeExpand<Self> {
T::expand_scalar(builder)
}
}
pub trait ZeroExpand: CubeType + Zero {
fn __expand_zero(scope: &mut Scope) -> Self::ExpandType;
}
pub trait OneExpand: CubeType + One {
fn __expand_one(scope: &mut Scope) -> Self::ExpandType;
}
impl<T: CubeType + Zero + IntoRuntime> ZeroExpand for T {
fn __expand_zero(scope: &mut Scope) -> Self::ExpandType {
T::zero().__expand_runtime_method(scope)
}
}
impl<T: CubeType + One + IntoRuntime> OneExpand for T {
fn __expand_one(scope: &mut Scope) -> Self::ExpandType {
T::one().__expand_runtime_method(scope)
}
}