tang 0.2.0

Math library for physical reality — geometry, spatial algebra, tensor, training, GPU compute, and 3D gaussian splatting
Documentation
//! WGSL compute shader code generation.

use std::fmt::Write;

use super::graph::ExprGraph;
use super::node::{ExprId, Node};

/// A generated WGSL compute shader.
pub struct WgslKernel {
    /// Complete WGSL shader source.
    pub source: String,
    /// Number of input values per work item.
    pub n_inputs: usize,
    /// Number of output values per work item.
    pub n_outputs: usize,
    /// Workgroup size (default: 256).
    pub workgroup_size: u32,
}

impl ExprGraph {
    /// Generate a WGSL compute shader that evaluates expressions in parallel.
    ///
    /// Each work item reads `n_inputs` values and writes `outputs.len()` values.
    /// The generated shader uses f32 (GPU native). Shared subexpressions are
    /// computed once per thread.
    ///
    /// The caller handles device/pipeline/dispatch (no wgpu dependency here).
    pub fn to_wgsl(&self, outputs: &[ExprId], n_inputs: usize) -> WgslKernel {
        let workgroup_size = 256u32;
        let n_outputs = outputs.len();

        // Find all live nodes (shared with codegen.rs and compile.rs)
        let live = self.live_set(outputs);
        let max_id = if live.is_empty() {
            0
        } else {
            *live.iter().max().unwrap()
        };

        let mut src = String::with_capacity(2048);

        // Header
        writeln!(src, "// Auto-generated by tang-expr").unwrap();
        writeln!(src).unwrap();

        // Params struct
        writeln!(src, "struct Params {{").unwrap();
        writeln!(src, "    count: u32,").unwrap();
        writeln!(src, "    _pad1: u32,").unwrap();
        writeln!(src, "    _pad2: u32,").unwrap();
        writeln!(src, "    _pad3: u32,").unwrap();
        writeln!(src, "}}").unwrap();
        writeln!(src).unwrap();

        // Bindings
        writeln!(
            src,
            "@group(0) @binding(0) var<storage, read> inputs: array<f32>;"
        )
        .unwrap();
        writeln!(
            src,
            "@group(0) @binding(1) var<storage, read_write> outputs: array<f32>;"
        )
        .unwrap();
        writeln!(src, "@group(0) @binding(2) var<uniform> params: Params;").unwrap();
        writeln!(src).unwrap();

        // Entry point
        writeln!(src, "@compute @workgroup_size({workgroup_size})").unwrap();
        writeln!(
            src,
            "fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{"
        )
        .unwrap();
        writeln!(src, "    let idx = gid.x;").unwrap();
        writeln!(src, "    if (idx >= params.count) {{ return; }}").unwrap();
        writeln!(src).unwrap();

        // Load inputs
        if n_inputs > 0 {
            writeln!(src, "    let base_in = idx * {n_inputs}u;").unwrap();
            for i in 0..n_inputs {
                writeln!(src, "    let x{i} = inputs[base_in + {i}u];").unwrap();
            }
            writeln!(src).unwrap();
        }

        // Evaluate in topological order (SSA form)
        for i in 0..=max_id {
            if !live.contains(&i) {
                continue;
            }
            let node = self.node(ExprId(i as u32));
            // Skip Var and Lit nodes that are used inline
            match node {
                Node::Var(_) | Node::Lit(_) => continue,
                _ => {}
            }
            let rhs = self.wgsl_expr(node);
            writeln!(src, "    let t{i} = {rhs};").unwrap();
        }
        writeln!(src).unwrap();

        // Store outputs
        if n_outputs > 0 {
            writeln!(src, "    let base_out = idx * {n_outputs}u;").unwrap();
            for (k, out) in outputs.iter().enumerate() {
                let val = self.wgsl_ref(*out);
                writeln!(src, "    outputs[base_out + {k}u] = {val};").unwrap();
            }
        }

        writeln!(src, "}}").unwrap();

        WgslKernel {
            source: src,
            n_inputs,
            n_outputs,
            workgroup_size,
        }
    }

    /// Generate WGSL expression for a node.
    fn wgsl_expr(&self, node: Node) -> String {
        match node {
            Node::Var(n) => format!("x{n}"),
            Node::Lit(bits) => {
                let v = f64::from_bits(bits);
                format_f32_literal(v)
            }
            Node::Add(a, b) => {
                format!("({} + {})", self.wgsl_ref(a), self.wgsl_ref(b))
            }
            Node::Mul(a, b) => {
                format!("({} * {})", self.wgsl_ref(a), self.wgsl_ref(b))
            }
            Node::Neg(a) => format!("(-{})", self.wgsl_ref(a)),
            Node::Recip(a) => format!("(1.0 / {})", self.wgsl_ref(a)),
            Node::Sqrt(a) => format!("sqrt({})", self.wgsl_ref(a)),
            Node::Sin(a) => format!("sin({})", self.wgsl_ref(a)),
            Node::Atan2(y, x) => {
                format!("atan2({}, {})", self.wgsl_ref(y), self.wgsl_ref(x))
            }
            Node::Exp2(a) => format!("exp2({})", self.wgsl_ref(a)),
            Node::Log2(a) => format!("log2({})", self.wgsl_ref(a)),
            Node::Select(c, a, b) => {
                // WGSL select(false_val, true_val, cond) — false value FIRST
                format!(
                    "select({}, {}, {} > 0.0)",
                    self.wgsl_ref(b),
                    self.wgsl_ref(a),
                    self.wgsl_ref(c)
                )
            }
        }
    }

    /// Reference a node: Var → x{n}, Lit → literal, others → t{index}.
    fn wgsl_ref(&self, id: ExprId) -> String {
        match self.node(id) {
            Node::Var(n) => format!("x{n}"),
            Node::Lit(bits) => {
                let v = f64::from_bits(bits);
                format_f32_literal(v)
            }
            _ => format!("t{}", id.0),
        }
    }
}

/// Format an f64 as an f32 WGSL literal.
fn format_f32_literal(v: f64) -> String {
    if v == 0.0 {
        "0.0".to_string()
    } else if v == 1.0 {
        "1.0".to_string()
    } else if v == -1.0 {
        "-1.0".to_string()
    } else if v == 2.0 {
        "2.0".to_string()
    } else {
        // Ensure the literal has a decimal point for WGSL
        let s = format!("{v}");
        if s.contains('.') || s.contains('e') || s.contains('E') {
            s
        } else {
            format!("{s}.0")
        }
    }
}

#[cfg(test)]
mod tests {
    use super::graph::ExprGraph;

    #[test]
    fn wgsl_basic() {
        let mut g = ExprGraph::new();
        let x = g.var(0);
        let y = g.var(1);
        let xx = g.mul(x, x);
        let yy = g.mul(y, y);
        let sum = g.add(xx, yy);
        let dist = g.sqrt(sum);

        let kernel = g.to_wgsl(&[dist], 2);
        assert!(kernel.source.contains("@compute"));
        assert!(kernel.source.contains("@workgroup_size(256)"));
        assert!(kernel.source.contains("let x0 = inputs[base_in + 0u];"));
        assert!(kernel.source.contains("let x1 = inputs[base_in + 1u];"));
        assert!(kernel.source.contains("sqrt("));
        assert_eq!(kernel.n_inputs, 2);
        assert_eq!(kernel.n_outputs, 1);
        assert_eq!(kernel.workgroup_size, 256);
    }

    #[test]
    fn wgsl_multiple_outputs() {
        let mut g = ExprGraph::new();
        let x = g.var(0);
        let y = g.var(1);
        let sum = g.add(x, y);
        let prod = g.mul(x, y);

        let kernel = g.to_wgsl(&[sum, prod], 2);
        assert_eq!(kernel.n_outputs, 2);
        assert!(kernel.source.contains("let base_out = idx * 2u;"));
        assert!(kernel.source.contains("outputs[base_out + 0u]"));
        assert!(kernel.source.contains("outputs[base_out + 1u]"));
    }

    #[test]
    fn wgsl_sin() {
        let mut g = ExprGraph::new();
        let x = g.var(0);
        let s = g.sin(x);

        let kernel = g.to_wgsl(&[s], 1);
        assert!(kernel.source.contains("sin(x0)"));
    }

    #[test]
    fn wgsl_lit_inline() {
        let mut g = ExprGraph::new();
        let x = g.var(0);
        let c = g.lit(3.14);
        let prod = g.mul(x, c);

        let kernel = g.to_wgsl(&[prod], 1);
        // Literal should be inlined, not assigned to a t variable
        assert!(kernel.source.contains("3.14"));
    }

    #[test]
    fn wgsl_select() {
        let mut g = ExprGraph::new();
        let x = g.var(0);
        let a = g.lit(3.0);
        let b = g.lit(7.0);
        let s = g.select(x, a, b);

        let kernel = g.to_wgsl(&[s], 1);
        // WGSL select(false_val, true_val, cond)
        assert!(kernel.source.contains("select("));
        assert!(kernel.source.contains("> 0.0)"));
    }

    #[test]
    fn wgsl_full_pipeline() {
        // Build a small expression, differentiate, simplify, compile to WGSL
        let mut g = ExprGraph::new();
        let x = g.var(0);
        let xx = g.mul(x, x);
        let dx = g.diff(xx, 0);
        let dx = g.simplify(dx);

        let kernel = g.to_wgsl(&[xx, dx], 1);
        assert_eq!(kernel.n_outputs, 2);
        assert!(kernel.source.contains("@compute"));
    }
}