mod amx;
pub mod ops;
pub mod types;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use morok_ir::pattern::TypedPatternMatcher;
use morok_ir::{AxisType, Op, prelude::*};
use morok_schedule::linearize::{line_rewrite_cleanups, linearize_with_cfg};
use crate::common::is_output_buffer;
use crate::{BufferArg, RenderedKernel, Result};
use self::ops::{CContext, count_references, render_uop};
use self::types::{c_const, c_dtype, c_reduce_identity, c_vconst, collect_vector_typedefs};
pub struct CRenderer;
impl CRenderer {
pub fn new() -> Self {
Self
}
}
impl Default for CRenderer {
fn default() -> Self {
Self::new()
}
}
impl crate::Renderer for CRenderer {
fn render(&self, uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
let kernel_name = name.unwrap_or("kernel");
let nodes = linearize_with_cfg(uop.clone());
let nodes = line_rewrite_cleanups(nodes);
for (i, node) in nodes.iter().enumerate() {
tracing::debug!(position = i, op = node.op().as_ref(), id = node.id, "c linearized node");
}
let mut buffers: Vec<Arc<UOp>> = Vec::new();
let mut variables: Vec<Arc<UOp>> = Vec::new();
for node in &nodes {
match node.op() {
Op::Param { device: None, .. } => buffers.push(node.clone()),
Op::DefineVar { .. } => variables.push(node.clone()),
_ => {}
}
}
buffers.sort_by_key(|b| if let Op::Param { slot, device: None, .. } = b.op() { *slot } else { usize::MAX });
let thread_info: Option<(Arc<UOp>, usize)> = nodes.iter().find_map(|n| {
if let Op::Range { axis_type, end, .. } = n.op()
&& matches!(axis_type, AxisType::Thread)
&& let Op::Const(cv) = end.op()
&& let ConstValue::Int(count) = cv.0
{
return Some((n.clone(), count as usize));
}
None
});
let has_threading = thread_info.is_some();
let thread_count = thread_info.as_ref().map(|(_, c)| *c).unwrap_or(1);
let mut buffer_args: Vec<BufferArg> = Vec::new();
for (i, buf) in buffers.iter().enumerate() {
if let Op::Param { slot, device: None, .. } = buf.op() {
let is_output = is_output_buffer(buf, &nodes);
buffer_args.push(BufferArg { index: *slot, name: format!("data{i}"), dtype: buf.dtype(), is_output });
}
}
let mut var_names: Vec<String> = Vec::new();
for var in &variables {
if let Op::DefineVar { name, .. } = var.op() {
var_names.push(name.clone());
}
}
if has_threading {
var_names.push("thread_id".to_string());
}
let ref_counts = count_references(&nodes);
let scope_escaping = find_scope_escaping_vars(&nodes, &ref_counts);
let mut ctx = CContext::new(ref_counts, scope_escaping);
let mut code_lines: Vec<String> = Vec::new();
code_lines.push("#include <stdbool.h>".to_string());
code_lines.push("".to_string());
let typedefs = collect_vector_typedefs(&nodes);
for td in &typedefs {
code_lines.push(td.clone());
}
if !typedefs.is_empty() {
code_lines.push("".to_string());
}
let wmma_defines = amx::collect_wmma_defines(&nodes);
for def in &wmma_defines {
code_lines.push(def.clone());
}
if !wmma_defines.is_empty() {
code_lines.push("".to_string());
}
let mut params: Vec<String> = Vec::new();
for (i, buf) in buffers.iter().enumerate() {
let buf_dtype = buf.dtype();
let elem_type = match &buf_dtype {
DType::Ptr { base, .. } => c_dtype(base),
_ => c_dtype(&buf_dtype),
};
let name = format!("data{i}");
params.push(format!("{elem_type}* restrict {name}"));
ctx.register(buf.id, name);
}
for var in &variables {
if let Op::DefineVar { name, .. } = var.op() {
let var_dtype = &var.dtype();
let c_type = c_dtype(var_dtype);
params.push(format!("const {c_type} {name}"));
ctx.register(var.id, name.clone());
}
}
if let Some((thread_range, _)) = &thread_info {
let range_dtype = &thread_range.dtype();
let c_type = c_dtype(range_dtype);
params.push(format!("const {c_type} thread_id"));
ctx.register(thread_range.id, "thread_id".to_string());
}
code_lines.push(format!("void {kernel_name}({}) {{", params.join(", ")));
for node in &nodes {
if let Op::DefineLocal(id) = node.op() {
let (base, size) = match node.dtype() {
DType::Ptr { base, size, .. } => (c_dtype(&base), size.unwrap_or(1)),
other => (c_dtype(&other), 1),
};
let name = format!("local{id}");
code_lines.push(format!(" {base} {name}[{size}];"));
ctx.register(node.id, name);
}
}
code_lines.push("".to_string());
for node in &nodes {
if let Op::Reduce { reduce_op, ranges, .. } = node.op() {
if ranges.is_empty() {
continue;
}
let dtype = &node.dtype();
let c_type = c_dtype(dtype);
let identity = c_reduce_identity(*reduce_op, dtype);
let acc_name = format!("acc{}", node.id);
code_lines.push(format!(" {c_type} {acc_name} = {identity};"));
ctx.register(node.id, acc_name);
}
}
for node in &nodes {
match node.op() {
Op::Const(cv) => {
let val = c_const(&cv.0, &node.dtype());
ctx.register(node.id, val);
}
Op::VConst { values } => {
let val = c_vconst(values, &node.dtype());
ctx.register(node.id, val);
}
_ => {}
}
}
for node in &nodes {
if let Op::Range { axis_id, axis_type, .. } = node.op()
&& !matches!(axis_type, AxisType::Thread)
{
let name = format!("ridx{}", axis_id.value());
ctx.register(node.id, name);
}
}
let mut kernel_body: Vec<String> = Vec::new();
for node in &nodes {
if matches!(node.op(), Op::Noop | Op::Group { .. }) {
ctx.register(node.id, String::new());
continue;
}
if let Op::Range { axis_type, .. } = node.op()
&& matches!(axis_type, AxisType::Thread)
{
continue;
}
render_uop(node, &mut ctx, &mut kernel_body);
}
if !ctx.hoisted_declarations.is_empty() {
code_lines.append(&mut ctx.hoisted_declarations);
}
code_lines.extend(kernel_body);
code_lines.push("}".to_string());
code_lines.push("".to_string());
let code = code_lines.join("\n");
tracing::debug!(generated_c = code, "c codegen: final generated code");
let mut result = RenderedKernel::new(code, kernel_name.to_string());
result.buffer_args = buffer_args;
result.var_names = var_names;
if thread_count > 1 {
result.global_size = Some([thread_count, 1, 1]);
result.local_size = Some([1, 1, 1]);
}
Ok(result)
}
fn backend_name(&self) -> &str {
"clang"
}
fn decompositor(&self) -> Option<TypedPatternMatcher<()>> {
None
}
}
fn find_scope_escaping_vars(nodes: &[Arc<UOp>], ref_counts: &HashMap<u64, usize>) -> HashSet<u64> {
let mut depth = 0usize;
let mut def_depth: HashMap<u64, usize> = HashMap::new();
let mut min_use_depth: HashMap<u64, usize> = HashMap::new();
for node in nodes {
match node.op() {
Op::Range { .. } | Op::If { .. } => {
if ref_counts.get(&node.id).copied().unwrap_or(0) > 1 {
def_depth.entry(node.id).or_insert(depth);
}
for src in node.op().sources() {
min_use_depth.entry(src.id).and_modify(|d| *d = (*d).min(depth)).or_insert(depth);
}
depth += 1;
continue;
}
Op::End { .. } | Op::EndIf { .. } => {
depth = depth.saturating_sub(1);
}
_ => {}
}
if ref_counts.get(&node.id).copied().unwrap_or(0) > 1 {
def_depth.entry(node.id).or_insert(depth);
}
for src in node.op().sources() {
min_use_depth.entry(src.id).and_modify(|d| *d = (*d).min(depth)).or_insert(depth);
}
}
def_depth
.into_iter()
.filter(|(id, def_d)| min_use_depth.get(id).copied().unwrap_or(*def_d) < *def_d)
.map(|(id, _)| id)
.collect()
}
pub fn render(uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel> {
let renderer = CRenderer::new();
crate::Renderer::render(&renderer, uop, name)
}