use std::collections::BTreeSet;
use std::sync::Arc;
use morok_dtype::ScalarDType;
use morok_ir::{Op, UOp, WmmaMetadata};
use super::types::c_dtype;
const AMX_FMA_Z_F32: u64 = 1 << 62;
const AMX_LOAD_PAIR_BIT: u64 = 1 << 62;
pub fn collect_wmma_defines(nodes: &[Arc<UOp>]) -> Vec<String> {
let mut seen = BTreeSet::new();
for node in nodes {
if let Op::Wmma { metadata, .. } = node.op() {
seen.insert(metadata.name.clone());
}
}
if seen.is_empty() {
return Vec::new();
}
let mut lines = Vec::new();
lines.push(r#"#define AMX_SET(imm5) __asm("nop\nnop\nnop\n.word (0x201000+(%0<<5)+%1)" : : "i"(17), "i"(imm5) : "memory")"#.to_string());
lines.push(r#"#define AMX(op, gpr, btf) __asm(".word (0x201000+(%0 << 5)+0%1-((0%1>>4)*6))" : : "i"(op), "r"((unsigned long long)(gpr)+(btf)) : "memory")"#.to_string());
for node in nodes {
if let Op::Wmma { metadata, .. } = node.op()
&& seen.remove(&metadata.name)
{
lines.push(render_amx_wmma_function(metadata));
}
}
lines
}
fn render_amx_wmma_function(metadata: &WmmaMetadata) -> String {
let (n, m, _k) = metadata.dims;
let (tile_y_count, tile_x_count) = metadata.tile_grid;
let use_tile_grid = tile_x_count > 1 || tile_y_count > 1;
let in_scalar = c_dtype(&metadata.dtype_in.scalar_dtype());
let out_type = format!("{}{}", in_scalar, n * m);
let a_type = format!("{}{}", in_scalar, n);
let b_type = format!("{}{}", in_scalar, m);
let bytes_per_elem = metadata.dtype_in.bytes();
let fma_op: u32 = match metadata.dtype_in.base() {
ScalarDType::Float64 => 10, ScalarDType::Float32 => 12, ScalarDType::Int16 => 14, ScalarDType::Float16 => 15, _ => 12,
};
let fma_flags: u64 =
if metadata.dtype_in.base() == ScalarDType::Float16 && metadata.dtype_out.base() == ScalarDType::Float32 {
AMX_FMA_Z_F32
} else {
0
};
let (ldx_encoding, ldy_encoding) = if use_tile_grid { (AMX_LOAD_PAIR_BIT, AMX_LOAD_PAIR_BIT) } else { (0, 0) };
let fma_calls = if use_tile_grid {
let bytes_per_tile_row: usize = 64;
let mut calls = Vec::new();
for ty in 0..tile_y_count {
for tx in 0..tile_x_count {
let z_row = (ty * tile_x_count + tx) as u64;
let x_off = (tx * bytes_per_tile_row) as u64;
let y_off = (ty * bytes_per_tile_row) as u64;
let encoding = fma_flags | (z_row << 20) | (x_off << 10) | y_off;
calls.push(format!(" AMX({fma_op}, 0, {encoding}ull);"));
}
}
calls.join("\n")
} else {
format!(" AMX({fma_op}, 0, {fma_flags}ull);")
};
format!(
"static {out_type} __{name}({a_type} data1, {b_type} data2, {out_type} data0){{\n \
AMX_SET(0);\n \
for(int ridx0 = 0; ridx0 < {n}; ridx0++){{ \
AMX(4, (int *)(&data0), 0ull<<62 | (ridx0*{bytes_per_elem}ull)<<56 | ridx0*64ull); }}\n \
AMX(0, (int *)(&data2), {ldx_encoding}ull); \
AMX(1, (int *)(&data1), {ldy_encoding}ull);\n\
{fma_calls}\n \
for(int ridx0 = 0; ridx0 < {n}; ridx0++){{ \
AMX(5, (int *)(&data0), 0ull<<62 | (ridx0*{bytes_per_elem}ull)<<56 | ridx0*64ull); }}\n \
AMX_SET(1);\n \
return data0;\n}}",
name = metadata.name,
)
}