Skip to main content

morok_codegen/llvm/text/
mod.rs

1//! Text-based LLVM IR code generation (main entry point).
2//!
3//! This module generates LLVM IR as plain strings using `format!` macros,
4//! following Tinygrad's approach in `renderer/llvmir.py`.
5//!
6//! # Kernel Signature
7//!
8//! Generates a single function with direct typed parameters and `noalias align 32`
9//! buffer annotations:
10//! ```llvm
11//! define void @kernel(ptr noalias align 32 %buf0, ..., i32 %N) #0 { ... }
12//! ```
13
14use std::sync::Arc;
15
16use morok_ir::pattern::TypedPatternMatcher;
17use morok_ir::{AxisType, Op, prelude::*};
18use morok_schedule::linearize::{line_rewrite_cleanups, linearize_with_cfg};
19
20use crate::common::is_output_buffer;
21use crate::llvm::common::{RenderContext, ldt};
22use crate::llvm::cpu::{reduce_identity, render_uop};
23use crate::{BufferArg, RenderedKernel, Renderer, Result};
24
25/// Text-based LLVM IR renderer.
26///
27/// Generates LLVM IR as strings, suitable for compilation via external clang.
28/// Produces a single function with direct typed parameters.
29pub struct LlvmTextRenderer;
30
31impl LlvmTextRenderer {
32    pub fn new() -> Self {
33        Self
34    }
35}
36
37impl Default for LlvmTextRenderer {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl Renderer for LlvmTextRenderer {
44    fn render(&self, uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
45        let kernel_name = name.unwrap_or("kernel");
46
47        let nodes = linearize_with_cfg(uop.clone());
48
49        // Stage 22: Apply line rewrite cleanups to handle gated INDEX operations.
50        // Converts gated STOREs to IF/STORE/ENDIF sequences.
51        // Based on Tinygrad's pm_linearize_cleanups (codegen/__init__.py:107-113).
52        let nodes = line_rewrite_cleanups(nodes);
53
54        for (i, node) in nodes.iter().enumerate() {
55            tracing::debug!(position = i, op = node.op().as_ref(), id = node.id, "linearized node");
56        }
57
58        let mut ctx = RenderContext::new();
59        let mut kernel: Vec<String> = Vec::new();
60        let mut buffer_args: Vec<BufferArg> = Vec::new();
61        let mut var_names: Vec<String> = Vec::new();
62
63        let mut buffers: Vec<Arc<UOp>> = Vec::new();
64        let mut variables: Vec<Arc<UOp>> = Vec::new();
65
66        for node in &nodes {
67            match node.op() {
68                Op::Param { device: None, .. } => {
69                    buffers.push(node.clone());
70                }
71                Op::DefineVar { .. } => {
72                    variables.push(node.clone());
73                }
74                _ => {}
75            }
76        }
77
78        buffers.sort_by_key(|b| if let Op::Param { slot, device: None, .. } = b.op() { *slot } else { usize::MAX });
79
80        let thread_info: Option<(Arc<UOp>, usize)> = nodes.iter().find_map(|n| {
81            if let Op::Range { axis_type, end, .. } = n.op()
82                && matches!(axis_type, AxisType::Thread)
83                && let Op::Const(cv) = end.op()
84                && let ConstValue::Int(count) = cv.0
85            {
86                return Some((n.clone(), count as usize));
87            }
88            None
89        });
90
91        let has_threading = thread_info.is_some();
92        let thread_count = thread_info.as_ref().map(|(_, c)| *c).unwrap_or(1);
93
94        for (i, buf) in buffers.iter().enumerate() {
95            if let Op::Param { slot, device: None, .. } = buf.op() {
96                let is_output = is_output_buffer(buf, &nodes);
97                buffer_args.push(BufferArg { index: *slot, name: format!("data{i}"), dtype: buf.dtype(), is_output });
98            }
99        }
100
101        for var in &variables {
102            if let Op::DefineVar { name, .. } = var.op() {
103                var_names.push(name.clone());
104            }
105        }
106        if has_threading {
107            var_names.push("thread_id".to_string());
108        }
109
110        // -- Build function parameters --
111        let mut inner_params: Vec<String> = Vec::new();
112
113        // Buffer pointer parameters
114        for (i, buf) in buffers.iter().enumerate() {
115            inner_params.push(format!("ptr noalias align 32 %buf{i}"));
116            ctx.register(buf.id, format!("%buf{i}"));
117        }
118
119        // Variable parameters
120        for var in &variables {
121            let var_base_name =
122                if let Op::DefineVar { name, .. } = var.op() { name.clone() } else { "var".to_string() };
123            let var_dtype = var.dtype();
124            let var_dtype_str = ldt(&var_dtype);
125            inner_params.push(format!("{var_dtype_str} %{var_base_name}"));
126            ctx.register(var.id, format!("%{var_base_name}"));
127        }
128
129        // Thread ID parameter
130        if let Some((thread_range, _)) = &thread_info {
131            let range_dtype = thread_range.dtype();
132            let range_dtype_str = ldt(&range_dtype);
133            inner_params.push(format!("{range_dtype_str} %thread_id"));
134
135            if let Op::Range { axis_id, .. } = thread_range.op() {
136                ctx.register(thread_range.id, "%thread_id".to_string());
137                ctx.register_range(axis_id.value(), "%thread_id".to_string());
138            }
139        }
140
141        // -- Build function body --
142        kernel.push("  ; Reduction accumulators".to_string());
143        for node in &nodes {
144            if let Op::Reduce { reduce_op, .. } = node.op() {
145                let dtype = ldt(&node.dtype());
146                let identity = reduce_identity(*reduce_op, &node.dtype());
147                let acc_name = format!("%reduce_{}", node.id);
148                kernel.push(format!("  {acc_name} = alloca {dtype}"));
149                kernel.push(format!("  store {dtype} {identity}, ptr {acc_name}"));
150                ctx.register(node.id, acc_name);
151            }
152        }
153        kernel.push("".to_string());
154
155        for node in &nodes {
156            match node.op() {
157                Op::Const(cv) => {
158                    let val = crate::llvm::common::lconst(&cv.0, &node.dtype());
159                    ctx.register(node.id, val);
160                }
161                Op::VConst { .. } => {
162                    ctx.name(node);
163                }
164                _ => {}
165            }
166        }
167
168        for node in &nodes {
169            if let Op::Range { axis_id, axis_type, .. } = node.op()
170                && !matches!(axis_type, AxisType::Thread)
171            {
172                let name = format!("%r{}", axis_id.value());
173                ctx.register(node.id, name);
174            }
175        }
176
177        for node in &nodes {
178            if matches!(node.op(), Op::Noop | Op::Group { .. }) {
179                ctx.register(node.id, String::new());
180                continue;
181            }
182            if let Op::Range { axis_type, .. } = node.op()
183                && matches!(axis_type, AxisType::Thread)
184            {
185                continue;
186            }
187            render_uop(node, &mut ctx, &mut kernel);
188        }
189
190        kernel.push("  ret void".to_string());
191
192        let ir = format!(
193            r#"; ModuleID = '{kernel_name}'
194source_filename = "{kernel_name}"
195
196{intrinsics}
197
198define void @{kernel_name}({inner_params}) #0 {{
199entry:
200{inner_body}
201}}
202
203attributes #0 = {{ nounwind "no-builtins" "no-trapping-math"="true" }}
204"#,
205            intrinsics = generate_intrinsic_declarations(&kernel),
206            inner_params = inner_params.join(", "),
207            inner_body = kernel.join("\n"),
208        );
209
210        tracing::debug!(generated_code = ir, "llvm codegen: final generated code");
211
212        let mut result = RenderedKernel::new(ir, kernel_name.to_string());
213        result.buffer_args = buffer_args;
214        result.var_names = var_names;
215
216        if thread_count > 1 {
217            result.global_size = Some([thread_count, 1, 1]);
218            result.local_size = Some([1, 1, 1]);
219        }
220
221        Ok(result)
222    }
223
224    fn backend_name(&self) -> &str {
225        "llvm-text"
226    }
227
228    fn decompositor(&self) -> Option<TypedPatternMatcher<()>> {
229        None
230    }
231}
232
233fn mangle_type(llvm_type: &str) -> String {
234    match llvm_type {
235        "float" => "f32".to_string(),
236        "double" => "f64".to_string(),
237        "half" => "f16".to_string(),
238        "i8" => "i8".to_string(),
239        "i16" => "i16".to_string(),
240        "i32" => "i32".to_string(),
241        "i64" => "i64".to_string(),
242        _ if llvm_type.starts_with('<') && llvm_type.ends_with('>') => {
243            let inner = &llvm_type[1..llvm_type.len() - 1];
244            let parts: Vec<&str> = inner.split(" x ").collect();
245            if parts.len() == 2 {
246                let count = parts[0].trim();
247                let base = mangle_type(parts[1].trim());
248                format!("v{count}{base}")
249            } else {
250                llvm_type.to_string()
251            }
252        }
253        _ => llvm_type.to_string(),
254    }
255}
256
257fn generate_intrinsic_declarations(kernel: &[String]) -> String {
258    let mut decls = Vec::new();
259    let kernel_str = kernel.join("\n");
260
261    for intrinsic in &[
262        "sqrt", "exp", "exp2", "log", "log2", "sin", "cos", "pow", "fabs", "floor", "ceil", "trunc", "round", "maxnum",
263        "minnum", "fmuladd", "erf",
264    ] {
265        for llvm_type in
266            &["float", "double", "half", "<2 x float>", "<4 x float>", "<8 x float>", "<2 x double>", "<4 x double>"]
267        {
268            let mangled = mangle_type(llvm_type);
269            let pattern = format!("@llvm.{intrinsic}.{mangled}");
270            if kernel_str.contains(&pattern) {
271                let decl = match *intrinsic {
272                    "fmuladd" => format!(
273                        "declare {llvm_type} @llvm.{intrinsic}.{mangled}({llvm_type}, {llvm_type}, {llvm_type})"
274                    ),
275                    "pow" | "maxnum" | "minnum" => {
276                        format!("declare {llvm_type} @llvm.{intrinsic}.{mangled}({llvm_type}, {llvm_type})")
277                    }
278                    _ => format!("declare {llvm_type} @llvm.{intrinsic}.{mangled}({llvm_type})"),
279                };
280                decls.push(decl);
281            }
282        }
283    }
284
285    for bits in &["i8", "i16", "i32", "i64"] {
286        let pattern = format!("@llvm.abs.{bits}");
287        if kernel_str.contains(&pattern) {
288            decls.push(format!("declare {bits} @llvm.abs.{bits}({bits}, i1)"));
289        }
290    }
291
292    decls.join("\n")
293}
294
295pub fn render(uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
296    let renderer = LlvmTextRenderer::new();
297    renderer.render(uop, name)
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303    use morok_dtype::{AddrSpace, DType};
304    use morok_ir::{BinaryOp, Op};
305
306    #[test]
307    fn test_simple_add() {
308        let a = UOp::param(0, 1, DType::Float32.ptr(Some(1), AddrSpace::Global), None);
309        let b = UOp::param(1, 1, DType::Float32.ptr(Some(1), AddrSpace::Global), None);
310        let out = UOp::param(2, 1, DType::Float32.ptr(Some(1), AddrSpace::Global), None);
311
312        let idx = UOp::index_const(0);
313        let a_idx = UOp::index().buffer(a.clone()).indices(vec![idx.clone()]).call().unwrap();
314        let b_idx = UOp::index().buffer(b.clone()).indices(vec![idx.clone()]).call().unwrap();
315        let out_idx = UOp::index().buffer(out.clone()).indices(vec![idx.clone()]).call().unwrap();
316
317        let a_load = UOp::load().buffer(a.clone()).index(a_idx).call();
318        let b_load = UOp::load().buffer(b.clone()).index(b_idx).call();
319
320        let add = UOp::new(Op::Binary(BinaryOp::Add, a_load, b_load), DType::Float32);
321
322        let store = out_idx.store(add);
323        let sink = UOp::sink(vec![store]);
324
325        let result = render(&sink, Some("test_add")).unwrap();
326        println!("{}", result.code);
327
328        assert!(result.code.contains("define void @test_add("));
329        assert!(result.code.contains("noalias align 32"));
330        assert!(!result.code.contains("_inner"));
331        assert!(!result.code.contains("ptr %args"));
332        assert!(result.code.contains("fadd"));
333        assert!(result.code.contains("load"));
334        assert!(result.code.contains("store"));
335    }
336}