1use std::sync::Arc;
2use wgpu::util::DeviceExt;
3
4use crate::gpu::shaders;
5use crate::gpu::util::readback_f32;
6use crate::gpu::{tuning, ScalarType};
7
8#[derive(Clone)]
9pub struct HistogramGpuInputs {
10 pub samples: Arc<wgpu::Buffer>,
11 pub sample_count: u32,
12 pub scalar: ScalarType,
13 pub weights: HistogramGpuWeights,
14}
15
16#[derive(Clone)]
17pub enum HistogramGpuWeights {
18 Uniform { total_weight: f32 },
19 HostF32 { data: Vec<f32>, total_weight: f32 },
20 HostF64 { data: Vec<f64>, total_weight: f32 },
21 GpuF32 { buffer: Arc<wgpu::Buffer> },
22 GpuF64 { buffer: Arc<wgpu::Buffer> },
23}
24
25pub enum HistogramNormalizationMode {
26 Count,
27 Probability,
28 Pdf { bin_width: f32 },
29}
30
31pub struct HistogramGpuParams {
32 pub min_value: f32,
33 pub inv_bin_width: f32,
34 pub bin_count: u32,
35}
36
37pub struct HistogramGpuOutput {
38 pub values_buffer: Arc<wgpu::Buffer>,
39 pub total_weight: f32,
40}
41
42#[repr(C)]
43#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
44struct HistogramUniforms {
45 min_value: f32,
46 inv_bin_width: f32,
47 sample_count: u32,
48 bin_count: u32,
49 accumulate_total: u32,
50 _pad: [u32; 3],
51}
52
53#[repr(C)]
54#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
55struct ConvertUniforms {
56 bin_count: u32,
57 _pad: [u32; 3],
58 scale: f32,
59 _pad2: [u32; 3],
60}
61
62struct MaterializedWeights {
63 mode: WeightMode,
64 buffer: Option<Arc<wgpu::Buffer>>,
65 total_hint: Option<f32>,
66 accumulate_total: bool,
67}
68
69struct HistogramPassInputs<'a> {
70 device: &'a Arc<wgpu::Device>,
71 queue: &'a Arc<wgpu::Queue>,
72 samples: &'a Arc<wgpu::Buffer>,
73 sample_count: u32,
74 sample_scalar: ScalarType,
75 params: &'a HistogramGpuParams,
76 counts_buffer: &'a Arc<wgpu::Buffer>,
77 total_weight_buffer: &'a Arc<wgpu::Buffer>,
78}
79
80struct HistogramBindGroupInputs<'a> {
81 device: &'a Arc<wgpu::Device>,
82 samples: &'a Arc<wgpu::Buffer>,
83 counts_buffer: &'a Arc<wgpu::Buffer>,
84 total_weight_buffer: &'a Arc<wgpu::Buffer>,
85 params: &'a HistogramGpuParams,
86 sample_count: u32,
87 accumulate_total: bool,
88}
89
90#[derive(Clone, Copy, PartialEq, Eq)]
91enum WeightMode {
92 Uniform,
93 F32,
94 F64,
95}
96
97pub async fn histogram_values_buffer(
98 device: &Arc<wgpu::Device>,
99 queue: &Arc<wgpu::Queue>,
100 inputs: HistogramGpuInputs,
101 params: &HistogramGpuParams,
102 normalization: HistogramNormalizationMode,
103) -> Result<HistogramGpuOutput, String> {
104 if params.bin_count == 0 {
105 return Err("hist: bin count must be positive".to_string());
106 }
107
108 let bin_count_usize = params.bin_count as usize;
109 let zero_counts = vec![0u32; bin_count_usize];
110 let counts_buffer = Arc::new(
111 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
112 label: Some("histogram-counts"),
113 contents: bytemuck::cast_slice(&zero_counts),
114 usage: wgpu::BufferUsages::STORAGE
115 | wgpu::BufferUsages::COPY_DST
116 | wgpu::BufferUsages::COPY_SRC,
117 }),
118 );
119
120 let total_weight_buffer = Arc::new(device.create_buffer_init(
121 &wgpu::util::BufferInitDescriptor {
122 label: Some("histogram-total-weight"),
123 contents: bytemuck::cast_slice(&[0u32]),
124 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::MAP_READ,
125 },
126 ));
127
128 let HistogramGpuInputs {
129 samples,
130 sample_count,
131 scalar,
132 weights,
133 } = inputs;
134 let materialized = materialize_weights(device, weights)?;
135
136 if sample_count > 0 {
137 let pass_inputs = HistogramPassInputs {
138 device,
139 queue,
140 samples: &samples,
141 sample_count,
142 sample_scalar: scalar,
143 params,
144 counts_buffer: &counts_buffer,
145 total_weight_buffer: &total_weight_buffer,
146 };
147 run_histogram_pass(&pass_inputs, &materialized)?;
148 }
149
150 let total_weight = if let Some(hint) = materialized.total_hint {
151 hint
152 } else {
153 readback_f32(device, &total_weight_buffer)
154 .await
155 .map_err(|e| format!("hist: failed to read GPU weights total: {e}"))?
156 };
157
158 let normalization_scale = match normalization {
159 HistogramNormalizationMode::Count => 1.0,
160 HistogramNormalizationMode::Probability => {
161 if total_weight <= f32::EPSILON {
162 0.0
163 } else {
164 1.0 / total_weight
165 }
166 }
167 HistogramNormalizationMode::Pdf { bin_width } => {
168 if total_weight <= f32::EPSILON || bin_width <= f32::EPSILON {
169 0.0
170 } else {
171 1.0 / (total_weight * bin_width)
172 }
173 }
174 };
175
176 let values_buffer = Arc::new(device.create_buffer(&wgpu::BufferDescriptor {
177 label: Some("histogram-counts-f32"),
178 size: (bin_count_usize * std::mem::size_of::<f32>()) as u64,
179 usage: wgpu::BufferUsages::STORAGE
180 | wgpu::BufferUsages::COPY_DST
181 | wgpu::BufferUsages::COPY_SRC
182 | wgpu::BufferUsages::MAP_READ,
183 mapped_at_creation: false,
184 }));
185
186 run_convert_pass(
187 device,
188 queue,
189 params.bin_count,
190 normalization_scale,
191 &counts_buffer,
192 &values_buffer,
193 )?;
194
195 Ok(HistogramGpuOutput {
196 values_buffer,
197 total_weight,
198 })
199}
200
201fn materialize_weights(
202 device: &Arc<wgpu::Device>,
203 weights: HistogramGpuWeights,
204) -> Result<MaterializedWeights, String> {
205 match weights {
206 HistogramGpuWeights::Uniform { total_weight } => Ok(MaterializedWeights {
207 mode: WeightMode::Uniform,
208 buffer: None,
209 total_hint: Some(total_weight),
210 accumulate_total: false,
211 }),
212 HistogramGpuWeights::HostF32 { data, total_weight } => {
213 let buffer = Arc::new(
214 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
215 label: Some("histogram-weights-host-f32"),
216 contents: bytemuck::cast_slice(&data),
217 usage: wgpu::BufferUsages::STORAGE,
218 }),
219 );
220 Ok(MaterializedWeights {
221 mode: WeightMode::F32,
222 buffer: Some(buffer),
223 total_hint: Some(total_weight),
224 accumulate_total: false,
225 })
226 }
227 HistogramGpuWeights::HostF64 { data, total_weight } => {
228 let buffer = Arc::new(
229 device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
230 label: Some("histogram-weights-host-f64"),
231 contents: bytemuck::cast_slice(&data),
232 usage: wgpu::BufferUsages::STORAGE,
233 }),
234 );
235 Ok(MaterializedWeights {
236 mode: WeightMode::F64,
237 buffer: Some(buffer),
238 total_hint: Some(total_weight),
239 accumulate_total: false,
240 })
241 }
242 HistogramGpuWeights::GpuF32 { buffer } => Ok(MaterializedWeights {
243 mode: WeightMode::F32,
244 buffer: Some(buffer),
245 total_hint: None,
246 accumulate_total: true,
247 }),
248 HistogramGpuWeights::GpuF64 { buffer } => Ok(MaterializedWeights {
249 mode: WeightMode::F64,
250 buffer: Some(buffer),
251 total_hint: None,
252 accumulate_total: true,
253 }),
254 }
255}
256
257fn run_histogram_pass(
258 inputs: &HistogramPassInputs<'_>,
259 weights: &MaterializedWeights,
260) -> Result<(), String> {
261 let workgroup_size = tuning::effective_workgroup_size();
262 let shader = compile_counts_shader(
263 inputs.device,
264 workgroup_size,
265 inputs.sample_scalar,
266 weights.mode,
267 );
268
269 let bind_inputs = HistogramBindGroupInputs {
270 device: inputs.device,
271 samples: inputs.samples,
272 counts_buffer: inputs.counts_buffer,
273 total_weight_buffer: inputs.total_weight_buffer,
274 params: inputs.params,
275 sample_count: inputs.sample_count,
276 accumulate_total: weights.accumulate_total,
277 };
278
279 let (bind_group_layout, bind_group) = match weights.mode {
280 WeightMode::Uniform => build_uniform_bind_group(&bind_inputs),
281 WeightMode::F32 | WeightMode::F64 => build_weighted_bind_group(
282 &bind_inputs,
283 weights.buffer.as_ref().expect("weights buffer missing"),
284 ),
285 }?;
286
287 let pipeline_layout = inputs
288 .device
289 .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
290 label: Some("histogram-counts-pipeline-layout"),
291 bind_group_layouts: &[&bind_group_layout],
292 push_constant_ranges: &[],
293 });
294
295 let pipeline = inputs
296 .device
297 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
298 label: Some("histogram-counts-pipeline"),
299 layout: Some(&pipeline_layout),
300 module: &shader,
301 entry_point: "main",
302 });
303
304 let mut encoder = inputs
305 .device
306 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
307 label: Some("histogram-counts-encoder"),
308 });
309 {
310 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
311 label: Some("histogram-counts-pass"),
312 timestamp_writes: None,
313 });
314 pass.set_pipeline(&pipeline);
315 pass.set_bind_group(0, &bind_group, &[]);
316 let workgroups = inputs.sample_count.div_ceil(workgroup_size);
317 pass.dispatch_workgroups(workgroups, 1, 1);
318 }
319 inputs.queue.submit(Some(encoder.finish()));
320
321 Ok(())
322}
323
324fn build_uniform_bind_group(
325 inputs: &HistogramBindGroupInputs<'_>,
326) -> Result<(Arc<wgpu::BindGroupLayout>, Arc<wgpu::BindGroup>), String> {
327 let layout = Arc::new(inputs.device.create_bind_group_layout(
328 &wgpu::BindGroupLayoutDescriptor {
329 label: Some("hist-counts-layout-uniform"),
330 entries: &[
331 storage_read_entry(0),
332 storage_read_write_entry(1),
333 storage_read_write_entry(2),
334 uniform_entry(3),
335 ],
336 },
337 ));
338
339 let uniforms = HistogramUniforms {
340 min_value: inputs.params.min_value,
341 inv_bin_width: inputs.params.inv_bin_width,
342 sample_count: inputs.sample_count,
343 bin_count: inputs.params.bin_count,
344 accumulate_total: if inputs.accumulate_total { 1 } else { 0 },
345 _pad: [0; 3],
346 };
347 let uniform_buffer = inputs
348 .device
349 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
350 label: Some("hist-counts-uniforms"),
351 contents: bytemuck::bytes_of(&uniforms),
352 usage: wgpu::BufferUsages::UNIFORM,
353 });
354
355 let bind_group = Arc::new(inputs.device.create_bind_group(&wgpu::BindGroupDescriptor {
356 label: Some("hist-counts-bind-group-uniform"),
357 layout: &layout,
358 entries: &[
359 wgpu::BindGroupEntry {
360 binding: 0,
361 resource: inputs.samples.as_entire_binding(),
362 },
363 wgpu::BindGroupEntry {
364 binding: 1,
365 resource: inputs.counts_buffer.as_entire_binding(),
366 },
367 wgpu::BindGroupEntry {
368 binding: 2,
369 resource: inputs.total_weight_buffer.as_entire_binding(),
370 },
371 wgpu::BindGroupEntry {
372 binding: 3,
373 resource: uniform_buffer.as_entire_binding(),
374 },
375 ],
376 }));
377
378 Ok((layout, bind_group))
379}
380
381fn build_weighted_bind_group(
382 inputs: &HistogramBindGroupInputs<'_>,
383 weights_buffer: &Arc<wgpu::Buffer>,
384) -> Result<(Arc<wgpu::BindGroupLayout>, Arc<wgpu::BindGroup>), String> {
385 let layout = Arc::new(inputs.device.create_bind_group_layout(
386 &wgpu::BindGroupLayoutDescriptor {
387 label: Some("hist-counts-layout-weighted"),
388 entries: &[
389 storage_read_entry(0),
390 storage_read_entry(1),
391 storage_read_write_entry(2),
392 storage_read_write_entry(3),
393 uniform_entry(4),
394 ],
395 },
396 ));
397
398 let uniforms = HistogramUniforms {
399 min_value: inputs.params.min_value,
400 inv_bin_width: inputs.params.inv_bin_width,
401 sample_count: inputs.sample_count,
402 bin_count: inputs.params.bin_count,
403 accumulate_total: if inputs.accumulate_total { 1 } else { 0 },
404 _pad: [0; 3],
405 };
406 let uniform_buffer = inputs
407 .device
408 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
409 label: Some("hist-counts-weighted-uniforms"),
410 contents: bytemuck::bytes_of(&uniforms),
411 usage: wgpu::BufferUsages::UNIFORM,
412 });
413
414 let bind_group = Arc::new(inputs.device.create_bind_group(&wgpu::BindGroupDescriptor {
415 label: Some("hist-counts-bind-group-weighted"),
416 layout: &layout,
417 entries: &[
418 wgpu::BindGroupEntry {
419 binding: 0,
420 resource: inputs.samples.as_entire_binding(),
421 },
422 wgpu::BindGroupEntry {
423 binding: 1,
424 resource: weights_buffer.as_entire_binding(),
425 },
426 wgpu::BindGroupEntry {
427 binding: 2,
428 resource: inputs.counts_buffer.as_entire_binding(),
429 },
430 wgpu::BindGroupEntry {
431 binding: 3,
432 resource: inputs.total_weight_buffer.as_entire_binding(),
433 },
434 wgpu::BindGroupEntry {
435 binding: 4,
436 resource: uniform_buffer.as_entire_binding(),
437 },
438 ],
439 }));
440
441 Ok((layout, bind_group))
442}
443
444fn storage_read_entry(binding: u32) -> wgpu::BindGroupLayoutEntry {
445 wgpu::BindGroupLayoutEntry {
446 binding,
447 visibility: wgpu::ShaderStages::COMPUTE,
448 ty: wgpu::BindingType::Buffer {
449 ty: wgpu::BufferBindingType::Storage { read_only: true },
450 has_dynamic_offset: false,
451 min_binding_size: None,
452 },
453 count: None,
454 }
455}
456
457fn storage_read_write_entry(binding: u32) -> wgpu::BindGroupLayoutEntry {
458 wgpu::BindGroupLayoutEntry {
459 binding,
460 visibility: wgpu::ShaderStages::COMPUTE,
461 ty: wgpu::BindingType::Buffer {
462 ty: wgpu::BufferBindingType::Storage { read_only: false },
463 has_dynamic_offset: false,
464 min_binding_size: None,
465 },
466 count: None,
467 }
468}
469
470fn uniform_entry(binding: u32) -> wgpu::BindGroupLayoutEntry {
471 wgpu::BindGroupLayoutEntry {
472 binding,
473 visibility: wgpu::ShaderStages::COMPUTE,
474 ty: wgpu::BindingType::Buffer {
475 ty: wgpu::BufferBindingType::Uniform,
476 has_dynamic_offset: false,
477 min_binding_size: None,
478 },
479 count: None,
480 }
481}
482
483fn run_convert_pass(
484 device: &Arc<wgpu::Device>,
485 queue: &Arc<wgpu::Queue>,
486 bin_count: u32,
487 normalization_scale: f32,
488 counts_buffer: &Arc<wgpu::Buffer>,
489 values_buffer: &Arc<wgpu::Buffer>,
490) -> Result<(), String> {
491 let workgroup_size = tuning::effective_workgroup_size();
492 let shader = compile_convert_shader(device, workgroup_size);
493
494 let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
495 label: Some("histogram-convert-bind-layout"),
496 entries: &[
497 storage_read_entry(0),
498 storage_read_write_entry(1),
499 uniform_entry(2),
500 ],
501 });
502
503 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
504 label: Some("histogram-convert-pipeline-layout"),
505 bind_group_layouts: &[&bind_group_layout],
506 push_constant_ranges: &[],
507 });
508
509 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
510 label: Some("histogram-convert-pipeline"),
511 layout: Some(&pipeline_layout),
512 module: &shader,
513 entry_point: "main",
514 });
515
516 let uniforms = ConvertUniforms {
517 bin_count,
518 _pad: [0; 3],
519 scale: normalization_scale,
520 _pad2: [0; 3],
521 };
522 let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
523 label: Some("histogram-convert-uniforms"),
524 contents: bytemuck::bytes_of(&uniforms),
525 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
526 });
527
528 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
529 label: Some("histogram-convert-bind-group"),
530 layout: &bind_group_layout,
531 entries: &[
532 wgpu::BindGroupEntry {
533 binding: 0,
534 resource: counts_buffer.as_entire_binding(),
535 },
536 wgpu::BindGroupEntry {
537 binding: 1,
538 resource: values_buffer.as_entire_binding(),
539 },
540 wgpu::BindGroupEntry {
541 binding: 2,
542 resource: uniform_buffer.as_entire_binding(),
543 },
544 ],
545 });
546
547 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
548 label: Some("histogram-convert-encoder"),
549 });
550 {
551 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
552 label: Some("histogram-convert-pass"),
553 timestamp_writes: None,
554 });
555 pass.set_pipeline(&pipeline);
556 pass.set_bind_group(0, &bind_group, &[]);
557 let workgroups = bin_count.div_ceil(workgroup_size);
558 pass.dispatch_workgroups(workgroups, 1, 1);
559 }
560 queue.submit(Some(encoder.finish()));
561
562 Ok(())
563}
564
565fn compile_counts_shader(
566 device: &Arc<wgpu::Device>,
567 workgroup_size: u32,
568 scalar: ScalarType,
569 weight_mode: WeightMode,
570) -> wgpu::ShaderModule {
571 let template = match (scalar, weight_mode) {
572 (ScalarType::F32, WeightMode::Uniform) => shaders::histogram::counts::F32_UNIFORM,
573 (ScalarType::F32, WeightMode::F32) => shaders::histogram::counts::F32_WEIGHTS_F32,
574 (ScalarType::F32, WeightMode::F64) => shaders::histogram::counts::F32_WEIGHTS_F64,
575 (ScalarType::F64, WeightMode::Uniform) => shaders::histogram::counts::F64_UNIFORM,
576 (ScalarType::F64, WeightMode::F32) => shaders::histogram::counts::F64_WEIGHTS_F32,
577 (ScalarType::F64, WeightMode::F64) => shaders::histogram::counts::F64_WEIGHTS_F64,
578 };
579 let source = template.replace("{{WORKGROUP_SIZE}}", &workgroup_size.to_string());
580 device.create_shader_module(wgpu::ShaderModuleDescriptor {
581 label: Some("histogram-counts-shader"),
582 source: wgpu::ShaderSource::Wgsl(source.into()),
583 })
584}
585
586fn compile_convert_shader(device: &Arc<wgpu::Device>, workgroup_size: u32) -> wgpu::ShaderModule {
587 let source = shaders::histogram::convert::TEMPLATE
588 .replace("{{WORKGROUP_SIZE}}", &workgroup_size.to_string());
589 device.create_shader_module(wgpu::ShaderModuleDescriptor {
590 label: Some("histogram-convert-shader"),
591 source: wgpu::ShaderSource::Wgsl(source.into()),
592 })
593}