Skip to main content

morok_codegen/c/
mod.rs

1//! C source code generation backend.
2//!
3//! Generates C source code from linearized UOp IR, suitable for compilation
4//! with `clang -shared -O2` and loading via `dlopen`.
5//!
6//! # Kernel Signature
7//!
8//! Emits a single function with typed `restrict` pointer params and const variable params:
9//!
10//! ```c
11//! void kernel(float* restrict data0, const int N) { /* body */ }
12//! ```
13
14mod amx;
15pub mod ops;
16pub mod types;
17
18use std::collections::{HashMap, HashSet};
19use std::sync::Arc;
20
21use morok_ir::pattern::TypedPatternMatcher;
22use morok_ir::{AxisType, Op, prelude::*};
23use morok_schedule::linearize::{line_rewrite_cleanups, linearize_with_cfg};
24
25use crate::common::is_output_buffer;
26use crate::{BufferArg, RenderedKernel, Result};
27
28use self::ops::{CContext, count_references, render_uop};
29use self::types::{c_const, c_dtype, c_reduce_identity, c_vconst, collect_vector_typedefs};
30
31/// C source code renderer for CPU execution via clang.
32pub struct CRenderer;
33
34impl CRenderer {
35    pub fn new() -> Self {
36        Self
37    }
38}
39
40impl Default for CRenderer {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46impl crate::Renderer for CRenderer {
47    fn render(&self, uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
48        let kernel_name = name.unwrap_or("kernel");
49
50        let nodes = linearize_with_cfg(uop.clone());
51
52        // Apply line rewrite cleanups (gated stores → if/store/endif)
53        let nodes = line_rewrite_cleanups(nodes);
54
55        for (i, node) in nodes.iter().enumerate() {
56            tracing::debug!(position = i, op = node.op().as_ref(), id = node.id, "c linearized node");
57        }
58
59        // Collect buffers and variables from linearized stream
60        let mut buffers: Vec<Arc<UOp>> = Vec::new();
61        let mut variables: Vec<Arc<UOp>> = Vec::new();
62
63        for node in &nodes {
64            match node.op() {
65                Op::Param { device: None, .. } => buffers.push(node.clone()),
66                Op::DefineVar { .. } => variables.push(node.clone()),
67                _ => {}
68            }
69        }
70
71        buffers.sort_by_key(|b| if let Op::Param { slot, device: None, .. } = b.op() { *slot } else { usize::MAX });
72
73        // Detect threading
74        let thread_info: Option<(Arc<UOp>, usize)> = nodes.iter().find_map(|n| {
75            if let Op::Range { axis_type, end, .. } = n.op()
76                && matches!(axis_type, AxisType::Thread)
77                && let Op::Const(cv) = end.op()
78                && let ConstValue::Int(count) = cv.0
79            {
80                return Some((n.clone(), count as usize));
81            }
82            None
83        });
84
85        let has_threading = thread_info.is_some();
86        let thread_count = thread_info.as_ref().map(|(_, c)| *c).unwrap_or(1);
87
88        // Build buffer args metadata
89        let mut buffer_args: Vec<BufferArg> = Vec::new();
90        for (i, buf) in buffers.iter().enumerate() {
91            if let Op::Param { slot, device: None, .. } = buf.op() {
92                let is_output = is_output_buffer(buf, &nodes);
93                buffer_args.push(BufferArg { index: *slot, name: format!("data{i}"), dtype: buf.dtype(), is_output });
94            }
95        }
96
97        // Build var_names
98        let mut var_names: Vec<String> = Vec::new();
99        for var in &variables {
100            if let Op::DefineVar { name, .. } = var.op() {
101                var_names.push(name.clone());
102            }
103        }
104        if has_threading {
105            var_names.push("thread_id".to_string());
106        }
107
108        // Count references for SSA inlining decisions
109        let ref_counts = count_references(&nodes);
110        let scope_escaping = find_scope_escaping_vars(&nodes, &ref_counts);
111        let mut ctx = CContext::new(ref_counts, scope_escaping);
112
113        // === Build C source ===
114        let mut code_lines: Vec<String> = Vec::new();
115
116        // Includes
117        code_lines.push("#include <stdbool.h>".to_string());
118        code_lines.push("".to_string());
119
120        // Vector typedefs
121        let typedefs = collect_vector_typedefs(&nodes);
122        for td in &typedefs {
123            code_lines.push(td.clone());
124        }
125        if !typedefs.is_empty() {
126            code_lines.push("".to_string());
127        }
128
129        // WMMA (AMX) defines and static functions
130        let wmma_defines = amx::collect_wmma_defines(&nodes);
131        for def in &wmma_defines {
132            code_lines.push(def.clone());
133        }
134        if !wmma_defines.is_empty() {
135            code_lines.push("".to_string());
136        }
137
138        // Build typed function params
139        let mut params: Vec<String> = Vec::new();
140
141        // Buffer parameters
142        for (i, buf) in buffers.iter().enumerate() {
143            let buf_dtype = buf.dtype();
144            let elem_type = match &buf_dtype {
145                DType::Ptr { base, .. } => c_dtype(base),
146                _ => c_dtype(&buf_dtype),
147            };
148            let name = format!("data{i}");
149            params.push(format!("{elem_type}* restrict {name}"));
150            ctx.register(buf.id, name);
151        }
152
153        // Variable parameters
154        for var in &variables {
155            if let Op::DefineVar { name, .. } = var.op() {
156                let var_dtype = &var.dtype();
157                let c_type = c_dtype(var_dtype);
158                params.push(format!("const {c_type} {name}"));
159                ctx.register(var.id, name.clone());
160            }
161        }
162
163        // Thread ID parameter
164        if let Some((thread_range, _)) = &thread_info {
165            let range_dtype = &thread_range.dtype();
166            let c_type = c_dtype(range_dtype);
167            params.push(format!("const {c_type} thread_id"));
168            ctx.register(thread_range.id, "thread_id".to_string());
169        }
170
171        // Function signature
172        code_lines.push(format!("void {kernel_name}({}) {{", params.join(", ")));
173
174        // Local memory allocations (stack arrays on CPU)
175        for node in &nodes {
176            if let Op::DefineLocal(id) = node.op() {
177                let (base, size) = match node.dtype() {
178                    DType::Ptr { base, size, .. } => (c_dtype(&base), size.unwrap_or(1)),
179                    other => (c_dtype(&other), 1),
180                };
181                let name = format!("local{id}");
182                code_lines.push(format!("  {base} {name}[{size}];"));
183                ctx.register(node.id, name);
184            }
185        }
186
187        code_lines.push("".to_string());
188
189        // Reduction accumulator declarations (need to be in outer scope)
190        for node in &nodes {
191            if let Op::Reduce { reduce_op, ranges, .. } = node.op() {
192                if ranges.is_empty() {
193                    continue;
194                }
195                let dtype = &node.dtype();
196                let c_type = c_dtype(dtype);
197                let identity = c_reduce_identity(*reduce_op, dtype);
198                let acc_name = format!("acc{}", node.id);
199                code_lines.push(format!("  {c_type} {acc_name} = {identity};"));
200                // Pre-register so the ops.rs render_uop finds it
201                ctx.register(node.id, acc_name);
202            }
203        }
204
205        // Register constants
206        for node in &nodes {
207            match node.op() {
208                Op::Const(cv) => {
209                    let val = c_const(&cv.0, &node.dtype());
210                    ctx.register(node.id, val);
211                }
212                Op::VConst { values } => {
213                    let val = c_vconst(values, &node.dtype());
214                    ctx.register(node.id, val);
215                }
216                _ => {}
217            }
218        }
219
220        // Pre-register range variable names
221        for node in &nodes {
222            if let Op::Range { axis_id, axis_type, .. } = node.op()
223                && !matches!(axis_type, AxisType::Thread)
224            {
225                let name = format!("ridx{}", axis_id.value());
226                ctx.register(node.id, name);
227            }
228        }
229
230        // Render all instructions
231        // Skip NOOP and GROUP — they are structural no-ops (Tinygrad cstyle.py:175)
232        let mut kernel_body: Vec<String> = Vec::new();
233        for node in &nodes {
234            if matches!(node.op(), Op::Noop | Op::Group { .. }) {
235                // Register with empty string so downstream UNROLL/CONTRACT can alias them.
236                // Matches LLVM backend behavior — these are structural no-ops.
237                ctx.register(node.id, String::new());
238                continue;
239            }
240            if let Op::Range { axis_type, .. } = node.op()
241                && matches!(axis_type, AxisType::Thread)
242            {
243                continue;
244            }
245            render_uop(node, &mut ctx, &mut kernel_body);
246        }
247
248        // Emit hoisted declarations for scope-escaping variables (before kernel body)
249        if !ctx.hoisted_declarations.is_empty() {
250            code_lines.append(&mut ctx.hoisted_declarations);
251        }
252        code_lines.extend(kernel_body);
253        code_lines.push("}".to_string());
254        code_lines.push("".to_string());
255
256        let code = code_lines.join("\n");
257
258        tracing::debug!(generated_c = code, "c codegen: final generated code");
259
260        let mut result = RenderedKernel::new(code, kernel_name.to_string());
261        result.buffer_args = buffer_args;
262        result.var_names = var_names;
263
264        if thread_count > 1 {
265            result.global_size = Some([thread_count, 1, 1]);
266            result.local_size = Some([1, 1, 1]);
267        }
268
269        Ok(result)
270    }
271
272    fn backend_name(&self) -> &str {
273        "clang"
274    }
275
276    fn decompositor(&self) -> Option<TypedPatternMatcher<()>> {
277        // C uses __builtin_ math functions (sqrt, exp, sin, etc.) — no decomposition needed.
278        // Threefry is handled by XOR in render.
279        None
280    }
281}
282
283/// Find variables that escape their declaration scope.
284///
285/// Walks the linearized instruction list tracking scope depth. A variable "escapes"
286/// if it's defined at a deeper scope than where it's used. Returns the set of UOp IDs
287/// that need function-scope declarations to avoid "use of undeclared identifier" errors.
288///
289/// This handles the case where pm_decomp creates sibling ENDs that share sub-DAG nodes.
290/// The linearizer places the shared node inside one loop, but another consumer is outside.
291fn find_scope_escaping_vars(nodes: &[Arc<UOp>], ref_counts: &HashMap<u64, usize>) -> HashSet<u64> {
292    let mut depth = 0usize;
293    let mut def_depth: HashMap<u64, usize> = HashMap::new();
294    let mut min_use_depth: HashMap<u64, usize> = HashMap::new();
295
296    for node in nodes {
297        // Track scope depth changes
298        match node.op() {
299            Op::Range { .. } | Op::If { .. } => {
300                // Definition of this node is at current depth (before entering)
301                if ref_counts.get(&node.id).copied().unwrap_or(0) > 1 {
302                    def_depth.entry(node.id).or_insert(depth);
303                }
304                // Record usages of sources at current depth
305                for src in node.op().sources() {
306                    min_use_depth.entry(src.id).and_modify(|d| *d = (*d).min(depth)).or_insert(depth);
307                }
308                depth += 1;
309                continue;
310            }
311            Op::End { .. } | Op::EndIf { .. } => {
312                depth = depth.saturating_sub(1);
313            }
314            _ => {}
315        }
316
317        // Record definition depth for multi-use values
318        if ref_counts.get(&node.id).copied().unwrap_or(0) > 1 {
319            def_depth.entry(node.id).or_insert(depth);
320        }
321
322        // Record minimum usage depth for all source operands
323        for src in node.op().sources() {
324            min_use_depth.entry(src.id).and_modify(|d| *d = (*d).min(depth)).or_insert(depth);
325        }
326    }
327
328    // Variables where any use is at a shallower depth than definition
329    def_depth
330        .into_iter()
331        .filter(|(id, def_d)| min_use_depth.get(id).copied().unwrap_or(*def_d) < *def_d)
332        .map(|(id, _)| id)
333        .collect()
334}
335
336/// Public render function for the C backend.
337pub fn render(uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
338    let renderer = CRenderer::new();
339    crate::Renderer::render(&renderer, uop, name)
340}