1mod amx;
16pub mod ops;
17pub mod types;
18
19use std::sync::Arc;
20
21use morok_ir::pattern::TypedPatternMatcher;
22use morok_ir::rewrite::graph_rewrite_bottom_up;
23use morok_ir::{AxisType, Op, prelude::*};
24use morok_schedule::linearize::{line_rewrite_cleanups, linearize_with_cfg};
25use morok_schedule::rangeify::patterns::pm_bool_devectorize;
26
27use crate::{BufferArg, RenderedKernel, Result};
28
29use self::ops::{CContext, count_references, render_uop};
30use self::types::{c_const, c_dtype, c_reduce_identity, c_vconst, collect_vector_typedefs};
31
32pub struct CRenderer;
34
35impl CRenderer {
36 pub fn new() -> Self {
37 Self
38 }
39}
40
41impl Default for CRenderer {
42 fn default() -> Self {
43 Self::new()
44 }
45}
46
47impl crate::Renderer for CRenderer {
48 fn render(&self, uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
49 let kernel_name = name.unwrap_or("kernel");
50
51 let uop = graph_rewrite_bottom_up(pm_bool_devectorize(), uop.clone(), &mut ());
53
54 tracing::debug!(ast_after_pm_bool_devectorize = %uop.tree(), "c codegen: after pm_bool_devectorize");
55
56 let nodes = linearize_with_cfg(uop);
58
59 let nodes = line_rewrite_cleanups(nodes);
61
62 for (i, node) in nodes.iter().enumerate() {
63 tracing::debug!(position = i, op = node.op().as_ref(), id = node.id, "c linearized node");
64 }
65
66 let mut buffers: Vec<Arc<UOp>> = Vec::new();
68 let mut variables: Vec<Arc<UOp>> = Vec::new();
69
70 for node in &nodes {
71 match node.op() {
72 Op::DefineGlobal(_) => buffers.push(node.clone()),
73 Op::DefineVar { .. } => variables.push(node.clone()),
74 _ => {}
75 }
76 }
77
78 buffers.sort_by_key(|b| if let Op::DefineGlobal(id) = b.op() { *id } else { usize::MAX });
79
80 let thread_info: Option<(Arc<UOp>, usize)> = nodes.iter().find_map(|n| {
82 if let Op::Range { axis_type, end, .. } = n.op()
83 && matches!(axis_type, AxisType::Thread)
84 && let Op::Const(cv) = end.op()
85 && let ConstValue::Int(count) = cv.0
86 {
87 return Some((n.clone(), count as usize));
88 }
89 None
90 });
91
92 let has_threading = thread_info.is_some();
93 let thread_count = thread_info.as_ref().map(|(_, c)| *c).unwrap_or(1);
94
95 let mut buffer_args: Vec<BufferArg> = Vec::new();
97 for (i, buf) in buffers.iter().enumerate() {
98 if let Op::DefineGlobal(id) = buf.op() {
99 let is_output = is_output_buffer(buf, &nodes);
100 buffer_args.push(BufferArg { index: *id, name: format!("data{i}"), dtype: buf.dtype(), is_output });
101 }
102 }
103
104 let mut var_names: Vec<String> = Vec::new();
106 for var in &variables {
107 if let Op::DefineVar { name, .. } = var.op() {
108 var_names.push(name.clone());
109 }
110 }
111 if has_threading {
112 var_names.push("thread_id".to_string());
113 }
114
115 let ref_counts = count_references(&nodes);
117 let mut ctx = CContext::new(ref_counts);
118
119 let mut code_lines: Vec<String> = Vec::new();
121
122 code_lines.push("#include <math.h>".to_string());
124 code_lines.push("#include <stdbool.h>".to_string());
125 code_lines.push("".to_string());
126
127 let typedefs = collect_vector_typedefs(&nodes);
129 for td in &typedefs {
130 code_lines.push(td.clone());
131 }
132 if !typedefs.is_empty() {
133 code_lines.push("".to_string());
134 }
135
136 let wmma_defines = amx::collect_wmma_defines(&nodes);
138 for def in &wmma_defines {
139 code_lines.push(def.clone());
140 }
141 if !wmma_defines.is_empty() {
142 code_lines.push("".to_string());
143 }
144
145 code_lines.push(format!("void {kernel_name}(void** args, long long* vars) {{"));
147
148 for (i, buf) in buffers.iter().enumerate() {
150 let buf_dtype = buf.dtype();
151 let elem_type = match &buf_dtype {
152 DType::Ptr { base, .. } => c_dtype(base),
153 _ => c_dtype(&buf_dtype),
154 };
155 let name = format!("data{i}");
156 code_lines.push(format!(" {elem_type}* {name} = ({elem_type}*)args[{i}];"));
157 ctx.register(buf.id, name);
158 }
159
160 for node in &nodes {
162 if let Op::DefineLocal(id) = node.op() {
163 let (base, size) = match node.dtype() {
164 DType::Ptr { base, size, .. } => (c_dtype(&base), size.unwrap_or(1)),
165 other => (c_dtype(&other), 1),
166 };
167 let name = format!("local{id}");
168 code_lines.push(format!(" {base} {name}[{size}];"));
169 ctx.register(node.id, name);
170 }
171 }
172
173 for (i, var) in variables.iter().enumerate() {
175 if let Op::DefineVar { name, .. } = var.op() {
176 let var_dtype = &var.dtype();
177 let c_type = c_dtype(var_dtype);
178 if c_type == "long long" {
179 code_lines.push(format!(" long long {name} = vars[{i}];"));
180 } else {
181 code_lines.push(format!(" {c_type} {name} = ({c_type})vars[{i}];"));
182 }
183 ctx.register(var.id, name.clone());
184 }
185 }
186
187 if let Some((thread_range, _)) = &thread_info {
189 let thread_idx = variables.len();
190 let range_dtype = &thread_range.dtype();
191 let c_type = c_dtype(range_dtype);
192 if c_type == "long long" {
193 code_lines.push(format!(" long long thread_id = vars[{thread_idx}];"));
194 } else {
195 code_lines.push(format!(" {c_type} thread_id = ({c_type})vars[{thread_idx}];"));
196 }
197
198 if let Op::Range { axis_id, .. } = thread_range.op() {
199 ctx.register(thread_range.id, "thread_id".to_string());
200 let _ = axis_id; }
202 }
203
204 code_lines.push("".to_string());
205
206 for node in &nodes {
208 if let Op::Reduce { reduce_op, ranges, .. } = node.op() {
209 if ranges.is_empty() {
210 continue;
211 }
212 let dtype = &node.dtype();
213 let c_type = c_dtype(dtype);
214 let identity = c_reduce_identity(*reduce_op, dtype);
215 let acc_name = format!("acc{}", node.id);
216 code_lines.push(format!(" {c_type} {acc_name} = {identity};"));
217 ctx.register(node.id, acc_name);
219 }
220 }
221
222 for node in &nodes {
224 match node.op() {
225 Op::Const(cv) => {
226 let val = c_const(&cv.0, &node.dtype());
227 ctx.register(node.id, val);
228 }
229 Op::VConst { values } => {
230 let val = c_vconst(values, &node.dtype());
231 ctx.register(node.id, val);
232 }
233 _ => {}
234 }
235 }
236
237 for node in &nodes {
239 if let Op::Range { axis_id, axis_type, .. } = node.op()
240 && !matches!(axis_type, AxisType::Thread)
241 {
242 let name = format!("ridx{}", axis_id.value());
243 ctx.register(node.id, name);
244 }
245 }
246
247 let mut kernel_body: Vec<String> = Vec::new();
249 for node in &nodes {
250 if let Op::Range { axis_type, .. } = node.op()
251 && matches!(axis_type, AxisType::Thread)
252 {
253 continue;
254 }
255 render_uop(node, &mut ctx, &mut kernel_body);
256 }
257
258 code_lines.extend(kernel_body);
259 code_lines.push("}".to_string());
260 code_lines.push("".to_string());
261
262 let code = code_lines.join("\n");
263
264 tracing::debug!(generated_c = code, "c codegen: final generated code");
265
266 let mut result = RenderedKernel::new(code, kernel_name.to_string());
267 result.buffer_args = buffer_args;
268 result.var_names = var_names;
269
270 if thread_count > 1 {
271 result.global_size = Some([thread_count, 1, 1]);
272 result.local_size = Some([1, 1, 1]);
273 }
274
275 Ok(result)
276 }
277
278 fn backend_name(&self) -> &str {
279 "clang"
280 }
281
282 fn decompositor(&self) -> Option<TypedPatternMatcher<()>> {
283 None
286 }
287}
288
289fn is_output_buffer(def_global: &Arc<UOp>, nodes: &[Arc<UOp>]) -> bool {
290 let buffer_id = def_global.id;
291
292 for node in nodes {
293 if let Some(buffer) = node.store_buffer() {
294 if buffer.id == buffer_id {
295 return true;
296 }
297 if let Op::Index { buffer: idx_buf, .. } = buffer.op()
298 && idx_buf.id == buffer_id
299 {
300 return true;
301 }
302 }
303 }
304 false
305}
306
307pub fn render(uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
309 let renderer = CRenderer::new();
310 crate::Renderer::render(&renderer, uop, name)
311}