Skip to main content

morok_codegen/c/
mod.rs

1//! C source code generation backend.
2//!
3//! Generates C source code from linearized UOp IR, suitable for compilation
4//! with `clang -shared -O2` and loading via `dlopen`.
5//!
6//! # Kernel Signature
7//!
8//! Emits a single function with typed `restrict` pointer params and const variable params:
9//!
10//! ```c
11//! void kernel(float* restrict data0, const int N) { /* body */ }
12//! ```
13
14mod amx;
15pub mod ops;
16pub mod types;
17
18use std::sync::Arc;
19
20use morok_ir::pattern::TypedPatternMatcher;
21use morok_ir::rewrite::graph_rewrite_bottom_up;
22use morok_ir::{AxisType, Op, prelude::*};
23use morok_schedule::linearize::{line_rewrite_cleanups, linearize_with_cfg};
24use morok_schedule::rangeify::patterns::pm_bool_devectorize;
25
26use crate::{BufferArg, RenderedKernel, Result};
27
28use self::ops::{CContext, count_references, render_uop};
29use self::types::{c_const, c_dtype, c_reduce_identity, c_vconst, collect_vector_typedefs};
30
31/// C source code renderer for CPU execution via clang.
32pub struct CRenderer;
33
34impl CRenderer {
35    pub fn new() -> Self {
36        Self
37    }
38}
39
40impl Default for CRenderer {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl crate::Renderer for CRenderer {
47    fn render(&self, uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
48        let kernel_name = name.unwrap_or("kernel");
49
50        // Apply pm_bool_devectorize as safety fallback
51        let uop = graph_rewrite_bottom_up(pm_bool_devectorize(), uop.clone(), &mut ());
52
53        tracing::debug!(ast_after_pm_bool_devectorize = %uop.tree(), "c codegen: after pm_bool_devectorize");
54
55        // Linearize the UOp DAG
56        let nodes = linearize_with_cfg(uop);
57
58        // Apply line rewrite cleanups (gated stores → if/store/endif)
59        let nodes = line_rewrite_cleanups(nodes);
60
61        for (i, node) in nodes.iter().enumerate() {
62            tracing::debug!(position = i, op = node.op().as_ref(), id = node.id, "c linearized node");
63        }
64
65        // Collect buffers and variables from linearized stream
66        let mut buffers: Vec<Arc<UOp>> = Vec::new();
67        let mut variables: Vec<Arc<UOp>> = Vec::new();
68
69        for node in &nodes {
70            match node.op() {
71                Op::DefineGlobal(_) => buffers.push(node.clone()),
72                Op::DefineVar { .. } => variables.push(node.clone()),
73                _ => {}
74            }
75        }
76
77        buffers.sort_by_key(|b| if let Op::DefineGlobal(id) = b.op() { *id } else { usize::MAX });
78
79        // Detect threading
80        let thread_info: Option<(Arc<UOp>, usize)> = nodes.iter().find_map(|n| {
81            if let Op::Range { axis_type, end, .. } = n.op()
82                && matches!(axis_type, AxisType::Thread)
83                && let Op::Const(cv) = end.op()
84                && let ConstValue::Int(count) = cv.0
85            {
86                return Some((n.clone(), count as usize));
87            }
88            None
89        });
90
91        let has_threading = thread_info.is_some();
92        let thread_count = thread_info.as_ref().map(|(_, c)| *c).unwrap_or(1);
93
94        // Build buffer args metadata
95        let mut buffer_args: Vec<BufferArg> = Vec::new();
96        for (i, buf) in buffers.iter().enumerate() {
97            if let Op::DefineGlobal(id) = buf.op() {
98                let is_output = is_output_buffer(buf, &nodes);
99                buffer_args.push(BufferArg { index: *id, name: format!("data{i}"), dtype: buf.dtype(), is_output });
100            }
101        }
102
103        // Build var_names
104        let mut var_names: Vec<String> = Vec::new();
105        for var in &variables {
106            if let Op::DefineVar { name, .. } = var.op() {
107                var_names.push(name.clone());
108            }
109        }
110        if has_threading {
111            var_names.push("thread_id".to_string());
112        }
113
114        // Count references for SSA inlining decisions
115        let ref_counts = count_references(&nodes);
116        let mut ctx = CContext::new(ref_counts);
117
118        // === Build C source ===
119        let mut code_lines: Vec<String> = Vec::new();
120
121        // Includes
122        code_lines.push("#include <stdbool.h>".to_string());
123        code_lines.push("".to_string());
124
125        // Vector typedefs
126        let typedefs = collect_vector_typedefs(&nodes);
127        for td in &typedefs {
128            code_lines.push(td.clone());
129        }
130        if !typedefs.is_empty() {
131            code_lines.push("".to_string());
132        }
133
134        // WMMA (AMX) defines and static functions
135        let wmma_defines = amx::collect_wmma_defines(&nodes);
136        for def in &wmma_defines {
137            code_lines.push(def.clone());
138        }
139        if !wmma_defines.is_empty() {
140            code_lines.push("".to_string());
141        }
142
143        // Build typed function params
144        let mut params: Vec<String> = Vec::new();
145
146        // Buffer parameters
147        for (i, buf) in buffers.iter().enumerate() {
148            let buf_dtype = buf.dtype();
149            let elem_type = match &buf_dtype {
150                DType::Ptr { base, .. } => c_dtype(base),
151                _ => c_dtype(&buf_dtype),
152            };
153            let name = format!("data{i}");
154            params.push(format!("{elem_type}* restrict {name}"));
155            ctx.register(buf.id, name);
156        }
157
158        // Variable parameters
159        for var in &variables {
160            if let Op::DefineVar { name, .. } = var.op() {
161                let var_dtype = &var.dtype();
162                let c_type = c_dtype(var_dtype);
163                params.push(format!("const {c_type} {name}"));
164                ctx.register(var.id, name.clone());
165            }
166        }
167
168        // Thread ID parameter
169        if let Some((thread_range, _)) = &thread_info {
170            let range_dtype = &thread_range.dtype();
171            let c_type = c_dtype(range_dtype);
172            params.push(format!("const {c_type} thread_id"));
173            ctx.register(thread_range.id, "thread_id".to_string());
174        }
175
176        // Function signature
177        code_lines.push(format!("void {kernel_name}({}) {{", params.join(", ")));
178
179        // Local memory allocations (stack arrays on CPU)
180        for node in &nodes {
181            if let Op::DefineLocal(id) = node.op() {
182                let (base, size) = match node.dtype() {
183                    DType::Ptr { base, size, .. } => (c_dtype(&base), size.unwrap_or(1)),
184                    other => (c_dtype(&other), 1),
185                };
186                let name = format!("local{id}");
187                code_lines.push(format!("  {base} {name}[{size}];"));
188                ctx.register(node.id, name);
189            }
190        }
191
192        code_lines.push("".to_string());
193
194        // Reduction accumulator declarations (need to be in outer scope)
195        for node in &nodes {
196            if let Op::Reduce { reduce_op, ranges, .. } = node.op() {
197                if ranges.is_empty() {
198                    continue;
199                }
200                let dtype = &node.dtype();
201                let c_type = c_dtype(dtype);
202                let identity = c_reduce_identity(*reduce_op, dtype);
203                let acc_name = format!("acc{}", node.id);
204                code_lines.push(format!("  {c_type} {acc_name} = {identity};"));
205                // Pre-register so the ops.rs render_uop finds it
206                ctx.register(node.id, acc_name);
207            }
208        }
209
210        // Register constants
211        for node in &nodes {
212            match node.op() {
213                Op::Const(cv) => {
214                    let val = c_const(&cv.0, &node.dtype());
215                    ctx.register(node.id, val);
216                }
217                Op::VConst { values } => {
218                    let val = c_vconst(values, &node.dtype());
219                    ctx.register(node.id, val);
220                }
221                _ => {}
222            }
223        }
224
225        // Pre-register range variable names
226        for node in &nodes {
227            if let Op::Range { axis_id, axis_type, .. } = node.op()
228                && !matches!(axis_type, AxisType::Thread)
229            {
230                let name = format!("ridx{}", axis_id.value());
231                ctx.register(node.id, name);
232            }
233        }
234
235        // Render all instructions
236        let mut kernel_body: Vec<String> = Vec::new();
237        for node in &nodes {
238            if let Op::Range { axis_type, .. } = node.op()
239                && matches!(axis_type, AxisType::Thread)
240            {
241                continue;
242            }
243            render_uop(node, &mut ctx, &mut kernel_body);
244        }
245
246        code_lines.extend(kernel_body);
247        code_lines.push("}".to_string());
248        code_lines.push("".to_string());
249
250        let code = code_lines.join("\n");
251
252        tracing::debug!(generated_c = code, "c codegen: final generated code");
253
254        let mut result = RenderedKernel::new(code, kernel_name.to_string());
255        result.buffer_args = buffer_args;
256        result.var_names = var_names;
257
258        if thread_count > 1 {
259            result.global_size = Some([thread_count, 1, 1]);
260            result.local_size = Some([1, 1, 1]);
261        }
262
263        Ok(result)
264    }
265
266    fn backend_name(&self) -> &str {
267        "clang"
268    }
269
270    fn decompositor(&self) -> Option<TypedPatternMatcher<()>> {
271        // C uses __builtin_ math functions (sqrt, exp, sin, etc.) — no decomposition needed.
272        // Threefry is handled by XOR in render.
273        None
274    }
275}
276
277fn is_output_buffer(def_global: &Arc<UOp>, nodes: &[Arc<UOp>]) -> bool {
278    let buffer_id = def_global.id;
279
280    for node in nodes {
281        if let Some(buffer) = node.store_buffer() {
282            if buffer.id == buffer_id {
283                return true;
284            }
285            if let Op::Index { buffer: idx_buf, .. } = buffer.op()
286                && idx_buf.id == buffer_id
287            {
288                return true;
289            }
290        }
291    }
292    false
293}
294
295/// Public render function for the C backend.
296pub fn render(uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
297    let renderer = CRenderer::new();
298    crate::Renderer::render(&renderer, uop, name)
299}