1mod amx;
15pub mod ops;
16pub mod types;
17
18use std::sync::Arc;
19
20use morok_ir::pattern::TypedPatternMatcher;
21use morok_ir::rewrite::graph_rewrite_bottom_up;
22use morok_ir::{AxisType, Op, prelude::*};
23use morok_schedule::linearize::{line_rewrite_cleanups, linearize_with_cfg};
24use morok_schedule::rangeify::patterns::pm_bool_devectorize;
25
26use crate::{BufferArg, RenderedKernel, Result};
27
28use self::ops::{CContext, count_references, render_uop};
29use self::types::{c_const, c_dtype, c_reduce_identity, c_vconst, collect_vector_typedefs};
30
31pub struct CRenderer;
33
34impl CRenderer {
35 pub fn new() -> Self {
36 Self
37 }
38}
39
40impl Default for CRenderer {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl crate::Renderer for CRenderer {
47 fn render(&self, uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
48 let kernel_name = name.unwrap_or("kernel");
49
50 let uop = graph_rewrite_bottom_up(pm_bool_devectorize(), uop.clone(), &mut ());
52
53 tracing::debug!(ast_after_pm_bool_devectorize = %uop.tree(), "c codegen: after pm_bool_devectorize");
54
55 let nodes = linearize_with_cfg(uop);
57
58 let nodes = line_rewrite_cleanups(nodes);
60
61 for (i, node) in nodes.iter().enumerate() {
62 tracing::debug!(position = i, op = node.op().as_ref(), id = node.id, "c linearized node");
63 }
64
65 let mut buffers: Vec<Arc<UOp>> = Vec::new();
67 let mut variables: Vec<Arc<UOp>> = Vec::new();
68
69 for node in &nodes {
70 match node.op() {
71 Op::DefineGlobal(_) => buffers.push(node.clone()),
72 Op::DefineVar { .. } => variables.push(node.clone()),
73 _ => {}
74 }
75 }
76
77 buffers.sort_by_key(|b| if let Op::DefineGlobal(id) = b.op() { *id } else { usize::MAX });
78
79 let thread_info: Option<(Arc<UOp>, usize)> = nodes.iter().find_map(|n| {
81 if let Op::Range { axis_type, end, .. } = n.op()
82 && matches!(axis_type, AxisType::Thread)
83 && let Op::Const(cv) = end.op()
84 && let ConstValue::Int(count) = cv.0
85 {
86 return Some((n.clone(), count as usize));
87 }
88 None
89 });
90
91 let has_threading = thread_info.is_some();
92 let thread_count = thread_info.as_ref().map(|(_, c)| *c).unwrap_or(1);
93
94 let mut buffer_args: Vec<BufferArg> = Vec::new();
96 for (i, buf) in buffers.iter().enumerate() {
97 if let Op::DefineGlobal(id) = buf.op() {
98 let is_output = is_output_buffer(buf, &nodes);
99 buffer_args.push(BufferArg { index: *id, name: format!("data{i}"), dtype: buf.dtype(), is_output });
100 }
101 }
102
103 let mut var_names: Vec<String> = Vec::new();
105 for var in &variables {
106 if let Op::DefineVar { name, .. } = var.op() {
107 var_names.push(name.clone());
108 }
109 }
110 if has_threading {
111 var_names.push("thread_id".to_string());
112 }
113
114 let ref_counts = count_references(&nodes);
116 let mut ctx = CContext::new(ref_counts);
117
118 let mut code_lines: Vec<String> = Vec::new();
120
121 code_lines.push("#include <stdbool.h>".to_string());
123 code_lines.push("".to_string());
124
125 let typedefs = collect_vector_typedefs(&nodes);
127 for td in &typedefs {
128 code_lines.push(td.clone());
129 }
130 if !typedefs.is_empty() {
131 code_lines.push("".to_string());
132 }
133
134 let wmma_defines = amx::collect_wmma_defines(&nodes);
136 for def in &wmma_defines {
137 code_lines.push(def.clone());
138 }
139 if !wmma_defines.is_empty() {
140 code_lines.push("".to_string());
141 }
142
143 let mut params: Vec<String> = Vec::new();
145
146 for (i, buf) in buffers.iter().enumerate() {
148 let buf_dtype = buf.dtype();
149 let elem_type = match &buf_dtype {
150 DType::Ptr { base, .. } => c_dtype(base),
151 _ => c_dtype(&buf_dtype),
152 };
153 let name = format!("data{i}");
154 params.push(format!("{elem_type}* restrict {name}"));
155 ctx.register(buf.id, name);
156 }
157
158 for var in &variables {
160 if let Op::DefineVar { name, .. } = var.op() {
161 let var_dtype = &var.dtype();
162 let c_type = c_dtype(var_dtype);
163 params.push(format!("const {c_type} {name}"));
164 ctx.register(var.id, name.clone());
165 }
166 }
167
168 if let Some((thread_range, _)) = &thread_info {
170 let range_dtype = &thread_range.dtype();
171 let c_type = c_dtype(range_dtype);
172 params.push(format!("const {c_type} thread_id"));
173 ctx.register(thread_range.id, "thread_id".to_string());
174 }
175
176 code_lines.push(format!("void {kernel_name}({}) {{", params.join(", ")));
178
179 for node in &nodes {
181 if let Op::DefineLocal(id) = node.op() {
182 let (base, size) = match node.dtype() {
183 DType::Ptr { base, size, .. } => (c_dtype(&base), size.unwrap_or(1)),
184 other => (c_dtype(&other), 1),
185 };
186 let name = format!("local{id}");
187 code_lines.push(format!(" {base} {name}[{size}];"));
188 ctx.register(node.id, name);
189 }
190 }
191
192 code_lines.push("".to_string());
193
194 for node in &nodes {
196 if let Op::Reduce { reduce_op, ranges, .. } = node.op() {
197 if ranges.is_empty() {
198 continue;
199 }
200 let dtype = &node.dtype();
201 let c_type = c_dtype(dtype);
202 let identity = c_reduce_identity(*reduce_op, dtype);
203 let acc_name = format!("acc{}", node.id);
204 code_lines.push(format!(" {c_type} {acc_name} = {identity};"));
205 ctx.register(node.id, acc_name);
207 }
208 }
209
210 for node in &nodes {
212 match node.op() {
213 Op::Const(cv) => {
214 let val = c_const(&cv.0, &node.dtype());
215 ctx.register(node.id, val);
216 }
217 Op::VConst { values } => {
218 let val = c_vconst(values, &node.dtype());
219 ctx.register(node.id, val);
220 }
221 _ => {}
222 }
223 }
224
225 for node in &nodes {
227 if let Op::Range { axis_id, axis_type, .. } = node.op()
228 && !matches!(axis_type, AxisType::Thread)
229 {
230 let name = format!("ridx{}", axis_id.value());
231 ctx.register(node.id, name);
232 }
233 }
234
235 let mut kernel_body: Vec<String> = Vec::new();
237 for node in &nodes {
238 if let Op::Range { axis_type, .. } = node.op()
239 && matches!(axis_type, AxisType::Thread)
240 {
241 continue;
242 }
243 render_uop(node, &mut ctx, &mut kernel_body);
244 }
245
246 code_lines.extend(kernel_body);
247 code_lines.push("}".to_string());
248 code_lines.push("".to_string());
249
250 let code = code_lines.join("\n");
251
252 tracing::debug!(generated_c = code, "c codegen: final generated code");
253
254 let mut result = RenderedKernel::new(code, kernel_name.to_string());
255 result.buffer_args = buffer_args;
256 result.var_names = var_names;
257
258 if thread_count > 1 {
259 result.global_size = Some([thread_count, 1, 1]);
260 result.local_size = Some([1, 1, 1]);
261 }
262
263 Ok(result)
264 }
265
266 fn backend_name(&self) -> &str {
267 "clang"
268 }
269
270 fn decompositor(&self) -> Option<TypedPatternMatcher<()>> {
271 None
274 }
275}
276
277fn is_output_buffer(def_global: &Arc<UOp>, nodes: &[Arc<UOp>]) -> bool {
278 let buffer_id = def_global.id;
279
280 for node in nodes {
281 if let Some(buffer) = node.store_buffer() {
282 if buffer.id == buffer_id {
283 return true;
284 }
285 if let Op::Index { buffer: idx_buf, .. } = buffer.op()
286 && idx_buf.id == buffer_id
287 {
288 return true;
289 }
290 }
291 }
292 false
293}
294
295pub fn render(uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
297 let renderer = CRenderer::new();
298 crate::Renderer::render(&renderer, uop, name)
299}