1use crate::{
6 nodes::*, Block, BlockId, CapabilityFlag, Dimension, Instruction, IrModule, IrType,
7 KernelConfig, KernelMode, Parameter, Terminator, Value, ValueId,
8};
9
10pub struct IrBuilder {
12 module: IrModule,
13 current_block: BlockId,
14}
15
16impl IrBuilder {
17 pub fn new(name: impl Into<String>) -> Self {
19 let module = IrModule::new(name);
20 let entry = module.entry_block;
21 Self {
22 module,
23 current_block: entry,
24 }
25 }
26
27 pub fn build(self) -> IrModule {
29 self.module
30 }
31
32 pub fn module(&self) -> &IrModule {
34 &self.module
35 }
36
37 pub fn set_config(&mut self, config: KernelConfig) {
39 self.module.config = config;
40 }
41
42 pub fn set_block_size(&mut self, x: u32, y: u32, z: u32) {
44 self.module.config.block_size = (x, y, z);
45 }
46
47 pub fn set_persistent(&mut self, persistent: bool) {
49 self.module.config.is_persistent = persistent;
50 if persistent {
51 self.module.config.mode = KernelMode::Persistent;
52 }
53 }
54
55 pub fn parameter(&mut self, name: impl Into<String>, ty: IrType) -> ValueId {
61 let value_id = ValueId::new();
62 let index = self.module.parameters.len();
63
64 self.module.parameters.push(Parameter {
65 name: name.into(),
66 ty: ty.clone(),
67 value_id,
68 index,
69 });
70
71 let value = Value::new(ty, IrNode::Parameter(index));
72 self.module.values.insert(value_id, value);
73
74 value_id
75 }
76
77 pub fn create_block(&mut self, label: impl Into<String>) -> BlockId {
83 let id = BlockId::new();
84 self.module.blocks.insert(id, Block::new(id, label));
85 id
86 }
87
88 pub fn switch_to_block(&mut self, block: BlockId) {
90 self.current_block = block;
91 }
92
93 pub fn current_block(&self) -> BlockId {
95 self.current_block
96 }
97
98 pub fn const_i32(&mut self, value: i32) -> ValueId {
104 self.add_value(IrType::I32, IrNode::Constant(ConstantValue::I32(value)))
105 }
106
107 pub fn const_i64(&mut self, value: i64) -> ValueId {
109 self.module.required_capabilities.add(CapabilityFlag::Int64);
110 self.add_value(IrType::I64, IrNode::Constant(ConstantValue::I64(value)))
111 }
112
113 pub fn const_u32(&mut self, value: u32) -> ValueId {
115 self.add_value(IrType::U32, IrNode::Constant(ConstantValue::U32(value)))
116 }
117
118 pub fn const_u64(&mut self, value: u64) -> ValueId {
120 self.module.required_capabilities.add(CapabilityFlag::Int64);
121 self.add_value(IrType::U64, IrNode::Constant(ConstantValue::U64(value)))
122 }
123
124 pub fn const_f32(&mut self, value: f32) -> ValueId {
126 self.add_value(IrType::F32, IrNode::Constant(ConstantValue::F32(value)))
127 }
128
129 pub fn const_f64(&mut self, value: f64) -> ValueId {
131 self.module
132 .required_capabilities
133 .add(CapabilityFlag::Float64);
134 self.add_value(IrType::F64, IrNode::Constant(ConstantValue::F64(value)))
135 }
136
137 pub fn const_bool(&mut self, value: bool) -> ValueId {
139 self.add_value(IrType::BOOL, IrNode::Constant(ConstantValue::Bool(value)))
140 }
141
142 pub fn add(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
148 let ty = self.get_value_type(lhs);
149 self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::Add, lhs, rhs))
150 }
151
152 pub fn sub(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
154 let ty = self.get_value_type(lhs);
155 self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::Sub, lhs, rhs))
156 }
157
158 pub fn mul(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
160 let ty = self.get_value_type(lhs);
161 self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::Mul, lhs, rhs))
162 }
163
164 pub fn div(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
166 let ty = self.get_value_type(lhs);
167 self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::Div, lhs, rhs))
168 }
169
170 pub fn rem(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
172 let ty = self.get_value_type(lhs);
173 self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::Rem, lhs, rhs))
174 }
175
176 pub fn and(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
178 let ty = self.get_value_type(lhs);
179 self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::And, lhs, rhs))
180 }
181
182 pub fn or(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
184 let ty = self.get_value_type(lhs);
185 self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::Or, lhs, rhs))
186 }
187
188 pub fn xor(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
190 let ty = self.get_value_type(lhs);
191 self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::Xor, lhs, rhs))
192 }
193
194 pub fn shl(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
196 let ty = self.get_value_type(lhs);
197 self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::Shl, lhs, rhs))
198 }
199
200 pub fn shr(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
202 let ty = self.get_value_type(lhs);
203 self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::Shr, lhs, rhs))
204 }
205
206 pub fn min(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
208 let ty = self.get_value_type(lhs);
209 self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::Min, lhs, rhs))
210 }
211
212 pub fn max(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
214 let ty = self.get_value_type(lhs);
215 self.add_instruction(ty, IrNode::BinaryOp(BinaryOp::Max, lhs, rhs))
216 }
217
218 pub fn neg(&mut self, value: ValueId) -> ValueId {
224 let ty = self.get_value_type(value);
225 self.add_instruction(ty, IrNode::UnaryOp(UnaryOp::Neg, value))
226 }
227
228 pub fn not(&mut self, value: ValueId) -> ValueId {
230 let ty = self.get_value_type(value);
231 self.add_instruction(ty, IrNode::UnaryOp(UnaryOp::Not, value))
232 }
233
234 pub fn abs(&mut self, value: ValueId) -> ValueId {
236 let ty = self.get_value_type(value);
237 self.add_instruction(ty, IrNode::UnaryOp(UnaryOp::Abs, value))
238 }
239
240 pub fn sqrt(&mut self, value: ValueId) -> ValueId {
242 let ty = self.get_value_type(value);
243 self.add_instruction(ty, IrNode::UnaryOp(UnaryOp::Sqrt, value))
244 }
245
246 pub fn floor(&mut self, value: ValueId) -> ValueId {
248 let ty = self.get_value_type(value);
249 self.add_instruction(ty, IrNode::UnaryOp(UnaryOp::Floor, value))
250 }
251
252 pub fn ceil(&mut self, value: ValueId) -> ValueId {
254 let ty = self.get_value_type(value);
255 self.add_instruction(ty, IrNode::UnaryOp(UnaryOp::Ceil, value))
256 }
257
258 pub fn eq(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
264 self.add_instruction(IrType::BOOL, IrNode::Compare(CompareOp::Eq, lhs, rhs))
265 }
266
267 pub fn ne(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
269 self.add_instruction(IrType::BOOL, IrNode::Compare(CompareOp::Ne, lhs, rhs))
270 }
271
272 pub fn lt(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
274 self.add_instruction(IrType::BOOL, IrNode::Compare(CompareOp::Lt, lhs, rhs))
275 }
276
277 pub fn le(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
279 self.add_instruction(IrType::BOOL, IrNode::Compare(CompareOp::Le, lhs, rhs))
280 }
281
282 pub fn gt(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
284 self.add_instruction(IrType::BOOL, IrNode::Compare(CompareOp::Gt, lhs, rhs))
285 }
286
287 pub fn ge(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
289 self.add_instruction(IrType::BOOL, IrNode::Compare(CompareOp::Ge, lhs, rhs))
290 }
291
292 pub fn load(&mut self, ptr: ValueId) -> ValueId {
298 let ptr_ty = self.get_value_type(ptr);
299 let elem_ty = match ptr_ty {
300 IrType::Ptr(inner) => (*inner).clone(),
301 _ => IrType::Void,
302 };
303 self.add_instruction(elem_ty, IrNode::Load(ptr))
304 }
305
306 pub fn store(&mut self, ptr: ValueId, value: ValueId) {
308 self.add_instruction(IrType::Void, IrNode::Store(ptr, value));
309 }
310
311 pub fn gep(&mut self, ptr: ValueId, indices: Vec<ValueId>) -> ValueId {
313 let ty = self.get_value_type(ptr);
314 self.add_instruction(ty, IrNode::GetElementPtr(ptr, indices))
315 }
316
317 pub fn shared_alloc(&mut self, ty: IrType, count: usize) -> ValueId {
319 self.module
320 .required_capabilities
321 .add(CapabilityFlag::SharedMemory);
322 let ptr_ty = IrType::ptr(ty.clone());
323 self.add_instruction(ptr_ty, IrNode::SharedAlloc(ty, count))
324 }
325
326 pub fn thread_id(&mut self, dim: Dimension) -> ValueId {
332 self.add_instruction(IrType::U32, IrNode::ThreadId(dim))
333 }
334
335 pub fn block_id(&mut self, dim: Dimension) -> ValueId {
337 self.add_instruction(IrType::U32, IrNode::BlockId(dim))
338 }
339
340 pub fn block_dim(&mut self, dim: Dimension) -> ValueId {
342 self.add_instruction(IrType::U32, IrNode::BlockDim(dim))
343 }
344
345 pub fn grid_dim(&mut self, dim: Dimension) -> ValueId {
347 self.add_instruction(IrType::U32, IrNode::GridDim(dim))
348 }
349
350 pub fn global_thread_id(&mut self, dim: Dimension) -> ValueId {
352 self.add_instruction(IrType::U32, IrNode::GlobalThreadId(dim))
353 }
354
355 pub fn barrier(&mut self) {
361 self.add_instruction(IrType::Void, IrNode::Barrier);
362 }
363
364 pub fn fence(&mut self, scope: MemoryScope) {
366 self.add_instruction(IrType::Void, IrNode::MemoryFence(scope));
367 }
368
369 pub fn grid_sync(&mut self) {
371 self.module
372 .required_capabilities
373 .add(CapabilityFlag::CooperativeGroups);
374 self.add_instruction(IrType::Void, IrNode::GridSync);
375 }
376
377 pub fn atomic_add(&mut self, ptr: ValueId, value: ValueId) -> ValueId {
383 let ty = self.get_value_type(value);
384 self.add_instruction(ty, IrNode::Atomic(AtomicOp::Add, ptr, value))
385 }
386
387 pub fn atomic_exchange(&mut self, ptr: ValueId, value: ValueId) -> ValueId {
389 let ty = self.get_value_type(value);
390 self.add_instruction(ty, IrNode::Atomic(AtomicOp::Exchange, ptr, value))
391 }
392
393 pub fn atomic_cas(&mut self, ptr: ValueId, expected: ValueId, desired: ValueId) -> ValueId {
395 let ty = self.get_value_type(expected);
396 self.add_instruction(ty, IrNode::AtomicCas(ptr, expected, desired))
397 }
398
399 pub fn select(&mut self, cond: ValueId, then_val: ValueId, else_val: ValueId) -> ValueId {
405 let ty = self.get_value_type(then_val);
406 self.add_instruction(ty, IrNode::Select(cond, then_val, else_val))
407 }
408
409 pub fn branch(&mut self, target: BlockId) {
411 self.set_terminator(Terminator::Branch(target));
412 self.add_successor(target);
413 }
414
415 pub fn cond_branch(&mut self, cond: ValueId, then_block: BlockId, else_block: BlockId) {
417 self.set_terminator(Terminator::CondBranch(cond, then_block, else_block));
418 self.add_successor(then_block);
419 self.add_successor(else_block);
420 }
421
422 pub fn ret(&mut self) {
424 self.set_terminator(Terminator::Return(None));
425 }
426
427 pub fn ret_value(&mut self, value: ValueId) {
429 self.set_terminator(Terminator::Return(Some(value)));
430 }
431
432 pub fn k2h_enqueue(&mut self, message: ValueId) {
438 self.add_instruction(IrType::Void, IrNode::K2HEnqueue(message));
439 }
440
441 pub fn h2k_dequeue(&mut self, msg_ty: IrType) -> ValueId {
443 self.add_instruction(msg_ty, IrNode::H2KDequeue)
444 }
445
446 pub fn h2k_is_empty(&mut self) -> ValueId {
448 self.add_instruction(IrType::BOOL, IrNode::H2KIsEmpty)
449 }
450
451 pub fn k2k_send(&mut self, dest: ValueId, message: ValueId) {
453 self.add_instruction(IrType::Void, IrNode::K2KSend(dest, message));
454 }
455
456 pub fn k2k_try_recv(&mut self, msg_ty: IrType) -> ValueId {
458 self.add_instruction(msg_ty, IrNode::K2KTryRecv)
459 }
460
461 pub fn hlc_now(&mut self) -> ValueId {
467 self.add_instruction(IrType::U64, IrNode::HlcNow)
468 }
469
470 pub fn hlc_tick(&mut self) -> ValueId {
472 self.add_instruction(IrType::U64, IrNode::HlcTick)
473 }
474
475 fn add_value(&mut self, ty: IrType, node: IrNode) -> ValueId {
480 let value = Value::new(ty, node);
481 let id = value.id;
482 self.module.values.insert(id, value);
483 id
484 }
485
486 fn add_instruction(&mut self, ty: IrType, node: IrNode) -> ValueId {
487 let result = ValueId::new();
488 let inst = Instruction::new(result, ty.clone(), node.clone());
489
490 if let Some(block) = self.module.blocks.get_mut(&self.current_block) {
491 block.add_instruction(inst);
492 }
493
494 let value = Value::new(ty, node);
496 self.module.values.insert(result, value);
497
498 result
499 }
500
501 fn set_terminator(&mut self, term: Terminator) {
502 if let Some(block) = self.module.blocks.get_mut(&self.current_block) {
503 block.set_terminator(term);
504 }
505 }
506
507 fn add_successor(&mut self, succ: BlockId) {
508 let current = self.current_block;
509 if let Some(block) = self.module.blocks.get_mut(¤t) {
510 block.successors.push(succ);
511 }
512 if let Some(succ_block) = self.module.blocks.get_mut(&succ) {
513 succ_block.predecessors.push(current);
514 }
515 }
516
517 fn get_value_type(&self, id: ValueId) -> IrType {
518 self.module
519 .values
520 .get(&id)
521 .map(|v| v.ty.clone())
522 .unwrap_or(IrType::Void)
523 }
524}
525
526pub struct IrBuilderScope<'a> {
528 builder: &'a mut IrBuilder,
529}
530
531impl<'a> IrBuilderScope<'a> {
532 pub fn new(builder: &'a mut IrBuilder) -> Self {
534 Self { builder }
535 }
536
537 pub fn add(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
539 self.builder.add(lhs, rhs)
540 }
541
542 pub fn mul(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
544 self.builder.mul(lhs, rhs)
545 }
546
547 pub fn load(&mut self, ptr: ValueId) -> ValueId {
549 self.builder.load(ptr)
550 }
551
552 pub fn store(&mut self, ptr: ValueId, value: ValueId) {
554 self.builder.store(ptr, value);
555 }
556
557 pub fn thread_id(&mut self, dim: Dimension) -> ValueId {
559 self.builder.thread_id(dim)
560 }
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566
567 #[test]
568 fn test_builder_basic() {
569 let mut builder = IrBuilder::new("test");
570
571 let x = builder.parameter("x", IrType::ptr(IrType::F32));
572 let y = builder.parameter("y", IrType::ptr(IrType::F32));
573
574 let _idx = builder.thread_id(Dimension::X);
575 let x_val = builder.load(x);
576 let y_val = builder.load(y);
577 let result = builder.add(x_val, y_val);
578 builder.store(y, result);
579 builder.ret();
580
581 let module = builder.build();
582 assert_eq!(module.name, "test");
583 assert_eq!(module.parameters.len(), 2);
584 }
585
586 #[test]
587 fn test_builder_constants() {
588 let mut builder = IrBuilder::new("test");
589
590 let a = builder.const_i32(42);
591 let b = builder.const_f32(3.125);
592 let c = builder.const_bool(true);
593
594 let module = builder.build();
595 assert!(module.values.contains_key(&a));
596 assert!(module.values.contains_key(&b));
597 assert!(module.values.contains_key(&c));
598 }
599
600 #[test]
601 fn test_builder_control_flow() {
602 let mut builder = IrBuilder::new("test");
603
604 let n = builder.parameter("n", IrType::I32);
605 let idx = builder.thread_id(Dimension::X);
606 let cond = builder.lt(idx, n);
607
608 let then_block = builder.create_block("then");
609 let end_block = builder.create_block("end");
610
611 builder.cond_branch(cond, then_block, end_block);
612
613 builder.switch_to_block(then_block);
614 builder.branch(end_block);
615
616 builder.switch_to_block(end_block);
617 builder.ret();
618
619 let module = builder.build();
620 assert_eq!(module.blocks.len(), 3);
621 }
622
623 #[test]
624 fn test_builder_capabilities() {
625 let mut builder = IrBuilder::new("test");
626
627 builder.const_f64(1.0);
629
630 builder.grid_sync();
632
633 let module = builder.build();
634 assert!(module.required_capabilities.has(CapabilityFlag::Float64));
635 assert!(module
636 .required_capabilities
637 .has(CapabilityFlag::CooperativeGroups));
638 }
639
640 #[test]
641 fn test_builder_persistent_config() {
642 let mut builder = IrBuilder::new("persistent_kernel");
643 builder.set_persistent(true);
644 builder.set_block_size(128, 1, 1);
645
646 let module = builder.build();
647 assert!(module.config.is_persistent);
648 assert_eq!(module.config.mode, KernelMode::Persistent);
649 assert_eq!(module.config.block_size, (128, 1, 1));
650 }
651}