1use std::fmt::Write;
4
5use crate::graph::ExprGraph;
6use crate::node::{ExprId, Node};
7
8pub struct WgslKernel {
10 pub source: String,
12 pub n_inputs: usize,
14 pub n_outputs: usize,
16 pub workgroup_size: u32,
18}
19
20impl ExprGraph {
21 pub fn to_wgsl(&self, outputs: &[ExprId], n_inputs: usize) -> WgslKernel {
29 let workgroup_size = 256u32;
30 let n_outputs = outputs.len();
31
32 let live = self.live_set(outputs);
34 let max_id = if live.is_empty() {
35 0
36 } else {
37 *live.iter().max().unwrap()
38 };
39
40 let mut src = String::with_capacity(2048);
41
42 writeln!(src, "// Auto-generated by tang-expr").unwrap();
44 writeln!(src).unwrap();
45
46 writeln!(src, "struct Params {{").unwrap();
48 writeln!(src, " count: u32,").unwrap();
49 writeln!(src, " _pad1: u32,").unwrap();
50 writeln!(src, " _pad2: u32,").unwrap();
51 writeln!(src, " _pad3: u32,").unwrap();
52 writeln!(src, "}}").unwrap();
53 writeln!(src).unwrap();
54
55 writeln!(
57 src,
58 "@group(0) @binding(0) var<storage, read> inputs: array<f32>;"
59 )
60 .unwrap();
61 writeln!(
62 src,
63 "@group(0) @binding(1) var<storage, read_write> outputs: array<f32>;"
64 )
65 .unwrap();
66 writeln!(src, "@group(0) @binding(2) var<uniform> params: Params;").unwrap();
67 writeln!(src).unwrap();
68
69 writeln!(src, "@compute @workgroup_size({workgroup_size})").unwrap();
71 writeln!(
72 src,
73 "fn main(@builtin(global_invocation_id) gid: vec3<u32>) {{"
74 )
75 .unwrap();
76 writeln!(src, " let idx = gid.x;").unwrap();
77 writeln!(src, " if (idx >= params.count) {{ return; }}").unwrap();
78 writeln!(src).unwrap();
79
80 if n_inputs > 0 {
82 writeln!(src, " let base_in = idx * {n_inputs}u;").unwrap();
83 for i in 0..n_inputs {
84 writeln!(src, " let x{i} = inputs[base_in + {i}u];").unwrap();
85 }
86 writeln!(src).unwrap();
87 }
88
89 for i in 0..=max_id {
91 if !live.contains(&i) {
92 continue;
93 }
94 let node = self.node(ExprId(i as u32));
95 match node {
97 Node::Var(_) | Node::Lit(_) => continue,
98 _ => {}
99 }
100 let rhs = self.wgsl_expr(node);
101 writeln!(src, " let t{i} = {rhs};").unwrap();
102 }
103 writeln!(src).unwrap();
104
105 if n_outputs > 0 {
107 writeln!(src, " let base_out = idx * {n_outputs}u;").unwrap();
108 for (k, out) in outputs.iter().enumerate() {
109 let val = self.wgsl_ref(*out);
110 writeln!(src, " outputs[base_out + {k}u] = {val};").unwrap();
111 }
112 }
113
114 writeln!(src, "}}").unwrap();
115
116 WgslKernel {
117 source: src,
118 n_inputs,
119 n_outputs,
120 workgroup_size,
121 }
122 }
123
124 fn wgsl_expr(&self, node: Node) -> String {
126 match node {
127 Node::Var(n) => format!("x{n}"),
128 Node::Lit(bits) => {
129 let v = f64::from_bits(bits);
130 format_f32_literal(v)
131 }
132 Node::Add(a, b) => {
133 format!("({} + {})", self.wgsl_ref(a), self.wgsl_ref(b))
134 }
135 Node::Mul(a, b) => {
136 format!("({} * {})", self.wgsl_ref(a), self.wgsl_ref(b))
137 }
138 Node::Neg(a) => format!("(-{})", self.wgsl_ref(a)),
139 Node::Recip(a) => format!("(1.0 / {})", self.wgsl_ref(a)),
140 Node::Sqrt(a) => format!("sqrt({})", self.wgsl_ref(a)),
141 Node::Sin(a) => format!("sin({})", self.wgsl_ref(a)),
142 Node::Atan2(y, x) => {
143 format!("atan2({}, {})", self.wgsl_ref(y), self.wgsl_ref(x))
144 }
145 Node::Exp2(a) => format!("exp2({})", self.wgsl_ref(a)),
146 Node::Log2(a) => format!("log2({})", self.wgsl_ref(a)),
147 Node::Select(c, a, b) => {
148 format!(
150 "select({}, {}, {} > 0.0)",
151 self.wgsl_ref(b),
152 self.wgsl_ref(a),
153 self.wgsl_ref(c)
154 )
155 }
156 }
157 }
158
159 fn wgsl_ref(&self, id: ExprId) -> String {
161 match self.node(id) {
162 Node::Var(n) => format!("x{n}"),
163 Node::Lit(bits) => {
164 let v = f64::from_bits(bits);
165 format_f32_literal(v)
166 }
167 _ => format!("t{}", id.0),
168 }
169 }
170}
171
172fn format_f32_literal(v: f64) -> String {
174 if v == 0.0 {
175 "0.0".to_string()
176 } else if v == 1.0 {
177 "1.0".to_string()
178 } else if v == -1.0 {
179 "-1.0".to_string()
180 } else if v == 2.0 {
181 "2.0".to_string()
182 } else {
183 let s = format!("{v}");
185 if s.contains('.') || s.contains('e') || s.contains('E') {
186 s
187 } else {
188 format!("{s}.0")
189 }
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use crate::graph::ExprGraph;
196
197 #[test]
198 fn wgsl_basic() {
199 let mut g = ExprGraph::new();
200 let x = g.var(0);
201 let y = g.var(1);
202 let xx = g.mul(x, x);
203 let yy = g.mul(y, y);
204 let sum = g.add(xx, yy);
205 let dist = g.sqrt(sum);
206
207 let kernel = g.to_wgsl(&[dist], 2);
208 assert!(kernel.source.contains("@compute"));
209 assert!(kernel.source.contains("@workgroup_size(256)"));
210 assert!(kernel.source.contains("let x0 = inputs[base_in + 0u];"));
211 assert!(kernel.source.contains("let x1 = inputs[base_in + 1u];"));
212 assert!(kernel.source.contains("sqrt("));
213 assert_eq!(kernel.n_inputs, 2);
214 assert_eq!(kernel.n_outputs, 1);
215 assert_eq!(kernel.workgroup_size, 256);
216 }
217
218 #[test]
219 fn wgsl_multiple_outputs() {
220 let mut g = ExprGraph::new();
221 let x = g.var(0);
222 let y = g.var(1);
223 let sum = g.add(x, y);
224 let prod = g.mul(x, y);
225
226 let kernel = g.to_wgsl(&[sum, prod], 2);
227 assert_eq!(kernel.n_outputs, 2);
228 assert!(kernel.source.contains("let base_out = idx * 2u;"));
229 assert!(kernel.source.contains("outputs[base_out + 0u]"));
230 assert!(kernel.source.contains("outputs[base_out + 1u]"));
231 }
232
233 #[test]
234 fn wgsl_sin() {
235 let mut g = ExprGraph::new();
236 let x = g.var(0);
237 let s = g.sin(x);
238
239 let kernel = g.to_wgsl(&[s], 1);
240 assert!(kernel.source.contains("sin(x0)"));
241 }
242
243 #[test]
244 fn wgsl_lit_inline() {
245 let mut g = ExprGraph::new();
246 let x = g.var(0);
247 let c = g.lit(3.14);
248 let prod = g.mul(x, c);
249
250 let kernel = g.to_wgsl(&[prod], 1);
251 assert!(kernel.source.contains("3.14"));
253 }
254
255 #[test]
256 fn wgsl_select() {
257 let mut g = ExprGraph::new();
258 let x = g.var(0);
259 let a = g.lit(3.0);
260 let b = g.lit(7.0);
261 let s = g.select(x, a, b);
262
263 let kernel = g.to_wgsl(&[s], 1);
264 assert!(kernel.source.contains("select("));
266 assert!(kernel.source.contains("> 0.0)"));
267 }
268
269 #[test]
270 fn wgsl_full_pipeline() {
271 let mut g = ExprGraph::new();
273 let x = g.var(0);
274 let xx = g.mul(x, x);
275 let dx = g.diff(xx, 0);
276 let dx = g.simplify(dx);
277
278 let kernel = g.to_wgsl(&[xx, dx], 1);
279 assert_eq!(kernel.n_outputs, 2);
280 assert!(kernel.source.contains("@compute"));
281 }
282}