Skip to main content

runmat_plot/gpu/shaders/
histogram.rs

1pub mod counts {
2    pub const F32_UNIFORM: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
3
4struct HistogramParams {
5    min_value: f32,
6    inv_bin_width: f32,
7    sample_count: u32,
8    bin_count: u32,
9    accumulate_total: u32,
10    _pad: vec3<u32>,
11};
12
13@group(0) @binding(0)
14var<storage, read> samples: array<f32>;
15
16@group(0) @binding(1)
17var<storage, read_write> counts: array<atomic<u32>>;
18
19@group(0) @binding(2)
20var<storage, read_write> total_weight: atomic<u32>;
21
22@group(0) @binding(3)
23var<uniform> params: HistogramParams;
24
25fn atomic_add_f32(target: ptr<storage, atomic<u32>>, value: f32) {
26    var old_bits = atomicLoad(target);
27    loop {
28        let old_value = bitcast<f32>(old_bits);
29        let new_bits = bitcast<u32>(old_value + value);
30        let exchange = atomicCompareExchangeWeak(target, old_bits, new_bits);
31        if (exchange.exchanged) {
32            break;
33        }
34        old_bits = exchange.old_value;
35    }
36}
37
38@compute @workgroup_size(WORKGROUP_SIZE)
39fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
40    let idx = gid.x;
41    if (idx >= params.sample_count) {
42        return;
43    }
44
45    let value = samples[idx];
46    let normalized = (value - params.min_value) * params.inv_bin_width;
47    let raw_bin = i32(floor(normalized));
48    let clamped = clamp(raw_bin, 0, i32(params.bin_count) - 1);
49    let bin_index = u32(clamped);
50
51    atomic_add_f32(&counts[bin_index], 1.0);
52    if (params.accumulate_total != 0u) {
53        atomic_add_f32(&total_weight, 1.0);
54    }
55}
56"#;
57
58    pub const F32_WEIGHTS_F32: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
59
60struct HistogramParams {
61    min_value: f32,
62    inv_bin_width: f32,
63    sample_count: u32,
64    bin_count: u32,
65    accumulate_total: u32,
66    _pad: vec3<u32>,
67};
68
69@group(0) @binding(0)
70var<storage, read> samples: array<f32>;
71
72@group(0) @binding(1)
73var<storage, read> weights: array<f32>;
74
75@group(0) @binding(2)
76var<storage, read_write> counts: array<atomic<u32>>;
77
78@group(0) @binding(3)
79var<storage, read_write> total_weight: atomic<u32>;
80
81@group(0) @binding(4)
82var<uniform> params: HistogramParams;
83
84fn atomic_add_f32(target: ptr<storage, atomic<u32>>, value: f32) {
85    var old_bits = atomicLoad(target);
86    loop {
87        let old_value = bitcast<f32>(old_bits);
88        let new_bits = bitcast<u32>(old_value + value);
89        let exchange = atomicCompareExchangeWeak(target, old_bits, new_bits);
90        if (exchange.exchanged) {
91            break;
92        }
93        old_bits = exchange.old_value;
94    }
95}
96
97@compute @workgroup_size(WORKGROUP_SIZE)
98fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
99    let idx = gid.x;
100    if (idx >= params.sample_count) {
101        return;
102    }
103
104    let value = samples[idx];
105    let normalized = (value - params.min_value) * params.inv_bin_width;
106    let raw_bin = i32(floor(normalized));
107    let clamped = clamp(raw_bin, 0, i32(params.bin_count) - 1);
108    let bin_index = u32(clamped);
109    let weight = weights[idx];
110
111    atomic_add_f32(&counts[bin_index], weight);
112    if (params.accumulate_total != 0u) {
113        atomic_add_f32(&total_weight, weight);
114    }
115}
116"#;
117
118    pub const F32_WEIGHTS_F64: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
119
120struct HistogramParams {
121    min_value: f32,
122    inv_bin_width: f32,
123    sample_count: u32,
124    bin_count: u32,
125    accumulate_total: u32,
126    _pad: vec3<u32>,
127};
128
129@group(0) @binding(0)
130var<storage, read> samples: array<f32>;
131
132@group(0) @binding(1)
133var<storage, read> weights: array<f64>;
134
135@group(0) @binding(2)
136var<storage, read_write> counts: array<atomic<u32>>;
137
138@group(0) @binding(3)
139var<storage, read_write> total_weight: atomic<u32>;
140
141@group(0) @binding(4)
142var<uniform> params: HistogramParams;
143
144fn atomic_add_f32(target: ptr<storage, atomic<u32>>, value: f32) {
145    var old_bits = atomicLoad(target);
146    loop {
147        let old_value = bitcast<f32>(old_bits);
148        let new_bits = bitcast<u32>(old_value + value);
149        let exchange = atomicCompareExchangeWeak(target, old_bits, new_bits);
150        if (exchange.exchanged) {
151            break;
152        }
153        old_bits = exchange.old_value;
154    }
155}
156
157@compute @workgroup_size(WORKGROUP_SIZE)
158fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
159    let idx = gid.x;
160    if (idx >= params.sample_count) {
161        return;
162    }
163
164    let value = samples[idx];
165    let normalized = (value - params.min_value) * params.inv_bin_width;
166    let raw_bin = i32(floor(normalized));
167    let clamped = clamp(raw_bin, 0, i32(params.bin_count) - 1);
168    let bin_index = u32(clamped);
169    let weight = f32(weights[idx]);
170
171    atomic_add_f32(&counts[bin_index], weight);
172    if (params.accumulate_total != 0u) {
173        atomic_add_f32(&total_weight, weight);
174    }
175}
176"#;
177
178    pub const F64_UNIFORM: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
179
180struct HistogramParams {
181    min_value: f32,
182    inv_bin_width: f32,
183    sample_count: u32,
184    bin_count: u32,
185    accumulate_total: u32,
186    _pad: vec3<u32>,
187};
188
189@group(0) @binding(0)
190var<storage, read> samples: array<f64>;
191
192@group(0) @binding(1)
193var<storage, read_write> counts: array<atomic<u32>>;
194
195@group(0) @binding(2)
196var<storage, read_write> total_weight: atomic<u32>;
197
198@group(0) @binding(3)
199var<uniform> params: HistogramParams;
200
201fn atomic_add_f32(target: ptr<storage, atomic<u32>>, value: f32) {
202    var old_bits = atomicLoad(target);
203    loop {
204        let old_value = bitcast<f32>(old_bits);
205        let new_bits = bitcast<u32>(old_value + value);
206        let exchange = atomicCompareExchangeWeak(target, old_bits, new_bits);
207        if (exchange.exchanged) {
208            break;
209        }
210        old_bits = exchange.old_value;
211    }
212}
213
214@compute @workgroup_size(WORKGROUP_SIZE)
215fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
216    let idx = gid.x;
217    if (idx >= params.sample_count) {
218        return;
219    }
220
221    let value = f32(samples[idx]);
222    let normalized = (value - params.min_value) * params.inv_bin_width;
223    let raw_bin = i32(floor(normalized));
224    let clamped = clamp(raw_bin, 0, i32(params.bin_count) - 1);
225    let bin_index = u32(clamped);
226
227    atomic_add_f32(&counts[bin_index], 1.0);
228    if (params.accumulate_total != 0u) {
229        atomic_add_f32(&total_weight, 1.0);
230    }
231}
232"#;
233
234    pub const F64_WEIGHTS_F32: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
235
236struct HistogramParams {
237    min_value: f32,
238    inv_bin_width: f32,
239    sample_count: u32,
240    bin_count: u32,
241    accumulate_total: u32,
242    _pad: vec3<u32>,
243};
244
245@group(0) @binding(0)
246var<storage, read> samples: array<f64>;
247
248@group(0) @binding(1)
249var<storage, read> weights: array<f32>;
250
251@group(0) @binding(2)
252var<storage, read_write> counts: array<atomic<u32>>;
253
254@group(0) @binding(3)
255var<storage, read_write> total_weight: atomic<u32>;
256
257@group(0) @binding(4)
258var<uniform> params: HistogramParams;
259
260fn atomic_add_f32(target: ptr<storage, atomic<u32>>, value: f32) {
261    var old_bits = atomicLoad(target);
262    loop {
263        let old_value = bitcast<f32>(old_bits);
264        let new_bits = bitcast<u32>(old_value + value);
265        let exchange = atomicCompareExchangeWeak(target, old_bits, new_bits);
266        if (exchange.exchanged) {
267            break;
268        }
269        old_bits = exchange.old_value;
270    }
271}
272
273@compute @workgroup_size(WORKGROUP_SIZE)
274fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
275    let idx = gid.x;
276    if (idx >= params.sample_count) {
277        return;
278    }
279
280    let value = f32(samples[idx]);
281    let normalized = (value - params.min_value) * params.inv_bin_width;
282    let raw_bin = i32(floor(normalized));
283    let clamped = clamp(raw_bin, 0, i32(params.bin_count) - 1);
284    let bin_index = u32(clamped);
285    let weight = weights[idx];
286
287    atomic_add_f32(&counts[bin_index], weight);
288    if (params.accumulate_total != 0u) {
289        atomic_add_f32(&total_weight, weight);
290    }
291}
292"#;
293
294    pub const F64_WEIGHTS_F64: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
295
296struct HistogramParams {
297    min_value: f32,
298    inv_bin_width: f32,
299    sample_count: u32,
300    bin_count: u32,
301    accumulate_total: u32,
302    _pad: vec3<u32>,
303};
304
305@group(0) @binding(0)
306var<storage, read> samples: array<f64>;
307
308@group(0) @binding(1)
309var<storage, read> weights: array<f64>;
310
311@group(0) @binding(2)
312var<storage, read_write> counts: array<atomic<u32>>;
313
314@group(0) @binding(3)
315var<storage, read_write> total_weight: atomic<u32>;
316
317@group(0) @binding(4)
318var<uniform> params: HistogramParams;
319
320fn atomic_add_f32(target: ptr<storage, atomic<u32>>, value: f32) {
321    var old_bits = atomicLoad(target);
322    loop {
323        let old_value = bitcast<f32>(old_bits);
324        let new_bits = bitcast<u32>(old_value + value);
325        let exchange = atomicCompareExchangeWeak(target, old_bits, new_bits);
326        if (exchange.exchanged) {
327            break;
328        }
329        old_bits = exchange.old_value;
330    }
331}
332
333@compute @workgroup_size(WORKGROUP_SIZE)
334fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
335    let idx = gid.x;
336    if (idx >= params.sample_count) {
337        return;
338    }
339
340    let value = f32(samples[idx]);
341    let normalized = (value - params.min_value) * params.inv_bin_width;
342    let raw_bin = i32(floor(normalized));
343    let clamped = clamp(raw_bin, 0, i32(params.bin_count) - 1);
344    let bin_index = u32(clamped);
345    let weight = f32(weights[idx]);
346
347    atomic_add_f32(&counts[bin_index], weight);
348    if (params.accumulate_total != 0u) {
349        atomic_add_f32(&total_weight, weight);
350    }
351}
352"#;
353}
354
355pub mod convert {
356    pub const TEMPLATE: &str = r#"const WORKGROUP_SIZE: u32 = {{WORKGROUP_SIZE}}u;
357
358struct ConvertParams {
359    bin_count: u32,
360    _pad0: vec3<u32>,
361    scale: f32,
362    _pad1: vec3<u32>,
363};
364
365@group(0) @binding(0)
366var<storage, read> counts: array<u32>;
367
368@group(0) @binding(1)
369var<storage, read_write> values: array<f32>;
370
371@group(0) @binding(2)
372var<uniform> params: ConvertParams;
373
374@compute @workgroup_size(WORKGROUP_SIZE)
375fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
376    let idx = gid.x;
377    if (idx >= params.bin_count) {
378        return;
379    }
380
381    let count_value = bitcast<f32>(counts[idx]);
382    values[idx] = count_value * params.scale;
383}
384"#;
385}