1mod amx;
15pub mod ops;
16pub mod types;
17
18use std::collections::{HashMap, HashSet};
19use std::sync::Arc;
20
21use svod_ir::pattern::TypedPatternMatcher;
22use svod_ir::{Op, prelude::*};
23
24use crate::common::{is_output_buffer, validate_custom_template_strict};
25use crate::{BufferArg, Error, RenderedKernel, Result};
26
27use self::ops::{CContext, count_references, render_uop};
28use self::types::{c_const, c_dtype, c_reduce_identity, c_vconst, collect_vector_typedefs};
29
30pub struct CRenderer;
32
33impl CRenderer {
34 pub fn new() -> Self {
35 Self
36 }
37}
38
39impl Default for CRenderer {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl crate::Renderer for CRenderer {
46 fn render(&self, uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
47 let kernel_name = name.unwrap_or("kernel");
48
49 let nodes: Vec<Arc<UOp>> = match uop.op() {
50 Op::Linear { ops } => ops.iter().cloned().collect(),
51 other => {
52 return Err(Error::InvalidGraph { reason: format!("C renderer expects LINEAR input, got {other:?}") });
53 }
54 };
55
56 for (i, node) in nodes.iter().enumerate() {
57 tracing::debug!(position = i, op = node.op().as_ref(), id = node.id, "c linearized node");
58 match node.op() {
59 Op::Custom { deps, code } | Op::CustomI { deps, code } => {
60 validate_custom_template_strict(code, deps.len())?;
61 }
62 _ => {}
63 }
64 }
65
66 let mut buffers: Vec<Arc<UOp>> = Vec::new();
68 let mut variables: Vec<Arc<UOp>> = Vec::new();
69
70 for node in &nodes {
71 match node.op() {
72 Op::Param { device: None, .. } => buffers.push(node.clone()),
73 Op::DefineVar { .. } => variables.push(node.clone()),
74 _ => {}
75 }
76 }
77
78 buffers.sort_by_key(|b| if let Op::Param { slot, device: None, .. } = b.op() { *slot } else { usize::MAX });
79
80 let mut buffer_args: Vec<BufferArg> = Vec::new();
82 for (i, buf) in buffers.iter().enumerate() {
83 if let Op::Param { slot, device: None, .. } = buf.op() {
84 let is_output = is_output_buffer(buf, &nodes);
85 buffer_args.push(BufferArg { index: *slot, name: format!("data{i}"), dtype: buf.dtype(), is_output });
86 }
87 }
88
89 let mut var_names: Vec<String> = Vec::new();
91 for var in &variables {
92 if let Op::DefineVar { name, .. } = var.op() {
93 var_names.push(name.clone());
94 }
95 }
96 let ref_counts = count_references(&nodes);
98 let scope_escaping = find_scope_escaping_vars(&nodes, &ref_counts);
99 let mut ctx = CContext::new(ref_counts, scope_escaping);
100
101 let mut code_lines: Vec<String> = Vec::new();
103
104 code_lines.push("#include <stdbool.h>".to_string());
106 code_lines.push("".to_string());
107
108 let typedefs = collect_vector_typedefs(&nodes);
110 for td in &typedefs {
111 code_lines.push(td.clone());
112 }
113 if !typedefs.is_empty() {
114 code_lines.push("".to_string());
115 }
116
117 let wmma_defines = amx::collect_wmma_defines(&nodes);
119 for def in &wmma_defines {
120 code_lines.push(def.clone());
121 }
122 if !wmma_defines.is_empty() {
123 code_lines.push("".to_string());
124 }
125
126 let mut params: Vec<String> = Vec::new();
128
129 for (i, buf) in buffers.iter().enumerate() {
131 let buf_dtype = buf.dtype();
132 let elem_type = match &buf_dtype {
133 DType::Ptr { base, .. } => c_dtype(base),
134 _ => c_dtype(&buf_dtype),
135 };
136 let name = format!("data{i}");
137 params.push(format!("{elem_type}* restrict {name}"));
138 ctx.register(buf.id, name);
139 }
140
141 for var in &variables {
143 if let Op::DefineVar { name, .. } = var.op() {
144 let var_dtype = &var.dtype();
145 let c_type = c_dtype(var_dtype);
146 params.push(format!("const {c_type} {name}"));
147 ctx.register(var.id, name.clone());
148 }
149 }
150
151 code_lines.push(format!("void {kernel_name}({}) {{", params.join(", ")));
153
154 for node in &nodes {
156 if let Op::DefineLocal(id) = node.op() {
157 let (base, size) = match node.dtype() {
158 DType::Ptr { base, size, .. } => (c_dtype(&base), size.unwrap_or(1)),
159 other => (c_dtype(&other), 1),
160 };
161 let name = format!("local{id}");
162 code_lines.push(format!(" {base} {name}[{size}];"));
163 ctx.register(node.id, name);
164 }
165 }
166
167 code_lines.push("".to_string());
168
169 for node in &nodes {
171 if let Op::Reduce { reduce_op, ranges, .. } = node.op() {
172 if ranges.is_empty() {
173 continue;
174 }
175 let dtype = &node.dtype();
176 let c_type = c_dtype(dtype);
177 let identity = c_reduce_identity(*reduce_op, dtype);
178 let acc_name = format!("acc{}", node.id);
179 code_lines.push(format!(" {c_type} {acc_name} = {identity};"));
180 ctx.register(node.id, acc_name);
182 }
183 }
184
185 for node in &nodes {
187 match node.op() {
188 Op::Const(cv) => {
189 let val = c_const(&cv.0, &node.dtype());
190 ctx.register(node.id, val);
191 }
192 Op::VConst { values } => {
193 let val = c_vconst(values, &node.dtype());
194 ctx.register(node.id, val);
195 }
196 _ => {}
197 }
198 }
199
200 for node in &nodes {
202 if let Op::Range { axis_id, .. } = node.op() {
203 let name = format!("ridx{}", axis_id.value());
204 ctx.register(node.id, name);
205 }
206 }
207
208 let mut kernel_body: Vec<String> = Vec::new();
211 for node in &nodes {
212 if matches!(node.op(), Op::Noop | Op::Group { .. }) {
213 ctx.register(node.id, String::new());
216 continue;
217 }
218 render_uop(node, &mut ctx, &mut kernel_body);
219 if let Some(err) = ctx.take_error() {
220 return Err(err);
221 }
222 }
223
224 if !ctx.hoisted_declarations.is_empty() {
226 code_lines.append(&mut ctx.hoisted_declarations);
227 }
228 code_lines.extend(kernel_body);
229 code_lines.push("}".to_string());
230 code_lines.push("".to_string());
231
232 let code = code_lines.join("\n");
233
234 tracing::debug!(generated_c = code, "c codegen: final generated code");
235
236 let mut result = RenderedKernel::new(code, kernel_name.to_string());
237 result.buffer_args = buffer_args;
238 result.var_names = var_names;
239
240 Ok(result)
241 }
242
243 fn backend_name(&self) -> &str {
244 "clang"
245 }
246
247 fn decompositor(&self) -> Option<TypedPatternMatcher<()>> {
248 None
251 }
252}
253
254fn find_scope_escaping_vars(nodes: &[Arc<UOp>], ref_counts: &HashMap<u64, usize>) -> HashSet<u64> {
263 let mut depth = 0usize;
264 let mut def_depth: HashMap<u64, usize> = HashMap::new();
265 let mut min_use_depth: HashMap<u64, usize> = HashMap::new();
266
267 for node in nodes {
268 match node.op() {
270 Op::Range { .. } | Op::If { .. } => {
271 if ref_counts.get(&node.id).copied().unwrap_or(0) > 1 {
273 def_depth.entry(node.id).or_insert(depth);
274 }
275 for src in node.op().sources() {
277 min_use_depth.entry(src.id).and_modify(|d| *d = (*d).min(depth)).or_insert(depth);
278 }
279 depth += 1;
280 continue;
281 }
282 Op::End { .. } | Op::EndIf { .. } => {
283 depth = depth.saturating_sub(1);
284 }
285 _ => {}
286 }
287
288 if ref_counts.get(&node.id).copied().unwrap_or(0) > 1 {
290 def_depth.entry(node.id).or_insert(depth);
291 }
292
293 for src in node.op().sources() {
295 min_use_depth.entry(src.id).and_modify(|d| *d = (*d).min(depth)).or_insert(depth);
296 }
297 }
298
299 def_depth
301 .into_iter()
302 .filter(|(id, def_d)| min_use_depth.get(id).copied().unwrap_or(*def_d) < *def_d)
303 .map(|(id, _)| id)
304 .collect()
305}
306
307pub fn render(uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
309 let renderer = CRenderer::new();
310 crate::Renderer::render(&renderer, uop, name)
311}