Skip to main content

runmat_plot/gpu/
histogram.rs

1use std::sync::Arc;
2use wgpu::util::DeviceExt;
3
4use crate::gpu::shaders;
5use crate::gpu::util::readback_f32;
6use crate::gpu::{tuning, ScalarType};
7
8#[derive(Clone)]
9pub struct HistogramGpuInputs {
10    pub samples: Arc<wgpu::Buffer>,
11    pub sample_count: u32,
12    pub scalar: ScalarType,
13    pub weights: HistogramGpuWeights,
14}
15
16#[derive(Clone)]
17pub enum HistogramGpuWeights {
18    Uniform { total_weight: f32 },
19    HostF32 { data: Vec<f32>, total_weight: f32 },
20    HostF64 { data: Vec<f64>, total_weight: f32 },
21    GpuF32 { buffer: Arc<wgpu::Buffer> },
22    GpuF64 { buffer: Arc<wgpu::Buffer> },
23}
24
25pub enum HistogramNormalizationMode {
26    Count,
27    Probability,
28    Pdf { bin_width: f32 },
29}
30
31pub struct HistogramGpuParams {
32    pub min_value: f32,
33    pub inv_bin_width: f32,
34    pub bin_count: u32,
35}
36
37pub struct HistogramGpuOutput {
38    pub values_buffer: Arc<wgpu::Buffer>,
39    pub total_weight: f32,
40}
41
42#[repr(C)]
43#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
44struct HistogramUniforms {
45    min_value: f32,
46    inv_bin_width: f32,
47    sample_count: u32,
48    bin_count: u32,
49    accumulate_total: u32,
50    _pad: [u32; 3],
51}
52
53#[repr(C)]
54#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
55struct ConvertUniforms {
56    bin_count: u32,
57    _pad: [u32; 3],
58    scale: f32,
59    _pad2: [u32; 3],
60}
61
62struct MaterializedWeights {
63    mode: WeightMode,
64    buffer: Option<Arc<wgpu::Buffer>>,
65    total_hint: Option<f32>,
66    accumulate_total: bool,
67}
68
69struct HistogramPassInputs<'a> {
70    device: &'a Arc<wgpu::Device>,
71    queue: &'a Arc<wgpu::Queue>,
72    samples: &'a Arc<wgpu::Buffer>,
73    sample_count: u32,
74    sample_scalar: ScalarType,
75    params: &'a HistogramGpuParams,
76    counts_buffer: &'a Arc<wgpu::Buffer>,
77    total_weight_buffer: &'a Arc<wgpu::Buffer>,
78}
79
80struct HistogramBindGroupInputs<'a> {
81    device: &'a Arc<wgpu::Device>,
82    samples: &'a Arc<wgpu::Buffer>,
83    counts_buffer: &'a Arc<wgpu::Buffer>,
84    total_weight_buffer: &'a Arc<wgpu::Buffer>,
85    params: &'a HistogramGpuParams,
86    sample_count: u32,
87    accumulate_total: bool,
88}
89
90#[derive(Clone, Copy, PartialEq, Eq)]
91enum WeightMode {
92    Uniform,
93    F32,
94    F64,
95}
96
97pub async fn histogram_values_buffer(
98    device: &Arc<wgpu::Device>,
99    queue: &Arc<wgpu::Queue>,
100    inputs: HistogramGpuInputs,
101    params: &HistogramGpuParams,
102    normalization: HistogramNormalizationMode,
103) -> Result<HistogramGpuOutput, String> {
104    if params.bin_count == 0 {
105        return Err("hist: bin count must be positive".to_string());
106    }
107
108    let bin_count_usize = params.bin_count as usize;
109    let zero_counts = vec![0u32; bin_count_usize];
110    let counts_buffer = Arc::new(
111        device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
112            label: Some("histogram-counts"),
113            contents: bytemuck::cast_slice(&zero_counts),
114            usage: wgpu::BufferUsages::STORAGE
115                | wgpu::BufferUsages::COPY_DST
116                | wgpu::BufferUsages::COPY_SRC,
117        }),
118    );
119
120    let total_weight_buffer = Arc::new(device.create_buffer_init(
121        &wgpu::util::BufferInitDescriptor {
122            label: Some("histogram-total-weight"),
123            contents: bytemuck::cast_slice(&[0u32]),
124            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::MAP_READ,
125        },
126    ));
127
128    let HistogramGpuInputs {
129        samples,
130        sample_count,
131        scalar,
132        weights,
133    } = inputs;
134    let materialized = materialize_weights(device, weights)?;
135
136    if sample_count > 0 {
137        let pass_inputs = HistogramPassInputs {
138            device,
139            queue,
140            samples: &samples,
141            sample_count,
142            sample_scalar: scalar,
143            params,
144            counts_buffer: &counts_buffer,
145            total_weight_buffer: &total_weight_buffer,
146        };
147        run_histogram_pass(&pass_inputs, &materialized)?;
148    }
149
150    let total_weight = if let Some(hint) = materialized.total_hint {
151        hint
152    } else {
153        readback_f32(device, &total_weight_buffer)
154            .await
155            .map_err(|e| format!("hist: failed to read GPU weights total: {e}"))?
156    };
157
158    let normalization_scale = match normalization {
159        HistogramNormalizationMode::Count => 1.0,
160        HistogramNormalizationMode::Probability => {
161            if total_weight <= f32::EPSILON {
162                0.0
163            } else {
164                1.0 / total_weight
165            }
166        }
167        HistogramNormalizationMode::Pdf { bin_width } => {
168            if total_weight <= f32::EPSILON || bin_width <= f32::EPSILON {
169                0.0
170            } else {
171                1.0 / (total_weight * bin_width)
172            }
173        }
174    };
175
176    let values_buffer = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
177        label: Some("histogram-counts-f32"),
178        size: (bin_count_usize * std::mem::size_of::<f32>()) as u64,
179        usage: wgpu::BufferUsages::STORAGE
180            | wgpu::BufferUsages::COPY_DST
181            | wgpu::BufferUsages::COPY_SRC
182            | wgpu::BufferUsages::MAP_READ,
183        mapped_at_creation: false,
184    }));
185
186    run_convert_pass(
187        device,
188        queue,
189        params.bin_count,
190        normalization_scale,
191        &counts_buffer,
192        &values_buffer,
193    )?;
194
195    Ok(HistogramGpuOutput {
196        values_buffer,
197        total_weight,
198    })
199}
200
201fn materialize_weights(
202    device: &Arc<wgpu::Device>,
203    weights: HistogramGpuWeights,
204) -> Result<MaterializedWeights, String> {
205    match weights {
206        HistogramGpuWeights::Uniform { total_weight } => Ok(MaterializedWeights {
207            mode: WeightMode::Uniform,
208            buffer: None,
209            total_hint: Some(total_weight),
210            accumulate_total: false,
211        }),
212        HistogramGpuWeights::HostF32 { data, total_weight } => {
213            let buffer = Arc::new(
214                device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
215                    label: Some("histogram-weights-host-f32"),
216                    contents: bytemuck::cast_slice(&data),
217                    usage: wgpu::BufferUsages::STORAGE,
218                }),
219            );
220            Ok(MaterializedWeights {
221                mode: WeightMode::F32,
222                buffer: Some(buffer),
223                total_hint: Some(total_weight),
224                accumulate_total: false,
225            })
226        }
227        HistogramGpuWeights::HostF64 { data, total_weight } => {
228            let buffer = Arc::new(
229                device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
230                    label: Some("histogram-weights-host-f64"),
231                    contents: bytemuck::cast_slice(&data),
232                    usage: wgpu::BufferUsages::STORAGE,
233                }),
234            );
235            Ok(MaterializedWeights {
236                mode: WeightMode::F64,
237                buffer: Some(buffer),
238                total_hint: Some(total_weight),
239                accumulate_total: false,
240            })
241        }
242        HistogramGpuWeights::GpuF32 { buffer } => Ok(MaterializedWeights {
243            mode: WeightMode::F32,
244            buffer: Some(buffer),
245            total_hint: None,
246            accumulate_total: true,
247        }),
248        HistogramGpuWeights::GpuF64 { buffer } => Ok(MaterializedWeights {
249            mode: WeightMode::F64,
250            buffer: Some(buffer),
251            total_hint: None,
252            accumulate_total: true,
253        }),
254    }
255}
256
257fn run_histogram_pass(
258    inputs: &HistogramPassInputs<'_>,
259    weights: &MaterializedWeights,
260) -> Result<(), String> {
261    let workgroup_size = tuning::effective_workgroup_size();
262    let shader = compile_counts_shader(
263        inputs.device,
264        workgroup_size,
265        inputs.sample_scalar,
266        weights.mode,
267    );
268
269    let bind_inputs = HistogramBindGroupInputs {
270        device: inputs.device,
271        samples: inputs.samples,
272        counts_buffer: inputs.counts_buffer,
273        total_weight_buffer: inputs.total_weight_buffer,
274        params: inputs.params,
275        sample_count: inputs.sample_count,
276        accumulate_total: weights.accumulate_total,
277    };
278
279    let (bind_group_layout, bind_group) = match weights.mode {
280        WeightMode::Uniform => build_uniform_bind_group(&bind_inputs),
281        WeightMode::F32 | WeightMode::F64 => build_weighted_bind_group(
282            &bind_inputs,
283            weights.buffer.as_ref().expect("weights buffer missing"),
284        ),
285    }?;
286
287    let pipeline_layout = inputs
288        .device
289        .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
290            label: Some("histogram-counts-pipeline-layout"),
291            bind_group_layouts: &[&bind_group_layout],
292            push_constant_ranges: &[],
293        });
294
295    let pipeline = inputs
296        .device
297        .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
298            label: Some("histogram-counts-pipeline"),
299            layout: Some(&pipeline_layout),
300            module: &shader,
301            entry_point: "main",
302        });
303
304    let mut encoder = inputs
305        .device
306        .create_command_encoder(&wgpu::CommandEncoderDescriptor {
307            label: Some("histogram-counts-encoder"),
308        });
309    {
310        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
311            label: Some("histogram-counts-pass"),
312            timestamp_writes: None,
313        });
314        pass.set_pipeline(&pipeline);
315        pass.set_bind_group(0, &bind_group, &[]);
316        let workgroups = inputs.sample_count.div_ceil(workgroup_size);
317        pass.dispatch_workgroups(workgroups, 1, 1);
318    }
319    inputs.queue.submit(Some(encoder.finish()));
320
321    Ok(())
322}
323
324fn build_uniform_bind_group(
325    inputs: &HistogramBindGroupInputs<'_>,
326) -> Result<(Arc<wgpu::BindGroupLayout>, Arc<wgpu::BindGroup>), String> {
327    let layout = Arc::new(inputs.device.create_bind_group_layout(
328        &wgpu::BindGroupLayoutDescriptor {
329            label: Some("hist-counts-layout-uniform"),
330            entries: &[
331                storage_read_entry(0),
332                storage_read_write_entry(1),
333                storage_read_write_entry(2),
334                uniform_entry(3),
335            ],
336        },
337    ));
338
339    let uniforms = HistogramUniforms {
340        min_value: inputs.params.min_value,
341        inv_bin_width: inputs.params.inv_bin_width,
342        sample_count: inputs.sample_count,
343        bin_count: inputs.params.bin_count,
344        accumulate_total: if inputs.accumulate_total { 1 } else { 0 },
345        _pad: [0; 3],
346    };
347    let uniform_buffer = inputs
348        .device
349        .create_buffer_init(&wgpu::util::BufferInitDescriptor {
350            label: Some("hist-counts-uniforms"),
351            contents: bytemuck::bytes_of(&uniforms),
352            usage: wgpu::BufferUsages::UNIFORM,
353        });
354
355    let bind_group = Arc::new(inputs.device.create_bind_group(&wgpu::BindGroupDescriptor {
356        label: Some("hist-counts-bind-group-uniform"),
357        layout: &layout,
358        entries: &[
359            wgpu::BindGroupEntry {
360                binding: 0,
361                resource: inputs.samples.as_entire_binding(),
362            },
363            wgpu::BindGroupEntry {
364                binding: 1,
365                resource: inputs.counts_buffer.as_entire_binding(),
366            },
367            wgpu::BindGroupEntry {
368                binding: 2,
369                resource: inputs.total_weight_buffer.as_entire_binding(),
370            },
371            wgpu::BindGroupEntry {
372                binding: 3,
373                resource: uniform_buffer.as_entire_binding(),
374            },
375        ],
376    }));
377
378    Ok((layout, bind_group))
379}
380
381fn build_weighted_bind_group(
382    inputs: &HistogramBindGroupInputs<'_>,
383    weights_buffer: &Arc<wgpu::Buffer>,
384) -> Result<(Arc<wgpu::BindGroupLayout>, Arc<wgpu::BindGroup>), String> {
385    let layout = Arc::new(inputs.device.create_bind_group_layout(
386        &wgpu::BindGroupLayoutDescriptor {
387            label: Some("hist-counts-layout-weighted"),
388            entries: &[
389                storage_read_entry(0),
390                storage_read_entry(1),
391                storage_read_write_entry(2),
392                storage_read_write_entry(3),
393                uniform_entry(4),
394            ],
395        },
396    ));
397
398    let uniforms = HistogramUniforms {
399        min_value: inputs.params.min_value,
400        inv_bin_width: inputs.params.inv_bin_width,
401        sample_count: inputs.sample_count,
402        bin_count: inputs.params.bin_count,
403        accumulate_total: if inputs.accumulate_total { 1 } else { 0 },
404        _pad: [0; 3],
405    };
406    let uniform_buffer = inputs
407        .device
408        .create_buffer_init(&wgpu::util::BufferInitDescriptor {
409            label: Some("hist-counts-weighted-uniforms"),
410            contents: bytemuck::bytes_of(&uniforms),
411            usage: wgpu::BufferUsages::UNIFORM,
412        });
413
414    let bind_group = Arc::new(inputs.device.create_bind_group(&wgpu::BindGroupDescriptor {
415        label: Some("hist-counts-bind-group-weighted"),
416        layout: &layout,
417        entries: &[
418            wgpu::BindGroupEntry {
419                binding: 0,
420                resource: inputs.samples.as_entire_binding(),
421            },
422            wgpu::BindGroupEntry {
423                binding: 1,
424                resource: weights_buffer.as_entire_binding(),
425            },
426            wgpu::BindGroupEntry {
427                binding: 2,
428                resource: inputs.counts_buffer.as_entire_binding(),
429            },
430            wgpu::BindGroupEntry {
431                binding: 3,
432                resource: inputs.total_weight_buffer.as_entire_binding(),
433            },
434            wgpu::BindGroupEntry {
435                binding: 4,
436                resource: uniform_buffer.as_entire_binding(),
437            },
438        ],
439    }));
440
441    Ok((layout, bind_group))
442}
443
444fn storage_read_entry(binding: u32) -> wgpu::BindGroupLayoutEntry {
445    wgpu::BindGroupLayoutEntry {
446        binding,
447        visibility: wgpu::ShaderStages::COMPUTE,
448        ty: wgpu::BindingType::Buffer {
449            ty: wgpu::BufferBindingType::Storage { read_only: true },
450            has_dynamic_offset: false,
451            min_binding_size: None,
452        },
453        count: None,
454    }
455}
456
457fn storage_read_write_entry(binding: u32) -> wgpu::BindGroupLayoutEntry {
458    wgpu::BindGroupLayoutEntry {
459        binding,
460        visibility: wgpu::ShaderStages::COMPUTE,
461        ty: wgpu::BindingType::Buffer {
462            ty: wgpu::BufferBindingType::Storage { read_only: false },
463            has_dynamic_offset: false,
464            min_binding_size: None,
465        },
466        count: None,
467    }
468}
469
470fn uniform_entry(binding: u32) -> wgpu::BindGroupLayoutEntry {
471    wgpu::BindGroupLayoutEntry {
472        binding,
473        visibility: wgpu::ShaderStages::COMPUTE,
474        ty: wgpu::BindingType::Buffer {
475            ty: wgpu::BufferBindingType::Uniform,
476            has_dynamic_offset: false,
477            min_binding_size: None,
478        },
479        count: None,
480    }
481}
482
483fn run_convert_pass(
484    device: &Arc<wgpu::Device>,
485    queue: &Arc<wgpu::Queue>,
486    bin_count: u32,
487    normalization_scale: f32,
488    counts_buffer: &Arc<wgpu::Buffer>,
489    values_buffer: &Arc<wgpu::Buffer>,
490) -> Result<(), String> {
491    let workgroup_size = tuning::effective_workgroup_size();
492    let shader = compile_convert_shader(device, workgroup_size);
493
494    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
495        label: Some("histogram-convert-bind-layout"),
496        entries: &[
497            storage_read_entry(0),
498            storage_read_write_entry(1),
499            uniform_entry(2),
500        ],
501    });
502
503    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
504        label: Some("histogram-convert-pipeline-layout"),
505        bind_group_layouts: &[&bind_group_layout],
506        push_constant_ranges: &[],
507    });
508
509    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
510        label: Some("histogram-convert-pipeline"),
511        layout: Some(&pipeline_layout),
512        module: &shader,
513        entry_point: "main",
514    });
515
516    let uniforms = ConvertUniforms {
517        bin_count,
518        _pad: [0; 3],
519        scale: normalization_scale,
520        _pad2: [0; 3],
521    };
522    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
523        label: Some("histogram-convert-uniforms"),
524        contents: bytemuck::bytes_of(&uniforms),
525        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
526    });
527
528    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
529        label: Some("histogram-convert-bind-group"),
530        layout: &bind_group_layout,
531        entries: &[
532            wgpu::BindGroupEntry {
533                binding: 0,
534                resource: counts_buffer.as_entire_binding(),
535            },
536            wgpu::BindGroupEntry {
537                binding: 1,
538                resource: values_buffer.as_entire_binding(),
539            },
540            wgpu::BindGroupEntry {
541                binding: 2,
542                resource: uniform_buffer.as_entire_binding(),
543            },
544        ],
545    });
546
547    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
548        label: Some("histogram-convert-encoder"),
549    });
550    {
551        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
552            label: Some("histogram-convert-pass"),
553            timestamp_writes: None,
554        });
555        pass.set_pipeline(&pipeline);
556        pass.set_bind_group(0, &bind_group, &[]);
557        let workgroups = bin_count.div_ceil(workgroup_size);
558        pass.dispatch_workgroups(workgroups, 1, 1);
559    }
560    queue.submit(Some(encoder.finish()));
561
562    Ok(())
563}
564
565fn compile_counts_shader(
566    device: &Arc<wgpu::Device>,
567    workgroup_size: u32,
568    scalar: ScalarType,
569    weight_mode: WeightMode,
570) -> wgpu::ShaderModule {
571    let template = match (scalar, weight_mode) {
572        (ScalarType::F32, WeightMode::Uniform) => shaders::histogram::counts::F32_UNIFORM,
573        (ScalarType::F32, WeightMode::F32) => shaders::histogram::counts::F32_WEIGHTS_F32,
574        (ScalarType::F32, WeightMode::F64) => shaders::histogram::counts::F32_WEIGHTS_F64,
575        (ScalarType::F64, WeightMode::Uniform) => shaders::histogram::counts::F64_UNIFORM,
576        (ScalarType::F64, WeightMode::F32) => shaders::histogram::counts::F64_WEIGHTS_F32,
577        (ScalarType::F64, WeightMode::F64) => shaders::histogram::counts::F64_WEIGHTS_F64,
578    };
579    let source = template.replace("{{WORKGROUP_SIZE}}", &workgroup_size.to_string());
580    device.create_shader_module(wgpu::ShaderModuleDescriptor {
581        label: Some("histogram-counts-shader"),
582        source: wgpu::ShaderSource::Wgsl(source.into()),
583    })
584}
585
586fn compile_convert_shader(device: &Arc<wgpu::Device>, workgroup_size: u32) -> wgpu::ShaderModule {
587    let source = shaders::histogram::convert::TEMPLATE
588        .replace("{{WORKGROUP_SIZE}}", &workgroup_size.to_string());
589    device.create_shader_module(wgpu::ShaderModuleDescriptor {
590        label: Some("histogram-convert-shader"),
591        source: wgpu::ShaderSource::Wgsl(source.into()),
592    })
593}