Skip to main content

ringkernel_ir/
builder.rs

1//! IR builder API.
2//!
3//! Provides an ergonomic interface for constructing IR modules.
4
5use crate::{
6    nodes::*, Block, BlockId, CapabilityFlag, Dimension, Instruction, IrModule, IrType,
7    KernelConfig, KernelMode, Parameter, Terminator, Value, ValueId,
8};
9
10/// Builder for constructing IR modules.
11pub struct IrBuilder {
12    module: IrModule,
13    current_block: BlockId,
14}
15
16impl IrBuilder {
17    /// Create a new builder.
18    pub fn new(name: impl Into<String>) -> Self {
19        let module = IrModule::new(name);
20        let entry = module.entry_block;
21        Self {
22            module,
23            current_block: entry,
24        }
25    }
26
27    /// Build and return the IR module.
28    pub fn build(self) -> IrModule {
29        self.module
30    }
31
32    /// Get a reference to the module being built.
33    pub fn module(&self) -> &IrModule {
34        &self.module
35    }
36
37    /// Set kernel configuration.
38    pub fn set_config(&mut self, config: KernelConfig) {
39        self.module.config = config;
40    }
41
42    /// Set block size.
43    pub fn set_block_size(&mut self, x: u32, y: u32, z: u32) {
44        self.module.config.block_size = (x, y, z);
45    }
46
47    /// Mark as persistent kernel.
48    pub fn set_persistent(&mut self, persistent: bool) {
49        self.module.config.is_persistent = persistent;
50        if persistent {
51            self.module.config.mode = KernelMode::Persistent;
52        }
53    }
54
55    // ========================================================================
56    // Parameters
57    // ========================================================================
58
59    /// Add a parameter.
60    pub fn parameter(&mut self, name: impl Into<String>, ty: IrType) -> ValueId {
61        let value_id = ValueId::new();
62        let index = self.module.parameters.len();
63
64        self.module.parameters.push(Parameter {
65            name: name.into(),
66            ty: ty.clone(),
67            value_id,
68            index,
69        });
70
71        let value = Value::new(ty, IrNode::Parameter(index));
72        self.module.values.insert(value_id, value);
73
74        value_id
75    }
76
77    // ========================================================================
78    // Blocks
79    // ========================================================================
80
81    /// Create a new block.
82    pub fn create_block(&mut self, label: impl Into<String>) -> BlockId {
83        let id = BlockId::new();
84        self.module.blocks.insert(id, Block::new(id, label));
85        id
86    }
87
88    /// Switch to a different block.
89    pub fn switch_to_block(&mut self, block: BlockId) {
90        self.current_block = block;
91    }
92
93    /// Get current block ID.
94    pub fn current_block(&self) -> BlockId {
95        self.current_block
96    }
97
98    // ========================================================================
99    // Constants
100    // ========================================================================
101
102    /// Create an i32 constant.
103    pub fn const_i32(&mut self, value: i32) -> ValueId {
104        self.add_value(IrType::I32, IrNode::Constant(ConstantValue::I32(value)))
105    }
106
107    /// Create an i64 constant.
108    pub fn const_i64(&mut self, value: i64) -> ValueId {
109        self.module.required_capabilities.add(CapabilityFlag::Int64);
110        self.add_value(IrType::I64, IrNode::Constant(ConstantValue::I64(value)))
111    }
112
113    /// Create a u32 constant.
114    pub fn const_u32(&mut self, value: u32) -> ValueId {
115        self.add_value(IrType::U32, IrNode::Constant(ConstantValue::U32(value)))
116    }
117
118    /// Create a u64 constant.
119    pub fn const_u64(&mut self, value: u64) -> ValueId {
120        self.module.required_capabilities.add(CapabilityFlag::Int64);
121        self.add_value(IrType::U64, IrNode::Constant(ConstantValue::U64(value)))
122    }
123
124    /// Create an f32 constant.
125    pub fn const_f32(&mut self, value: f32) -> ValueId {
126        self.add_value(IrType::F32, IrNode::Constant(ConstantValue::F32(value)))
127    }
128
129    /// Create an f64 constant.
130    pub fn const_f64(&mut self, value: f64) -> ValueId {
131        self.module
132            .required_capabilities
133            .add(CapabilityFlag::Float64);
134        self.add_value(IrType::F64, IrNode::Constant(ConstantValue::F64(value)))
135    }
136
137    /// Create a boolean constant.
138    pub fn const_bool(&mut self, value: bool) -> ValueId {
139        self.add_value(IrType::BOOL, IrNode::Constant(ConstantValue::Bool(value)))
140    }
141
142    // ========================================================================
143    // Binary Operations
144    // ========================================================================
145
146    /// Add two values.
147    pub fn add(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
148        let ty = self.get_value_type(lhs);
149        self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::Add, lhs, rhs))
150    }
151
152    /// Subtract two values.
153    pub fn sub(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
154        let ty = self.get_value_type(lhs);
155        self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::Sub, lhs, rhs))
156    }
157
158    /// Multiply two values.
159    pub fn mul(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
160        let ty = self.get_value_type(lhs);
161        self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::Mul, lhs, rhs))
162    }
163
164    /// Divide two values.
165    pub fn div(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
166        let ty = self.get_value_type(lhs);
167        self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::Div, lhs, rhs))
168    }
169
170    /// Remainder.
171    pub fn rem(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
172        let ty = self.get_value_type(lhs);
173        self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::Rem, lhs, rhs))
174    }
175
176    /// Bitwise AND.
177    pub fn and(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
178        let ty = self.get_value_type(lhs);
179        self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::And, lhs, rhs))
180    }
181
182    /// Bitwise OR.
183    pub fn or(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
184        let ty = self.get_value_type(lhs);
185        self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::Or, lhs, rhs))
186    }
187
188    /// Bitwise XOR.
189    pub fn xor(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
190        let ty = self.get_value_type(lhs);
191        self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::Xor, lhs, rhs))
192    }
193
194    /// Left shift.
195    pub fn shl(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
196        let ty = self.get_value_type(lhs);
197        self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::Shl, lhs, rhs))
198    }
199
200    /// Logical right shift.
201    pub fn shr(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
202        let ty = self.get_value_type(lhs);
203        self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::Shr, lhs, rhs))
204    }
205
206    /// Minimum.
207    pub fn min(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
208        let ty = self.get_value_type(lhs);
209        self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::Min, lhs, rhs))
210    }
211
212    /// Maximum.
213    pub fn max(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
214        let ty = self.get_value_type(lhs);
215        self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::Max, lhs, rhs))
216    }
217
218    // ========================================================================
219    // Unary Operations
220    // ========================================================================
221
222    /// Negate.
223    pub fn neg(&mut self, value: ValueId) -> ValueId {
224        let ty = self.get_value_type(value);
225        self.add_instruction(ty, IrNode::UnaryOp(UnaryOp::Neg, value))
226    }
227
228    /// Bitwise NOT.
229    pub fn not(&mut self, value: ValueId) -> ValueId {
230        let ty = self.get_value_type(value);
231        self.add_instruction(ty, IrNode::UnaryOp(UnaryOp::Not, value))
232    }
233
234    /// Absolute value.
235    pub fn abs(&mut self, value: ValueId) -> ValueId {
236        let ty = self.get_value_type(value);
237        self.add_instruction(ty, IrNode::UnaryOp(UnaryOp::Abs, value))
238    }
239
240    /// Square root.
241    pub fn sqrt(&mut self, value: ValueId) -> ValueId {
242        let ty = self.get_value_type(value);
243        self.add_instruction(ty, IrNode::UnaryOp(UnaryOp::Sqrt, value))
244    }
245
246    /// Floor.
247    pub fn floor(&mut self, value: ValueId) -> ValueId {
248        let ty = self.get_value_type(value);
249        self.add_instruction(ty, IrNode::UnaryOp(UnaryOp::Floor, value))
250    }
251
252    /// Ceiling.
253    pub fn ceil(&mut self, value: ValueId) -> ValueId {
254        let ty = self.get_value_type(value);
255        self.add_instruction(ty, IrNode::UnaryOp(UnaryOp::Ceil, value))
256    }
257
258    // ========================================================================
259    // Comparison
260    // ========================================================================
261
262    /// Equal comparison.
263    pub fn eq(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
264        self.add_instruction(IrType::BOOL, IrNode::Compare(CompareOp::Eq, lhs, rhs))
265    }
266
267    /// Not equal comparison.
268    pub fn ne(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
269        self.add_instruction(IrType::BOOL, IrNode::Compare(CompareOp::Ne, lhs, rhs))
270    }
271
272    /// Less than.
273    pub fn lt(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
274        self.add_instruction(IrType::BOOL, IrNode::Compare(CompareOp::Lt, lhs, rhs))
275    }
276
277    /// Less than or equal.
278    pub fn le(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
279        self.add_instruction(IrType::BOOL, IrNode::Compare(CompareOp::Le, lhs, rhs))
280    }
281
282    /// Greater than.
283    pub fn gt(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
284        self.add_instruction(IrType::BOOL, IrNode::Compare(CompareOp::Gt, lhs, rhs))
285    }
286
287    /// Greater than or equal.
288    pub fn ge(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
289        self.add_instruction(IrType::BOOL, IrNode::Compare(CompareOp::Ge, lhs, rhs))
290    }
291
292    // ========================================================================
293    // Memory
294    // ========================================================================
295
296    /// Load from pointer.
297    pub fn load(&mut self, ptr: ValueId) -> ValueId {
298        let ptr_ty = self.get_value_type(ptr);
299        let elem_ty = match ptr_ty {
300            IrType::Ptr(inner) => (*inner).clone(),
301            _ => IrType::Void,
302        };
303        self.add_instruction(elem_ty, IrNode::Load(ptr))
304    }
305
306    /// Store to pointer.
307    pub fn store(&mut self, ptr: ValueId, value: ValueId) {
308        self.add_instruction(IrType::Void, IrNode::Store(ptr, value));
309    }
310
311    /// Get element pointer.
312    pub fn gep(&mut self, ptr: ValueId, indices: Vec<ValueId>) -> ValueId {
313        let ty = self.get_value_type(ptr);
314        self.add_instruction(ty, IrNode::GetElementPtr(ptr, indices))
315    }
316
317    /// Allocate shared memory.
318    pub fn shared_alloc(&mut self, ty: IrType, count: usize) -> ValueId {
319        self.module
320            .required_capabilities
321            .add(CapabilityFlag::SharedMemory);
322        let ptr_ty = IrType::ptr(ty.clone());
323        self.add_instruction(ptr_ty, IrNode::SharedAlloc(ty, count))
324    }
325
326    // ========================================================================
327    // GPU Indexing
328    // ========================================================================
329
330    /// Get thread ID.
331    pub fn thread_id(&mut self, dim: Dimension) -> ValueId {
332        self.add_instruction(IrType::U32, IrNode::ThreadId(dim))
333    }
334
335    /// Get block ID.
336    pub fn block_id(&mut self, dim: Dimension) -> ValueId {
337        self.add_instruction(IrType::U32, IrNode::BlockId(dim))
338    }
339
340    /// Get block dimension.
341    pub fn block_dim(&mut self, dim: Dimension) -> ValueId {
342        self.add_instruction(IrType::U32, IrNode::BlockDim(dim))
343    }
344
345    /// Get grid dimension.
346    pub fn grid_dim(&mut self, dim: Dimension) -> ValueId {
347        self.add_instruction(IrType::U32, IrNode::GridDim(dim))
348    }
349
350    /// Get global thread ID.
351    pub fn global_thread_id(&mut self, dim: Dimension) -> ValueId {
352        self.add_instruction(IrType::U32, IrNode::GlobalThreadId(dim))
353    }
354
355    // ========================================================================
356    // Synchronization
357    // ========================================================================
358
359    /// Block/threadgroup barrier.
360    pub fn barrier(&mut self) {
361        self.add_instruction(IrType::Void, IrNode::Barrier);
362    }
363
364    /// Memory fence.
365    pub fn fence(&mut self, scope: MemoryScope) {
366        self.add_instruction(IrType::Void, IrNode::MemoryFence(scope));
367    }
368
369    /// Grid sync (cooperative groups).
370    pub fn grid_sync(&mut self) {
371        self.module
372            .required_capabilities
373            .add(CapabilityFlag::CooperativeGroups);
374        self.add_instruction(IrType::Void, IrNode::GridSync);
375    }
376
377    // ========================================================================
378    // Atomics
379    // ========================================================================
380
381    /// Atomic add.
382    pub fn atomic_add(&mut self, ptr: ValueId, value: ValueId) -> ValueId {
383        let ty = self.get_value_type(value);
384        self.add_instruction(ty, IrNode::Atomic(AtomicOp::Add, ptr, value))
385    }
386
387    /// Atomic exchange.
388    pub fn atomic_exchange(&mut self, ptr: ValueId, value: ValueId) -> ValueId {
389        let ty = self.get_value_type(value);
390        self.add_instruction(ty, IrNode::Atomic(AtomicOp::Exchange, ptr, value))
391    }
392
393    /// Atomic compare-and-swap.
394    pub fn atomic_cas(&mut self, ptr: ValueId, expected: ValueId, desired: ValueId) -> ValueId {
395        let ty = self.get_value_type(expected);
396        self.add_instruction(ty, IrNode::AtomicCas(ptr, expected, desired))
397    }
398
399    // ========================================================================
400    // Control Flow
401    // ========================================================================
402
403    /// Select (ternary).
404    pub fn select(&mut self, cond: ValueId, then_val: ValueId, else_val: ValueId) -> ValueId {
405        let ty = self.get_value_type(then_val);
406        self.add_instruction(ty, IrNode::Select(cond, then_val, else_val))
407    }
408
409    /// Branch to block.
410    pub fn branch(&mut self, target: BlockId) {
411        self.set_terminator(Terminator::Branch(target));
412        self.add_successor(target);
413    }
414
415    /// Conditional branch.
416    pub fn cond_branch(&mut self, cond: ValueId, then_block: BlockId, else_block: BlockId) {
417        self.set_terminator(Terminator::CondBranch(cond, then_block, else_block));
418        self.add_successor(then_block);
419        self.add_successor(else_block);
420    }
421
422    /// Return from kernel.
423    pub fn ret(&mut self) {
424        self.set_terminator(Terminator::Return(None));
425    }
426
427    /// Return value from kernel.
428    pub fn ret_value(&mut self, value: ValueId) {
429        self.set_terminator(Terminator::Return(Some(value)));
430    }
431
432    // ========================================================================
433    // RingKernel Messaging
434    // ========================================================================
435
436    /// Enqueue to output (K2H).
437    pub fn k2h_enqueue(&mut self, message: ValueId) {
438        self.add_instruction(IrType::Void, IrNode::K2HEnqueue(message));
439    }
440
441    /// Dequeue from input (H2K).
442    pub fn h2k_dequeue(&mut self, msg_ty: IrType) -> ValueId {
443        self.add_instruction(msg_ty, IrNode::H2KDequeue)
444    }
445
446    /// Check if input queue is empty.
447    pub fn h2k_is_empty(&mut self) -> ValueId {
448        self.add_instruction(IrType::BOOL, IrNode::H2KIsEmpty)
449    }
450
451    /// Send K2K message.
452    pub fn k2k_send(&mut self, dest: ValueId, message: ValueId) {
453        self.add_instruction(IrType::Void, IrNode::K2KSend(dest, message));
454    }
455
456    /// Try receive K2K message.
457    pub fn k2k_try_recv(&mut self, msg_ty: IrType) -> ValueId {
458        self.add_instruction(msg_ty, IrNode::K2KTryRecv)
459    }
460
461    // ========================================================================
462    // HLC Operations
463    // ========================================================================
464
465    /// Get current HLC time.
466    pub fn hlc_now(&mut self) -> ValueId {
467        self.add_instruction(IrType::U64, IrNode::HlcNow)
468    }
469
470    /// Tick HLC.
471    pub fn hlc_tick(&mut self) -> ValueId {
472        self.add_instruction(IrType::U64, IrNode::HlcTick)
473    }
474
475    // ========================================================================
476    // Helper Methods
477    // ========================================================================
478
479    fn add_value(&mut self, ty: IrType, node: IrNode) -> ValueId {
480        let value = Value::new(ty, node);
481        let id = value.id;
482        self.module.values.insert(id, value);
483        id
484    }
485
486    fn add_instruction(&mut self, ty: IrType, node: IrNode) -> ValueId {
487        let result = ValueId::new();
488        let inst = Instruction::new(result, ty.clone(), node.clone());
489
490        if let Some(block) = self.module.blocks.get_mut(&self.current_block) {
491            block.add_instruction(inst);
492        }
493
494        // Also add to values map
495        let value = Value::new(ty, node);
496        self.module.values.insert(result, value);
497
498        result
499    }
500
501    fn set_terminator(&mut self, term: Terminator) {
502        if let Some(block) = self.module.blocks.get_mut(&self.current_block) {
503            block.set_terminator(term);
504        }
505    }
506
507    fn add_successor(&mut self, succ: BlockId) {
508        let current = self.current_block;
509        if let Some(block) = self.module.blocks.get_mut(&current) {
510            block.successors.push(succ);
511        }
512        if let Some(succ_block) = self.module.blocks.get_mut(&succ) {
513            succ_block.predecessors.push(current);
514        }
515    }
516
517    fn get_value_type(&self, id: ValueId) -> IrType {
518        self.module
519            .values
520            .get(&id)
521            .map(|v| v.ty.clone())
522            .unwrap_or(IrType::Void)
523    }
524}
525
526/// Scoped builder for structured control flow.
527pub struct IrBuilderScope<'a> {
528    builder: &'a mut IrBuilder,
529}
530
531impl<'a> IrBuilderScope<'a> {
532    /// Create a new scope.
533    pub fn new(builder: &'a mut IrBuilder) -> Self {
534        Self { builder }
535    }
536
537    /// Add two values.
538    pub fn add(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
539        self.builder.add(lhs, rhs)
540    }
541
542    /// Multiply two values.
543    pub fn mul(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
544        self.builder.mul(lhs, rhs)
545    }
546
547    /// Load from pointer.
548    pub fn load(&mut self, ptr: ValueId) -> ValueId {
549        self.builder.load(ptr)
550    }
551
552    /// Store to pointer.
553    pub fn store(&mut self, ptr: ValueId, value: ValueId) {
554        self.builder.store(ptr, value);
555    }
556
557    /// Get thread ID.
558    pub fn thread_id(&mut self, dim: Dimension) -> ValueId {
559        self.builder.thread_id(dim)
560    }
561}
562
563#[cfg(test)]
564mod tests {
565    use super::*;
566
567    #[test]
568    fn test_builder_basic() {
569        let mut builder = IrBuilder::new("test");
570
571        let x = builder.parameter("x", IrType::ptr(IrType::F32));
572        let y = builder.parameter("y", IrType::ptr(IrType::F32));
573
574        let _idx = builder.thread_id(Dimension::X);
575        let x_val = builder.load(x);
576        let y_val = builder.load(y);
577        let result = builder.add(x_val, y_val);
578        builder.store(y, result);
579        builder.ret();
580
581        let module = builder.build();
582        assert_eq!(module.name, "test");
583        assert_eq!(module.parameters.len(), 2);
584    }
585
586    #[test]
587    fn test_builder_constants() {
588        let mut builder = IrBuilder::new("test");
589
590        let a = builder.const_i32(42);
591        let b = builder.const_f32(3.125);
592        let c = builder.const_bool(true);
593
594        let module = builder.build();
595        assert!(module.values.contains_key(&a));
596        assert!(module.values.contains_key(&b));
597        assert!(module.values.contains_key(&c));
598    }
599
600    #[test]
601    fn test_builder_control_flow() {
602        let mut builder = IrBuilder::new("test");
603
604        let n = builder.parameter("n", IrType::I32);
605        let idx = builder.thread_id(Dimension::X);
606        let cond = builder.lt(idx, n);
607
608        let then_block = builder.create_block("then");
609        let end_block = builder.create_block("end");
610
611        builder.cond_branch(cond, then_block, end_block);
612
613        builder.switch_to_block(then_block);
614        builder.branch(end_block);
615
616        builder.switch_to_block(end_block);
617        builder.ret();
618
619        let module = builder.build();
620        assert_eq!(module.blocks.len(), 3);
621    }
622
623    #[test]
624    fn test_builder_capabilities() {
625        let mut builder = IrBuilder::new("test");
626
627        // f64 should add Float64 capability
628        builder.const_f64(1.0);
629
630        // grid_sync should add CooperativeGroups
631        builder.grid_sync();
632
633        let module = builder.build();
634        assert!(module.required_capabilities.has(CapabilityFlag::Float64));
635        assert!(module
636            .required_capabilities
637            .has(CapabilityFlag::CooperativeGroups));
638    }
639
640    #[test]
641    fn test_builder_persistent_config() {
642        let mut builder = IrBuilder::new("persistent_kernel");
643        builder.set_persistent(true);
644        builder.set_block_size(128, 1, 1);
645
646        let module = builder.build();
647        assert!(module.config.is_persistent);
648        assert_eq!(module.config.mode, KernelMode::Persistent);
649        assert_eq!(module.config.block_size, (128, 1, 1));
650    }
651}