use std::hash::{Hash, Hasher};
use std::mem::discriminant;
use morok_dtype::DeviceSpec;
use morok_dtype::{DType, ScalarDType};
#[derive(Debug, Clone, Copy, PartialEq, derive_more::From)]
#[derive(serde::Serialize, serde::Deserialize)]
pub enum ConstValue {
Int(i64),
UInt(u64),
Float(f64),
Bool(bool),
}
macro_rules! impl_from_widening {
($($ty:ty => Int),+ $(,)?) => { $(
impl From<$ty> for ConstValue {
fn from(v: $ty) -> Self { ConstValue::Int(v as i64) }
}
)+ };
($($ty:ty => UInt),+ $(,)?) => { $(
impl From<$ty> for ConstValue {
fn from(v: $ty) -> Self { ConstValue::UInt(v as u64) }
}
)+ };
}
impl_from_widening!(i8 => Int, i16 => Int, i32 => Int);
impl_from_widening!(u8 => UInt, u16 => UInt, u32 => UInt);
impl From<f32> for ConstValue {
fn from(v: f32) -> Self {
ConstValue::Float(v as f64)
}
}
impl Hash for ConstValue {
fn hash<H: Hasher>(&self, state: &mut H) {
discriminant(self).hash(state);
match self {
ConstValue::Int(v) => v.hash(state),
ConstValue::UInt(v) => v.hash(state),
ConstValue::Float(v) => v.to_bits().hash(state),
ConstValue::Bool(v) => v.hash(state),
}
}
}
macro_rules! cast_via {
($v:expr, $target:ty, $storage:ty) => {
($v as $target) as $storage
};
}
macro_rules! impl_cast {
($self:expr, $to:expr) => {
match ($self, $to) {
(ConstValue::Bool(v), dt) => cast_bool(v, dt)?,
(ConstValue::Int(v), dt) => cast_int(v, dt)?,
(ConstValue::UInt(v), dt) => cast_uint(v, dt)?,
(ConstValue::Float(v), dt) => cast_float(v, dt)?,
}
};
}
#[inline]
fn cast_bool(v: bool, to: ScalarDType) -> Option<ConstValue> {
use ScalarDType::*;
Some(match to {
Bool => ConstValue::Bool(v),
Int8 | Int16 | Int32 | Int64 | Index => ConstValue::Int(v as i64),
UInt8 | UInt16 | UInt32 | UInt64 => ConstValue::UInt(v as u64),
Float16 | BFloat16 | Float32 | Float64 => ConstValue::Float(v as u8 as f64),
_ => return None,
})
}
#[inline]
fn cast_int(v: i64, to: ScalarDType) -> Option<ConstValue> {
use ScalarDType::*;
Some(match to {
Bool => ConstValue::Bool(v != 0),
Int8 => ConstValue::Int(cast_via!(v, i8, i64)),
Int16 => ConstValue::Int(cast_via!(v, i16, i64)),
Int32 => ConstValue::Int(cast_via!(v, i32, i64)),
Int64 | Index => ConstValue::Int(v),
UInt8 => ConstValue::UInt(cast_via!(v, u8, u64)),
UInt16 => ConstValue::UInt(cast_via!(v, u16, u64)),
UInt32 => ConstValue::UInt(cast_via!(v, u32, u64)),
UInt64 => ConstValue::UInt(v as u64),
Float16 | BFloat16 | Float32 | Float64 => ConstValue::Float(v as f64),
_ => return None,
})
}
#[inline]
fn cast_uint(v: u64, to: ScalarDType) -> Option<ConstValue> {
use ScalarDType::*;
Some(match to {
Bool => ConstValue::Bool(v != 0),
Int8 => ConstValue::Int(cast_via!(v, i8, i64)),
Int16 => ConstValue::Int(cast_via!(v, i16, i64)),
Int32 => ConstValue::Int(cast_via!(v, i32, i64)),
Int64 | Index => ConstValue::Int(v as i64),
UInt8 => ConstValue::UInt(cast_via!(v, u8, u64)),
UInt16 => ConstValue::UInt(cast_via!(v, u16, u64)),
UInt32 => ConstValue::UInt(cast_via!(v, u32, u64)),
UInt64 => ConstValue::UInt(v),
Float16 | BFloat16 | Float32 | Float64 => ConstValue::Float(v as f64),
_ => return None,
})
}
#[inline]
fn cast_float(v: f64, to: ScalarDType) -> Option<ConstValue> {
use ScalarDType::*;
Some(match to {
Bool => ConstValue::Bool(v != 0.0),
Int8 => ConstValue::Int(cast_via!(v, i8, i64)),
Int16 => ConstValue::Int(cast_via!(v, i16, i64)),
Int32 => ConstValue::Int(cast_via!(v, i32, i64)),
Int64 | Index => ConstValue::Int(v as i64),
UInt8 => ConstValue::UInt(cast_via!(v as i64, u8, u64)),
UInt16 => ConstValue::UInt(cast_via!(v as i64, u16, u64)),
UInt32 => ConstValue::UInt(cast_via!(v as i64, u32, u64)),
UInt64 => ConstValue::UInt((v as i64) as u64),
Float16 | BFloat16 | Float32 | Float64 => ConstValue::Float(v),
_ => return None,
})
}
impl ConstValue {
pub const fn dtype(&self) -> DType {
match self {
ConstValue::Int(_) => DType::Int64,
ConstValue::UInt(_) => DType::UInt64,
ConstValue::Float(_) => DType::Float64,
ConstValue::Bool(_) => DType::Bool,
}
}
pub const fn zero(dtype: ScalarDType) -> Self {
use ScalarDType::*;
match dtype {
Bool => Self::Bool(false),
Int8 | Int16 | Int32 | Int64 => Self::Int(0),
UInt8 | UInt16 | UInt32 | UInt64 => Self::UInt(0),
FP8E4M3 | FP8E5M2 | Float16 | BFloat16 | Float32 | Float64 => Self::Float(0.0),
Void | Index => Self::Int(0), }
}
pub const fn one(dtype: ScalarDType) -> Self {
use ScalarDType::*;
match dtype {
Bool => Self::Bool(true),
Int8 | Int16 | Int32 | Int64 => Self::Int(1),
UInt8 | UInt16 | UInt32 | UInt64 => Self::UInt(1),
FP8E4M3 | FP8E5M2 | Float16 | BFloat16 | Float32 | Float64 => Self::Float(1.0),
Void | Index => Self::Int(1), }
}
pub const fn neg_one(dtype: ScalarDType) -> Option<Self> {
use ScalarDType::*;
Some(match dtype {
Int8 | Int16 | Int32 | Int64 | Index => Self::Int(-1),
FP8E4M3 | FP8E5M2 | Float16 | BFloat16 | Float32 | Float64 => Self::Float(-1.0),
_ => return None,
})
}
pub const fn min(dtype: ScalarDType) -> Self {
use ScalarDType::*;
match dtype {
Bool => Self::Bool(false),
Int8 => Self::Int(i8::MIN as i64),
Int16 => Self::Int(i16::MIN as i64),
Int32 => Self::Int(i32::MIN as i64),
Int64 | Index => Self::Int(i64::MIN),
UInt8 | UInt16 | UInt32 | UInt64 => Self::UInt(0),
FP8E4M3 | FP8E5M2 | Float16 => Self::Float(-65504.0),
BFloat16 => Self::Float(-3.38953e38),
Float32 => Self::Float(f32::MIN as f64),
Float64 => Self::Float(f64::MIN),
Void => Self::Int(0),
}
}
pub const fn max(dtype: ScalarDType) -> Self {
use ScalarDType::*;
match dtype {
Bool => Self::Bool(true),
Int8 => Self::Int(i8::MAX as i64),
Int16 => Self::Int(i16::MAX as i64),
Int32 => Self::Int(i32::MAX as i64),
Int64 | Index => Self::Int(i64::MAX),
UInt8 => Self::UInt(u8::MAX as u64),
UInt16 => Self::UInt(u16::MAX as u64),
UInt32 => Self::UInt(u32::MAX as u64),
UInt64 => Self::UInt(u64::MAX),
FP8E4M3 | FP8E5M2 | Float16 => Self::Float(65504.0),
BFloat16 => Self::Float(3.38953e38),
Float32 => Self::Float(f32::MAX as f64),
Float64 => Self::Float(f64::MAX),
Void => Self::Int(0),
}
}
pub fn cast(&self, dtype: &DType) -> Option<Self> {
let scalar_dtype = dtype.scalar()?;
Some(impl_cast!(*self, scalar_dtype))
}
pub const fn is_zero(&self) -> bool {
match self {
Self::Int(0) | Self::UInt(0) | Self::Bool(false) => true,
Self::Float(f) => *f == 0.0,
_ => false,
}
}
pub const fn is_one(&self) -> bool {
match self {
Self::Int(1) | Self::UInt(1) | Self::Bool(true) => true,
Self::Float(f) => *f == 1.0,
_ => false,
}
}
pub const fn is_neg_one(&self) -> bool {
match self {
Self::Int(-1) => true,
Self::Float(f) => *f == -1.0,
_ => false,
}
}
pub const fn try_int(&self) -> Option<i64> {
match self {
Self::Int(v) => Some(*v),
Self::UInt(v) => Some(*v as i64),
_ => None,
}
}
pub const fn try_float(&self) -> Option<f64> {
match self {
Self::Float(v) => Some(*v),
_ => None,
}
}
pub fn truncate(self, dtype: ScalarDType) -> Self {
use ScalarDType::*;
match (self, dtype) {
(Self::Int(v), Int8) => Self::Int((v as i8) as i64),
(Self::Int(v), Int16) => Self::Int((v as i16) as i64),
(Self::Int(v), Int32) => Self::Int((v as i32) as i64),
(Self::Int(v), Int64 | Index) => Self::Int(v),
(Self::UInt(v), UInt8) => Self::UInt((v as u8) as u64),
(Self::UInt(v), UInt16) => Self::UInt((v as u16) as u64),
(Self::UInt(v), UInt32) => Self::UInt((v as u32) as u64),
(Self::UInt(v), UInt64) => Self::UInt(v),
_ => self,
}
}
}
pub use morok_dtype::AddrSpace;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(serde::Serialize, serde::Deserialize)]
pub struct BufferizeOpts {
pub device: Option<DeviceSpec>,
pub addrspace: AddrSpace,
pub removable: bool,
}
impl BufferizeOpts {
pub fn new(device: DeviceSpec) -> Self {
Self { device: Some(device), addrspace: AddrSpace::Global, removable: true }
}
pub fn local() -> Self {
Self { device: None, addrspace: AddrSpace::Local, removable: true }
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(serde::Serialize, serde::Deserialize)]
pub struct ContiguousHint {
pub op: String,
pub axis: Option<usize>,
pub arg: Option<i64>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[derive(serde::Serialize, serde::Deserialize)]
pub enum AxisType {
Outer,
Global,
Warp,
Local,
Loop,
GroupReduce,
Reduce,
Upcast,
Unroll,
Thread,
Placeholder,
}
impl AxisType {
pub const fn is_kernel_boundary(&self) -> bool {
matches!(self, Self::Outer)
}
pub const fn priority(self) -> i32 {
match self {
Self::Outer => -2,
Self::Loop => -1,
Self::Global | Self::Thread => 0,
Self::Warp => 1,
Self::Local | Self::GroupReduce => 2,
Self::Upcast => 3,
Self::Reduce => 4,
Self::Unroll => 5,
Self::Placeholder => -3,
}
}
pub const fn letter(self) -> char {
match self {
Self::Outer => 'O',
Self::Loop => 'L',
Self::Global => 'g',
Self::Thread => 't',
Self::Warp => 'w',
Self::Local => 'l',
Self::GroupReduce => 'G',
Self::Upcast => 'u',
Self::Reduce => 'R',
Self::Unroll => 'r',
Self::Placeholder => 'P',
}
}
pub const fn is_parallel(self) -> bool {
matches!(self, Self::Global | Self::Thread | Self::Local | Self::Warp)
}
pub const fn is_reduce(self) -> bool {
matches!(self, Self::Reduce | Self::GroupReduce | Self::Unroll)
}
}
impl PartialOrd for AxisType {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for AxisType {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.priority().cmp(&other.priority())
}
}
impl std::fmt::Display for AxisType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.letter())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[derive(serde::Serialize, serde::Deserialize)]
pub enum AxisId {
Unrenumbered(usize),
Renumbered(usize),
}
impl AxisId {
pub fn value(&self) -> usize {
match self {
AxisId::Unrenumbered(n) | AxisId::Renumbered(n) => *n,
}
}
pub fn is_renumbered(&self) -> bool {
matches!(self, AxisId::Renumbered(_))
}
}
impl std::fmt::Display for AxisId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AxisId::Unrenumbered(n) => write!(f, "U{}", n),
AxisId::Renumbered(n) => write!(f, "R{}", n),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[derive(serde::Serialize, serde::Deserialize)]
pub enum ReduceOp {
Add,
Mul,
Max,
Min,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, strum::AsRefStr, strum::VariantNames)]
#[derive(serde::Serialize, serde::Deserialize)]
pub enum UnaryOp {
Neg,
Not,
Abs,
Sqrt,
Rsqrt,
Exp,
Exp2,
Log,
Log2,
Sin,
Cos,
Tan,
Reciprocal,
Trunc,
Floor,
Ceil,
Round,
Sign,
Erf,
Square,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, strum::AsRefStr, strum::VariantNames)]
#[derive(serde::Serialize, serde::Deserialize)]
pub enum BinaryOp {
Add,
Mul,
Sub,
Mod,
Max,
Pow,
Idiv,
Fdiv,
Lt,
Le,
Eq,
Ne,
Gt,
Ge,
And,
Or,
Xor,
Shl,
Shr,
Threefry,
}
impl BinaryOp {
pub fn is_comparison(self) -> bool {
matches!(self, Self::Lt | Self::Le | Self::Eq | Self::Ne | Self::Gt | Self::Ge)
}
pub fn is_arithmetic(self) -> bool {
matches!(self, Self::Add | Self::Mul | Self::Sub | Self::Mod | Self::Max | Self::Pow | Self::Idiv | Self::Fdiv)
}
pub fn is_bitwise(self) -> bool {
matches!(self, Self::And | Self::Or | Self::Xor | Self::Shl | Self::Shr)
}
pub fn is_associative(self) -> bool {
matches!(self, Self::Add | Self::Mul | Self::And | Self::Or | Self::Max)
}
pub fn is_commutative(self) -> bool {
matches!(self, Self::Add | Self::Mul | Self::Eq | Self::Ne | Self::And | Self::Or | Self::Xor | Self::Max)
}
pub fn is_idempotent(self) -> bool {
matches!(self, Self::Or | Self::And | Self::Max)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, strum::AsRefStr, strum::VariantNames)]
#[derive(serde::Serialize, serde::Deserialize)]
pub enum TernaryOp {
Where,
MulAcc,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(serde::Serialize, serde::Deserialize)]
pub struct WmmaUpcastAxes {
pub a: Vec<(usize, usize)>,
pub b: Vec<(usize, usize)>,
pub c: Vec<(usize, usize)>,
}
impl WmmaUpcastAxes {
pub fn all_axis_ids(&self) -> Vec<usize> {
let mut ids: Vec<usize> = self.a.iter().chain(self.b.iter()).chain(self.c.iter()).map(|(id, _)| *id).collect();
ids.sort_unstable();
ids.dedup();
ids
}
pub fn by_index(&self, index: usize) -> &[(usize, usize)] {
match index {
0 => &self.a,
1 => &self.b,
2 => &self.c,
_ => panic!("WMMA operand index must be 0, 1, or 2"),
}
}
pub fn source_size(&self, index: usize) -> usize {
self.by_index(index).iter().map(|(_, s)| s).product::<usize>().max(1)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(serde::Serialize, serde::Deserialize)]
pub struct WmmaMetadata {
pub name: String,
pub dims: (usize, usize, usize),
pub dtype_in: DType,
pub dtype_out: DType,
pub device: String,
pub threads: usize,
pub upcast_axes: WmmaUpcastAxes,
pub reduce_axes: Vec<usize>,
pub tile_grid: (usize, usize),
}
#[derive(Debug, Clone, Copy)]
#[derive(serde::Serialize, serde::Deserialize)]
pub struct ConstValueHash(pub ConstValue);
impl PartialEq for ConstValueHash {
fn eq(&self, other: &Self) -> bool {
match (self.0, other.0) {
(ConstValue::Int(a), ConstValue::Int(b)) => a == b,
(ConstValue::UInt(a), ConstValue::UInt(b)) => a == b,
(ConstValue::Float(a), ConstValue::Float(b)) => a.to_bits() == b.to_bits(),
(ConstValue::Bool(a), ConstValue::Bool(b)) => a == b,
_ => false,
}
}
}
impl Eq for ConstValueHash {}
impl Hash for ConstValueHash {
fn hash<H: Hasher>(&self, state: &mut H) {
(discriminant(&self.0)).hash(state);
match self.0 {
ConstValue::Int(v) => v.hash(state),
ConstValue::UInt(v) => v.hash(state),
ConstValue::Float(v) => v.to_bits().hash(state),
ConstValue::Bool(v) => v.hash(state),
}
}
}