Skip to main content

ringkernel_ir/
nodes.rs

1//! IR node definitions.
2//!
3//! Defines all operations that can appear in the IR.
4
5use crate::{BlockId, Dimension, IrType, ValueId};
6
7/// An IR instruction that produces a value.
8#[derive(Debug, Clone)]
9pub struct Instruction {
10    /// The value produced by this instruction.
11    pub result: ValueId,
12    /// The result type.
13    pub result_type: IrType,
14    /// The operation.
15    pub node: IrNode,
16}
17
18impl Instruction {
19    /// Create a new instruction.
20    pub fn new(result: ValueId, result_type: IrType, node: IrNode) -> Self {
21        Self {
22            result,
23            result_type,
24            node,
25        }
26    }
27}
28
29/// IR node representing an operation.
30#[derive(Debug, Clone)]
31pub enum IrNode {
32    // ========================================================================
33    // Constants and Parameters
34    // ========================================================================
35    /// Constant value.
36    Constant(ConstantValue),
37    /// Parameter reference.
38    Parameter(usize),
39    /// Undefined value (for phi nodes without all predecessors).
40    Undef,
41
42    // ========================================================================
43    // Binary Operations
44    // ========================================================================
45    /// Binary operation.
46    BinaryOp(BinaryOp, ValueId, ValueId),
47
48    // ========================================================================
49    // Unary Operations
50    // ========================================================================
51    /// Unary operation.
52    UnaryOp(UnaryOp, ValueId),
53
54    // ========================================================================
55    // Comparison Operations
56    // ========================================================================
57    /// Comparison operation.
58    Compare(CompareOp, ValueId, ValueId),
59
60    // ========================================================================
61    // Type Conversions
62    // ========================================================================
63    /// Cast to a different type.
64    Cast(CastKind, ValueId, IrType),
65
66    // ========================================================================
67    // Memory Operations
68    // ========================================================================
69    /// Load from pointer.
70    Load(ValueId),
71    /// Store to pointer (no result value).
72    Store(ValueId, ValueId),
73    /// Get element pointer.
74    GetElementPtr(ValueId, Vec<ValueId>),
75    /// Allocate local variable.
76    Alloca(IrType),
77    /// Allocate shared memory.
78    SharedAlloc(IrType, usize),
79    /// Extract struct field.
80    ExtractField(ValueId, usize),
81    /// Insert struct field.
82    InsertField(ValueId, usize, ValueId),
83
84    // ========================================================================
85    // GPU Index Operations
86    // ========================================================================
87    /// Get thread ID.
88    ThreadId(Dimension),
89    /// Get block ID.
90    BlockId(Dimension),
91    /// Get block dimension.
92    BlockDim(Dimension),
93    /// Get grid dimension.
94    GridDim(Dimension),
95    /// Get global thread ID (block_id * block_dim + thread_id).
96    GlobalThreadId(Dimension),
97    /// Get warp/wavefront ID.
98    WarpId,
99    /// Get lane ID within warp.
100    LaneId,
101
102    // ========================================================================
103    // Synchronization Operations
104    // ========================================================================
105    /// Threadgroup/block barrier.
106    Barrier,
107    /// Memory fence.
108    MemoryFence(MemoryScope),
109    /// Grid-wide sync (cooperative groups).
110    GridSync,
111
112    // ========================================================================
113    // Atomic Operations
114    // ========================================================================
115    /// Atomic operation.
116    Atomic(AtomicOp, ValueId, ValueId),
117    /// Atomic compare-and-swap.
118    AtomicCas(ValueId, ValueId, ValueId),
119
120    // ========================================================================
121    // Warp/Subgroup Operations
122    // ========================================================================
123    /// Warp vote (all, any, ballot).
124    WarpVote(WarpVoteOp, ValueId),
125    /// Warp shuffle.
126    WarpShuffle(WarpShuffleOp, ValueId, ValueId),
127    /// Warp reduce.
128    WarpReduce(WarpReduceOp, ValueId),
129
130    // ========================================================================
131    // Math Operations
132    // ========================================================================
133    /// Math function.
134    Math(MathOp, Vec<ValueId>),
135
136    // ========================================================================
137    // Control Flow (non-terminator)
138    // ========================================================================
139    /// Select (ternary operator).
140    Select(ValueId, ValueId, ValueId),
141    /// Phi node for SSA.
142    Phi(Vec<(BlockId, ValueId)>),
143
144    // ========================================================================
145    // RingKernel Messaging
146    // ========================================================================
147    /// Enqueue to output queue.
148    K2HEnqueue(ValueId),
149    /// Dequeue from input queue.
150    H2KDequeue,
151    /// Check if input queue is empty.
152    H2KIsEmpty,
153    /// Send K2K message.
154    K2KSend(ValueId, ValueId),
155    /// Receive K2K message.
156    K2KRecv,
157    /// Try receive K2K message (non-blocking).
158    K2KTryRecv,
159
160    // ========================================================================
161    // HLC Operations
162    // ========================================================================
163    /// Get current HLC time.
164    HlcNow,
165    /// Tick HLC.
166    HlcTick,
167    /// Update HLC from incoming timestamp.
168    HlcUpdate(ValueId),
169
170    // ========================================================================
171    // Function Call
172    // ========================================================================
173    /// Call a function.
174    Call(String, Vec<ValueId>),
175}
176
177/// Constant values.
178#[derive(Debug, Clone, PartialEq)]
179pub enum ConstantValue {
180    /// Boolean constant.
181    Bool(bool),
182    /// 32-bit signed integer.
183    I32(i32),
184    /// 64-bit signed integer.
185    I64(i64),
186    /// 32-bit unsigned integer.
187    U32(u32),
188    /// 64-bit unsigned integer.
189    U64(u64),
190    /// 32-bit float.
191    F32(f32),
192    /// 64-bit float.
193    F64(f64),
194    /// Null pointer.
195    Null,
196    /// Array of constants.
197    Array(Vec<ConstantValue>),
198    /// Struct constant.
199    Struct(Vec<ConstantValue>),
200}
201
202impl ConstantValue {
203    /// Get the IR type of this constant.
204    pub fn ir_type(&self) -> IrType {
205        match self {
206            ConstantValue::Bool(_) => IrType::BOOL,
207            ConstantValue::I32(_) => IrType::I32,
208            ConstantValue::I64(_) => IrType::I64,
209            ConstantValue::U32(_) => IrType::U32,
210            ConstantValue::U64(_) => IrType::U64,
211            ConstantValue::F32(_) => IrType::F32,
212            ConstantValue::F64(_) => IrType::F64,
213            ConstantValue::Null => IrType::ptr(IrType::Void),
214            ConstantValue::Array(elements) => {
215                if elements.is_empty() {
216                    IrType::array(IrType::Void, 0)
217                } else {
218                    IrType::array(elements[0].ir_type(), elements.len())
219                }
220            }
221            ConstantValue::Struct(_) => IrType::Void, // Would need struct type info
222        }
223    }
224}
225
226/// Binary operations.
227#[derive(Debug, Clone, Copy, PartialEq, Eq)]
228pub enum BinaryOp {
229    // Arithmetic
230    /// Addition.
231    Add,
232    /// Subtraction.
233    Sub,
234    /// Multiplication.
235    Mul,
236    /// Division.
237    Div,
238    /// Remainder/modulo.
239    Rem,
240
241    // Bitwise
242    /// Bitwise AND.
243    And,
244    /// Bitwise OR.
245    Or,
246    /// Bitwise XOR.
247    Xor,
248    /// Left shift.
249    Shl,
250    /// Logical right shift.
251    Shr,
252    /// Arithmetic right shift.
253    Sar,
254
255    // Floating-point specific
256    /// Fused multiply-add.
257    Fma,
258    /// Power.
259    Pow,
260    /// Minimum.
261    Min,
262    /// Maximum.
263    Max,
264}
265
266impl std::fmt::Display for BinaryOp {
267    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
268        match self {
269            BinaryOp::Add => write!(f, "add"),
270            BinaryOp::Sub => write!(f, "sub"),
271            BinaryOp::Mul => write!(f, "mul"),
272            BinaryOp::Div => write!(f, "div"),
273            BinaryOp::Rem => write!(f, "rem"),
274            BinaryOp::And => write!(f, "and"),
275            BinaryOp::Or => write!(f, "or"),
276            BinaryOp::Xor => write!(f, "xor"),
277            BinaryOp::Shl => write!(f, "shl"),
278            BinaryOp::Shr => write!(f, "shr"),
279            BinaryOp::Sar => write!(f, "sar"),
280            BinaryOp::Fma => write!(f, "fma"),
281            BinaryOp::Pow => write!(f, "pow"),
282            BinaryOp::Min => write!(f, "min"),
283            BinaryOp::Max => write!(f, "max"),
284        }
285    }
286}
287
288/// Unary operations.
289#[derive(Debug, Clone, Copy, PartialEq, Eq)]
290pub enum UnaryOp {
291    /// Negation.
292    Neg,
293    /// Bitwise NOT.
294    Not,
295    /// Logical NOT (for booleans).
296    LogicalNot,
297    /// Absolute value.
298    Abs,
299    /// Square root.
300    Sqrt,
301    /// Reciprocal square root.
302    Rsqrt,
303    /// Floor.
304    Floor,
305    /// Ceiling.
306    Ceil,
307    /// Round to nearest.
308    Round,
309    /// Truncate.
310    Trunc,
311    /// Sign.
312    Sign,
313}
314
315impl std::fmt::Display for UnaryOp {
316    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
317        match self {
318            UnaryOp::Neg => write!(f, "neg"),
319            UnaryOp::Not => write!(f, "not"),
320            UnaryOp::LogicalNot => write!(f, "lnot"),
321            UnaryOp::Abs => write!(f, "abs"),
322            UnaryOp::Sqrt => write!(f, "sqrt"),
323            UnaryOp::Rsqrt => write!(f, "rsqrt"),
324            UnaryOp::Floor => write!(f, "floor"),
325            UnaryOp::Ceil => write!(f, "ceil"),
326            UnaryOp::Round => write!(f, "round"),
327            UnaryOp::Trunc => write!(f, "trunc"),
328            UnaryOp::Sign => write!(f, "sign"),
329        }
330    }
331}
332
333/// Comparison operations.
334#[derive(Debug, Clone, Copy, PartialEq, Eq)]
335pub enum CompareOp {
336    /// Equal.
337    Eq,
338    /// Not equal.
339    Ne,
340    /// Less than.
341    Lt,
342    /// Less than or equal.
343    Le,
344    /// Greater than.
345    Gt,
346    /// Greater than or equal.
347    Ge,
348}
349
350impl std::fmt::Display for CompareOp {
351    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
352        match self {
353            CompareOp::Eq => write!(f, "eq"),
354            CompareOp::Ne => write!(f, "ne"),
355            CompareOp::Lt => write!(f, "lt"),
356            CompareOp::Le => write!(f, "le"),
357            CompareOp::Gt => write!(f, "gt"),
358            CompareOp::Ge => write!(f, "ge"),
359        }
360    }
361}
362
363/// Cast kinds.
364#[derive(Debug, Clone, Copy, PartialEq, Eq)]
365pub enum CastKind {
366    /// Bitcast (same size, different type).
367    Bitcast,
368    /// Zero extend.
369    ZeroExtend,
370    /// Sign extend.
371    SignExtend,
372    /// Truncate.
373    Truncate,
374    /// Float to int.
375    FloatToInt,
376    /// Int to float.
377    IntToFloat,
378    /// Float to float (change precision).
379    FloatConvert,
380    /// Pointer cast.
381    PtrCast,
382}
383
384/// Memory scope for fences.
385#[derive(Debug, Clone, Copy, PartialEq, Eq)]
386pub enum MemoryScope {
387    /// Thread-local scope.
388    Thread,
389    /// Threadgroup/block scope.
390    Threadgroup,
391    /// Device scope.
392    Device,
393    /// System scope.
394    System,
395}
396
397/// Atomic operations.
398#[derive(Debug, Clone, Copy, PartialEq, Eq)]
399pub enum AtomicOp {
400    /// Atomic load.
401    Load,
402    /// Atomic store.
403    Store,
404    /// Atomic exchange.
405    Exchange,
406    /// Atomic add.
407    Add,
408    /// Atomic sub.
409    Sub,
410    /// Atomic min.
411    Min,
412    /// Atomic max.
413    Max,
414    /// Atomic AND.
415    And,
416    /// Atomic OR.
417    Or,
418    /// Atomic XOR.
419    Xor,
420}
421
422/// Warp vote operations.
423#[derive(Debug, Clone, Copy, PartialEq, Eq)]
424pub enum WarpVoteOp {
425    /// All threads have true.
426    All,
427    /// Any thread has true.
428    Any,
429    /// Ballot (bitmask of predicates).
430    Ballot,
431}
432
433/// Warp shuffle operations.
434#[derive(Debug, Clone, Copy, PartialEq, Eq)]
435pub enum WarpShuffleOp {
436    /// Shuffle indexed.
437    Index,
438    /// Shuffle up.
439    Up,
440    /// Shuffle down.
441    Down,
442    /// Shuffle XOR.
443    Xor,
444}
445
446/// Warp reduce operations.
447#[derive(Debug, Clone, Copy, PartialEq, Eq)]
448pub enum WarpReduceOp {
449    /// Sum reduction.
450    Sum,
451    /// Product reduction.
452    Product,
453    /// Minimum reduction.
454    Min,
455    /// Maximum reduction.
456    Max,
457    /// AND reduction.
458    And,
459    /// OR reduction.
460    Or,
461    /// XOR reduction.
462    Xor,
463}
464
465/// Math operations (intrinsics).
466#[derive(Debug, Clone, Copy, PartialEq, Eq)]
467pub enum MathOp {
468    // Trigonometric
469    /// Sine.
470    Sin,
471    /// Cosine.
472    Cos,
473    /// Tangent.
474    Tan,
475    /// Arc sine.
476    Asin,
477    /// Arc cosine.
478    Acos,
479    /// Arc tangent.
480    Atan,
481    /// Arc tangent with two arguments.
482    Atan2,
483
484    // Hyperbolic
485    /// Hyperbolic sine.
486    Sinh,
487    /// Hyperbolic cosine.
488    Cosh,
489    /// Hyperbolic tangent.
490    Tanh,
491
492    // Exponential/Logarithmic
493    /// Exponential (e^x).
494    Exp,
495    /// Exponential base 2.
496    Exp2,
497    /// Natural logarithm.
498    Log,
499    /// Logarithm base 2.
500    Log2,
501    /// Logarithm base 10.
502    Log10,
503
504    // Other
505    /// Linear interpolation.
506    Lerp,
507    /// Clamp.
508    Clamp,
509    /// Step function.
510    Step,
511    /// Smooth step.
512    SmoothStep,
513    /// Fract (fractional part).
514    Fract,
515    /// Copy sign.
516    CopySign,
517}
518
519/// Block terminator instructions.
520#[derive(Debug, Clone)]
521pub enum Terminator {
522    /// Return from kernel.
523    Return(Option<ValueId>),
524    /// Unconditional branch.
525    Branch(BlockId),
526    /// Conditional branch.
527    CondBranch(ValueId, BlockId, BlockId),
528    /// Switch statement.
529    Switch(ValueId, BlockId, Vec<(ConstantValue, BlockId)>),
530    /// Unreachable (for optimization).
531    Unreachable,
532}
533
534#[cfg(test)]
535mod tests {
536    use super::*;
537
538    #[test]
539    fn test_constant_ir_type() {
540        assert_eq!(ConstantValue::I32(42).ir_type(), IrType::I32);
541        assert_eq!(ConstantValue::F32(3.125).ir_type(), IrType::F32);
542        assert_eq!(ConstantValue::Bool(true).ir_type(), IrType::BOOL);
543    }
544
545    #[test]
546    fn test_binary_op_display() {
547        assert_eq!(format!("{}", BinaryOp::Add), "add");
548        assert_eq!(format!("{}", BinaryOp::Mul), "mul");
549    }
550
551    #[test]
552    fn test_unary_op_display() {
553        assert_eq!(format!("{}", UnaryOp::Neg), "neg");
554        assert_eq!(format!("{}", UnaryOp::Sqrt), "sqrt");
555    }
556
557    #[test]
558    fn test_compare_op_display() {
559        assert_eq!(format!("{}", CompareOp::Eq), "eq");
560        assert_eq!(format!("{}", CompareOp::Lt), "lt");
561    }
562}