use anyhow::Result;
use ug::lang::ssa;
struct V(ssa::VarId);
impl std::fmt::Display for V {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "__var{}", self.0.as_usize())
}
}
struct C(ssa::Const);
impl std::fmt::Display for C {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use std::num::FpCategory;
match self.0 {
ssa::Const::I32(v) => write!(f, "{v}"),
ssa::Const::F32(v) => match v.classify() {
FpCategory::Nan => write!(f, "0. / 0."),
FpCategory::Infinite if v > 0. => write!(f, "1. / 0."),
FpCategory::Infinite => write!(f, "-1. / 0."),
FpCategory::Zero | FpCategory::Normal | FpCategory::Subnormal => {
write!(f, "{v:?}")
}
},
}
}
}
struct A(ssa::A);
impl std::fmt::Display for A {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.0 {
ssa::A::Var(v) => V(v).fmt(f),
ssa::A::Const(c) => C(c).fmt(f),
}
}
}
struct D(ssa::DType);
impl std::fmt::Display for D {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let dtype = match self.0 {
ssa::DType::F32 => "float",
ssa::DType::I32 => "int",
ssa::DType::PtrF32 => "float*",
ssa::DType::PtrI32 => "int*",
};
f.write_str(dtype)
}
}
pub fn gen<W: std::io::Write>(w: &mut W, func_name: &str, kernel: &ssa::Kernel) -> Result<()> {
let mut args = std::collections::HashMap::new();
for (idx, instr) in kernel.instrs.iter().enumerate() {
if let ssa::Instr::DefineGlobal { index, dtype } = instr {
args.insert(idx, (*index, *dtype));
}
}
let mut args = args.into_iter().collect::<Vec<_>>();
args.sort_by_key(|v| v.0);
writeln!(w, "extern \"C\" __global__ void {func_name}(")?;
for (arg_idx2, &(var_id, (arg_idx, dtype))) in args.iter().enumerate() {
if arg_idx != arg_idx2 {
anyhow::bail!("unexpected arguments in kernel {args:?}")
}
let is_last = arg_idx == args.len() - 1;
let delim = if is_last { "" } else { "," };
writeln!(w, " {} {}{delim}", D(dtype), V(ssa::VarId::new(var_id)))?;
}
writeln!(w, ") {{")?;
let mut depth = 0;
for (var_id, instr) in kernel.instrs.iter().enumerate() {
use ssa::Instr as I;
let var_id = V(ssa::VarId::new(var_id));
let indent = " ".repeat(2 * depth + 2);
match instr {
I::DefineGlobal { index: _, dtype: _ } => {}
I::DefineLocal { dtype, size } => match dtype {
ssa::DType::PtrI32 => writeln!(w, "{indent}__shared__ int {var_id}[{size}];")?,
ssa::DType::PtrF32 => writeln!(w, "{indent}__shared__ float {var_id}[{size}];")?,
ssa::DType::F32 | ssa::DType::I32 => {
anyhow::bail!("unsupported dtype for DefineLocal {dtype:?}")
}
},
I::DefineAcc(cst) | I::Const(cst) => match cst {
ssa::Const::I32(v) => writeln!(w, "{indent}int {var_id} = {v};")?,
ssa::Const::F32(v) => {
use std::num::FpCategory;
let v = match v.classify() {
FpCategory::Nan => "0. / 0.".to_string(),
FpCategory::Infinite if *v > 0. => "1. / 0.".to_string(),
FpCategory::Infinite => "-1. / 0.".to_string(),
FpCategory::Zero | FpCategory::Normal | FpCategory::Subnormal => {
format!("{v:?}")
}
};
writeln!(w, "{indent}float {var_id} = {v};")?
}
},
I::Range { lo, up, end_idx: _ } => {
writeln!(
w,
"{indent}for (int {var_id} = {}; {var_id} < {}; ++{var_id}) {{",
A(*lo),
A(*up)
)?;
depth += 1;
}
I::EndRange { start_idx: _ } => {
if depth == 0 {
anyhow::bail!("unmatched EndRange")
}
depth -= 1;
let indent = " ".repeat(2 * depth + 2);
writeln!(w, "{indent}}}")?;
}
I::Load { src, offset, dtype } => {
writeln!(w, "{indent}{} {var_id} = {}[{}];", D(*dtype), V(*src), A(*offset))?;
}
I::Assign { dst, src } => {
writeln!(w, "{indent}{} = {};", V(*dst), A(*src))?;
}
I::Store { dst, offset, value, dtype: _ } => {
writeln!(w, "{indent}{}[{}] = {};", V(*dst), A(*offset), A(*value))?;
}
I::Binary { op, lhs, rhs, dtype } => {
let op = match op {
ssa::BinaryOp::Add => format!("{} + {}", A(*lhs), A(*rhs)),
ssa::BinaryOp::Mul => format!("{} * {}", A(*lhs), A(*rhs)),
ssa::BinaryOp::Sub => format!("{} - {}", A(*lhs), A(*rhs)),
ssa::BinaryOp::Div => format!("{} / {}", A(*lhs), A(*rhs)),
ssa::BinaryOp::Min => format!("min({}, {})", A(*lhs), A(*rhs)),
ssa::BinaryOp::Max => format!("max({}, {})", A(*lhs), A(*rhs)),
};
writeln!(w, "{indent}{} {var_id} = {op};", D(*dtype),)?;
}
I::Unary { op, arg, dtype } => {
let op = match op {
ssa::UnaryOp::Exp => "exp",
ssa::UnaryOp::Neg => "neg",
};
writeln!(w, "{indent}{} {var_id} = {op}({});", D(*dtype), A(*arg))?;
}
I::Special(ssa::Special::LocalIdx) => {
writeln!(w, "{indent}int {var_id} = threadIdx.x;")?
}
I::Special(ssa::Special::GridIdx) => writeln!(w, "{indent}int {var_id} = blockIdx.x;")?,
I::Barrier => writeln!(w, "{indent}__synchthreads();")?,
}
}
writeln!(w, "}}")?;
if depth > 0 {
anyhow::bail!("unmatched Range")
}
Ok(())
}