1use std::hash::{Hash, Hasher};
7use std::mem::discriminant;
8
9use morok_dtype::DeviceSpec;
10use morok_dtype::{DType, ScalarDType};
11
12#[derive(Debug, Clone, Copy, PartialEq, derive_more::From)]
14#[derive(serde::Serialize, serde::Deserialize)]
15pub enum ConstValue {
16 Int(i64),
17 UInt(u64),
18 Float(f64),
19 Bool(bool),
20}
21
22macro_rules! impl_from_widening {
23 ($($ty:ty => Int),+ $(,)?) => { $(
24 impl From<$ty> for ConstValue {
25 fn from(v: $ty) -> Self { ConstValue::Int(v as i64) }
26 }
27 )+ };
28 ($($ty:ty => UInt),+ $(,)?) => { $(
29 impl From<$ty> for ConstValue {
30 fn from(v: $ty) -> Self { ConstValue::UInt(v as u64) }
31 }
32 )+ };
33}
34
35impl_from_widening!(i8 => Int, i16 => Int, i32 => Int);
36impl_from_widening!(u8 => UInt, u16 => UInt, u32 => UInt);
37
38impl From<f32> for ConstValue {
39 fn from(v: f32) -> Self {
40 ConstValue::Float(v as f64)
41 }
42}
43
44impl Hash for ConstValue {
47 fn hash<H: Hasher>(&self, state: &mut H) {
48 discriminant(self).hash(state);
49 match self {
50 ConstValue::Int(v) => v.hash(state),
51 ConstValue::UInt(v) => v.hash(state),
52 ConstValue::Float(v) => v.to_bits().hash(state),
53 ConstValue::Bool(v) => v.hash(state),
54 }
55 }
56}
57
58macro_rules! cast_via {
60 ($v:expr, $target:ty, $storage:ty) => {
61 ($v as $target) as $storage
62 };
63}
64
65macro_rules! impl_cast {
67 ($self:expr, $to:expr) => {
68 match ($self, $to) {
69 (ConstValue::Bool(v), dt) => cast_bool(v, dt)?,
70 (ConstValue::Int(v), dt) => cast_int(v, dt)?,
71 (ConstValue::UInt(v), dt) => cast_uint(v, dt)?,
72 (ConstValue::Float(v), dt) => cast_float(v, dt)?,
73 }
74 };
75}
76
77#[inline]
78fn cast_bool(v: bool, to: ScalarDType) -> Option<ConstValue> {
79 use ScalarDType::*;
80 Some(match to {
81 Bool => ConstValue::Bool(v),
82 Int8 | Int16 | Int32 | Int64 | Index => ConstValue::Int(v as i64),
83 UInt8 | UInt16 | UInt32 | UInt64 => ConstValue::UInt(v as u64),
84 Float16 | BFloat16 | Float32 | Float64 => ConstValue::Float(v as u8 as f64),
85 _ => return None,
86 })
87}
88
89#[inline]
90fn cast_int(v: i64, to: ScalarDType) -> Option<ConstValue> {
91 use ScalarDType::*;
92 Some(match to {
93 Bool => ConstValue::Bool(v != 0),
94 Int8 => ConstValue::Int(cast_via!(v, i8, i64)),
95 Int16 => ConstValue::Int(cast_via!(v, i16, i64)),
96 Int32 => ConstValue::Int(cast_via!(v, i32, i64)),
97 Int64 | Index => ConstValue::Int(v),
98 UInt8 => ConstValue::UInt(cast_via!(v, u8, u64)),
99 UInt16 => ConstValue::UInt(cast_via!(v, u16, u64)),
100 UInt32 => ConstValue::UInt(cast_via!(v, u32, u64)),
101 UInt64 => ConstValue::UInt(v as u64),
102 Float16 | BFloat16 | Float32 | Float64 => ConstValue::Float(v as f64),
103 _ => return None,
104 })
105}
106
107#[inline]
108fn cast_uint(v: u64, to: ScalarDType) -> Option<ConstValue> {
109 use ScalarDType::*;
110 Some(match to {
111 Bool => ConstValue::Bool(v != 0),
112 Int8 => ConstValue::Int(cast_via!(v, i8, i64)),
113 Int16 => ConstValue::Int(cast_via!(v, i16, i64)),
114 Int32 => ConstValue::Int(cast_via!(v, i32, i64)),
115 Int64 | Index => ConstValue::Int(v as i64),
116 UInt8 => ConstValue::UInt(cast_via!(v, u8, u64)),
117 UInt16 => ConstValue::UInt(cast_via!(v, u16, u64)),
118 UInt32 => ConstValue::UInt(cast_via!(v, u32, u64)),
119 UInt64 => ConstValue::UInt(v),
120 Float16 | BFloat16 | Float32 | Float64 => ConstValue::Float(v as f64),
121 _ => return None,
122 })
123}
124
125#[inline]
126fn cast_float(v: f64, to: ScalarDType) -> Option<ConstValue> {
127 use ScalarDType::*;
128 Some(match to {
129 Bool => ConstValue::Bool(v != 0.0),
130 Int8 => ConstValue::Int(cast_via!(v, i8, i64)),
131 Int16 => ConstValue::Int(cast_via!(v, i16, i64)),
132 Int32 => ConstValue::Int(cast_via!(v, i32, i64)),
133 Int64 | Index => ConstValue::Int(v as i64),
134 UInt8 => ConstValue::UInt(cast_via!(v as i64, u8, u64)),
136 UInt16 => ConstValue::UInt(cast_via!(v as i64, u16, u64)),
137 UInt32 => ConstValue::UInt(cast_via!(v as i64, u32, u64)),
138 UInt64 => ConstValue::UInt((v as i64) as u64),
139 Float16 | BFloat16 | Float32 | Float64 => ConstValue::Float(v),
140 _ => return None,
141 })
142}
143
144impl ConstValue {
145 pub const fn dtype(&self) -> DType {
146 match self {
147 ConstValue::Int(_) => DType::Int64,
148 ConstValue::UInt(_) => DType::UInt64,
149 ConstValue::Float(_) => DType::Float64,
150 ConstValue::Bool(_) => DType::Bool,
151 }
152 }
153
154 pub const fn zero(dtype: ScalarDType) -> Self {
155 use ScalarDType::*;
156 match dtype {
157 Bool => Self::Bool(false),
158 Int8 | Int16 | Int32 | Int64 => Self::Int(0),
159 UInt8 | UInt16 | UInt32 | UInt64 => Self::UInt(0),
160 FP8E4M3 | FP8E5M2 | Float16 | BFloat16 | Float32 | Float64 => Self::Float(0.0),
161 Void | Index => Self::Int(0), }
163 }
164
165 pub const fn one(dtype: ScalarDType) -> Self {
166 use ScalarDType::*;
167 match dtype {
168 Bool => Self::Bool(true),
169 Int8 | Int16 | Int32 | Int64 => Self::Int(1),
170 UInt8 | UInt16 | UInt32 | UInt64 => Self::UInt(1),
171 FP8E4M3 | FP8E5M2 | Float16 | BFloat16 | Float32 | Float64 => Self::Float(1.0),
172 Void | Index => Self::Int(1), }
174 }
175
176 pub const fn neg_one(dtype: ScalarDType) -> Option<Self> {
177 use ScalarDType::*;
178 Some(match dtype {
179 Int8 | Int16 | Int32 | Int64 | Index => Self::Int(-1),
180 FP8E4M3 | FP8E5M2 | Float16 | BFloat16 | Float32 | Float64 => Self::Float(-1.0),
181 _ => return None,
182 })
183 }
184
185 pub const fn min(dtype: ScalarDType) -> Self {
187 use ScalarDType::*;
188 match dtype {
189 Bool => Self::Bool(false),
190 Int8 => Self::Int(i8::MIN as i64),
191 Int16 => Self::Int(i16::MIN as i64),
192 Int32 => Self::Int(i32::MIN as i64),
193 Int64 | Index => Self::Int(i64::MIN),
194 UInt8 | UInt16 | UInt32 | UInt64 => Self::UInt(0),
195 FP8E4M3 | FP8E5M2 | Float16 => Self::Float(-65504.0),
196 BFloat16 => Self::Float(-3.38953e38),
197 Float32 => Self::Float(f32::MIN as f64),
198 Float64 => Self::Float(f64::MIN),
199 Void => Self::Int(0),
200 }
201 }
202
203 pub const fn max(dtype: ScalarDType) -> Self {
205 use ScalarDType::*;
206 match dtype {
207 Bool => Self::Bool(true),
208 Int8 => Self::Int(i8::MAX as i64),
209 Int16 => Self::Int(i16::MAX as i64),
210 Int32 => Self::Int(i32::MAX as i64),
211 Int64 | Index => Self::Int(i64::MAX),
212 UInt8 => Self::UInt(u8::MAX as u64),
213 UInt16 => Self::UInt(u16::MAX as u64),
214 UInt32 => Self::UInt(u32::MAX as u64),
215 UInt64 => Self::UInt(u64::MAX),
216 FP8E4M3 | FP8E5M2 | Float16 => Self::Float(65504.0),
217 BFloat16 => Self::Float(3.38953e38),
218 Float32 => Self::Float(f32::MAX as f64),
219 Float64 => Self::Float(f64::MAX),
220 Void => Self::Int(0),
221 }
222 }
223
224 pub fn cast(&self, dtype: &DType) -> Option<Self> {
244 let scalar_dtype = dtype.scalar()?;
245
246 Some(impl_cast!(*self, scalar_dtype))
247 }
248
249 pub const fn is_zero(&self) -> bool {
253 match self {
254 Self::Int(0) | Self::UInt(0) | Self::Bool(false) => true,
255 Self::Float(f) => *f == 0.0,
256 _ => false,
257 }
258 }
259
260 pub const fn is_one(&self) -> bool {
264 match self {
265 Self::Int(1) | Self::UInt(1) | Self::Bool(true) => true,
266 Self::Float(f) => *f == 1.0,
267 _ => false,
268 }
269 }
270
271 pub const fn is_neg_one(&self) -> bool {
275 match self {
276 Self::Int(-1) => true,
277 Self::Float(f) => *f == -1.0,
278 _ => false,
279 }
280 }
281
282 pub const fn try_int(&self) -> Option<i64> {
286 match self {
287 Self::Int(v) => Some(*v),
288 Self::UInt(v) => Some(*v as i64),
289 _ => None,
290 }
291 }
292
293 pub const fn try_float(&self) -> Option<f64> {
297 match self {
298 Self::Float(v) => Some(*v),
299 _ => None,
300 }
301 }
302
303 pub fn truncate(self, dtype: ScalarDType) -> Self {
308 use ScalarDType::*;
309 match (self, dtype) {
310 (Self::Int(v), Int8) => Self::Int((v as i8) as i64),
312 (Self::Int(v), Int16) => Self::Int((v as i16) as i64),
313 (Self::Int(v), Int32) => Self::Int((v as i32) as i64),
314 (Self::Int(v), Int64 | Index) => Self::Int(v),
315
316 (Self::UInt(v), UInt8) => Self::UInt((v as u8) as u64),
318 (Self::UInt(v), UInt16) => Self::UInt((v as u16) as u64),
319 (Self::UInt(v), UInt32) => Self::UInt((v as u32) as u64),
320 (Self::UInt(v), UInt64) => Self::UInt(v),
321
322 _ => self,
324 }
325 }
326}
327
328pub use morok_dtype::AddrSpace;
330
331#[derive(Debug, Clone, PartialEq, Eq, Hash)]
333#[derive(serde::Serialize, serde::Deserialize)]
334pub struct BufferizeOpts {
335 pub device: Option<DeviceSpec>,
337 pub addrspace: AddrSpace,
339 pub removable: bool,
344}
345
346impl BufferizeOpts {
347 pub fn new(device: DeviceSpec) -> Self {
348 Self { device: Some(device), addrspace: AddrSpace::Global, removable: true }
349 }
350
351 pub fn local() -> Self {
352 Self { device: None, addrspace: AddrSpace::Local, removable: true }
353 }
354}
355
356#[derive(Debug, Clone, PartialEq, Eq, Hash)]
364#[derive(serde::Serialize, serde::Deserialize)]
365pub struct ContiguousHint {
366 pub op: String,
368 pub axis: Option<usize>,
370 pub arg: Option<i64>,
372}
373
374#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
376#[derive(serde::Serialize, serde::Deserialize)]
377pub enum AxisType {
378 Outer,
384 Global,
386 Warp,
388 Local,
390 Loop,
392 GroupReduce,
394 Reduce,
396 Upcast,
398 Unroll,
400 Thread,
402 Placeholder,
405}
406
407impl AxisType {
408 pub const fn is_kernel_boundary(&self) -> bool {
414 matches!(self, Self::Outer)
415 }
416
417 pub const fn priority(self) -> i32 {
432 match self {
433 Self::Outer => -2,
434 Self::Loop => -1,
435 Self::Global | Self::Thread => 0,
436 Self::Warp => 1,
437 Self::Local | Self::GroupReduce => 2,
438 Self::Upcast => 3,
439 Self::Reduce => 4,
440 Self::Unroll => 5,
441 Self::Placeholder => -3,
442 }
443 }
444
445 pub const fn letter(self) -> char {
461 match self {
462 Self::Outer => 'O',
463 Self::Loop => 'L',
464 Self::Global => 'g',
465 Self::Thread => 't',
466 Self::Warp => 'w',
467 Self::Local => 'l',
468 Self::GroupReduce => 'G',
469 Self::Upcast => 'u',
470 Self::Reduce => 'R',
471 Self::Unroll => 'r',
472 Self::Placeholder => 'P',
473 }
474 }
475
476 pub const fn is_parallel(self) -> bool {
481 matches!(self, Self::Global | Self::Thread | Self::Local | Self::Warp)
482 }
483
484 pub const fn is_reduce(self) -> bool {
486 matches!(self, Self::Reduce | Self::GroupReduce | Self::Unroll)
487 }
488}
489
490impl PartialOrd for AxisType {
491 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
492 Some(self.cmp(other))
493 }
494}
495
496impl Ord for AxisType {
497 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
498 self.priority().cmp(&other.priority())
499 }
500}
501
502impl std::fmt::Display for AxisType {
503 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
504 write!(f, "{}", self.letter())
505 }
506}
507
508#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
517#[derive(serde::Serialize, serde::Deserialize)]
518pub enum AxisId {
519 Unrenumbered(usize),
521 Renumbered(usize),
523}
524
525impl AxisId {
526 pub fn value(&self) -> usize {
528 match self {
529 AxisId::Unrenumbered(n) | AxisId::Renumbered(n) => *n,
530 }
531 }
532
533 pub fn is_renumbered(&self) -> bool {
535 matches!(self, AxisId::Renumbered(_))
536 }
537}
538
539impl std::fmt::Display for AxisId {
540 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
541 match self {
542 AxisId::Unrenumbered(n) => write!(f, "U{}", n),
543 AxisId::Renumbered(n) => write!(f, "R{}", n),
544 }
545 }
546}
547
548#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
550#[derive(serde::Serialize, serde::Deserialize)]
551pub enum ReduceOp {
552 Add,
554 Mul,
556 Max,
558 Min,
560}
561
562#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, strum::AsRefStr, strum::VariantNames)]
566#[derive(serde::Serialize, serde::Deserialize)]
567pub enum UnaryOp {
568 Neg,
570 Not,
572 Abs,
574 Sqrt,
576 Rsqrt,
578 Exp,
580 Exp2,
582 Log,
584 Log2,
586 Sin,
588 Cos,
590 Tan,
592 Reciprocal,
594 Trunc,
596 Floor,
598 Ceil,
600 Round,
602 Sign,
604 Erf,
606 Square,
608}
609
610#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, strum::AsRefStr, strum::VariantNames)]
616#[derive(serde::Serialize, serde::Deserialize)]
617pub enum BinaryOp {
618 Add,
621 Mul,
623 Sub,
625 Mod,
634 Max,
636 Pow,
638 Idiv,
647 Fdiv,
652
653 Lt,
656 Le,
658 Eq,
660 Ne,
662 Gt,
664 Ge,
666
667 And,
670 Or,
672 Xor,
674 Shl,
676 Shr,
678
679 Threefry,
682}
683
684impl BinaryOp {
685 pub fn is_comparison(self) -> bool {
687 matches!(self, Self::Lt | Self::Le | Self::Eq | Self::Ne | Self::Gt | Self::Ge)
688 }
689
690 pub fn is_arithmetic(self) -> bool {
692 matches!(self, Self::Add | Self::Mul | Self::Sub | Self::Mod | Self::Max | Self::Pow | Self::Idiv | Self::Fdiv)
693 }
694
695 pub fn is_bitwise(self) -> bool {
697 matches!(self, Self::And | Self::Or | Self::Xor | Self::Shl | Self::Shr)
698 }
699
700 pub fn is_associative(self) -> bool {
702 matches!(self, Self::Add | Self::Mul | Self::And | Self::Or | Self::Max)
703 }
704
705 pub fn is_commutative(self) -> bool {
707 matches!(self, Self::Add | Self::Mul | Self::Eq | Self::Ne | Self::And | Self::Or | Self::Xor | Self::Max)
708 }
709
710 pub fn is_idempotent(self) -> bool {
712 matches!(self, Self::Or | Self::And | Self::Max)
713 }
714}
715
716#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, strum::AsRefStr, strum::VariantNames)]
718#[derive(serde::Serialize, serde::Deserialize)]
719pub enum TernaryOp {
720 Where,
722 MulAcc,
724}
725
726#[derive(Debug, Clone, PartialEq, Eq, Hash)]
732#[derive(serde::Serialize, serde::Deserialize)]
733pub struct WmmaUpcastAxes {
734 pub a: Vec<(usize, usize)>,
736 pub b: Vec<(usize, usize)>,
738 pub c: Vec<(usize, usize)>,
740}
741
742impl WmmaUpcastAxes {
743 pub fn all_axis_ids(&self) -> Vec<usize> {
745 let mut ids: Vec<usize> = self.a.iter().chain(self.b.iter()).chain(self.c.iter()).map(|(id, _)| *id).collect();
746 ids.sort_unstable();
747 ids.dedup();
748 ids
749 }
750
751 pub fn by_index(&self, index: usize) -> &[(usize, usize)] {
753 match index {
754 0 => &self.a,
755 1 => &self.b,
756 2 => &self.c,
757 _ => panic!("WMMA operand index must be 0, 1, or 2"),
758 }
759 }
760
761 pub fn source_size(&self, index: usize) -> usize {
763 self.by_index(index).iter().map(|(_, s)| s).product::<usize>().max(1)
764 }
765}
766
767#[derive(Debug, Clone, PartialEq, Eq, Hash)]
769#[derive(serde::Serialize, serde::Deserialize)]
770pub struct WmmaMetadata {
771 pub name: String,
773 pub dims: (usize, usize, usize),
775 pub dtype_in: DType,
777 pub dtype_out: DType,
779 pub device: String,
781 pub threads: usize,
783 pub upcast_axes: WmmaUpcastAxes,
785 pub reduce_axes: Vec<usize>,
787 pub tile_grid: (usize, usize),
792}
793
794#[derive(Debug, Clone, Copy)]
803#[derive(serde::Serialize, serde::Deserialize)]
804pub struct ConstValueHash(pub ConstValue);
805
806impl PartialEq for ConstValueHash {
807 fn eq(&self, other: &Self) -> bool {
808 match (self.0, other.0) {
809 (ConstValue::Int(a), ConstValue::Int(b)) => a == b,
810 (ConstValue::UInt(a), ConstValue::UInt(b)) => a == b,
811 (ConstValue::Float(a), ConstValue::Float(b)) => a.to_bits() == b.to_bits(),
812 (ConstValue::Bool(a), ConstValue::Bool(b)) => a == b,
813 _ => false,
814 }
815 }
816}
817
818impl Eq for ConstValueHash {}
819
820impl Hash for ConstValueHash {
821 fn hash<H: Hasher>(&self, state: &mut H) {
822 (discriminant(&self.0)).hash(state);
823 match self.0 {
824 ConstValue::Int(v) => v.hash(state),
825 ConstValue::UInt(v) => v.hash(state),
826 ConstValue::Float(v) => v.to_bits().hash(state),
827 ConstValue::Bool(v) => v.hash(state),
828 }
829 }
830}