Skip to main content

morok_codegen/llvm/text/
mod.rs

1//! Text-based LLVM IR code generation (main entry point).
2//!
3//! This module generates LLVM IR as plain strings using `format!` macros,
4//! following Tinygrad's approach in `renderer/llvmir.py`.
5//!
6//! # Kernel Signature
7//!
8//! Generates a single function with direct typed parameters and `noalias align 32`
9//! buffer annotations:
10//! ```llvm
11//! define void @kernel(ptr noalias align 32 %buf0, ..., i32 %N) #0 { ... }
12//! ```
13
14use std::sync::Arc;
15
16use morok_ir::pattern::TypedPatternMatcher;
17use morok_ir::rewrite::graph_rewrite_bottom_up;
18use morok_ir::{AxisType, Op, prelude::*};
19use morok_schedule::linearize::{line_rewrite_cleanups, linearize_with_cfg};
20use morok_schedule::rangeify::patterns::pm_bool_devectorize;
21
22use crate::llvm::common::{RenderContext, ldt};
23use crate::llvm::cpu::{reduce_identity, render_uop};
24use crate::{BufferArg, RenderedKernel, Renderer, Result};
25
26/// Text-based LLVM IR renderer.
27///
28/// Generates LLVM IR as strings, suitable for compilation via external clang.
29/// Produces a single function with direct typed parameters.
30pub struct LlvmTextRenderer;
31
32impl LlvmTextRenderer {
33    pub fn new() -> Self {
34        Self
35    }
36}
37
38impl Default for LlvmTextRenderer {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44impl Renderer for LlvmTextRenderer {
45    fn render(&self, uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
46        let kernel_name = name.unwrap_or("kernel");
47
48        // NOTE: pm_render() is now called in schedule/optimizer Stage 19.
49        // Keeping pm_bool_devectorize as safety fallback for direct codegen paths.
50        let uop = graph_rewrite_bottom_up(pm_bool_devectorize(), uop.clone(), &mut ());
51
52        tracing::debug!(ast_after_pm_bool_devectorize = %uop.tree(), "codegen: after pm_bool_devectorize");
53
54        let nodes = linearize_with_cfg(uop);
55
56        // Stage 22: Apply line rewrite cleanups to handle gated INDEX operations.
57        // Converts gated STOREs to IF/STORE/ENDIF sequences.
58        // Based on Tinygrad's pm_linearize_cleanups (codegen/__init__.py:107-113).
59        let nodes = line_rewrite_cleanups(nodes);
60
61        for (i, node) in nodes.iter().enumerate() {
62            tracing::debug!(position = i, op = node.op().as_ref(), id = node.id, "linearized node");
63        }
64
65        let mut ctx = RenderContext::new();
66        let mut kernel: Vec<String> = Vec::new();
67        let mut buffer_args: Vec<BufferArg> = Vec::new();
68        let mut var_names: Vec<String> = Vec::new();
69
70        let mut buffers: Vec<Arc<UOp>> = Vec::new();
71        let mut variables: Vec<Arc<UOp>> = Vec::new();
72
73        for node in &nodes {
74            match node.op() {
75                Op::DefineGlobal(_) => {
76                    buffers.push(node.clone());
77                }
78                Op::DefineVar { .. } => {
79                    variables.push(node.clone());
80                }
81                _ => {}
82            }
83        }
84
85        buffers.sort_by_key(|b| if let Op::DefineGlobal(id) = b.op() { *id } else { usize::MAX });
86
87        let thread_info: Option<(Arc<UOp>, usize)> = nodes.iter().find_map(|n| {
88            if let Op::Range { axis_type, end, .. } = n.op()
89                && matches!(axis_type, AxisType::Thread)
90                && let Op::Const(cv) = end.op()
91                && let ConstValue::Int(count) = cv.0
92            {
93                return Some((n.clone(), count as usize));
94            }
95            None
96        });
97
98        let has_threading = thread_info.is_some();
99        let thread_count = thread_info.as_ref().map(|(_, c)| *c).unwrap_or(1);
100
101        for (i, buf) in buffers.iter().enumerate() {
102            if let Op::DefineGlobal(id) = buf.op() {
103                let is_output = is_output_buffer(buf, &nodes);
104                buffer_args.push(BufferArg { index: *id, name: format!("data{i}"), dtype: buf.dtype(), is_output });
105            }
106        }
107
108        for var in &variables {
109            if let Op::DefineVar { name, .. } = var.op() {
110                var_names.push(name.clone());
111            }
112        }
113        if has_threading {
114            var_names.push("thread_id".to_string());
115        }
116
117        // -- Build function parameters --
118        let mut inner_params: Vec<String> = Vec::new();
119
120        // Buffer pointer parameters
121        for (i, buf) in buffers.iter().enumerate() {
122            inner_params.push(format!("ptr noalias align 32 %buf{i}"));
123            ctx.register(buf.id, format!("%buf{i}"));
124        }
125
126        // Variable parameters
127        for var in &variables {
128            let var_base_name =
129                if let Op::DefineVar { name, .. } = var.op() { name.clone() } else { "var".to_string() };
130            let var_dtype = var.dtype();
131            let var_dtype_str = ldt(&var_dtype);
132            inner_params.push(format!("{var_dtype_str} %{var_base_name}"));
133            ctx.register(var.id, format!("%{var_base_name}"));
134        }
135
136        // Thread ID parameter
137        if let Some((thread_range, _)) = &thread_info {
138            let range_dtype = thread_range.dtype();
139            let range_dtype_str = ldt(&range_dtype);
140            inner_params.push(format!("{range_dtype_str} %thread_id"));
141
142            if let Op::Range { axis_id, .. } = thread_range.op() {
143                ctx.register(thread_range.id, "%thread_id".to_string());
144                ctx.register_range(axis_id.value(), "%thread_id".to_string());
145            }
146        }
147
148        // -- Build function body --
149        kernel.push("  ; Reduction accumulators".to_string());
150        for node in &nodes {
151            if let Op::Reduce { reduce_op, .. } = node.op() {
152                let dtype = ldt(&node.dtype());
153                let identity = reduce_identity(*reduce_op, &node.dtype());
154                let acc_name = format!("%reduce_{}", node.id);
155                kernel.push(format!("  {acc_name} = alloca {dtype}"));
156                kernel.push(format!("  store {dtype} {identity}, ptr {acc_name}"));
157                ctx.register(node.id, acc_name);
158            }
159        }
160        kernel.push("".to_string());
161
162        for node in &nodes {
163            match node.op() {
164                Op::Const(cv) => {
165                    let val = crate::llvm::common::lconst(&cv.0, &node.dtype());
166                    ctx.register(node.id, val);
167                }
168                Op::VConst { .. } => {
169                    ctx.name(node);
170                }
171                _ => {}
172            }
173        }
174
175        for node in &nodes {
176            if let Op::Range { axis_id, axis_type, .. } = node.op()
177                && !matches!(axis_type, AxisType::Thread)
178            {
179                let name = format!("%r{}", axis_id.value());
180                ctx.register(node.id, name);
181            }
182        }
183
184        for node in &nodes {
185            if let Op::Range { axis_type, .. } = node.op()
186                && matches!(axis_type, AxisType::Thread)
187            {
188                continue;
189            }
190            render_uop(node, &mut ctx, &mut kernel);
191        }
192
193        kernel.push("  ret void".to_string());
194
195        let ir = format!(
196            r#"; ModuleID = '{kernel_name}'
197source_filename = "{kernel_name}"
198
199{intrinsics}
200
201define void @{kernel_name}({inner_params}) #0 {{
202entry:
203{inner_body}
204}}
205
206attributes #0 = {{ nounwind "no-builtins" "no-trapping-math"="true" }}
207"#,
208            intrinsics = generate_intrinsic_declarations(&kernel),
209            inner_params = inner_params.join(", "),
210            inner_body = kernel.join("\n"),
211        );
212
213        tracing::debug!(generated_code = ir, "llvm codegen: final generated code");
214
215        let mut result = RenderedKernel::new(ir, kernel_name.to_string());
216        result.buffer_args = buffer_args;
217        result.var_names = var_names;
218
219        if thread_count > 1 {
220            result.global_size = Some([thread_count, 1, 1]);
221            result.local_size = Some([1, 1, 1]);
222        }
223
224        Ok(result)
225    }
226
227    fn backend_name(&self) -> &str {
228        "llvm-text"
229    }
230
231    fn decompositor(&self) -> Option<TypedPatternMatcher<()>> {
232        None
233    }
234}
235
236fn is_output_buffer(def_global: &Arc<UOp>, nodes: &[Arc<UOp>]) -> bool {
237    let buffer_id = def_global.id;
238
239    for node in nodes {
240        if let Some(buffer) = node.store_buffer() {
241            if buffer.id == buffer_id {
242                return true;
243            }
244            if let Op::Index { buffer: idx_buf, .. } = buffer.op()
245                && idx_buf.id == buffer_id
246            {
247                return true;
248            }
249        }
250    }
251    false
252}
253
254fn mangle_type(llvm_type: &str) -> String {
255    match llvm_type {
256        "float" => "f32".to_string(),
257        "double" => "f64".to_string(),
258        "half" => "f16".to_string(),
259        "i8" => "i8".to_string(),
260        "i16" => "i16".to_string(),
261        "i32" => "i32".to_string(),
262        "i64" => "i64".to_string(),
263        _ if llvm_type.starts_with('<') && llvm_type.ends_with('>') => {
264            let inner = &llvm_type[1..llvm_type.len() - 1];
265            let parts: Vec<&str> = inner.split(" x ").collect();
266            if parts.len() == 2 {
267                let count = parts[0].trim();
268                let base = mangle_type(parts[1].trim());
269                format!("v{count}{base}")
270            } else {
271                llvm_type.to_string()
272            }
273        }
274        _ => llvm_type.to_string(),
275    }
276}
277
278fn generate_intrinsic_declarations(kernel: &[String]) -> String {
279    let mut decls = Vec::new();
280    let kernel_str = kernel.join("\n");
281
282    for intrinsic in &[
283        "sqrt", "exp", "exp2", "log", "log2", "sin", "cos", "pow", "fabs", "floor", "ceil", "trunc", "round", "maxnum",
284        "minnum", "fmuladd", "erf",
285    ] {
286        for llvm_type in
287            &["float", "double", "half", "<2 x float>", "<4 x float>", "<8 x float>", "<2 x double>", "<4 x double>"]
288        {
289            let mangled = mangle_type(llvm_type);
290            let pattern = format!("@llvm.{intrinsic}.{mangled}");
291            if kernel_str.contains(&pattern) {
292                let decl = match *intrinsic {
293                    "fmuladd" => format!(
294                        "declare {llvm_type} @llvm.{intrinsic}.{mangled}({llvm_type}, {llvm_type}, {llvm_type})"
295                    ),
296                    "pow" | "maxnum" | "minnum" => {
297                        format!("declare {llvm_type} @llvm.{intrinsic}.{mangled}({llvm_type}, {llvm_type})")
298                    }
299                    _ => format!("declare {llvm_type} @llvm.{intrinsic}.{mangled}({llvm_type})"),
300                };
301                decls.push(decl);
302            }
303        }
304    }
305
306    for bits in &["i8", "i16", "i32", "i64"] {
307        let pattern = format!("@llvm.abs.{bits}");
308        if kernel_str.contains(&pattern) {
309            decls.push(format!("declare {bits} @llvm.abs.{bits}({bits}, i1)"));
310        }
311    }
312
313    decls.join("\n")
314}
315
316pub fn render(uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
317    let renderer = LlvmTextRenderer::new();
318    renderer.render(uop, name)
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use morok_dtype::{AddrSpace, DType};
325    use morok_ir::{BinaryOp, Op};
326
327    #[test]
328    fn test_simple_add() {
329        let a = UOp::define_global(0, DType::Float32.ptr(Some(1), AddrSpace::Global));
330        let b = UOp::define_global(1, DType::Float32.ptr(Some(1), AddrSpace::Global));
331        let out = UOp::define_global(2, DType::Float32.ptr(Some(1), AddrSpace::Global));
332
333        let idx = UOp::index_const(0);
334        let a_idx = UOp::index().buffer(a.clone()).indices(vec![idx.clone()]).call().unwrap();
335        let b_idx = UOp::index().buffer(b.clone()).indices(vec![idx.clone()]).call().unwrap();
336        let out_idx = UOp::index().buffer(out.clone()).indices(vec![idx.clone()]).call().unwrap();
337
338        let a_load = UOp::load().buffer(a.clone()).index(a_idx).call();
339        let b_load = UOp::load().buffer(b.clone()).index(b_idx).call();
340
341        let add = UOp::new(Op::Binary(BinaryOp::Add, a_load, b_load), DType::Float32);
342
343        let store = out_idx.store(add);
344        let sink = UOp::sink(vec![store]);
345
346        let result = render(&sink, Some("test_add")).unwrap();
347        println!("{}", result.code);
348
349        assert!(result.code.contains("define void @test_add("));
350        assert!(result.code.contains("noalias align 32"));
351        assert!(!result.code.contains("_inner"));
352        assert!(!result.code.contains("ptr %args"));
353        assert!(result.code.contains("fadd"));
354        assert!(result.code.contains("load"));
355        assert!(result.code.contains("store"));
356    }
357}