use std::fmt::Write;
use super::graph::ExprGraph;
use super::node::{ExprId, Node};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum Dialect {
Wgsl,
Msl,
Cuda,
C,
}
pub struct ComputeKernel {
pub source: String,
pub n_inputs: usize,
pub n_outputs: usize,
pub workgroup_size: u32,
pub dialect: Dialect,
pub entry_point: &'static str,
}
impl ExprGraph {
pub fn to_kernel(&self, outputs: &[ExprId], n_inputs: usize, dialect: Dialect) -> ComputeKernel {
let workgroup_size = 256u32;
let n_outputs = outputs.len();
let live = self.live_set(outputs);
let max_id = if live.is_empty() {
0
} else {
*live.iter().max().unwrap()
};
let mut src = String::with_capacity(2048);
match dialect {
Dialect::Wgsl => emit_wgsl(&mut src, self, outputs, n_inputs, n_outputs, &live, max_id, workgroup_size),
Dialect::Msl => emit_msl(&mut src, self, outputs, n_inputs, n_outputs, &live, max_id),
Dialect::Cuda => emit_cuda(&mut src, self, outputs, n_inputs, n_outputs, &live, max_id),
Dialect::C => emit_c(&mut src, self, outputs, n_inputs, n_outputs, &live, max_id),
}
ComputeKernel {
source: src,
n_inputs,
n_outputs,
workgroup_size,
dialect,
entry_point: "k0",
}
}
}
fn emit_body(
src: &mut String,
graph: &ExprGraph,
outputs: &[ExprId],
n_inputs: usize,
n_outputs: usize,
live: &std::collections::HashSet<usize>,
max_id: usize,
indent: &str,
decl: &str,
lit_suffix: &str,
thread_id: &str,
dialect: Dialect,
) {
if n_inputs > 0 {
let base = format!("{thread_id} * {n_inputs}u");
let base_wgsl = format!("{thread_id} * {n_inputs}");
for i in 0..n_inputs {
match dialect {
Dialect::Wgsl => {
writeln!(src, "{indent}{decl} x{i} = inputs[{base} + {i}u];").unwrap();
}
Dialect::Msl => {
writeln!(src, "{indent}{decl} x{i} = inputs[{base_wgsl} + {i}];").unwrap();
}
Dialect::Cuda => {
writeln!(src, "{indent}{decl} x{i} = in{i}[{thread_id}];").unwrap();
}
Dialect::C => {
writeln!(src, "{indent}{decl} x{i} = inputs[i * {n_inputs} + {i}];").unwrap();
}
}
}
writeln!(src).unwrap();
}
for i in 0..=max_id {
if !live.contains(&i) {
continue;
}
let node = graph.node(ExprId(i as u32));
match node {
Node::Var(_) | Node::Lit(_) => continue,
_ => {}
}
let rhs = expr_str(graph, node, lit_suffix, dialect);
writeln!(src, "{indent}{decl} t{i} = {rhs};").unwrap();
}
writeln!(src).unwrap();
if n_outputs > 0 {
for (k, out) in outputs.iter().enumerate() {
let val = ref_str(graph, *out, lit_suffix);
match dialect {
Dialect::Wgsl => {
let base = format!("{thread_id} * {n_outputs}u");
writeln!(src, "{indent}outputs[{base} + {k}u] = {val};").unwrap();
}
Dialect::Msl => {
let base = format!("{thread_id} * {n_outputs}");
writeln!(src, "{indent}outputs[{base} + {k}] = {val};").unwrap();
}
Dialect::Cuda => {
let base = format!("{thread_id} * {n_outputs}");
writeln!(src, "{indent}outputs[{base} + {k}] = {val};").unwrap();
}
Dialect::C => {
writeln!(src, "{indent}outputs[i * {n_outputs} + {k}] = {val};").unwrap();
}
}
}
}
}
fn expr_str(graph: &ExprGraph, node: Node, suffix: &str, dialect: Dialect) -> String {
match node {
Node::Var(n) => format!("x{n}"),
Node::Lit(bits) => format_literal(f64::from_bits(bits), suffix),
Node::Add(a, b) => format!("({} + {})", ref_str(graph, a, suffix), ref_str(graph, b, suffix)),
Node::Mul(a, b) => format!("({} * {})", ref_str(graph, a, suffix), ref_str(graph, b, suffix)),
Node::Neg(a) => format!("(-{})", ref_str(graph, a, suffix)),
Node::Recip(a) => format!("(1.0{suffix} / {})", ref_str(graph, a, suffix)),
Node::Sqrt(a) => format!("sqrt({})", ref_str(graph, a, suffix)),
Node::Sin(a) => format!("sin({})", ref_str(graph, a, suffix)),
Node::Atan2(y, x) => format!("atan2({}, {})", ref_str(graph, y, suffix), ref_str(graph, x, suffix)),
Node::Exp2(a) => format!("exp2({})", ref_str(graph, a, suffix)),
Node::Log2(a) => format!("log2({})", ref_str(graph, a, suffix)),
Node::Select(c, a, b) => {
match dialect {
Dialect::Wgsl => {
format!(
"select({}, {}, {} > 0.0)",
ref_str(graph, b, suffix),
ref_str(graph, a, suffix),
ref_str(graph, c, suffix),
)
}
_ => {
format!(
"({} > 0.0{suffix} ? {} : {})",
ref_str(graph, c, suffix),
ref_str(graph, a, suffix),
ref_str(graph, b, suffix),
)
}
}
}
}
}
fn ref_str(graph: &ExprGraph, id: ExprId, suffix: &str) -> String {
match graph.node(id) {
Node::Var(n) => format!("x{n}"),
Node::Lit(bits) => format_literal(f64::from_bits(bits), suffix),
_ => format!("t{}", id.0),
}
}
fn format_literal(v: f64, suffix: &str) -> String {
let base = if v == 0.0 {
"0.0".to_string()
} else if v == 1.0 {
"1.0".to_string()
} else if v == -1.0 {
"-1.0".to_string()
} else if v == 2.0 {
"2.0".to_string()
} else {
let s = format!("{v}");
if s.contains('.') || s.contains('e') || s.contains('E') {
s
} else {
format!("{s}.0")
}
};
format!("{base}{suffix}")
}
fn emit_wgsl(
src: &mut String,
graph: &ExprGraph,
outputs: &[ExprId],
n_inputs: usize,
n_outputs: usize,
live: &std::collections::HashSet<usize>,
max_id: usize,
workgroup_size: u32,
) {
writeln!(src, "// Auto-generated by tang-expr").unwrap();
writeln!(src).unwrap();
writeln!(src, "struct Params {{").unwrap();
writeln!(src, " count: u32,").unwrap();
writeln!(src, " _pad1: u32,").unwrap();
writeln!(src, " _pad2: u32,").unwrap();
writeln!(src, " _pad3: u32,").unwrap();
writeln!(src, "}}").unwrap();
writeln!(src).unwrap();
writeln!(src, "@group(0) @binding(0) var<storage, read> inputs: array<f32>;").unwrap();
writeln!(src, "@group(0) @binding(1) var<storage, read_write> outputs: array<f32>;").unwrap();
writeln!(src, "@group(0) @binding(2) var<uniform> params: Params;").unwrap();
writeln!(src).unwrap();
writeln!(src, "@compute @workgroup_size({workgroup_size})").unwrap();
writeln!(src, "fn k0(@builtin(global_invocation_id) gid: vec3<u32>) {{").unwrap();
writeln!(src, " let idx = gid.x;").unwrap();
writeln!(src, " if (idx >= params.count) {{ return; }}").unwrap();
writeln!(src).unwrap();
emit_body(src, graph, outputs, n_inputs, n_outputs, live, max_id, " ", "let", "", "idx", Dialect::Wgsl);
writeln!(src, "}}").unwrap();
}
fn emit_msl(
src: &mut String,
graph: &ExprGraph,
outputs: &[ExprId],
n_inputs: usize,
n_outputs: usize,
live: &std::collections::HashSet<usize>,
max_id: usize,
) {
writeln!(src, "// Auto-generated by tang-expr").unwrap();
writeln!(src, "#include <metal_stdlib>").unwrap();
writeln!(src, "using namespace metal;").unwrap();
writeln!(src).unwrap();
write!(src, "kernel void k0(").unwrap();
writeln!(src, "device const float* inputs [[buffer(0)]],").unwrap();
writeln!(src, " device float* outputs [[buffer(1)]],").unwrap();
writeln!(src, " device const uint& count [[buffer(2)]],").unwrap();
writeln!(src, " uint gid [[thread_position_in_grid]]) {{").unwrap();
writeln!(src, " if (gid >= count) {{ return; }}").unwrap();
writeln!(src).unwrap();
emit_body(src, graph, outputs, n_inputs, n_outputs, live, max_id, " ", "float", "f", "gid", Dialect::Msl);
writeln!(src, "}}").unwrap();
}
fn emit_cuda(
src: &mut String,
graph: &ExprGraph,
outputs: &[ExprId],
n_inputs: usize,
n_outputs: usize,
live: &std::collections::HashSet<usize>,
max_id: usize,
) {
writeln!(src, "// Auto-generated by tang-expr").unwrap();
writeln!(src).unwrap();
write!(src, "extern \"C\" __global__ void k0(").unwrap();
for i in 0..n_inputs {
writeln!(src, "const float* __restrict__ in{i},").unwrap();
write!(src, " ").unwrap();
}
writeln!(src, "float* __restrict__ outputs,").unwrap();
writeln!(src, " const unsigned int count) {{").unwrap();
writeln!(src, " unsigned int gid = blockIdx.x * blockDim.x + threadIdx.x;").unwrap();
writeln!(src, " if (gid >= count) {{ return; }}").unwrap();
writeln!(src).unwrap();
emit_body(src, graph, outputs, n_inputs, n_outputs, live, max_id, " ", "float", "f", "gid", Dialect::Cuda);
writeln!(src, "}}").unwrap();
}
fn emit_c(
src: &mut String,
graph: &ExprGraph,
outputs: &[ExprId],
n_inputs: usize,
n_outputs: usize,
live: &std::collections::HashSet<usize>,
max_id: usize,
) {
writeln!(src, "// Auto-generated by tang-expr").unwrap();
writeln!(src, "#include <math.h>").unwrap();
writeln!(src).unwrap();
write!(src, "void k0(").unwrap();
writeln!(src, "const float* inputs,").unwrap();
writeln!(src, " float* outputs,").unwrap();
writeln!(src, " int count) {{").unwrap();
writeln!(src, " for (int i = 0; i < count; i++) {{").unwrap();
emit_body(src, graph, outputs, n_inputs, n_outputs, live, max_id, " ", "float", "f", "i", Dialect::C);
writeln!(src, " }}").unwrap();
writeln!(src, "}}").unwrap();
}
#[cfg(test)]
mod tests {
use super::*;
use super::graph::ExprGraph;
#[test]
fn wgsl_matches_original() {
let mut g = ExprGraph::new();
let x = g.var(0);
let y = g.var(1);
let xx = g.mul(x, x);
let yy = g.mul(y, y);
let sum = g.add(xx, yy);
let dist = g.sqrt(sum);
let old = g.to_wgsl(&[dist], 2);
let new = g.to_kernel(&[dist], 2, Dialect::Wgsl);
assert_eq!(new.n_inputs, old.n_inputs);
assert_eq!(new.n_outputs, old.n_outputs);
assert_eq!(new.workgroup_size, old.workgroup_size);
assert!(new.source.contains("@compute"));
assert!(new.source.contains("sqrt("));
}
#[test]
fn msl_entry_point() {
let mut g = ExprGraph::new();
let x = g.var(0);
let y = g.var(1);
let sum = g.add(x, y);
let kernel = g.to_kernel(&[sum], 2, Dialect::Msl);
assert!(kernel.source.contains("kernel void k0("));
assert!(kernel.source.contains("thread_position_in_grid"));
assert!(kernel.source.contains("#include <metal_stdlib>"));
assert_eq!(kernel.entry_point, "k0");
}
#[test]
fn cuda_entry_point() {
let mut g = ExprGraph::new();
let x = g.var(0);
let y = g.var(1);
let prod = g.mul(x, y);
let kernel = g.to_kernel(&[prod], 2, Dialect::Cuda);
assert!(kernel.source.contains("extern \"C\" __global__ void k0("));
assert!(kernel.source.contains("blockIdx.x * blockDim.x + threadIdx.x"));
assert!(kernel.source.contains("const float* __restrict__ in0,"));
assert!(kernel.source.contains("const float* __restrict__ in1,"));
assert!(!kernel.source.contains("inputs["));
assert!(kernel.source.contains("in0[gid]"));
assert!(kernel.source.contains("in1[gid]"));
}
#[test]
fn c_loop() {
let mut g = ExprGraph::new();
let x = g.var(0);
let s = g.sin(x);
let kernel = g.to_kernel(&[s], 1, Dialect::C);
assert!(kernel.source.contains("for (int i = 0; i < count; i++)"));
assert!(kernel.source.contains("sin("));
}
#[test]
fn msl_select_ternary() {
let mut g = ExprGraph::new();
let x = g.var(0);
let a = g.lit(3.0);
let b = g.lit(7.0);
let s = g.select(x, a, b);
let kernel = g.to_kernel(&[s], 1, Dialect::Msl);
assert!(kernel.source.contains("?"));
assert!(!kernel.source.contains("select("));
}
#[test]
fn wgsl_select_builtin() {
let mut g = ExprGraph::new();
let x = g.var(0);
let a = g.lit(3.0);
let b = g.lit(7.0);
let s = g.select(x, a, b);
let kernel = g.to_kernel(&[s], 1, Dialect::Wgsl);
assert!(kernel.source.contains("select("));
}
#[test]
fn msl_literal_suffix() {
let mut g = ExprGraph::new();
let x = g.var(0);
let c = g.lit(3.14);
let prod = g.mul(x, c);
let kernel = g.to_kernel(&[prod], 1, Dialect::Msl);
assert!(kernel.source.contains("3.14f"));
}
#[test]
fn multiple_outputs_all_dialects() {
let mut g = ExprGraph::new();
let x = g.var(0);
let y = g.var(1);
let sum = g.add(x, y);
let prod = g.mul(x, y);
for dialect in [Dialect::Wgsl, Dialect::Msl, Dialect::Cuda, Dialect::C] {
let kernel = g.to_kernel(&[sum, prod], 2, dialect);
assert_eq!(kernel.n_outputs, 2);
assert_eq!(kernel.n_inputs, 2);
}
}
#[test]
fn full_pipeline_all_dialects() {
let mut g = ExprGraph::new();
let x = g.var(0);
let xx = g.mul(x, x);
let dx = g.diff(xx, 0);
let dx = g.simplify(dx);
for dialect in [Dialect::Wgsl, Dialect::Msl, Dialect::Cuda, Dialect::C] {
let kernel = g.to_kernel(&[xx, dx], 1, dialect);
assert_eq!(kernel.n_outputs, 2);
assert!(!kernel.source.is_empty());
}
}
}