1#![warn(missing_docs)]
41
42mod builder;
43mod capabilities;
44mod error;
45mod nodes;
46mod printer;
47mod types;
48mod validation;
49
50pub mod lower_cuda;
51pub mod lower_msl;
52pub mod lower_wgsl;
53pub mod optimize;
54
55pub use builder::{IrBuilder, IrBuilderScope};
56pub use capabilities::{BackendCapabilities, Capabilities, CapabilityFlag};
57pub use error::{IrError, IrResult};
58pub use lower_cuda::{
59 lower_to_cuda, lower_to_cuda_with_config, CudaLowering, CudaLoweringConfig, LoweringError,
60};
61pub use lower_msl::{
62 lower_to_msl, lower_to_msl_with_config, MslLowering, MslLoweringConfig, MslLoweringError,
63};
64pub use lower_wgsl::{
65 lower_to_wgsl, lower_to_wgsl_with_config, WgslLowering, WgslLoweringConfig, WgslLoweringError,
66};
67pub use nodes::*;
68pub use optimize::{
69 optimize, run_constant_folding, run_dce, AlgebraicSimplification, ConstantFolding,
70 DeadBlockElimination, DeadCodeElimination, OptimizationPass, OptimizationResult, PassManager,
71};
72pub use printer::IrPrinter;
73pub use types::{IrType, ScalarType, VectorType};
74pub use validation::{ValidationLevel, ValidationResult, Validator};
75
76use std::collections::HashMap;
77use std::sync::atomic::{AtomicU64, Ordering};
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
81pub struct ValueId(u64);
82
83impl ValueId {
84 pub fn new() -> Self {
86 static COUNTER: AtomicU64 = AtomicU64::new(0);
87 Self(COUNTER.fetch_add(1, Ordering::Relaxed))
88 }
89
90 pub fn raw(&self) -> u64 {
92 self.0
93 }
94}
95
96impl Default for ValueId {
97 fn default() -> Self {
98 Self::new()
99 }
100}
101
102impl std::fmt::Display for ValueId {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 write!(f, "%{}", self.0)
105 }
106}
107
108#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
110pub struct BlockId(u64);
111
112impl BlockId {
113 pub fn new() -> Self {
115 static COUNTER: AtomicU64 = AtomicU64::new(0);
116 Self(COUNTER.fetch_add(1, Ordering::Relaxed))
117 }
118
119 pub fn raw(&self) -> u64 {
121 self.0
122 }
123}
124
125impl Default for BlockId {
126 fn default() -> Self {
127 Self::new()
128 }
129}
130
131impl std::fmt::Display for BlockId {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 write!(f, "bb{}", self.0)
134 }
135}
136
137#[derive(Debug, Clone)]
139pub struct IrModule {
140 pub name: String,
142 pub parameters: Vec<Parameter>,
144 pub entry_block: BlockId,
146 pub blocks: HashMap<BlockId, Block>,
148 pub values: HashMap<ValueId, Value>,
150 pub required_capabilities: Capabilities,
152 pub config: KernelConfig,
154}
155
156impl IrModule {
157 pub fn new(name: impl Into<String>) -> Self {
159 let entry = BlockId::new();
160 let mut blocks = HashMap::new();
161 blocks.insert(entry, Block::new(entry, "entry"));
162
163 Self {
164 name: name.into(),
165 parameters: Vec::new(),
166 entry_block: entry,
167 blocks,
168 values: HashMap::new(),
169 required_capabilities: Capabilities::default(),
170 config: KernelConfig::default(),
171 }
172 }
173
174 pub fn get_block(&self, id: BlockId) -> Option<&Block> {
176 self.blocks.get(&id)
177 }
178
179 pub fn get_block_mut(&mut self, id: BlockId) -> Option<&mut Block> {
181 self.blocks.get_mut(&id)
182 }
183
184 pub fn get_value(&self, id: ValueId) -> Option<&Value> {
186 self.values.get(&id)
187 }
188
189 pub fn add_value(&mut self, value: Value) -> ValueId {
191 let id = value.id;
192 self.values.insert(id, value);
193 id
194 }
195
196 pub fn entry(&self) -> &Block {
198 self.blocks
199 .get(&self.entry_block)
200 .expect("entry block must exist")
201 }
202
203 pub fn validate(&self, level: ValidationLevel) -> ValidationResult {
205 Validator::new(level).validate(self)
206 }
207
208 pub fn pretty_print(&self) -> String {
210 IrPrinter::new().print(self)
211 }
212}
213
214#[derive(Debug, Clone)]
216pub struct Parameter {
217 pub name: String,
219 pub ty: IrType,
221 pub value_id: ValueId,
223 pub index: usize,
225}
226
227#[derive(Debug, Clone)]
229pub struct Block {
230 pub id: BlockId,
232 pub label: String,
234 pub instructions: Vec<Instruction>,
236 pub terminator: Option<Terminator>,
238 pub predecessors: Vec<BlockId>,
240 pub successors: Vec<BlockId>,
242}
243
244impl Block {
245 pub fn new(id: BlockId, label: impl Into<String>) -> Self {
247 Self {
248 id,
249 label: label.into(),
250 instructions: Vec::new(),
251 terminator: None,
252 predecessors: Vec::new(),
253 successors: Vec::new(),
254 }
255 }
256
257 pub fn add_instruction(&mut self, inst: Instruction) {
259 self.instructions.push(inst);
260 }
261
262 pub fn set_terminator(&mut self, term: Terminator) {
264 self.terminator = Some(term);
265 }
266
267 pub fn is_terminated(&self) -> bool {
269 self.terminator.is_some()
270 }
271}
272
273#[derive(Debug, Clone)]
275pub struct Value {
276 pub id: ValueId,
278 pub ty: IrType,
280 pub node: IrNode,
282}
283
284impl Value {
285 pub fn new(ty: IrType, node: IrNode) -> Self {
287 Self {
288 id: ValueId::new(),
289 ty,
290 node,
291 }
292 }
293}
294
295#[derive(Debug, Clone)]
297pub struct KernelConfig {
298 pub block_size: (u32, u32, u32),
300 pub grid_size: Option<(u32, u32, u32)>,
302 pub shared_memory_bytes: u32,
304 pub is_persistent: bool,
306 pub mode: KernelMode,
308}
309
310impl Default for KernelConfig {
311 fn default() -> Self {
312 Self {
313 block_size: (256, 1, 1),
314 grid_size: None,
315 shared_memory_bytes: 0,
316 is_persistent: false,
317 mode: KernelMode::Compute,
318 }
319 }
320}
321
322#[derive(Debug, Clone, Copy, PartialEq, Eq)]
324pub enum KernelMode {
325 Compute,
327 Persistent,
329 Stencil,
331}
332
333#[derive(Debug, Clone, Copy, PartialEq, Eq)]
335pub enum Dimension {
336 X,
338 Y,
340 Z,
342}
343
344impl std::fmt::Display for Dimension {
345 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346 match self {
347 Dimension::X => write!(f, "x"),
348 Dimension::Y => write!(f, "y"),
349 Dimension::Z => write!(f, "z"),
350 }
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn test_value_id_unique() {
360 let id1 = ValueId::new();
361 let id2 = ValueId::new();
362 assert_ne!(id1, id2);
363 }
364
365 #[test]
366 fn test_block_id_unique() {
367 let id1 = BlockId::new();
368 let id2 = BlockId::new();
369 assert_ne!(id1, id2);
370 }
371
372 #[test]
373 fn test_ir_module_new() {
374 let module = IrModule::new("test_kernel");
375 assert_eq!(module.name, "test_kernel");
376 assert!(module.parameters.is_empty());
377 assert!(module.blocks.contains_key(&module.entry_block));
378 }
379
380 #[test]
381 fn test_block_operations() {
382 let id = BlockId::new();
383 let mut block = Block::new(id, "test");
384
385 assert!(!block.is_terminated());
386 assert!(block.instructions.is_empty());
387
388 block.set_terminator(Terminator::Return(None));
389 assert!(block.is_terminated());
390 }
391
392 #[test]
393 fn test_dimension_display() {
394 assert_eq!(format!("{}", Dimension::X), "x");
395 assert_eq!(format!("{}", Dimension::Y), "y");
396 assert_eq!(format!("{}", Dimension::Z), "z");
397 }
398
399 #[test]
400 fn test_kernel_config_default() {
401 let config = KernelConfig::default();
402 assert_eq!(config.block_size, (256, 1, 1));
403 assert!(config.grid_size.is_none());
404 assert_eq!(config.shared_memory_bytes, 0);
405 assert!(!config.is_persistent);
406 }
407}