Skip to main content

morok_ir/
types.rs

1//! Type definitions for IR operations.
2//!
3//! This module contains all the fundamental type enums and structs used throughout
4//! the IR, including operation types, constant values, and metadata structures.
5
6use std::hash::{Hash, Hasher};
7use std::mem::discriminant;
8
9use morok_dtype::DeviceSpec;
10use morok_dtype::{DType, ScalarDType};
11
12/// Constant value that can be stored in a UOp.
13#[derive(Debug, Clone, Copy, PartialEq, derive_more::From)]
14#[derive(serde::Serialize, serde::Deserialize)]
15pub enum ConstValue {
16    Int(i64),
17    UInt(u64),
18    Float(f64),
19    Bool(bool),
20}
21
22macro_rules! impl_from_widening {
23    ($($ty:ty => Int),+ $(,)?) => { $(
24        impl From<$ty> for ConstValue {
25            fn from(v: $ty) -> Self { ConstValue::Int(v as i64) }
26        }
27    )+ };
28    ($($ty:ty => UInt),+ $(,)?) => { $(
29        impl From<$ty> for ConstValue {
30            fn from(v: $ty) -> Self { ConstValue::UInt(v as u64) }
31        }
32    )+ };
33}
34
35impl_from_widening!(i8 => Int, i16 => Int, i32 => Int);
36impl_from_widening!(u8 => UInt, u16 => UInt, u32 => UInt);
37
38impl From<f32> for ConstValue {
39    fn from(v: f32) -> Self {
40        ConstValue::Float(v as f64)
41    }
42}
43
44/// Manual Hash impl because f64 doesn't implement Hash.
45/// Uses to_bits() for floats, which means NaN values with identical bit patterns hash equally.
46impl Hash for ConstValue {
47    fn hash<H: Hasher>(&self, state: &mut H) {
48        discriminant(self).hash(state);
49        match self {
50            ConstValue::Int(v) => v.hash(state),
51            ConstValue::UInt(v) => v.hash(state),
52            ConstValue::Float(v) => v.to_bits().hash(state),
53            ConstValue::Bool(v) => v.hash(state),
54        }
55    }
56}
57
58/// Helper macro to cast to target width and back to storage type (for proper truncation/extension).
59macro_rules! cast_via {
60    ($v:expr, $target:ty, $storage:ty) => {
61        ($v as $target) as $storage
62    };
63}
64
65/// Macro to generate casting logic by delegating to helper functions.
66macro_rules! impl_cast {
67    ($self:expr, $to:expr) => {
68        match ($self, $to) {
69            (ConstValue::Bool(v), dt) => cast_bool(v, dt)?,
70            (ConstValue::Int(v), dt) => cast_int(v, dt)?,
71            (ConstValue::UInt(v), dt) => cast_uint(v, dt)?,
72            (ConstValue::Float(v), dt) => cast_float(v, dt)?,
73        }
74    };
75}
76
77#[inline]
78fn cast_bool(v: bool, to: ScalarDType) -> Option<ConstValue> {
79    use ScalarDType::*;
80    Some(match to {
81        Bool => ConstValue::Bool(v),
82        Int8 | Int16 | Int32 | Int64 | Index => ConstValue::Int(v as i64),
83        UInt8 | UInt16 | UInt32 | UInt64 => ConstValue::UInt(v as u64),
84        Float16 | BFloat16 | Float32 | Float64 => ConstValue::Float(v as u8 as f64),
85        _ => return None,
86    })
87}
88
89#[inline]
90fn cast_int(v: i64, to: ScalarDType) -> Option<ConstValue> {
91    use ScalarDType::*;
92    Some(match to {
93        Bool => ConstValue::Bool(v != 0),
94        Int8 => ConstValue::Int(cast_via!(v, i8, i64)),
95        Int16 => ConstValue::Int(cast_via!(v, i16, i64)),
96        Int32 => ConstValue::Int(cast_via!(v, i32, i64)),
97        Int64 | Index => ConstValue::Int(v),
98        UInt8 => ConstValue::UInt(cast_via!(v, u8, u64)),
99        UInt16 => ConstValue::UInt(cast_via!(v, u16, u64)),
100        UInt32 => ConstValue::UInt(cast_via!(v, u32, u64)),
101        UInt64 => ConstValue::UInt(v as u64),
102        Float16 | BFloat16 | Float32 | Float64 => ConstValue::Float(v as f64),
103        _ => return None,
104    })
105}
106
107#[inline]
108fn cast_uint(v: u64, to: ScalarDType) -> Option<ConstValue> {
109    use ScalarDType::*;
110    Some(match to {
111        Bool => ConstValue::Bool(v != 0),
112        Int8 => ConstValue::Int(cast_via!(v, i8, i64)),
113        Int16 => ConstValue::Int(cast_via!(v, i16, i64)),
114        Int32 => ConstValue::Int(cast_via!(v, i32, i64)),
115        Int64 | Index => ConstValue::Int(v as i64),
116        UInt8 => ConstValue::UInt(cast_via!(v, u8, u64)),
117        UInt16 => ConstValue::UInt(cast_via!(v, u16, u64)),
118        UInt32 => ConstValue::UInt(cast_via!(v, u32, u64)),
119        UInt64 => ConstValue::UInt(v),
120        Float16 | BFloat16 | Float32 | Float64 => ConstValue::Float(v as f64),
121        _ => return None,
122    })
123}
124
125#[inline]
126fn cast_float(v: f64, to: ScalarDType) -> Option<ConstValue> {
127    use ScalarDType::*;
128    Some(match to {
129        Bool => ConstValue::Bool(v != 0.0),
130        Int8 => ConstValue::Int(cast_via!(v, i8, i64)),
131        Int16 => ConstValue::Int(cast_via!(v, i16, i64)),
132        Int32 => ConstValue::Int(cast_via!(v, i32, i64)),
133        Int64 | Index => ConstValue::Int(v as i64),
134        // Float-to-unsigned: route through i64 first (matches Tinygrad behavior)
135        UInt8 => ConstValue::UInt(cast_via!(v as i64, u8, u64)),
136        UInt16 => ConstValue::UInt(cast_via!(v as i64, u16, u64)),
137        UInt32 => ConstValue::UInt(cast_via!(v as i64, u32, u64)),
138        UInt64 => ConstValue::UInt((v as i64) as u64),
139        Float16 | BFloat16 | Float32 | Float64 => ConstValue::Float(v),
140        _ => return None,
141    })
142}
143
144impl ConstValue {
145    pub const fn dtype(&self) -> DType {
146        match self {
147            ConstValue::Int(_) => DType::Int64,
148            ConstValue::UInt(_) => DType::UInt64,
149            ConstValue::Float(_) => DType::Float64,
150            ConstValue::Bool(_) => DType::Bool,
151        }
152    }
153
154    pub const fn zero(dtype: ScalarDType) -> Self {
155        use ScalarDType::*;
156        match dtype {
157            Bool => Self::Bool(false),
158            Int8 | Int16 | Int32 | Int64 => Self::Int(0),
159            UInt8 | UInt16 | UInt32 | UInt64 => Self::UInt(0),
160            FP8E4M3 | FP8E5M2 | Float16 | BFloat16 | Float32 | Float64 => Self::Float(0.0),
161            Void | Index => Self::Int(0), // TODO: remove this types from scalars
162        }
163    }
164
165    pub const fn one(dtype: ScalarDType) -> Self {
166        use ScalarDType::*;
167        match dtype {
168            Bool => Self::Bool(true),
169            Int8 | Int16 | Int32 | Int64 => Self::Int(1),
170            UInt8 | UInt16 | UInt32 | UInt64 => Self::UInt(1),
171            FP8E4M3 | FP8E5M2 | Float16 | BFloat16 | Float32 | Float64 => Self::Float(1.0),
172            Void | Index => Self::Int(1), // TODO: remove this types from scalars
173        }
174    }
175
176    pub const fn neg_one(dtype: ScalarDType) -> Option<Self> {
177        use ScalarDType::*;
178        Some(match dtype {
179            Int8 | Int16 | Int32 | Int64 | Index => Self::Int(-1),
180            FP8E4M3 | FP8E5M2 | Float16 | BFloat16 | Float32 | Float64 => Self::Float(-1.0),
181            _ => return None,
182        })
183    }
184
185    /// Minimum representable value for a scalar dtype (matches Tinygrad's `dtypes.min`).
186    pub const fn min(dtype: ScalarDType) -> Self {
187        use ScalarDType::*;
188        match dtype {
189            Bool => Self::Bool(false),
190            Int8 => Self::Int(i8::MIN as i64),
191            Int16 => Self::Int(i16::MIN as i64),
192            Int32 => Self::Int(i32::MIN as i64),
193            Int64 | Index => Self::Int(i64::MIN),
194            UInt8 | UInt16 | UInt32 | UInt64 => Self::UInt(0),
195            FP8E4M3 | FP8E5M2 | Float16 => Self::Float(-65504.0),
196            BFloat16 => Self::Float(-3.38953e38),
197            Float32 => Self::Float(f32::MIN as f64),
198            Float64 => Self::Float(f64::MIN),
199            Void => Self::Int(0),
200        }
201    }
202
203    /// Maximum representable value for a scalar dtype (matches Tinygrad's `dtypes.max`).
204    pub const fn max(dtype: ScalarDType) -> Self {
205        use ScalarDType::*;
206        match dtype {
207            Bool => Self::Bool(true),
208            Int8 => Self::Int(i8::MAX as i64),
209            Int16 => Self::Int(i16::MAX as i64),
210            Int32 => Self::Int(i32::MAX as i64),
211            Int64 | Index => Self::Int(i64::MAX),
212            UInt8 => Self::UInt(u8::MAX as u64),
213            UInt16 => Self::UInt(u16::MAX as u64),
214            UInt32 => Self::UInt(u32::MAX as u64),
215            UInt64 => Self::UInt(u64::MAX),
216            FP8E4M3 | FP8E5M2 | Float16 => Self::Float(65504.0),
217            BFloat16 => Self::Float(3.38953e38),
218            Float32 => Self::Float(f32::MAX as f64),
219            Float64 => Self::Float(f64::MAX),
220            Void => Self::Int(0),
221        }
222    }
223
224    /// Cast this constant value to the target dtype.
225    ///
226    /// Returns `None` if:
227    /// - The target dtype is not a scalar type
228    /// - The target dtype is not representable as a ConstValue (e.g., Void, Index, special float formats)
229    ///
230    /// # Safety and Semantics
231    ///
232    /// This method performs constant folding for cast operations and allows ALL casts
233    /// (including lossy ones like float->int) since the user explicitly wrote the cast operation.
234    ///
235    /// Uses Rust's `as` operator for conversions, which follows C semantics:
236    /// - Truncation for narrowing conversions (e.g., i64 -> i32)
237    /// - Wrap-around for unsigned overflow
238    /// - Truncation toward zero for float-to-int conversions
239    ///
240    /// For multi-stage conversions (e.g., casting through intermediate types),
241    /// the value is cast to the target width and then extended back to the storage type.
242    /// Example: i64 -> i8 -> i64 ensures proper sign extension.
243    pub fn cast(&self, dtype: &DType) -> Option<Self> {
244        let scalar_dtype = dtype.scalar()?;
245
246        Some(impl_cast!(*self, scalar_dtype))
247    }
248
249    /// Returns true if this constant is zero (additive identity).
250    ///
251    /// Works for all numeric types: Int, UInt, Float, Bool.
252    pub const fn is_zero(&self) -> bool {
253        match self {
254            Self::Int(0) | Self::UInt(0) | Self::Bool(false) => true,
255            Self::Float(f) => *f == 0.0,
256            _ => false,
257        }
258    }
259
260    /// Returns true if this constant is one (multiplicative identity).
261    ///
262    /// Works for all numeric types: Int, UInt, Float, Bool.
263    pub const fn is_one(&self) -> bool {
264        match self {
265            Self::Int(1) | Self::UInt(1) | Self::Bool(true) => true,
266            Self::Float(f) => *f == 1.0,
267            _ => false,
268        }
269    }
270
271    /// Returns true if this constant is negative one.
272    ///
273    /// Used for patterns like `x // -1 → -x`.
274    pub const fn is_neg_one(&self) -> bool {
275        match self {
276            Self::Int(-1) => true,
277            Self::Float(f) => *f == -1.0,
278            _ => false,
279        }
280    }
281
282    /// Try to extract an integer value (i64 or u64 as i64).
283    ///
284    /// Used for constant pattern matching with specific integer values.
285    pub const fn try_int(&self) -> Option<i64> {
286        match self {
287            Self::Int(v) => Some(*v),
288            Self::UInt(v) => Some(*v as i64),
289            _ => None,
290        }
291    }
292
293    /// Try to extract a float value (f64).
294    ///
295    /// Used for constant pattern matching with specific float values.
296    pub const fn try_float(&self) -> Option<f64> {
297        match self {
298            Self::Float(v) => Some(*v),
299            _ => None,
300        }
301    }
302
303    /// Truncate value to fit within dtype boundaries (two's complement wrapping).
304    ///
305    /// This is equivalent to Tinygrad's ctypes-based truncation. Used for constant
306    /// folding to ensure results respect the target dtype's bit width.
307    pub fn truncate(self, dtype: ScalarDType) -> Self {
308        use ScalarDType::*;
309        match (self, dtype) {
310            // Signed integers: cast to target width, then back to i64
311            (Self::Int(v), Int8) => Self::Int((v as i8) as i64),
312            (Self::Int(v), Int16) => Self::Int((v as i16) as i64),
313            (Self::Int(v), Int32) => Self::Int((v as i32) as i64),
314            (Self::Int(v), Int64 | Index) => Self::Int(v),
315
316            // Unsigned integers: cast to target width, then back to u64
317            (Self::UInt(v), UInt8) => Self::UInt((v as u8) as u64),
318            (Self::UInt(v), UInt16) => Self::UInt((v as u16) as u64),
319            (Self::UInt(v), UInt32) => Self::UInt((v as u32) as u64),
320            (Self::UInt(v), UInt64) => Self::UInt(v),
321
322            // Float/Bool: no truncation needed
323            _ => self,
324        }
325    }
326}
327
328// Re-export AddrSpace from dtype to avoid duplication
329pub use morok_dtype::AddrSpace;
330
331/// Options for BUFFERIZE operation.
332#[derive(Debug, Clone, PartialEq, Eq, Hash)]
333#[derive(serde::Serialize, serde::Deserialize)]
334pub struct BufferizeOpts {
335    /// Device specification or None for local buffers.
336    pub device: Option<DeviceSpec>,
337    /// Address space (GLOBAL or LOCAL).
338    pub addrspace: AddrSpace,
339    /// Whether buffer_removal may inline this BUFFERIZE.
340    /// Multi-consumer realize boundaries set this to `false` so that
341    /// `dead_axis_removal` (which creates new BUFFERIZE nodes) preserves
342    /// the protection across mega-pass fixpoint iterations.
343    pub removable: bool,
344}
345
346impl BufferizeOpts {
347    pub fn new(device: DeviceSpec) -> Self {
348        Self { device: Some(device), addrspace: AddrSpace::Global, removable: true }
349    }
350
351    pub fn local() -> Self {
352        Self { device: None, addrspace: AddrSpace::Local, removable: true }
353    }
354}
355
356/// Optimization hint carried by CONTIGUOUS ops.
357///
358/// This is a simplified representation of optimizer hints that can be
359/// converted to/from the full `Opt` type in the schedule crate.
360/// Keeps the IR layer decoupled from optimizer-specific types.
361///
362/// Based on Tinygrad's CONTIGUOUS.arg which carries Opt tuples.
363#[derive(Debug, Clone, PartialEq, Eq, Hash)]
364#[derive(serde::Serialize, serde::Deserialize)]
365pub struct ContiguousHint {
366    /// Operation name (e.g., "UPCAST", "LOCAL", "UNROLL")
367    pub op: String,
368    /// Target axis index (if applicable)
369    pub axis: Option<usize>,
370    /// Integer argument (amount, size, etc.)
371    pub arg: Option<i64>,
372}
373
374/// Axis type for loop ranges and reductions.
375#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
376#[derive(serde::Serialize, serde::Deserialize)]
377pub enum AxisType {
378    /// Outer kernel-level scheduling dimension (doesn't go inside kernels).
379    ///
380    /// Used to mark ranges that exist at the scheduling/orchestration level
381    /// but don't become part of kernel execution. These ranges are used during
382    /// kernel splitting to identify boundaries.
383    Outer,
384    /// GPU grid dimension.
385    Global,
386    /// Warp/wavefront dimension.
387    Warp,
388    /// GPU block/workgroup dimension (local memory scope).
389    Local,
390    /// Regular loop.
391    Loop,
392    /// Grouped reduction.
393    GroupReduce,
394    /// Reduction axis.
395    Reduce,
396    /// Vectorization axis (upcast).
397    Upcast,
398    /// Unrolled loop.
399    Unroll,
400    /// Thread dimension.
401    Thread,
402    /// Temporary canonicalized range for RESHAPE caching (Tinygrad: AxisType.PLACEHOLDER).
403    /// Substituted in before `_apply_reshape` and substituted back after.
404    Placeholder,
405}
406
407impl AxisType {
408    /// Returns true if this axis type represents a kernel boundary.
409    ///
410    /// Kernel boundary ranges (Outer) exist at the scheduling level and
411    /// don't go inside individual kernels. During kernel splitting, operations
412    /// with outer ranges are skipped from being packaged into KERNEL ops.
413    pub const fn is_kernel_boundary(&self) -> bool {
414        matches!(self, Self::Outer)
415    }
416
417    /// Returns the priority for sorting ranges.
418    ///
419    /// Lower values are outer loops, higher values are inner loops.
420    /// Matches Tinygrad's axis_to_pos ordering for kernel optimization.
421    ///
422    /// **Priority Order:**
423    /// - Outer: -2 (kernel-level boundary)
424    /// - Loop: -1 (not yet parallelized)
425    /// - Global/Thread: 0 (outer parallelism)
426    /// - Warp: 1 (sub-group parallelism)
427    /// - Local/GroupReduce: 2 (workgroup parallelism + synchronization)
428    /// - Upcast: 3 (vectorization)
429    /// - Reduce: 4 (reduction loops)
430    /// - Unroll: 5 (unrolled loops, innermost)
431    pub const fn priority(self) -> i32 {
432        match self {
433            Self::Outer => -2,
434            Self::Loop => -1,
435            Self::Global | Self::Thread => 0,
436            Self::Warp => 1,
437            Self::Local | Self::GroupReduce => 2,
438            Self::Upcast => 3,
439            Self::Reduce => 4,
440            Self::Unroll => 5,
441            Self::Placeholder => -3,
442        }
443    }
444
445    /// Returns the single-letter code for this axis type.
446    ///
447    /// Used in kernel name generation and debug output.
448    ///
449    /// **Letter Codes:**
450    /// - O: Outer
451    /// - L: Loop
452    /// - g: Global
453    /// - t: Thread
454    /// - w: Warp
455    /// - l: Local
456    /// - G: GroupReduce
457    /// - u: Upcast
458    /// - R: Reduce
459    /// - r: Unroll
460    pub const fn letter(self) -> char {
461        match self {
462            Self::Outer => 'O',
463            Self::Loop => 'L',
464            Self::Global => 'g',
465            Self::Thread => 't',
466            Self::Warp => 'w',
467            Self::Local => 'l',
468            Self::GroupReduce => 'G',
469            Self::Upcast => 'u',
470            Self::Reduce => 'R',
471            Self::Unroll => 'r',
472            Self::Placeholder => 'P',
473        }
474    }
475
476    /// Returns true if this is a parallelizable axis type.
477    ///
478    /// Parallel axes represent GPU/thread dispatch dimensions that don't
479    /// contribute to accumulator placement in reduce_to_acc.
480    pub const fn is_parallel(self) -> bool {
481        matches!(self, Self::Global | Self::Thread | Self::Local | Self::Warp)
482    }
483
484    /// Returns true if this is a reduction axis type.
485    pub const fn is_reduce(self) -> bool {
486        matches!(self, Self::Reduce | Self::GroupReduce | Self::Unroll)
487    }
488}
489
490impl PartialOrd for AxisType {
491    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
492        Some(self.cmp(other))
493    }
494}
495
496impl Ord for AxisType {
497    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
498        self.priority().cmp(&other.priority())
499    }
500}
501
502impl std::fmt::Display for AxisType {
503    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
504        write!(f, "{}", self.letter())
505    }
506}
507
508/// State of range numbering for kernel deduplication.
509///
510/// Ranges go through two states during the compilation pipeline:
511/// - `Unrenumbered`: Created during rangeify with unique IDs for graph construction
512/// - `Renumbered`: Assigned sequential IDs starting from 0 within each kernel
513///
514/// The enum makes the renumber_range pattern naturally idempotent:
515/// it only matches `Unrenumbered` variants and produces `Renumbered` variants.
516#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
517#[derive(serde::Serialize, serde::Deserialize)]
518pub enum AxisId {
519    /// Range created during rangeify, not yet renumbered.
520    Unrenumbered(usize),
521    /// Range renumbered for kernel deduplication.
522    Renumbered(usize),
523}
524
525impl AxisId {
526    /// Get the numeric value, regardless of state.
527    pub fn value(&self) -> usize {
528        match self {
529            AxisId::Unrenumbered(n) | AxisId::Renumbered(n) => *n,
530        }
531    }
532
533    /// Check if this range has been renumbered.
534    pub fn is_renumbered(&self) -> bool {
535        matches!(self, AxisId::Renumbered(_))
536    }
537}
538
539impl std::fmt::Display for AxisId {
540    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
541        match self {
542            AxisId::Unrenumbered(n) => write!(f, "U{}", n),
543            AxisId::Renumbered(n) => write!(f, "R{}", n),
544        }
545    }
546}
547
548/// Reduction operation types.
549#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
550#[derive(serde::Serialize, serde::Deserialize)]
551pub enum ReduceOp {
552    /// Sum reduction (a + b).
553    Add,
554    /// Product reduction (a * b).
555    Mul,
556    /// Maximum reduction (max(a, b)).
557    Max,
558    /// Minimum reduction (min(a, b)).
559    Min,
560}
561
562/// Unary operation types.
563///
564/// All unary operations preserve the input dtype.
565#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, strum::AsRefStr, strum::VariantNames)]
566#[derive(serde::Serialize, serde::Deserialize)]
567pub enum UnaryOp {
568    /// Negation: -x
569    Neg,
570    /// Logical/bitwise NOT: !x (bool) or ~x (int)
571    Not,
572    /// Absolute value: |x|
573    Abs,
574    /// Square root: √x
575    Sqrt,
576    /// Reciprocal square root: 1/√x
577    Rsqrt,
578    /// Natural exponential: e^x
579    Exp,
580    /// Base-2 exponential: 2^x
581    Exp2,
582    /// Natural logarithm: ln(x)
583    Log,
584    /// Base-2 logarithm: log₂(x)
585    Log2,
586    /// Sine: sin(x) (float only)
587    Sin,
588    /// Cosine: cos(x) (float only)
589    Cos,
590    /// Tangent: tan(x) (float only)
591    Tan,
592    /// Reciprocal: 1/x
593    Reciprocal,
594    /// Truncate towards zero (remove fractional part)
595    Trunc,
596    /// Floor: round towards -∞
597    Floor,
598    /// Ceiling: round towards +∞
599    Ceil,
600    /// Round: round to nearest integer (half to even)
601    Round,
602    /// Sign: -1 for negative, 0 for zero, 1 for positive
603    Sign,
604    /// Error function: erf(x) (float only)
605    Erf,
606    /// Square: x²
607    Square,
608}
609
610/// Binary operation types.
611///
612/// Arithmetic operations (Add, Mul, Sub, Mod, Max, Pow, Idiv, Fdiv) preserve the LHS dtype.
613/// Comparison operations (Lt, Eq, Ne) always return DType::Bool.
614/// Bitwise operations (And, Or, Xor, Shl, Shr) preserve dtype and require int/bool types.
615#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, strum::AsRefStr, strum::VariantNames)]
616#[derive(serde::Serialize, serde::Deserialize)]
617pub enum BinaryOp {
618    // Arithmetic operations
619    /// Addition: a + b
620    Add,
621    /// Multiplication: a * b
622    Mul,
623    /// Subtraction: a - b
624    Sub,
625    /// Modulo: a % b (C-style remainder)
626    ///
627    /// Uses C/Rust semantics where result has the sign of the dividend (first operand).
628    /// This matches Tinygrad's MOD and C's % operator.
629    ///
630    /// **NOT** Python's modulo operator (which has sign of the divisor).
631    ///
632    /// Examples: -9 % 5 = -4 (Python gives 1), 9 % -5 = 4 (Python gives -1)
633    Mod,
634    /// Maximum: max(a, b)
635    Max,
636    /// Power: a^b
637    Pow,
638    /// Integer division: a / b (truncated toward zero)
639    ///
640    /// Uses C-style truncation, NOT floor division.
641    /// This matches Tinygrad's IDIV and C's / operator for integers.
642    ///
643    /// **NOT** Python's // floor division (which rounds toward -∞).
644    ///
645    /// Examples: -9 / 5 = -1 (Python's // gives -2), 9 / -5 = -1 (Python's // gives -2)
646    Idiv,
647    /// Float division: a / b (exact IEEE 754 division)
648    ///
649    /// Only used for float dtypes. Performs exact floating-point division.
650    /// Matches Tinygrad's FDIV.
651    Fdiv,
652
653    // Comparison operations
654    /// Less than: a < b
655    Lt,
656    /// Less than or equal: a <= b
657    Le,
658    /// Equality: a == b
659    Eq,
660    /// Inequality: a != b
661    Ne,
662    /// Greater than: a > b
663    Gt,
664    /// Greater than or equal: a >= b
665    Ge,
666
667    // Bitwise operations (int/bool only)
668    /// Bitwise AND: a & b
669    And,
670    /// Bitwise OR: a | b
671    Or,
672    /// Bitwise XOR: a ^ b
673    Xor,
674    /// Left shift: a << b
675    Shl,
676    /// Right shift: a >> b
677    Shr,
678
679    // Special operations
680    /// Threefry PRNG: threefry(x, key) -> uint64
681    Threefry,
682}
683
684impl BinaryOp {
685    /// Returns true if this is a comparison operation.
686    pub fn is_comparison(self) -> bool {
687        matches!(self, Self::Lt | Self::Le | Self::Eq | Self::Ne | Self::Gt | Self::Ge)
688    }
689
690    /// Returns true if this is an arithmetic operation.
691    pub fn is_arithmetic(self) -> bool {
692        matches!(self, Self::Add | Self::Mul | Self::Sub | Self::Mod | Self::Max | Self::Pow | Self::Idiv | Self::Fdiv)
693    }
694
695    /// Returns true if this is a bitwise operation.
696    pub fn is_bitwise(self) -> bool {
697        matches!(self, Self::And | Self::Or | Self::Xor | Self::Shl | Self::Shr)
698    }
699
700    /// Returns true if this operation is associative.
701    pub fn is_associative(self) -> bool {
702        matches!(self, Self::Add | Self::Mul | Self::And | Self::Or | Self::Max)
703    }
704
705    /// Returns true if this operation is commutative.
706    pub fn is_commutative(self) -> bool {
707        matches!(self, Self::Add | Self::Mul | Self::Eq | Self::Ne | Self::And | Self::Or | Self::Xor | Self::Max)
708    }
709
710    /// Returns true if this operation is idempotent (f(x, x) = x).
711    pub fn is_idempotent(self) -> bool {
712        matches!(self, Self::Or | Self::And | Self::Max)
713    }
714}
715
716/// Ternary operation types.
717#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, strum::AsRefStr, strum::VariantNames)]
718#[derive(serde::Serialize, serde::Deserialize)]
719pub enum TernaryOp {
720    /// Conditional selection: condition ? true_val : false_val
721    Where,
722    /// Multiply-accumulate: a * b + c (fused operation)
723    MulAcc,
724}
725
726/// Per-source upcast axes for WMMA operations.
727///
728/// Each WMMA source (A, B, C) may have different upcast axis sizes
729/// based on `elements_per_thread`. For example, CUDA 8-16-16 with
730/// `elements_per_thread=(8,4,4)` produces A=8, B=4, C=4 element groups.
731#[derive(Debug, Clone, PartialEq, Eq, Hash)]
732#[derive(serde::Serialize, serde::Deserialize)]
733pub struct WmmaUpcastAxes {
734    /// A operand upcast axes (input matrix).
735    pub a: Vec<(usize, usize)>,
736    /// B operand upcast axes (input matrix).
737    pub b: Vec<(usize, usize)>,
738    /// C operand upcast axes (output/accumulator).
739    pub c: Vec<(usize, usize)>,
740}
741
742impl WmmaUpcastAxes {
743    /// Returns deduplicated axis IDs from all three operands.
744    pub fn all_axis_ids(&self) -> Vec<usize> {
745        let mut ids: Vec<usize> = self.a.iter().chain(self.b.iter()).chain(self.c.iter()).map(|(id, _)| *id).collect();
746        ids.sort_unstable();
747        ids.dedup();
748        ids
749    }
750
751    /// Returns the axes for operand at the given index (0=A, 1=B, 2=C).
752    pub fn by_index(&self, index: usize) -> &[(usize, usize)] {
753        match index {
754            0 => &self.a,
755            1 => &self.b,
756            2 => &self.c,
757            _ => panic!("WMMA operand index must be 0, 1, or 2"),
758        }
759    }
760
761    /// Returns the product of axis sizes for operand at given index.
762    pub fn source_size(&self, index: usize) -> usize {
763        self.by_index(index).iter().map(|(_, s)| s).product::<usize>().max(1)
764    }
765}
766
767/// Metadata for WMMA (Warp Matrix Multiply-Accumulate) operations.
768#[derive(Debug, Clone, PartialEq, Eq, Hash)]
769#[derive(serde::Serialize, serde::Deserialize)]
770pub struct WmmaMetadata {
771    /// Operation name (e.g., "WMMA_INSTRUCTION").
772    pub name: String,
773    /// Matrix dimensions (N, M, K).
774    pub dims: (usize, usize, usize),
775    /// Input matrix dtype.
776    pub dtype_in: DType,
777    /// Output/accumulator dtype.
778    pub dtype_out: DType,
779    /// Target device string.
780    pub device: String,
781    /// Thread count.
782    pub threads: usize,
783    /// Per-source upcast axes for vectorization (A, B, C each have their own).
784    pub upcast_axes: WmmaUpcastAxes,
785    /// TC reduce axis IDs (used for exclude_args in expansion).
786    pub reduce_axes: Vec<usize>,
787    /// Tile grid for multi-FMA batching (tile_y_count, tile_x_count).
788    ///
789    /// When > (1, 1), uses load-pair mode and emits multiple FMAs per K iteration
790    /// to compute a 2×2 grid of output tiles. Default is (1, 1).
791    pub tile_grid: (usize, usize),
792}
793
794/// Wrapper for ConstValue that implements Eq and Hash.
795///
796/// Floats don't implement Eq/Hash due to IEEE 754 NaN semantics (NaN != NaN).
797/// This wrapper uses bitwise comparison: two floats are equal if their bit patterns match.
798/// This means:
799/// - NaN values with identical bit patterns are considered equal
800/// - Different NaN representations are not equal
801/// - This is consistent with hash consing requirements
802#[derive(Debug, Clone, Copy)]
803#[derive(serde::Serialize, serde::Deserialize)]
804pub struct ConstValueHash(pub ConstValue);
805
806impl PartialEq for ConstValueHash {
807    fn eq(&self, other: &Self) -> bool {
808        match (self.0, other.0) {
809            (ConstValue::Int(a), ConstValue::Int(b)) => a == b,
810            (ConstValue::UInt(a), ConstValue::UInt(b)) => a == b,
811            (ConstValue::Float(a), ConstValue::Float(b)) => a.to_bits() == b.to_bits(),
812            (ConstValue::Bool(a), ConstValue::Bool(b)) => a == b,
813            _ => false,
814        }
815    }
816}
817
818impl Eq for ConstValueHash {}
819
820impl Hash for ConstValueHash {
821    fn hash<H: Hasher>(&self, state: &mut H) {
822        (discriminant(&self.0)).hash(state);
823        match self.0 {
824            ConstValue::Int(v) => v.hash(state),
825            ConstValue::UInt(v) => v.hash(state),
826            ConstValue::Float(v) => v.to_bits().hash(state),
827            ConstValue::Bool(v) => v.hash(state),
828        }
829    }
830}