Skip to main content

morok_codegen/c/
mod.rs

1//! C source code generation backend.
2//!
3//! Generates C source code from linearized UOp IR, suitable for compilation
4//! with `clang -shared -O2` and loading via `dlopen`.
5//!
6//! # Kernel Signature
7//!
8//! ```c
9//! void kernel(void** args, long long* vars);
10//! ```
11//! - `args[i]` = buffer pointer (order from `DefineGlobal` index)
12//! - `vars[i]` = i64 variable value (order from `var_names`)
13//! - Thread ID as last var when `global_size[0] > 1`
14
15mod amx;
16pub mod ops;
17pub mod types;
18
19use std::sync::Arc;
20
21use morok_ir::pattern::TypedPatternMatcher;
22use morok_ir::rewrite::graph_rewrite_bottom_up;
23use morok_ir::{AxisType, Op, prelude::*};
24use morok_schedule::linearize::{line_rewrite_cleanups, linearize_with_cfg};
25use morok_schedule::rangeify::patterns::pm_bool_devectorize;
26
27use crate::{BufferArg, RenderedKernel, Result};
28
29use self::ops::{CContext, count_references, render_uop};
30use self::types::{c_const, c_dtype, c_reduce_identity, c_vconst, collect_vector_typedefs};
31
32/// C source code renderer for CPU execution via clang.
33pub struct CRenderer;
34
35impl CRenderer {
36    pub fn new() -> Self {
37        Self
38    }
39}
40
41impl Default for CRenderer {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47impl crate::Renderer for CRenderer {
48    fn render(&self, uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
49        let kernel_name = name.unwrap_or("kernel");
50
51        // Apply pm_bool_devectorize as safety fallback
52        let uop = graph_rewrite_bottom_up(pm_bool_devectorize(), uop.clone(), &mut ());
53
54        tracing::debug!(ast_after_pm_bool_devectorize = %uop.tree(), "c codegen: after pm_bool_devectorize");
55
56        // Linearize the UOp DAG
57        let nodes = linearize_with_cfg(uop);
58
59        // Apply line rewrite cleanups (gated stores → if/store/endif)
60        let nodes = line_rewrite_cleanups(nodes);
61
62        for (i, node) in nodes.iter().enumerate() {
63            tracing::debug!(position = i, op = node.op().as_ref(), id = node.id, "c linearized node");
64        }
65
66        // Collect buffers and variables from linearized stream
67        let mut buffers: Vec<Arc<UOp>> = Vec::new();
68        let mut variables: Vec<Arc<UOp>> = Vec::new();
69
70        for node in &nodes {
71            match node.op() {
72                Op::DefineGlobal(_) => buffers.push(node.clone()),
73                Op::DefineVar { .. } => variables.push(node.clone()),
74                _ => {}
75            }
76        }
77
78        buffers.sort_by_key(|b| if let Op::DefineGlobal(id) = b.op() { *id } else { usize::MAX });
79
80        // Detect threading
81        let thread_info: Option<(Arc<UOp>, usize)> = nodes.iter().find_map(|n| {
82            if let Op::Range { axis_type, end, .. } = n.op()
83                && matches!(axis_type, AxisType::Thread)
84                && let Op::Const(cv) = end.op()
85                && let ConstValue::Int(count) = cv.0
86            {
87                return Some((n.clone(), count as usize));
88            }
89            None
90        });
91
92        let has_threading = thread_info.is_some();
93        let thread_count = thread_info.as_ref().map(|(_, c)| *c).unwrap_or(1);
94
95        // Build buffer args metadata
96        let mut buffer_args: Vec<BufferArg> = Vec::new();
97        for (i, buf) in buffers.iter().enumerate() {
98            if let Op::DefineGlobal(id) = buf.op() {
99                let is_output = is_output_buffer(buf, &nodes);
100                buffer_args.push(BufferArg { index: *id, name: format!("data{i}"), dtype: buf.dtype(), is_output });
101            }
102        }
103
104        // Build var_names
105        let mut var_names: Vec<String> = Vec::new();
106        for var in &variables {
107            if let Op::DefineVar { name, .. } = var.op() {
108                var_names.push(name.clone());
109            }
110        }
111        if has_threading {
112            var_names.push("thread_id".to_string());
113        }
114
115        // Count references for SSA inlining decisions
116        let ref_counts = count_references(&nodes);
117        let mut ctx = CContext::new(ref_counts);
118
119        // === Build C source ===
120        let mut code_lines: Vec<String> = Vec::new();
121
122        // Includes
123        code_lines.push("#include <math.h>".to_string());
124        code_lines.push("#include <stdbool.h>".to_string());
125        code_lines.push("".to_string());
126
127        // Vector typedefs
128        let typedefs = collect_vector_typedefs(&nodes);
129        for td in &typedefs {
130            code_lines.push(td.clone());
131        }
132        if !typedefs.is_empty() {
133            code_lines.push("".to_string());
134        }
135
136        // WMMA (AMX) defines and static functions
137        let wmma_defines = amx::collect_wmma_defines(&nodes);
138        for def in &wmma_defines {
139            code_lines.push(def.clone());
140        }
141        if !wmma_defines.is_empty() {
142            code_lines.push("".to_string());
143        }
144
145        // Function signature
146        code_lines.push(format!("void {kernel_name}(void** args, long long* vars) {{"));
147
148        // Buffer pointer casts
149        for (i, buf) in buffers.iter().enumerate() {
150            let buf_dtype = buf.dtype();
151            let elem_type = match &buf_dtype {
152                DType::Ptr { base, .. } => c_dtype(base),
153                _ => c_dtype(&buf_dtype),
154            };
155            let name = format!("data{i}");
156            code_lines.push(format!("  {elem_type}* {name} = ({elem_type}*)args[{i}];"));
157            ctx.register(buf.id, name);
158        }
159
160        // Local memory allocations (stack arrays on CPU)
161        for node in &nodes {
162            if let Op::DefineLocal(id) = node.op() {
163                let (base, size) = match node.dtype() {
164                    DType::Ptr { base, size, .. } => (c_dtype(&base), size.unwrap_or(1)),
165                    other => (c_dtype(&other), 1),
166                };
167                let name = format!("local{id}");
168                code_lines.push(format!("  {base} {name}[{size}];"));
169                ctx.register(node.id, name);
170            }
171        }
172
173        // Variable loads
174        for (i, var) in variables.iter().enumerate() {
175            if let Op::DefineVar { name, .. } = var.op() {
176                let var_dtype = &var.dtype();
177                let c_type = c_dtype(var_dtype);
178                if c_type == "long long" {
179                    code_lines.push(format!("  long long {name} = vars[{i}];"));
180                } else {
181                    code_lines.push(format!("  {c_type} {name} = ({c_type})vars[{i}];"));
182                }
183                ctx.register(var.id, name.clone());
184            }
185        }
186
187        // Thread ID
188        if let Some((thread_range, _)) = &thread_info {
189            let thread_idx = variables.len();
190            let range_dtype = &thread_range.dtype();
191            let c_type = c_dtype(range_dtype);
192            if c_type == "long long" {
193                code_lines.push(format!("  long long thread_id = vars[{thread_idx}];"));
194            } else {
195                code_lines.push(format!("  {c_type} thread_id = ({c_type})vars[{thread_idx}];"));
196            }
197
198            if let Op::Range { axis_id, .. } = thread_range.op() {
199                ctx.register(thread_range.id, "thread_id".to_string());
200                let _ = axis_id; // Thread range is registered above
201            }
202        }
203
204        code_lines.push("".to_string());
205
206        // Reduction accumulator declarations (need to be in outer scope)
207        for node in &nodes {
208            if let Op::Reduce { reduce_op, ranges, .. } = node.op() {
209                if ranges.is_empty() {
210                    continue;
211                }
212                let dtype = &node.dtype();
213                let c_type = c_dtype(dtype);
214                let identity = c_reduce_identity(*reduce_op, dtype);
215                let acc_name = format!("acc{}", node.id);
216                code_lines.push(format!("  {c_type} {acc_name} = {identity};"));
217                // Pre-register so the ops.rs render_uop finds it
218                ctx.register(node.id, acc_name);
219            }
220        }
221
222        // Register constants
223        for node in &nodes {
224            match node.op() {
225                Op::Const(cv) => {
226                    let val = c_const(&cv.0, &node.dtype());
227                    ctx.register(node.id, val);
228                }
229                Op::VConst { values } => {
230                    let val = c_vconst(values, &node.dtype());
231                    ctx.register(node.id, val);
232                }
233                _ => {}
234            }
235        }
236
237        // Pre-register range variable names
238        for node in &nodes {
239            if let Op::Range { axis_id, axis_type, .. } = node.op()
240                && !matches!(axis_type, AxisType::Thread)
241            {
242                let name = format!("ridx{}", axis_id.value());
243                ctx.register(node.id, name);
244            }
245        }
246
247        // Render all instructions
248        let mut kernel_body: Vec<String> = Vec::new();
249        for node in &nodes {
250            if let Op::Range { axis_type, .. } = node.op()
251                && matches!(axis_type, AxisType::Thread)
252            {
253                continue;
254            }
255            render_uop(node, &mut ctx, &mut kernel_body);
256        }
257
258        code_lines.extend(kernel_body);
259        code_lines.push("}".to_string());
260        code_lines.push("".to_string());
261
262        let code = code_lines.join("\n");
263
264        tracing::debug!(generated_c = code, "c codegen: final generated code");
265
266        let mut result = RenderedKernel::new(code, kernel_name.to_string());
267        result.buffer_args = buffer_args;
268        result.var_names = var_names;
269
270        if thread_count > 1 {
271            result.global_size = Some([thread_count, 1, 1]);
272            result.local_size = Some([1, 1, 1]);
273        }
274
275        Ok(result)
276    }
277
278    fn backend_name(&self) -> &str {
279        "clang"
280    }
281
282    fn decompositor(&self) -> Option<TypedPatternMatcher<()>> {
283        // C has math.h with sqrt, exp, sin, etc. — no decomposition needed
284        // for standard transcendentals. Threefry is handled by XOR in render.
285        None
286    }
287}
288
289fn is_output_buffer(def_global: &Arc<UOp>, nodes: &[Arc<UOp>]) -> bool {
290    let buffer_id = def_global.id;
291
292    for node in nodes {
293        if let Some(buffer) = node.store_buffer() {
294            if buffer.id == buffer_id {
295                return true;
296            }
297            if let Op::Index { buffer: idx_buf, .. } = buffer.op()
298                && idx_buf.id == buffer_id
299            {
300                return true;
301            }
302        }
303    }
304    false
305}
306
307/// Public render function for the C backend.
308pub fn render(uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
309    let renderer = CRenderer::new();
310    crate::Renderer::render(&renderer, uop, name)
311}