use std::fmt::Write;
use super::graph::ExprGraph;
use super::node::{ExprId, Node};
pub struct WgslKernel {
pub source: String,
pub n_inputs: usize,
pub n_outputs: usize,
pub workgroup_size: u32,
}
impl ExprGraph {
pub fn to_wgsl(&self, outputs: &[ExprId], n_inputs: usize) -> WgslKernel {
let workgroup_size = 256u32;
let n_outputs = outputs.len();
let live = self.live_set(outputs);
let max_id = if live.is_empty() {
0
} else {
*live.iter().max().unwrap()
};
let mut src = String::with_capacity(2048);
writeln!(src, "// Auto-generated by tang-expr").unwrap();
writeln!(src).unwrap();
writeln!(src, "struct Params {{").unwrap();
writeln!(src, " count: u32,").unwrap();
writeln!(src, " _pad1: u32,").unwrap();
writeln!(src, " _pad2: u32,").unwrap();
writeln!(src, " _pad3: u32,").unwrap();
writeln!(src, "}}").unwrap();
writeln!(src).unwrap();
writeln!(
src,
"@group(0) @binding(0) var<storage, read> inputs: array<f32>;"
)
.unwrap();
writeln!(
src,
"@group(0) @binding(1) var<storage, read_write> outputs: array<f32>;"
)
.unwrap();
writeln!(src, "@group(0) @binding(2) var<uniform> params: Params;").unwrap();
writeln!(src).unwrap();
writeln!(src, "@compute @workgroup_size({workgroup_size})").unwrap();
writeln!(
src,
"fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{"
)
.unwrap();
writeln!(src, " let idx = gid.x;").unwrap();
writeln!(src, " if (idx >= params.count) {{ return; }}").unwrap();
writeln!(src).unwrap();
if n_inputs > 0 {
writeln!(src, " let base_in = idx * {n_inputs}u;").unwrap();
for i in 0..n_inputs {
writeln!(src, " let x{i} = inputs[base_in + {i}u];").unwrap();
}
writeln!(src).unwrap();
}
for i in 0..=max_id {
if !live.contains(&i) {
continue;
}
let node = self.node(ExprId(i as u32));
match node {
Node::Var(_) | Node::Lit(_) => continue,
_ => {}
}
let rhs = self.wgsl_expr(node);
writeln!(src, " let t{i} = {rhs};").unwrap();
}
writeln!(src).unwrap();
if n_outputs > 0 {
writeln!(src, " let base_out = idx * {n_outputs}u;").unwrap();
for (k, out) in outputs.iter().enumerate() {
let val = self.wgsl_ref(*out);
writeln!(src, " outputs[base_out + {k}u] = {val};").unwrap();
}
}
writeln!(src, "}}").unwrap();
WgslKernel {
source: src,
n_inputs,
n_outputs,
workgroup_size,
}
}
fn wgsl_expr(&self, node: Node) -> String {
match node {
Node::Var(n) => format!("x{n}"),
Node::Lit(bits) => {
let v = f64::from_bits(bits);
format_f32_literal(v)
}
Node::Add(a, b) => {
format!("({} + {})", self.wgsl_ref(a), self.wgsl_ref(b))
}
Node::Mul(a, b) => {
format!("({} * {})", self.wgsl_ref(a), self.wgsl_ref(b))
}
Node::Neg(a) => format!("(-{})", self.wgsl_ref(a)),
Node::Recip(a) => format!("(1.0 / {})", self.wgsl_ref(a)),
Node::Sqrt(a) => format!("sqrt({})", self.wgsl_ref(a)),
Node::Sin(a) => format!("sin({})", self.wgsl_ref(a)),
Node::Atan2(y, x) => {
format!("atan2({}, {})", self.wgsl_ref(y), self.wgsl_ref(x))
}
Node::Exp2(a) => format!("exp2({})", self.wgsl_ref(a)),
Node::Log2(a) => format!("log2({})", self.wgsl_ref(a)),
Node::Select(c, a, b) => {
format!(
"select({}, {}, {} > 0.0)",
self.wgsl_ref(b),
self.wgsl_ref(a),
self.wgsl_ref(c)
)
}
}
}
fn wgsl_ref(&self, id: ExprId) -> String {
match self.node(id) {
Node::Var(n) => format!("x{n}"),
Node::Lit(bits) => {
let v = f64::from_bits(bits);
format_f32_literal(v)
}
_ => format!("t{}", id.0),
}
}
}
fn format_f32_literal(v: f64) -> String {
if v == 0.0 {
"0.0".to_string()
} else if v == 1.0 {
"1.0".to_string()
} else if v == -1.0 {
"-1.0".to_string()
} else if v == 2.0 {
"2.0".to_string()
} else {
let s = format!("{v}");
if s.contains('.') || s.contains('e') || s.contains('E') {
s
} else {
format!("{s}.0")
}
}
}
#[cfg(test)]
mod tests {
use super::graph::ExprGraph;
#[test]
fn wgsl_basic() {
let mut g = ExprGraph::new();
let x = g.var(0);
let y = g.var(1);
let xx = g.mul(x, x);
let yy = g.mul(y, y);
let sum = g.add(xx, yy);
let dist = g.sqrt(sum);
let kernel = g.to_wgsl(&[dist], 2);
assert!(kernel.source.contains("@compute"));
assert!(kernel.source.contains("@workgroup_size(256)"));
assert!(kernel.source.contains("let x0 = inputs[base_in + 0u];"));
assert!(kernel.source.contains("let x1 = inputs[base_in + 1u];"));
assert!(kernel.source.contains("sqrt("));
assert_eq!(kernel.n_inputs, 2);
assert_eq!(kernel.n_outputs, 1);
assert_eq!(kernel.workgroup_size, 256);
}
#[test]
fn wgsl_multiple_outputs() {
let mut g = ExprGraph::new();
let x = g.var(0);
let y = g.var(1);
let sum = g.add(x, y);
let prod = g.mul(x, y);
let kernel = g.to_wgsl(&[sum, prod], 2);
assert_eq!(kernel.n_outputs, 2);
assert!(kernel.source.contains("let base_out = idx * 2u;"));
assert!(kernel.source.contains("outputs[base_out + 0u]"));
assert!(kernel.source.contains("outputs[base_out + 1u]"));
}
#[test]
fn wgsl_sin() {
let mut g = ExprGraph::new();
let x = g.var(0);
let s = g.sin(x);
let kernel = g.to_wgsl(&[s], 1);
assert!(kernel.source.contains("sin(x0)"));
}
#[test]
fn wgsl_lit_inline() {
let mut g = ExprGraph::new();
let x = g.var(0);
let c = g.lit(3.14);
let prod = g.mul(x, c);
let kernel = g.to_wgsl(&[prod], 1);
assert!(kernel.source.contains("3.14"));
}
#[test]
fn wgsl_select() {
let mut g = ExprGraph::new();
let x = g.var(0);
let a = g.lit(3.0);
let b = g.lit(7.0);
let s = g.select(x, a, b);
let kernel = g.to_wgsl(&[s], 1);
assert!(kernel.source.contains("select("));
assert!(kernel.source.contains("> 0.0)"));
}
#[test]
fn wgsl_full_pipeline() {
let mut g = ExprGraph::new();
let x = g.var(0);
let xx = g.mul(x, x);
let dx = g.diff(xx, 0);
let dx = g.simplify(dx);
let kernel = g.to_wgsl(&[xx, dx], 1);
assert_eq!(kernel.n_outputs, 2);
assert!(kernel.source.contains("@compute"));
}
}