1mod amx;
15pub mod ops;
16pub mod types;
17
18use std::collections::{HashMap, HashSet};
19use std::sync::Arc;
20
21use morok_ir::pattern::TypedPatternMatcher;
22use morok_ir::{AxisType, Op, prelude::*};
23use morok_schedule::linearize::{line_rewrite_cleanups, linearize_with_cfg};
24
25use crate::common::is_output_buffer;
26use crate::{BufferArg, RenderedKernel, Result};
27
28use self::ops::{CContext, count_references, render_uop};
29use self::types::{c_const, c_dtype, c_reduce_identity, c_vconst, collect_vector_typedefs};
30
31pub struct CRenderer;
33
34impl CRenderer {
35 pub fn new() -> Self {
36 Self
37 }
38}
39
40impl Default for CRenderer {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl crate::Renderer for CRenderer {
47 fn render(&self, uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
48 let kernel_name = name.unwrap_or("kernel");
49
50 let nodes = linearize_with_cfg(uop.clone());
51
52 let nodes = line_rewrite_cleanups(nodes);
54
55 for (i, node) in nodes.iter().enumerate() {
56 tracing::debug!(position = i, op = node.op().as_ref(), id = node.id, "c linearized node");
57 }
58
59 let mut buffers: Vec<Arc<UOp>> = Vec::new();
61 let mut variables: Vec<Arc<UOp>> = Vec::new();
62
63 for node in &nodes {
64 match node.op() {
65 Op::Param { device: None, .. } => buffers.push(node.clone()),
66 Op::DefineVar { .. } => variables.push(node.clone()),
67 _ => {}
68 }
69 }
70
71 buffers.sort_by_key(|b| if let Op::Param { slot, device: None, .. } = b.op() { *slot } else { usize::MAX });
72
73 let thread_info: Option<(Arc<UOp>, usize)> = nodes.iter().find_map(|n| {
75 if let Op::Range { axis_type, end, .. } = n.op()
76 && matches!(axis_type, AxisType::Thread)
77 && let Op::Const(cv) = end.op()
78 && let ConstValue::Int(count) = cv.0
79 {
80 return Some((n.clone(), count as usize));
81 }
82 None
83 });
84
85 let has_threading = thread_info.is_some();
86 let thread_count = thread_info.as_ref().map(|(_, c)| *c).unwrap_or(1);
87
88 let mut buffer_args: Vec<BufferArg> = Vec::new();
90 for (i, buf) in buffers.iter().enumerate() {
91 if let Op::Param { slot, device: None, .. } = buf.op() {
92 let is_output = is_output_buffer(buf, &nodes);
93 buffer_args.push(BufferArg { index: *slot, name: format!("data{i}"), dtype: buf.dtype(), is_output });
94 }
95 }
96
97 let mut var_names: Vec<String> = Vec::new();
99 for var in &variables {
100 if let Op::DefineVar { name, .. } = var.op() {
101 var_names.push(name.clone());
102 }
103 }
104 if has_threading {
105 var_names.push("thread_id".to_string());
106 }
107
108 let ref_counts = count_references(&nodes);
110 let scope_escaping = find_scope_escaping_vars(&nodes, &ref_counts);
111 let mut ctx = CContext::new(ref_counts, scope_escaping);
112
113 let mut code_lines: Vec<String> = Vec::new();
115
116 code_lines.push("#include <stdbool.h>".to_string());
118 code_lines.push("".to_string());
119
120 let typedefs = collect_vector_typedefs(&nodes);
122 for td in &typedefs {
123 code_lines.push(td.clone());
124 }
125 if !typedefs.is_empty() {
126 code_lines.push("".to_string());
127 }
128
129 let wmma_defines = amx::collect_wmma_defines(&nodes);
131 for def in &wmma_defines {
132 code_lines.push(def.clone());
133 }
134 if !wmma_defines.is_empty() {
135 code_lines.push("".to_string());
136 }
137
138 let mut params: Vec<String> = Vec::new();
140
141 for (i, buf) in buffers.iter().enumerate() {
143 let buf_dtype = buf.dtype();
144 let elem_type = match &buf_dtype {
145 DType::Ptr { base, .. } => c_dtype(base),
146 _ => c_dtype(&buf_dtype),
147 };
148 let name = format!("data{i}");
149 params.push(format!("{elem_type}* restrict {name}"));
150 ctx.register(buf.id, name);
151 }
152
153 for var in &variables {
155 if let Op::DefineVar { name, .. } = var.op() {
156 let var_dtype = &var.dtype();
157 let c_type = c_dtype(var_dtype);
158 params.push(format!("const {c_type} {name}"));
159 ctx.register(var.id, name.clone());
160 }
161 }
162
163 if let Some((thread_range, _)) = &thread_info {
165 let range_dtype = &thread_range.dtype();
166 let c_type = c_dtype(range_dtype);
167 params.push(format!("const {c_type} thread_id"));
168 ctx.register(thread_range.id, "thread_id".to_string());
169 }
170
171 code_lines.push(format!("void {kernel_name}({}) {{", params.join(", ")));
173
174 for node in &nodes {
176 if let Op::DefineLocal(id) = node.op() {
177 let (base, size) = match node.dtype() {
178 DType::Ptr { base, size, .. } => (c_dtype(&base), size.unwrap_or(1)),
179 other => (c_dtype(&other), 1),
180 };
181 let name = format!("local{id}");
182 code_lines.push(format!(" {base} {name}[{size}];"));
183 ctx.register(node.id, name);
184 }
185 }
186
187 code_lines.push("".to_string());
188
189 for node in &nodes {
191 if let Op::Reduce { reduce_op, ranges, .. } = node.op() {
192 if ranges.is_empty() {
193 continue;
194 }
195 let dtype = &node.dtype();
196 let c_type = c_dtype(dtype);
197 let identity = c_reduce_identity(*reduce_op, dtype);
198 let acc_name = format!("acc{}", node.id);
199 code_lines.push(format!(" {c_type} {acc_name} = {identity};"));
200 ctx.register(node.id, acc_name);
202 }
203 }
204
205 for node in &nodes {
207 match node.op() {
208 Op::Const(cv) => {
209 let val = c_const(&cv.0, &node.dtype());
210 ctx.register(node.id, val);
211 }
212 Op::VConst { values } => {
213 let val = c_vconst(values, &node.dtype());
214 ctx.register(node.id, val);
215 }
216 _ => {}
217 }
218 }
219
220 for node in &nodes {
222 if let Op::Range { axis_id, axis_type, .. } = node.op()
223 && !matches!(axis_type, AxisType::Thread)
224 {
225 let name = format!("ridx{}", axis_id.value());
226 ctx.register(node.id, name);
227 }
228 }
229
230 let mut kernel_body: Vec<String> = Vec::new();
233 for node in &nodes {
234 if matches!(node.op(), Op::Noop | Op::Group { .. }) {
235 ctx.register(node.id, String::new());
238 continue;
239 }
240 if let Op::Range { axis_type, .. } = node.op()
241 && matches!(axis_type, AxisType::Thread)
242 {
243 continue;
244 }
245 render_uop(node, &mut ctx, &mut kernel_body);
246 }
247
248 if !ctx.hoisted_declarations.is_empty() {
250 code_lines.append(&mut ctx.hoisted_declarations);
251 }
252 code_lines.extend(kernel_body);
253 code_lines.push("}".to_string());
254 code_lines.push("".to_string());
255
256 let code = code_lines.join("\n");
257
258 tracing::debug!(generated_c = code, "c codegen: final generated code");
259
260 let mut result = RenderedKernel::new(code, kernel_name.to_string());
261 result.buffer_args = buffer_args;
262 result.var_names = var_names;
263
264 if thread_count > 1 {
265 result.global_size = Some([thread_count, 1, 1]);
266 result.local_size = Some([1, 1, 1]);
267 }
268
269 Ok(result)
270 }
271
272 fn backend_name(&self) -> &str {
273 "clang"
274 }
275
276 fn decompositor(&self) -> Option<TypedPatternMatcher<()>> {
277 None
280 }
281}
282
283fn find_scope_escaping_vars(nodes: &[Arc<UOp>], ref_counts: &HashMap<u64, usize>) -> HashSet<u64> {
292 let mut depth = 0usize;
293 let mut def_depth: HashMap<u64, usize> = HashMap::new();
294 let mut min_use_depth: HashMap<u64, usize> = HashMap::new();
295
296 for node in nodes {
297 match node.op() {
299 Op::Range { .. } | Op::If { .. } => {
300 if ref_counts.get(&node.id).copied().unwrap_or(0) > 1 {
302 def_depth.entry(node.id).or_insert(depth);
303 }
304 for src in node.op().sources() {
306 min_use_depth.entry(src.id).and_modify(|d| *d = (*d).min(depth)).or_insert(depth);
307 }
308 depth += 1;
309 continue;
310 }
311 Op::End { .. } | Op::EndIf { .. } => {
312 depth = depth.saturating_sub(1);
313 }
314 _ => {}
315 }
316
317 if ref_counts.get(&node.id).copied().unwrap_or(0) > 1 {
319 def_depth.entry(node.id).or_insert(depth);
320 }
321
322 for src in node.op().sources() {
324 min_use_depth.entry(src.id).and_modify(|d| *d = (*d).min(depth)).or_insert(depth);
325 }
326 }
327
328 def_depth
330 .into_iter()
331 .filter(|(id, def_d)| min_use_depth.get(id).copied().unwrap_or(*def_d) < *def_d)
332 .map(|(id, _)| id)
333 .collect()
334}
335
336pub fn render(uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
338 let renderer = CRenderer::new();
339 crate::Renderer::render(&renderer, uop, name)
340}