Skip to main content

ringkernel_ir/
lib.rs

1//! RingKernel Intermediate Representation (IR)
2//!
3//! This crate provides a unified IR for GPU code generation across multiple backends
4//! (CUDA, WGSL, MSL). The IR is SSA-based and captures GPU-specific operations.
5//!
6//! # Architecture
7//!
8//! ```text
9//! Rust DSL → IR → Backend-specific lowering → CUDA/WGSL/MSL
10//! ```
11//!
12//! # Example
13//!
14//! ```ignore
15//! use ringkernel_ir::{IrBuilder, IrType, Dimension};
16//!
17//! let mut builder = IrBuilder::new("saxpy");
18//!
19//! // Define parameters
20//! let x = builder.parameter("x", IrType::Ptr(Box::new(IrType::F32)));
21//! let y = builder.parameter("y", IrType::Ptr(Box::new(IrType::F32)));
22//! let a = builder.parameter("a", IrType::F32);
23//! let n = builder.parameter("n", IrType::I32);
24//!
25//! // Get thread index
26//! let idx = builder.thread_id(Dimension::X);
27//!
28//! // Bounds check
29//! let in_bounds = builder.lt(idx, n);
30//! builder.if_then(in_bounds, |b| {
31//!     let x_val = b.load(x, idx);
32//!     let y_val = b.load(y, idx);
33//!     let result = b.add(b.mul(a, x_val), y_val);
34//!     b.store(y, idx, result);
35//! });
36//!
37//! let ir = builder.build();
38//! ```
39
40#![warn(missing_docs)]
41
42mod builder;
43mod capabilities;
44mod error;
45mod nodes;
46mod printer;
47mod types;
48mod validation;
49
50pub mod lower_cuda;
51pub mod lower_msl;
52pub mod lower_wgsl;
53pub mod optimize;
54
55pub use builder::{IrBuilder, IrBuilderScope};
56pub use capabilities::{BackendCapabilities, Capabilities, CapabilityFlag};
57pub use error::{IrError, IrResult};
58pub use lower_cuda::{
59    lower_to_cuda, lower_to_cuda_with_config, CudaLowering, CudaLoweringConfig, LoweringError,
60};
61pub use lower_msl::{
62    lower_to_msl, lower_to_msl_with_config, MslLowering, MslLoweringConfig, MslLoweringError,
63};
64pub use lower_wgsl::{
65    lower_to_wgsl, lower_to_wgsl_with_config, WgslLowering, WgslLoweringConfig, WgslLoweringError,
66};
67pub use nodes::*;
68pub use optimize::{
69    optimize, run_constant_folding, run_dce, AlgebraicSimplification, ConstantFolding,
70    DeadBlockElimination, DeadCodeElimination, OptimizationPass, OptimizationResult, PassManager,
71};
72pub use printer::IrPrinter;
73pub use types::{IrType, ScalarType, VectorType};
74pub use validation::{ValidationLevel, ValidationResult, Validator};
75
76use std::collections::HashMap;
77use std::sync::atomic::{AtomicU64, Ordering};
78
79/// Unique identifier for IR values.
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
81pub struct ValueId(u64);
82
83impl ValueId {
84    /// Create a new unique value ID.
85    pub fn new() -> Self {
86        static COUNTER: AtomicU64 = AtomicU64::new(0);
87        Self(COUNTER.fetch_add(1, Ordering::Relaxed))
88    }
89
90    /// Get the raw ID value.
91    pub fn raw(&self) -> u64 {
92        self.0
93    }
94}
95
96impl Default for ValueId {
97    fn default() -> Self {
98        Self::new()
99    }
100}
101
102impl std::fmt::Display for ValueId {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        write!(f, "%{}", self.0)
105    }
106}
107
108/// Unique identifier for IR blocks.
109#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
110pub struct BlockId(u64);
111
112impl BlockId {
113    /// Create a new unique block ID.
114    pub fn new() -> Self {
115        static COUNTER: AtomicU64 = AtomicU64::new(0);
116        Self(COUNTER.fetch_add(1, Ordering::Relaxed))
117    }
118
119    /// Get the raw ID value.
120    pub fn raw(&self) -> u64 {
121        self.0
122    }
123}
124
125impl Default for BlockId {
126    fn default() -> Self {
127        Self::new()
128    }
129}
130
131impl std::fmt::Display for BlockId {
132    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133        write!(f, "bb{}", self.0)
134    }
135}
136
137/// A complete IR module representing a GPU kernel.
138#[derive(Debug, Clone)]
139pub struct IrModule {
140    /// Module name (kernel name).
141    pub name: String,
142    /// Function parameters.
143    pub parameters: Vec<Parameter>,
144    /// Entry block.
145    pub entry_block: BlockId,
146    /// All blocks in the module.
147    pub blocks: HashMap<BlockId, Block>,
148    /// All values defined in the module.
149    pub values: HashMap<ValueId, Value>,
150    /// Required capabilities for this module.
151    pub required_capabilities: Capabilities,
152    /// Kernel configuration.
153    pub config: KernelConfig,
154}
155
156impl IrModule {
157    /// Create a new empty module.
158    pub fn new(name: impl Into<String>) -> Self {
159        let entry = BlockId::new();
160        let mut blocks = HashMap::new();
161        blocks.insert(entry, Block::new(entry, "entry"));
162
163        Self {
164            name: name.into(),
165            parameters: Vec::new(),
166            entry_block: entry,
167            blocks,
168            values: HashMap::new(),
169            required_capabilities: Capabilities::default(),
170            config: KernelConfig::default(),
171        }
172    }
173
174    /// Get a block by ID.
175    pub fn get_block(&self, id: BlockId) -> Option<&Block> {
176        self.blocks.get(&id)
177    }
178
179    /// Get a mutable block by ID.
180    pub fn get_block_mut(&mut self, id: BlockId) -> Option<&mut Block> {
181        self.blocks.get_mut(&id)
182    }
183
184    /// Get a value by ID.
185    pub fn get_value(&self, id: ValueId) -> Option<&Value> {
186        self.values.get(&id)
187    }
188
189    /// Add a value to the module.
190    pub fn add_value(&mut self, value: Value) -> ValueId {
191        let id = value.id;
192        self.values.insert(id, value);
193        id
194    }
195
196    /// Get the entry block.
197    pub fn entry(&self) -> &Block {
198        self.blocks
199            .get(&self.entry_block)
200            .expect("entry block must exist")
201    }
202
203    /// Validate the module.
204    pub fn validate(&self, level: ValidationLevel) -> ValidationResult {
205        Validator::new(level).validate(self)
206    }
207
208    /// Pretty-print the IR.
209    pub fn pretty_print(&self) -> String {
210        IrPrinter::new().print(self)
211    }
212}
213
214/// A function parameter.
215#[derive(Debug, Clone)]
216pub struct Parameter {
217    /// Parameter name.
218    pub name: String,
219    /// Parameter type.
220    pub ty: IrType,
221    /// Value ID for this parameter.
222    pub value_id: ValueId,
223    /// Parameter index.
224    pub index: usize,
225}
226
227/// A basic block containing IR nodes.
228#[derive(Debug, Clone)]
229pub struct Block {
230    /// Block identifier.
231    pub id: BlockId,
232    /// Block label (for debugging).
233    pub label: String,
234    /// Instructions in this block.
235    pub instructions: Vec<Instruction>,
236    /// Terminator instruction.
237    pub terminator: Option<Terminator>,
238    /// Predecessor blocks.
239    pub predecessors: Vec<BlockId>,
240    /// Successor blocks.
241    pub successors: Vec<BlockId>,
242}
243
244impl Block {
245    /// Create a new block.
246    pub fn new(id: BlockId, label: impl Into<String>) -> Self {
247        Self {
248            id,
249            label: label.into(),
250            instructions: Vec::new(),
251            terminator: None,
252            predecessors: Vec::new(),
253            successors: Vec::new(),
254        }
255    }
256
257    /// Add an instruction to the block.
258    pub fn add_instruction(&mut self, inst: Instruction) {
259        self.instructions.push(inst);
260    }
261
262    /// Set the terminator.
263    pub fn set_terminator(&mut self, term: Terminator) {
264        self.terminator = Some(term);
265    }
266
267    /// Check if block is terminated.
268    pub fn is_terminated(&self) -> bool {
269        self.terminator.is_some()
270    }
271}
272
273/// An IR value with type information.
274#[derive(Debug, Clone)]
275pub struct Value {
276    /// Value identifier.
277    pub id: ValueId,
278    /// Value type.
279    pub ty: IrType,
280    /// The node that produces this value.
281    pub node: IrNode,
282}
283
284impl Value {
285    /// Create a new value.
286    pub fn new(ty: IrType, node: IrNode) -> Self {
287        Self {
288            id: ValueId::new(),
289            ty,
290            node,
291        }
292    }
293}
294
295/// Kernel configuration.
296#[derive(Debug, Clone)]
297pub struct KernelConfig {
298    /// Block size (threads per block).
299    pub block_size: (u32, u32, u32),
300    /// Grid size (blocks per grid), if static.
301    pub grid_size: Option<(u32, u32, u32)>,
302    /// Shared memory size in bytes.
303    pub shared_memory_bytes: u32,
304    /// Whether this is a persistent kernel.
305    pub is_persistent: bool,
306    /// Kernel mode.
307    pub mode: KernelMode,
308}
309
310impl Default for KernelConfig {
311    fn default() -> Self {
312        Self {
313            block_size: (256, 1, 1),
314            grid_size: None,
315            shared_memory_bytes: 0,
316            is_persistent: false,
317            mode: KernelMode::Compute,
318        }
319    }
320}
321
322/// Kernel execution mode.
323#[derive(Debug, Clone, Copy, PartialEq, Eq)]
324pub enum KernelMode {
325    /// Standard compute kernel.
326    Compute,
327    /// Persistent message-processing kernel.
328    Persistent,
329    /// Stencil computation kernel.
330    Stencil,
331}
332
333/// Dimension for GPU indexing.
334#[derive(Debug, Clone, Copy, PartialEq, Eq)]
335pub enum Dimension {
336    /// X dimension.
337    X,
338    /// Y dimension.
339    Y,
340    /// Z dimension.
341    Z,
342}
343
344impl std::fmt::Display for Dimension {
345    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346        match self {
347            Dimension::X => write!(f, "x"),
348            Dimension::Y => write!(f, "y"),
349            Dimension::Z => write!(f, "z"),
350        }
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[test]
359    fn test_value_id_unique() {
360        let id1 = ValueId::new();
361        let id2 = ValueId::new();
362        assert_ne!(id1, id2);
363    }
364
365    #[test]
366    fn test_block_id_unique() {
367        let id1 = BlockId::new();
368        let id2 = BlockId::new();
369        assert_ne!(id1, id2);
370    }
371
372    #[test]
373    fn test_ir_module_new() {
374        let module = IrModule::new("test_kernel");
375        assert_eq!(module.name, "test_kernel");
376        assert!(module.parameters.is_empty());
377        assert!(module.blocks.contains_key(&module.entry_block));
378    }
379
380    #[test]
381    fn test_block_operations() {
382        let id = BlockId::new();
383        let mut block = Block::new(id, "test");
384
385        assert!(!block.is_terminated());
386        assert!(block.instructions.is_empty());
387
388        block.set_terminator(Terminator::Return(None));
389        assert!(block.is_terminated());
390    }
391
392    #[test]
393    fn test_dimension_display() {
394        assert_eq!(format!("{}", Dimension::X), "x");
395        assert_eq!(format!("{}", Dimension::Y), "y");
396        assert_eq!(format!("{}", Dimension::Z), "z");
397    }
398
399    #[test]
400    fn test_kernel_config_default() {
401        let config = KernelConfig::default();
402        assert_eq!(config.block_size, (256, 1, 1));
403        assert!(config.grid_size.is_none());
404        assert_eq!(config.shared_memory_bytes, 0);
405        assert!(!config.is_persistent);
406    }
407}