ringkernel_ir/nodes.rs
1//! IR node definitions.
2//!
3//! Defines all operations that can appear in the IR.
4
5use crate::{BlockId, Dimension, IrType, ValueId};
6
7/// An IR instruction that produces a value.
8#[derive(Debug, Clone)]
9pub struct Instruction {
10 /// The value produced by this instruction.
11 pub result: ValueId,
12 /// The result type.
13 pub result_type: IrType,
14 /// The operation.
15 pub node: IrNode,
16}
17
18impl Instruction {
19 /// Create a new instruction.
20 pub fn new(result: ValueId, result_type: IrType, node: IrNode) -> Self {
21 Self {
22 result,
23 result_type,
24 node,
25 }
26 }
27}
28
29/// IR node representing an operation.
30#[derive(Debug, Clone)]
31pub enum IrNode {
32 // ========================================================================
33 // Constants and Parameters
34 // ========================================================================
35 /// Constant value.
36 Constant(ConstantValue),
37 /// Parameter reference.
38 Parameter(usize),
39 /// Undefined value (for phi nodes without all predecessors).
40 Undef,
41
42 // ========================================================================
43 // Binary Operations
44 // ========================================================================
45 /// Binary operation.
46 BinaryOp(BinaryOp, ValueId, ValueId),
47
48 // ========================================================================
49 // Unary Operations
50 // ========================================================================
51 /// Unary operation.
52 UnaryOp(UnaryOp, ValueId),
53
54 // ========================================================================
55 // Comparison Operations
56 // ========================================================================
57 /// Comparison operation.
58 Compare(CompareOp, ValueId, ValueId),
59
60 // ========================================================================
61 // Type Conversions
62 // ========================================================================
63 /// Cast to a different type.
64 Cast(CastKind, ValueId, IrType),
65
66 // ========================================================================
67 // Memory Operations
68 // ========================================================================
69 /// Load from pointer.
70 Load(ValueId),
71 /// Store to pointer (no result value).
72 Store(ValueId, ValueId),
73 /// Get element pointer.
74 GetElementPtr(ValueId, Vec<ValueId>),
75 /// Allocate local variable.
76 Alloca(IrType),
77 /// Allocate shared memory.
78 SharedAlloc(IrType, usize),
79 /// Extract struct field.
80 ExtractField(ValueId, usize),
81 /// Insert struct field.
82 InsertField(ValueId, usize, ValueId),
83
84 // ========================================================================
85 // GPU Index Operations
86 // ========================================================================
87 /// Get thread ID.
88 ThreadId(Dimension),
89 /// Get block ID.
90 BlockId(Dimension),
91 /// Get block dimension.
92 BlockDim(Dimension),
93 /// Get grid dimension.
94 GridDim(Dimension),
95 /// Get global thread ID (block_id * block_dim + thread_id).
96 GlobalThreadId(Dimension),
97 /// Get warp/wavefront ID.
98 WarpId,
99 /// Get lane ID within warp.
100 LaneId,
101
102 // ========================================================================
103 // Synchronization Operations
104 // ========================================================================
105 /// Threadgroup/block barrier.
106 Barrier,
107 /// Memory fence.
108 MemoryFence(MemoryScope),
109 /// Grid-wide sync (cooperative groups).
110 GridSync,
111
112 // ========================================================================
113 // Atomic Operations
114 // ========================================================================
115 /// Atomic operation.
116 Atomic(AtomicOp, ValueId, ValueId),
117 /// Atomic compare-and-swap.
118 AtomicCas(ValueId, ValueId, ValueId),
119
120 // ========================================================================
121 // Warp/Subgroup Operations
122 // ========================================================================
123 /// Warp vote (all, any, ballot).
124 WarpVote(WarpVoteOp, ValueId),
125 /// Warp shuffle.
126 WarpShuffle(WarpShuffleOp, ValueId, ValueId),
127 /// Warp reduce.
128 WarpReduce(WarpReduceOp, ValueId),
129
130 // ========================================================================
131 // Math Operations
132 // ========================================================================
133 /// Math function.
134 Math(MathOp, Vec<ValueId>),
135
136 // ========================================================================
137 // Control Flow (non-terminator)
138 // ========================================================================
139 /// Select (ternary operator).
140 Select(ValueId, ValueId, ValueId),
141 /// Phi node for SSA.
142 Phi(Vec<(BlockId, ValueId)>),
143
144 // ========================================================================
145 // RingKernel Messaging
146 // ========================================================================
147 /// Enqueue to output queue.
148 K2HEnqueue(ValueId),
149 /// Dequeue from input queue.
150 H2KDequeue,
151 /// Check if input queue is empty.
152 H2KIsEmpty,
153 /// Send K2K message.
154 K2KSend(ValueId, ValueId),
155 /// Receive K2K message.
156 K2KRecv,
157 /// Try receive K2K message (non-blocking).
158 K2KTryRecv,
159
160 // ========================================================================
161 // HLC Operations
162 // ========================================================================
163 /// Get current HLC time.
164 HlcNow,
165 /// Tick HLC.
166 HlcTick,
167 /// Update HLC from incoming timestamp.
168 HlcUpdate(ValueId),
169
170 // ========================================================================
171 // Function Call
172 // ========================================================================
173 /// Call a function.
174 Call(String, Vec<ValueId>),
175}
176
177/// Constant values.
178#[derive(Debug, Clone, PartialEq)]
179pub enum ConstantValue {
180 /// Boolean constant.
181 Bool(bool),
182 /// 32-bit signed integer.
183 I32(i32),
184 /// 64-bit signed integer.
185 I64(i64),
186 /// 32-bit unsigned integer.
187 U32(u32),
188 /// 64-bit unsigned integer.
189 U64(u64),
190 /// 32-bit float.
191 F32(f32),
192 /// 64-bit float.
193 F64(f64),
194 /// Null pointer.
195 Null,
196 /// Array of constants.
197 Array(Vec<ConstantValue>),
198 /// Struct constant.
199 Struct(Vec<ConstantValue>),
200}
201
202impl ConstantValue {
203 /// Get the IR type of this constant.
204 pub fn ir_type(&self) -> IrType {
205 match self {
206 ConstantValue::Bool(_) => IrType::BOOL,
207 ConstantValue::I32(_) => IrType::I32,
208 ConstantValue::I64(_) => IrType::I64,
209 ConstantValue::U32(_) => IrType::U32,
210 ConstantValue::U64(_) => IrType::U64,
211 ConstantValue::F32(_) => IrType::F32,
212 ConstantValue::F64(_) => IrType::F64,
213 ConstantValue::Null => IrType::ptr(IrType::Void),
214 ConstantValue::Array(elements) => {
215 if elements.is_empty() {
216 IrType::array(IrType::Void, 0)
217 } else {
218 IrType::array(elements[0].ir_type(), elements.len())
219 }
220 }
221 ConstantValue::Struct(_) => IrType::Void, // Would need struct type info
222 }
223 }
224}
225
226/// Binary operations.
227#[derive(Debug, Clone, Copy, PartialEq, Eq)]
228pub enum BinaryOp {
229 // Arithmetic
230 /// Addition.
231 Add,
232 /// Subtraction.
233 Sub,
234 /// Multiplication.
235 Mul,
236 /// Division.
237 Div,
238 /// Remainder/modulo.
239 Rem,
240
241 // Bitwise
242 /// Bitwise AND.
243 And,
244 /// Bitwise OR.
245 Or,
246 /// Bitwise XOR.
247 Xor,
248 /// Left shift.
249 Shl,
250 /// Logical right shift.
251 Shr,
252 /// Arithmetic right shift.
253 Sar,
254
255 // Floating-point specific
256 /// Fused multiply-add.
257 Fma,
258 /// Power.
259 Pow,
260 /// Minimum.
261 Min,
262 /// Maximum.
263 Max,
264}
265
266impl std::fmt::Display for BinaryOp {
267 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
268 match self {
269 BinaryOp::Add => write!(f, "add"),
270 BinaryOp::Sub => write!(f, "sub"),
271 BinaryOp::Mul => write!(f, "mul"),
272 BinaryOp::Div => write!(f, "div"),
273 BinaryOp::Rem => write!(f, "rem"),
274 BinaryOp::And => write!(f, "and"),
275 BinaryOp::Or => write!(f, "or"),
276 BinaryOp::Xor => write!(f, "xor"),
277 BinaryOp::Shl => write!(f, "shl"),
278 BinaryOp::Shr => write!(f, "shr"),
279 BinaryOp::Sar => write!(f, "sar"),
280 BinaryOp::Fma => write!(f, "fma"),
281 BinaryOp::Pow => write!(f, "pow"),
282 BinaryOp::Min => write!(f, "min"),
283 BinaryOp::Max => write!(f, "max"),
284 }
285 }
286}
287
288/// Unary operations.
289#[derive(Debug, Clone, Copy, PartialEq, Eq)]
290pub enum UnaryOp {
291 /// Negation.
292 Neg,
293 /// Bitwise NOT.
294 Not,
295 /// Logical NOT (for booleans).
296 LogicalNot,
297 /// Absolute value.
298 Abs,
299 /// Square root.
300 Sqrt,
301 /// Reciprocal square root.
302 Rsqrt,
303 /// Floor.
304 Floor,
305 /// Ceiling.
306 Ceil,
307 /// Round to nearest.
308 Round,
309 /// Truncate.
310 Trunc,
311 /// Sign.
312 Sign,
313}
314
315impl std::fmt::Display for UnaryOp {
316 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
317 match self {
318 UnaryOp::Neg => write!(f, "neg"),
319 UnaryOp::Not => write!(f, "not"),
320 UnaryOp::LogicalNot => write!(f, "lnot"),
321 UnaryOp::Abs => write!(f, "abs"),
322 UnaryOp::Sqrt => write!(f, "sqrt"),
323 UnaryOp::Rsqrt => write!(f, "rsqrt"),
324 UnaryOp::Floor => write!(f, "floor"),
325 UnaryOp::Ceil => write!(f, "ceil"),
326 UnaryOp::Round => write!(f, "round"),
327 UnaryOp::Trunc => write!(f, "trunc"),
328 UnaryOp::Sign => write!(f, "sign"),
329 }
330 }
331}
332
333/// Comparison operations.
334#[derive(Debug, Clone, Copy, PartialEq, Eq)]
335pub enum CompareOp {
336 /// Equal.
337 Eq,
338 /// Not equal.
339 Ne,
340 /// Less than.
341 Lt,
342 /// Less than or equal.
343 Le,
344 /// Greater than.
345 Gt,
346 /// Greater than or equal.
347 Ge,
348}
349
350impl std::fmt::Display for CompareOp {
351 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
352 match self {
353 CompareOp::Eq => write!(f, "eq"),
354 CompareOp::Ne => write!(f, "ne"),
355 CompareOp::Lt => write!(f, "lt"),
356 CompareOp::Le => write!(f, "le"),
357 CompareOp::Gt => write!(f, "gt"),
358 CompareOp::Ge => write!(f, "ge"),
359 }
360 }
361}
362
363/// Cast kinds.
364#[derive(Debug, Clone, Copy, PartialEq, Eq)]
365pub enum CastKind {
366 /// Bitcast (same size, different type).
367 Bitcast,
368 /// Zero extend.
369 ZeroExtend,
370 /// Sign extend.
371 SignExtend,
372 /// Truncate.
373 Truncate,
374 /// Float to int.
375 FloatToInt,
376 /// Int to float.
377 IntToFloat,
378 /// Float to float (change precision).
379 FloatConvert,
380 /// Pointer cast.
381 PtrCast,
382}
383
384/// Memory scope for fences.
385#[derive(Debug, Clone, Copy, PartialEq, Eq)]
386pub enum MemoryScope {
387 /// Thread-local scope.
388 Thread,
389 /// Threadgroup/block scope.
390 Threadgroup,
391 /// Device scope.
392 Device,
393 /// System scope.
394 System,
395}
396
397/// Atomic operations.
398#[derive(Debug, Clone, Copy, PartialEq, Eq)]
399pub enum AtomicOp {
400 /// Atomic load.
401 Load,
402 /// Atomic store.
403 Store,
404 /// Atomic exchange.
405 Exchange,
406 /// Atomic add.
407 Add,
408 /// Atomic sub.
409 Sub,
410 /// Atomic min.
411 Min,
412 /// Atomic max.
413 Max,
414 /// Atomic AND.
415 And,
416 /// Atomic OR.
417 Or,
418 /// Atomic XOR.
419 Xor,
420}
421
422/// Warp vote operations.
423#[derive(Debug, Clone, Copy, PartialEq, Eq)]
424pub enum WarpVoteOp {
425 /// All threads have true.
426 All,
427 /// Any thread has true.
428 Any,
429 /// Ballot (bitmask of predicates).
430 Ballot,
431}
432
433/// Warp shuffle operations.
434#[derive(Debug, Clone, Copy, PartialEq, Eq)]
435pub enum WarpShuffleOp {
436 /// Shuffle indexed.
437 Index,
438 /// Shuffle up.
439 Up,
440 /// Shuffle down.
441 Down,
442 /// Shuffle XOR.
443 Xor,
444}
445
446/// Warp reduce operations.
447#[derive(Debug, Clone, Copy, PartialEq, Eq)]
448pub enum WarpReduceOp {
449 /// Sum reduction.
450 Sum,
451 /// Product reduction.
452 Product,
453 /// Minimum reduction.
454 Min,
455 /// Maximum reduction.
456 Max,
457 /// AND reduction.
458 And,
459 /// OR reduction.
460 Or,
461 /// XOR reduction.
462 Xor,
463}
464
465/// Math operations (intrinsics).
466#[derive(Debug, Clone, Copy, PartialEq, Eq)]
467pub enum MathOp {
468 // Trigonometric
469 /// Sine.
470 Sin,
471 /// Cosine.
472 Cos,
473 /// Tangent.
474 Tan,
475 /// Arc sine.
476 Asin,
477 /// Arc cosine.
478 Acos,
479 /// Arc tangent.
480 Atan,
481 /// Arc tangent with two arguments.
482 Atan2,
483
484 // Hyperbolic
485 /// Hyperbolic sine.
486 Sinh,
487 /// Hyperbolic cosine.
488 Cosh,
489 /// Hyperbolic tangent.
490 Tanh,
491
492 // Exponential/Logarithmic
493 /// Exponential (e^x).
494 Exp,
495 /// Exponential base 2.
496 Exp2,
497 /// Natural logarithm.
498 Log,
499 /// Logarithm base 2.
500 Log2,
501 /// Logarithm base 10.
502 Log10,
503
504 // Other
505 /// Linear interpolation.
506 Lerp,
507 /// Clamp.
508 Clamp,
509 /// Step function.
510 Step,
511 /// Smooth step.
512 SmoothStep,
513 /// Fract (fractional part).
514 Fract,
515 /// Copy sign.
516 CopySign,
517}
518
519/// Block terminator instructions.
520#[derive(Debug, Clone)]
521pub enum Terminator {
522 /// Return from kernel.
523 Return(Option<ValueId>),
524 /// Unconditional branch.
525 Branch(BlockId),
526 /// Conditional branch.
527 CondBranch(ValueId, BlockId, BlockId),
528 /// Switch statement.
529 Switch(ValueId, BlockId, Vec<(ConstantValue, BlockId)>),
530 /// Unreachable (for optimization).
531 Unreachable,
532}
533
534#[cfg(test)]
535mod tests {
536 use super::*;
537
538 #[test]
539 fn test_constant_ir_type() {
540 assert_eq!(ConstantValue::I32(42).ir_type(), IrType::I32);
541 assert_eq!(ConstantValue::F32(3.125).ir_type(), IrType::F32);
542 assert_eq!(ConstantValue::Bool(true).ir_type(), IrType::BOOL);
543 }
544
545 #[test]
546 fn test_binary_op_display() {
547 assert_eq!(format!("{}", BinaryOp::Add), "add");
548 assert_eq!(format!("{}", BinaryOp::Mul), "mul");
549 }
550
551 #[test]
552 fn test_unary_op_display() {
553 assert_eq!(format!("{}", UnaryOp::Neg), "neg");
554 assert_eq!(format!("{}", UnaryOp::Sqrt), "sqrt");
555 }
556
557 #[test]
558 fn test_compare_op_display() {
559 assert_eq!(format!("{}", CompareOp::Eq), "eq");
560 assert_eq!(format!("{}", CompareOp::Lt), "lt");
561 }
562}