1pub mod counts {
2 pub const F32_UNIFORM: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
3
4struct HistogramParams {
5 min_value: f32,
6 inv_bin_width: f32,
7 sample_count: u32,
8 bin_count: u32,
9 accumulate_total: u32,
10 _pad: vec3<u32>,
11};
12
13@group(0) @binding(0)
14var<storage, read> samples: array<f32>;
15
16@group(0) @binding(1)
17var<storage, read_write> counts: array<atomic<u32>>;
18
19@group(0) @binding(2)
20var<storage, read_write> total_weight: atomic<u32>;
21
22@group(0) @binding(3)
23var<uniform> params: HistogramParams;
24
25fn atomic_add_f32(target: ptr<storage, atomic<u32>>, value: f32) {
26 var old_bits = atomicLoad(target);
27 loop {
28 let old_value = bitcast<f32>(old_bits);
29 let new_bits = bitcast<u32>(old_value + value);
30 let exchange = atomicCompareExchangeWeak(target, old_bits, new_bits);
31 if (exchange.exchanged) {
32 break;
33 }
34 old_bits = exchange.old_value;
35 }
36}
37
38@compute @workgroup_size(WORKGROUP_SIZE)
39fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
40 let idx = gid.x;
41 if (idx >= params.sample_count) {
42 return;
43 }
44
45 let value = samples[idx];
46 let normalized = (value - params.min_value) * params.inv_bin_width;
47 let raw_bin = i32(floor(normalized));
48 let clamped = clamp(raw_bin, 0, i32(params.bin_count) - 1);
49 let bin_index = u32(clamped);
50
51 atomic_add_f32(&counts[bin_index], 1.0);
52 if (params.accumulate_total != 0u) {
53 atomic_add_f32(&total_weight, 1.0);
54 }
55}
56"#;
57
58 pub const F32_WEIGHTS_F32: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
59
60struct HistogramParams {
61 min_value: f32,
62 inv_bin_width: f32,
63 sample_count: u32,
64 bin_count: u32,
65 accumulate_total: u32,
66 _pad: vec3<u32>,
67};
68
69@group(0) @binding(0)
70var<storage, read> samples: array<f32>;
71
72@group(0) @binding(1)
73var<storage, read> weights: array<f32>;
74
75@group(0) @binding(2)
76var<storage, read_write> counts: array<atomic<u32>>;
77
78@group(0) @binding(3)
79var<storage, read_write> total_weight: atomic<u32>;
80
81@group(0) @binding(4)
82var<uniform> params: HistogramParams;
83
84fn atomic_add_f32(target: ptr<storage, atomic<u32>>, value: f32) {
85 var old_bits = atomicLoad(target);
86 loop {
87 let old_value = bitcast<f32>(old_bits);
88 let new_bits = bitcast<u32>(old_value + value);
89 let exchange = atomicCompareExchangeWeak(target, old_bits, new_bits);
90 if (exchange.exchanged) {
91 break;
92 }
93 old_bits = exchange.old_value;
94 }
95}
96
97@compute @workgroup_size(WORKGROUP_SIZE)
98fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
99 let idx = gid.x;
100 if (idx >= params.sample_count) {
101 return;
102 }
103
104 let value = samples[idx];
105 let normalized = (value - params.min_value) * params.inv_bin_width;
106 let raw_bin = i32(floor(normalized));
107 let clamped = clamp(raw_bin, 0, i32(params.bin_count) - 1);
108 let bin_index = u32(clamped);
109 let weight = weights[idx];
110
111 atomic_add_f32(&counts[bin_index], weight);
112 if (params.accumulate_total != 0u) {
113 atomic_add_f32(&total_weight, weight);
114 }
115}
116"#;
117
118 pub const F32_WEIGHTS_F64: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
119
120struct HistogramParams {
121 min_value: f32,
122 inv_bin_width: f32,
123 sample_count: u32,
124 bin_count: u32,
125 accumulate_total: u32,
126 _pad: vec3<u32>,
127};
128
129@group(0) @binding(0)
130var<storage, read> samples: array<f32>;
131
132@group(0) @binding(1)
133var<storage, read> weights: array<f64>;
134
135@group(0) @binding(2)
136var<storage, read_write> counts: array<atomic<u32>>;
137
138@group(0) @binding(3)
139var<storage, read_write> total_weight: atomic<u32>;
140
141@group(0) @binding(4)
142var<uniform> params: HistogramParams;
143
144fn atomic_add_f32(target: ptr<storage, atomic<u32>>, value: f32) {
145 var old_bits = atomicLoad(target);
146 loop {
147 let old_value = bitcast<f32>(old_bits);
148 let new_bits = bitcast<u32>(old_value + value);
149 let exchange = atomicCompareExchangeWeak(target, old_bits, new_bits);
150 if (exchange.exchanged) {
151 break;
152 }
153 old_bits = exchange.old_value;
154 }
155}
156
157@compute @workgroup_size(WORKGROUP_SIZE)
158fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
159 let idx = gid.x;
160 if (idx >= params.sample_count) {
161 return;
162 }
163
164 let value = samples[idx];
165 let normalized = (value - params.min_value) * params.inv_bin_width;
166 let raw_bin = i32(floor(normalized));
167 let clamped = clamp(raw_bin, 0, i32(params.bin_count) - 1);
168 let bin_index = u32(clamped);
169 let weight = f32(weights[idx]);
170
171 atomic_add_f32(&counts[bin_index], weight);
172 if (params.accumulate_total != 0u) {
173 atomic_add_f32(&total_weight, weight);
174 }
175}
176"#;
177
178 pub const F64_UNIFORM: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
179
180struct HistogramParams {
181 min_value: f32,
182 inv_bin_width: f32,
183 sample_count: u32,
184 bin_count: u32,
185 accumulate_total: u32,
186 _pad: vec3<u32>,
187};
188
189@group(0) @binding(0)
190var<storage, read> samples: array<f64>;
191
192@group(0) @binding(1)
193var<storage, read_write> counts: array<atomic<u32>>;
194
195@group(0) @binding(2)
196var<storage, read_write> total_weight: atomic<u32>;
197
198@group(0) @binding(3)
199var<uniform> params: HistogramParams;
200
201fn atomic_add_f32(target: ptr<storage, atomic<u32>>, value: f32) {
202 var old_bits = atomicLoad(target);
203 loop {
204 let old_value = bitcast<f32>(old_bits);
205 let new_bits = bitcast<u32>(old_value + value);
206 let exchange = atomicCompareExchangeWeak(target, old_bits, new_bits);
207 if (exchange.exchanged) {
208 break;
209 }
210 old_bits = exchange.old_value;
211 }
212}
213
214@compute @workgroup_size(WORKGROUP_SIZE)
215fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
216 let idx = gid.x;
217 if (idx >= params.sample_count) {
218 return;
219 }
220
221 let value = f32(samples[idx]);
222 let normalized = (value - params.min_value) * params.inv_bin_width;
223 let raw_bin = i32(floor(normalized));
224 let clamped = clamp(raw_bin, 0, i32(params.bin_count) - 1);
225 let bin_index = u32(clamped);
226
227 atomic_add_f32(&counts[bin_index], 1.0);
228 if (params.accumulate_total != 0u) {
229 atomic_add_f32(&total_weight, 1.0);
230 }
231}
232"#;
233
234 pub const F64_WEIGHTS_F32: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
235
236struct HistogramParams {
237 min_value: f32,
238 inv_bin_width: f32,
239 sample_count: u32,
240 bin_count: u32,
241 accumulate_total: u32,
242 _pad: vec3<u32>,
243};
244
245@group(0) @binding(0)
246var<storage, read> samples: array<f64>;
247
248@group(0) @binding(1)
249var<storage, read> weights: array<f32>;
250
251@group(0) @binding(2)
252var<storage, read_write> counts: array<atomic<u32>>;
253
254@group(0) @binding(3)
255var<storage, read_write> total_weight: atomic<u32>;
256
257@group(0) @binding(4)
258var<uniform> params: HistogramParams;
259
260fn atomic_add_f32(target: ptr<storage, atomic<u32>>, value: f32) {
261 var old_bits = atomicLoad(target);
262 loop {
263 let old_value = bitcast<f32>(old_bits);
264 let new_bits = bitcast<u32>(old_value + value);
265 let exchange = atomicCompareExchangeWeak(target, old_bits, new_bits);
266 if (exchange.exchanged) {
267 break;
268 }
269 old_bits = exchange.old_value;
270 }
271}
272
273@compute @workgroup_size(WORKGROUP_SIZE)
274fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
275 let idx = gid.x;
276 if (idx >= params.sample_count) {
277 return;
278 }
279
280 let value = f32(samples[idx]);
281 let normalized = (value - params.min_value) * params.inv_bin_width;
282 let raw_bin = i32(floor(normalized));
283 let clamped = clamp(raw_bin, 0, i32(params.bin_count) - 1);
284 let bin_index = u32(clamped);
285 let weight = weights[idx];
286
287 atomic_add_f32(&counts[bin_index], weight);
288 if (params.accumulate_total != 0u) {
289 atomic_add_f32(&total_weight, weight);
290 }
291}
292"#;
293
294 pub const F64_WEIGHTS_F64: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
295
296struct HistogramParams {
297 min_value: f32,
298 inv_bin_width: f32,
299 sample_count: u32,
300 bin_count: u32,
301 accumulate_total: u32,
302 _pad: vec3<u32>,
303};
304
305@group(0) @binding(0)
306var<storage, read> samples: array<f64>;
307
308@group(0) @binding(1)
309var<storage, read> weights: array<f64>;
310
311@group(0) @binding(2)
312var<storage, read_write> counts: array<atomic<u32>>;
313
314@group(0) @binding(3)
315var<storage, read_write> total_weight: atomic<u32>;
316
317@group(0) @binding(4)
318var<uniform> params: HistogramParams;
319
320fn atomic_add_f32(target: ptr<storage, atomic<u32>>, value: f32) {
321 var old_bits = atomicLoad(target);
322 loop {
323 let old_value = bitcast<f32>(old_bits);
324 let new_bits = bitcast<u32>(old_value + value);
325 let exchange = atomicCompareExchangeWeak(target, old_bits, new_bits);
326 if (exchange.exchanged) {
327 break;
328 }
329 old_bits = exchange.old_value;
330 }
331}
332
333@compute @workgroup_size(WORKGROUP_SIZE)
334fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
335 let idx = gid.x;
336 if (idx >= params.sample_count) {
337 return;
338 }
339
340 let value = f32(samples[idx]);
341 let normalized = (value - params.min_value) * params.inv_bin_width;
342 let raw_bin = i32(floor(normalized));
343 let clamped = clamp(raw_bin, 0, i32(params.bin_count) - 1);
344 let bin_index = u32(clamped);
345 let weight = f32(weights[idx]);
346
347 atomic_add_f32(&counts[bin_index], weight);
348 if (params.accumulate_total != 0u) {
349 atomic_add_f32(&total_weight, weight);
350 }
351}
352"#;
353}
354
355pub mod convert {
356 pub const TEMPLATE: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
357
358struct ConvertParams {
359 bin_count: u32,
360 _pad0: vec3<u32>,
361 scale: f32,
362 _pad1: vec3<u32>,
363};
364
365@group(0) @binding(0)
366var<storage, read> counts: array<u32>;
367
368@group(0) @binding(1)
369var<storage, read_write> values: array<f32>;
370
371@group(0) @binding(2)
372var<uniform> params: ConvertParams;
373
374@compute @workgroup_size(WORKGROUP_SIZE)
375fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
376 let idx = gid.x;
377 if (idx >= params.bin_count) {
378 return;
379 }
380
381 let count_value = bitcast<f32>(counts[idx]);
382 values[idx] = count_value * params.scale;
383}
384"#;
385}