1use std::sync::mpsc;
2use std::time::Instant;
3
4use bytemuck::{Pod, Zeroable};
5use wgpu::util::DeviceExt;
6
7use crate::error::{Error, Result};
8use crate::external::OwnedHostTemporalInputs;
9use crate::frame::ScalarField;
10use crate::parameters::HostSupervisionParameters;
11
12#[derive(Clone, Debug)]
13pub struct GpuKernelResult {
14 pub adapter_name: String,
15 pub backend: String,
16 pub trust: Vec<f32>,
17 pub alpha: Vec<f32>,
18 pub intervention: Vec<f32>,
19 pub total_ms: f64,
20 pub dispatch_ms: f64,
21 pub readback_ms: f64,
22 pub workgroup_size: (u32, u32, u32),
23}
24
25#[derive(Clone, Debug)]
26struct ChunkExecutionResult {
27 trust: Vec<f32>,
28 alpha: Vec<f32>,
29 intervention: Vec<f32>,
30 total_ms: f64,
31 dispatch_ms: f64,
32 readback_ms: f64,
33}
34
35#[repr(C)]
36#[derive(Clone, Copy, Pod, Zeroable)]
37struct GpuParams {
38 size: [u32; 4],
39 alpha_range: [f32; 4],
40 residual_threshold: [f32; 4],
41 depth_threshold: [f32; 4],
42 normal_threshold: [f32; 4],
43 neighborhood_threshold: [f32; 4],
44 local_contrast_threshold: [f32; 4],
45 hazard_curve_threshold: [f32; 4],
46 weights_a: [f32; 4],
47 weights_b: [f32; 4],
48 history_instability_mix: [f32; 4],
49 structural_a: [f32; 4],
50 structural_b: [f32; 4],
51}
52
53#[repr(C)]
54#[derive(Clone, Copy, Pod, Zeroable)]
55struct GpuColor {
56 value: [f32; 4],
57}
58
59#[repr(C)]
60#[derive(Clone, Copy, Pod, Zeroable)]
61struct GpuDepthPair {
62 value: [f32; 2],
63}
64
65#[repr(C)]
66#[derive(Clone, Copy, Pod, Zeroable)]
67struct GpuNormalPair {
68 current: [f32; 4],
69 history: [f32; 4],
70}
71
72#[repr(C)]
73#[derive(Clone, Copy, Pod, Zeroable)]
74struct GpuVec4 {
75 value: [f32; 4],
76}
77
78const SHADER_SOURCE: &str = r#"
79struct Params {
80 size: vec4<u32>,
81 alpha_range: vec4<f32>,
82 residual_threshold: vec4<f32>,
83 depth_threshold: vec4<f32>,
84 normal_threshold: vec4<f32>,
85 neighborhood_threshold: vec4<f32>,
86 local_contrast_threshold: vec4<f32>,
87 hazard_curve_threshold: vec4<f32>,
88 weights_a: vec4<f32>,
89 weights_b: vec4<f32>,
90 history_instability_mix: vec4<f32>,
91 structural_a: vec4<f32>,
92 structural_b: vec4<f32>,
93}
94
95@group(0) @binding(0) var<storage, read> current_color: array<vec4<f32>>;
96@group(0) @binding(1) var<storage, read> reprojected_history: array<vec4<f32>>;
97@group(0) @binding(2) var<storage, read> depth_pairs: array<vec2<f32>>;
98
99struct NormalPair {
100 current: vec4<f32>,
101 history: vec4<f32>,
102}
103
104@group(0) @binding(3) var<storage, read> normal_pairs: array<NormalPair>;
105@group(0) @binding(4) var<uniform> params: Params;
106@group(0) @binding(5) var<storage, read_write> trust_out: array<f32>;
107@group(0) @binding(6) var<storage, read_write> alpha_out: array<f32>;
108@group(0) @binding(7) var<storage, read_write> intervention_out: array<f32>;
109
110fn index_of(x: u32, y: u32) -> u32 {
111 return y * params.size.x + x;
112}
113
114fn luma(color: vec3<f32>) -> f32 {
115 return dot(color, vec3<f32>(0.2126, 0.7152, 0.0722));
116}
117
118fn smoothstep_threshold(low: f32, high: f32, value: f32) -> f32 {
119 let edge_span = max(high - low, 1e-6);
120 let t = clamp((value - low) / edge_span, 0.0, 1.0);
121 return t * t * (3.0 - 2.0 * t);
122}
123
124fn color_at(x: i32, y: i32) -> vec3<f32> {
125 let width = i32(params.size.x);
126 let height = i32(params.size.y);
127 let clamped_x = clamp(x, 0, width - 1);
128 let clamped_y = clamp(y, 0, height - 1);
129 let idx = index_of(u32(clamped_x), u32(clamped_y));
130 return current_color[idx].xyz;
131}
132
133fn local_contrast_gate(x: i32, y: i32) -> f32 {
134 let center = luma(color_at(x, y));
135 var strongest = 0.0;
136 for (var oy: i32 = -1; oy <= 1; oy = oy + 1) {
137 for (var ox: i32 = -1; ox <= 1; ox = ox + 1) {
138 if (ox == 0 && oy == 0) {
139 continue;
140 }
141 strongest = max(strongest, abs(center - luma(color_at(x + ox, y + oy))));
142 }
143 }
144 return smoothstep_threshold(
145 params.local_contrast_threshold.x,
146 params.local_contrast_threshold.y,
147 strongest
148 );
149}
150
151fn neighborhood_gate(x: i32, y: i32, history_luma: f32) -> f32 {
152 var min_luma = 1e9;
153 var max_luma = -1e9;
154 for (var oy: i32 = -1; oy <= 1; oy = oy + 1) {
155 for (var ox: i32 = -1; ox <= 1; ox = ox + 1) {
156 let sample = luma(color_at(x + ox, y + oy));
157 min_luma = min(min_luma, sample);
158 max_luma = max(max_luma, sample);
159 }
160 }
161 var distance = 0.0;
162 if (history_luma < min_luma) {
163 distance = min_luma - history_luma;
164 } else if (history_luma > max_luma) {
165 distance = history_luma - max_luma;
166 }
167 return smoothstep_threshold(
168 params.neighborhood_threshold.x,
169 params.neighborhood_threshold.y,
170 distance
171 );
172}
173
174@compute @workgroup_size(1, 1, 1)
175fn main(
176 @builtin(global_invocation_id) gid: vec3<u32>,
177) {
178 if (gid.x >= params.size.x || gid.y >= params.size.y) {
179 return;
180 }
181 let idx = index_of(gid.x, gid.y);
182 let pixel_x = i32(gid.x);
183 let pixel_y = i32(gid.y);
184 let current = current_color[idx].xyz;
185 let history = reprojected_history[idx].xyz;
186 let residual = (abs(current.x - history.x) + abs(current.y - history.y) + abs(current.z - history.z)) / 3.0;
187 let residual_gate = smoothstep_threshold(params.residual_threshold.x, params.residual_threshold.y, residual);
188 let depth_pair = depth_pairs[idx];
189 let depth_gate = smoothstep_threshold(
190 params.depth_threshold.x,
191 params.depth_threshold.y,
192 abs(depth_pair.x - depth_pair.y)
193 );
194 let normal_pair = normal_pairs[idx];
195 let n0 = normalize(normal_pair.current.xyz);
196 let n1 = normalize(normal_pair.history.xyz);
197 let normal_gate = smoothstep_threshold(
198 params.normal_threshold.x,
199 params.normal_threshold.y,
200 1.0 - clamp(dot(n0, n1), -1.0, 1.0)
201 );
202 let history_luma = luma(history);
203 let neighbor_gate = neighborhood_gate(pixel_x, pixel_y, history_luma);
204 let thin_gate = local_contrast_gate(pixel_x, pixel_y);
205 let history_instability = clamp(
206 params.history_instability_mix.x * residual_gate +
207 params.history_instability_mix.y * neighbor_gate,
208 0.0,
209 1.0
210 );
211 let structural_disagreement = max(depth_gate, normal_gate);
212 var grammar_component = 0.0;
213 if (structural_disagreement >= params.structural_a.x) {
214 grammar_component = 0.88;
215 } else if (residual_gate >= params.structural_a.y && neighbor_gate >= params.structural_a.z) {
216 grammar_component = 0.62;
217 } else if (thin_gate >= params.structural_b.x && residual_gate >= params.structural_b.y) {
218 grammar_component = 0.32;
219 }
220 let hazard_raw =
221 params.weights_a.x * residual_gate +
222 params.weights_a.y * depth_gate +
223 params.weights_a.z * normal_gate +
224 params.weights_a.w * neighbor_gate +
225 params.weights_b.x * thin_gate +
226 params.weights_b.y * history_instability +
227 params.weights_b.z * grammar_component;
228 let hazard = smoothstep_threshold(
229 params.hazard_curve_threshold.x,
230 params.hazard_curve_threshold.y,
231 clamp(hazard_raw, 0.0, 1.0)
232 );
233 trust_out[idx] = 1.0 - hazard;
234 alpha_out[idx] = params.alpha_range.x + (params.alpha_range.y - params.alpha_range.x) * hazard;
235 intervention_out[idx] = hazard;
236}
237"#;
238
239pub fn try_execute_host_minimum_kernel(
240 inputs: &OwnedHostTemporalInputs,
241 parameters: HostSupervisionParameters,
242) -> Result<Option<GpuKernelResult>> {
243 pollster::block_on(try_execute_host_minimum_kernel_async(inputs, parameters))
244}
245
246async fn try_execute_host_minimum_kernel_async(
247 inputs: &OwnedHostTemporalInputs,
248 parameters: HostSupervisionParameters,
249) -> Result<Option<GpuKernelResult>> {
250 let instance = wgpu::Instance::default();
251 let adapter = match instance
252 .request_adapter(&wgpu::RequestAdapterOptions {
253 power_preference: wgpu::PowerPreference::HighPerformance,
254 compatible_surface: None,
255 force_fallback_adapter: false,
256 })
257 .await
258 {
259 Some(adapter) => adapter,
260 None => return Ok(None),
261 };
262
263 let adapter_info = adapter.get_info();
264 let adapter_limits = adapter.limits();
265 let (device, queue) = adapter
266 .request_device(
267 &wgpu::DeviceDescriptor {
268 label: Some("dsfb-computer-graphics-gpu-path"),
269 required_features: wgpu::Features::empty(),
270 required_limits: wgpu::Limits {
271 max_storage_buffer_binding_size: adapter_limits.max_storage_buffer_binding_size,
272 max_buffer_size: adapter_limits.max_buffer_size,
273 ..wgpu::Limits::default()
274 },
275 },
276 None,
277 )
278 .await
279 .map_err(|error| Error::Message(format!("failed to request wgpu device: {error}")))?;
280
281 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
282 label: Some("dsfb-host-minimum-wgsl"),
283 source: wgpu::ShaderSource::Wgsl(SHADER_SOURCE.into()),
284 });
285 let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
286 label: Some("dsfb-host-minimum-layout"),
287 entries: &[
288 storage_layout_entry(0, true),
289 storage_layout_entry(1, true),
290 storage_layout_entry(2, true),
291 storage_layout_entry(3, true),
292 uniform_layout_entry(4),
293 storage_layout_entry(5, false),
294 storage_layout_entry(6, false),
295 storage_layout_entry(7, false),
296 ],
297 });
298 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
299 label: Some("dsfb-host-minimum-pipeline-layout"),
300 bind_group_layouts: &[&bind_group_layout],
301 push_constant_ranges: &[],
302 });
303 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
304 label: Some("dsfb-host-minimum-pipeline"),
305 layout: Some(&pipeline_layout),
306 module: &shader,
307 entry_point: "main",
308 });
309 let max_binding_size = device.limits().max_storage_buffer_binding_size as usize;
310 let chunk = if requires_tiled_dispatch(inputs, max_binding_size) {
311 execute_host_minimum_tiled(
312 &device,
313 &queue,
314 &pipeline,
315 &bind_group_layout,
316 inputs,
317 parameters,
318 max_binding_size,
319 )?
320 } else {
321 execute_host_minimum_chunk(
322 &device,
323 &queue,
324 &pipeline,
325 &bind_group_layout,
326 inputs,
327 parameters,
328 )?
329 };
330
331 Ok(Some(GpuKernelResult {
332 adapter_name: adapter_info.name,
333 backend: format!("{:?}", adapter_info.backend),
334 trust: chunk.trust,
335 alpha: chunk.alpha,
336 intervention: chunk.intervention,
337 total_ms: chunk.total_ms,
338 dispatch_ms: chunk.dispatch_ms,
339 readback_ms: chunk.readback_ms,
340 workgroup_size: (1, 1, 1),
341 }))
342}
343
344fn requires_tiled_dispatch(inputs: &OwnedHostTemporalInputs, max_binding_size: usize) -> bool {
345 let pixel_count = inputs.width().saturating_mul(inputs.height());
346 let largest_binding_bytes = pixel_count.saturating_mul(std::mem::size_of::<GpuNormalPair>());
347 largest_binding_bytes > max_binding_size
348}
349
350fn execute_host_minimum_tiled(
351 device: &wgpu::Device,
352 queue: &wgpu::Queue,
353 pipeline: &wgpu::ComputePipeline,
354 bind_group_layout: &wgpu::BindGroupLayout,
355 inputs: &OwnedHostTemporalInputs,
356 parameters: HostSupervisionParameters,
357 max_binding_size: usize,
358) -> Result<ChunkExecutionResult> {
359 let width = inputs.width();
360 let height = inputs.height();
361 let bytes_per_row = width
362 .saturating_mul(std::mem::size_of::<GpuNormalPair>())
363 .max(1);
364 let max_rows_with_padding = max_binding_size / bytes_per_row;
365 let stripe_rows = max_rows_with_padding.saturating_sub(2).max(1);
366 if stripe_rows == 0 {
367 return Err(Error::Message(
368 "GPU tiled dispatch could not derive a non-zero stripe height".to_string(),
369 ));
370 }
371
372 let pixel_count = width * height;
373 let mut trust = Vec::with_capacity(pixel_count);
374 let mut alpha = Vec::with_capacity(pixel_count);
375 let mut intervention = Vec::with_capacity(pixel_count);
376 let mut total_ms = 0.0;
377 let mut dispatch_ms = 0.0;
378 let mut readback_ms = 0.0;
379 let mut output_row_start = 0usize;
380
381 while output_row_start < height {
382 let output_rows = stripe_rows.min(height - output_row_start);
383 let pad_top = usize::from(output_row_start > 0);
384 let pad_bottom = usize::from(output_row_start + output_rows < height);
385 let sub_start = output_row_start.saturating_sub(pad_top);
386 let sub_end = (output_row_start + output_rows + pad_bottom).min(height);
387 let sub_inputs = slice_inputs_rows(inputs, sub_start, sub_end);
388 let sub_result = execute_host_minimum_chunk(
389 device,
390 queue,
391 pipeline,
392 bind_group_layout,
393 &sub_inputs,
394 parameters,
395 )?;
396 let row_stride = width;
397 let kept_start = pad_top * row_stride;
398 let kept_len = output_rows * row_stride;
399 let kept_end = kept_start + kept_len;
400 trust.extend_from_slice(&sub_result.trust[kept_start..kept_end]);
401 alpha.extend_from_slice(&sub_result.alpha[kept_start..kept_end]);
402 intervention.extend_from_slice(&sub_result.intervention[kept_start..kept_end]);
403 total_ms += sub_result.total_ms;
404 dispatch_ms += sub_result.dispatch_ms;
405 readback_ms += sub_result.readback_ms;
406 output_row_start += output_rows;
407 }
408
409 Ok(ChunkExecutionResult {
410 trust,
411 alpha,
412 intervention,
413 total_ms,
414 dispatch_ms,
415 readback_ms,
416 })
417}
418
419fn execute_host_minimum_chunk(
420 device: &wgpu::Device,
421 queue: &wgpu::Queue,
422 pipeline: &wgpu::ComputePipeline,
423 bind_group_layout: &wgpu::BindGroupLayout,
424 inputs: &OwnedHostTemporalInputs,
425 parameters: HostSupervisionParameters,
426) -> Result<ChunkExecutionResult> {
427 let pixel_count = inputs.width() * inputs.height();
428 let color_current = pack_colors(&inputs.current_color);
429 let color_history = pack_colors(&inputs.reprojected_history);
430 let depth_pairs = pack_depth_pairs(&inputs.current_depth, &inputs.reprojected_depth);
431 let normal_pairs = pack_normal_pairs(&inputs.current_normals, &inputs.reprojected_normals);
432 let params = pack_params(inputs.width(), inputs.height(), parameters);
433
434 let current_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
435 label: Some("current-color"),
436 contents: bytemuck::cast_slice(&color_current),
437 usage: wgpu::BufferUsages::STORAGE,
438 });
439 let history_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
440 label: Some("reprojected-history"),
441 contents: bytemuck::cast_slice(&color_history),
442 usage: wgpu::BufferUsages::STORAGE,
443 });
444 let depth_pairs_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
445 label: Some("depth-pairs"),
446 contents: bytemuck::cast_slice(&depth_pairs),
447 usage: wgpu::BufferUsages::STORAGE,
448 });
449 let normal_pairs_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
450 label: Some("normal-pairs"),
451 contents: bytemuck::cast_slice(&normal_pairs),
452 usage: wgpu::BufferUsages::STORAGE,
453 });
454 let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
455 label: Some("params"),
456 contents: bytemuck::bytes_of(¶ms),
457 usage: wgpu::BufferUsages::UNIFORM,
458 });
459
460 let output_size = (pixel_count * std::mem::size_of::<f32>()) as u64;
461 let trust_buffer = device.create_buffer(&wgpu::BufferDescriptor {
462 label: Some("trust-output"),
463 size: output_size,
464 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
465 mapped_at_creation: false,
466 });
467 let alpha_buffer = device.create_buffer(&wgpu::BufferDescriptor {
468 label: Some("alpha-output"),
469 size: output_size,
470 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
471 mapped_at_creation: false,
472 });
473 let intervention_buffer = device.create_buffer(&wgpu::BufferDescriptor {
474 label: Some("intervention-output"),
475 size: output_size,
476 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
477 mapped_at_creation: false,
478 });
479
480 let trust_staging = create_staging_buffer(device, output_size, "trust-staging");
481 let alpha_staging = create_staging_buffer(device, output_size, "alpha-staging");
482 let intervention_staging = create_staging_buffer(device, output_size, "intervention-staging");
483 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
484 label: Some("dsfb-host-minimum-bind-group"),
485 layout: bind_group_layout,
486 entries: &[
487 storage_binding(0, ¤t_buffer),
488 storage_binding(1, &history_buffer),
489 storage_binding(2, &depth_pairs_buffer),
490 storage_binding(3, &normal_pairs_buffer),
491 uniform_binding(4, ¶ms_buffer),
492 storage_binding(5, &trust_buffer),
493 storage_binding(6, &alpha_buffer),
494 storage_binding(7, &intervention_buffer),
495 ],
496 });
497
498 let total_start = Instant::now();
499 let dispatch_start = Instant::now();
500 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
501 label: Some("dsfb-host-minimum-encoder"),
502 });
503 {
504 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
505 label: Some("dsfb-host-minimum-pass"),
506 timestamp_writes: None,
507 });
508 pass.set_pipeline(pipeline);
509 pass.set_bind_group(0, &bind_group, &[]);
510 let groups_x = inputs.width() as u32;
511 let groups_y = inputs.height() as u32;
512 pass.dispatch_workgroups(groups_x, groups_y, 1);
513 }
514 encoder.copy_buffer_to_buffer(&trust_buffer, 0, &trust_staging, 0, output_size);
515 encoder.copy_buffer_to_buffer(&alpha_buffer, 0, &alpha_staging, 0, output_size);
516 encoder.copy_buffer_to_buffer(
517 &intervention_buffer,
518 0,
519 &intervention_staging,
520 0,
521 output_size,
522 );
523 queue.submit(Some(encoder.finish()));
524 device.poll(wgpu::Maintain::Wait);
525 let dispatch_ms = dispatch_start.elapsed().as_secs_f64() * 1000.0;
526
527 let readback_start = Instant::now();
528 let trust = read_f32_buffer(device, &trust_staging, pixel_count)?;
529 let alpha = read_f32_buffer(device, &alpha_staging, pixel_count)?;
530 let intervention = read_f32_buffer(device, &intervention_staging, pixel_count)?;
531 let readback_ms = readback_start.elapsed().as_secs_f64() * 1000.0;
532
533 Ok(ChunkExecutionResult {
534 trust,
535 alpha,
536 intervention,
537 total_ms: total_start.elapsed().as_secs_f64() * 1000.0,
538 dispatch_ms,
539 readback_ms,
540 })
541}
542
543fn slice_inputs_rows(
544 inputs: &OwnedHostTemporalInputs,
545 row_start: usize,
546 row_end: usize,
547) -> OwnedHostTemporalInputs {
548 let height = row_end.saturating_sub(row_start);
549 let width = inputs.width();
550 OwnedHostTemporalInputs {
551 current_color: slice_frame_rows(&inputs.current_color, row_start, row_end),
552 reprojected_history: slice_frame_rows(&inputs.reprojected_history, row_start, row_end),
553 motion_vectors: slice_rows(&inputs.motion_vectors, width, row_start, row_end),
554 current_depth: slice_rows(&inputs.current_depth, width, row_start, row_end),
555 reprojected_depth: slice_rows(&inputs.reprojected_depth, width, row_start, row_end),
556 current_normals: slice_rows(&inputs.current_normals, width, row_start, row_end),
557 reprojected_normals: slice_rows(&inputs.reprojected_normals, width, row_start, row_end),
558 visibility_hint: inputs
559 .visibility_hint
560 .as_ref()
561 .map(|mask| slice_rows(mask, width, row_start, row_end)),
562 thin_hint: inputs
563 .thin_hint
564 .as_ref()
565 .map(|field| ScalarField::from_values(width, height, slice_rows(field.values(), width, row_start, row_end))),
566 }
567}
568
569fn slice_frame_rows(frame: &crate::frame::ImageFrame, row_start: usize, row_end: usize) -> crate::frame::ImageFrame {
570 let width = frame.width();
571 let height = row_end.saturating_sub(row_start);
572 let mut pixels = Vec::with_capacity(width * height);
573 for y in row_start..row_end {
574 for x in 0..width {
575 pixels.push(frame.get(x, y));
576 }
577 }
578 crate::frame::ImageFrame::from_pixels(width, height, pixels)
579}
580
581fn slice_rows<T: Copy>(values: &[T], width: usize, row_start: usize, row_end: usize) -> Vec<T> {
582 let start = row_start * width;
583 let end = row_end * width;
584 values[start..end].to_vec()
585}
586
587fn create_staging_buffer(device: &wgpu::Device, size: u64, label: &str) -> wgpu::Buffer {
588 device.create_buffer(&wgpu::BufferDescriptor {
589 label: Some(label),
590 size,
591 usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
592 mapped_at_creation: false,
593 })
594}
595
596fn read_f32_buffer(device: &wgpu::Device, buffer: &wgpu::Buffer, count: usize) -> Result<Vec<f32>> {
597 let slice = buffer.slice(..);
598 let (sender, receiver) = mpsc::channel();
599 slice.map_async(wgpu::MapMode::Read, move |result| {
600 let _ = sender.send(result);
601 });
602 device.poll(wgpu::Maintain::Wait);
603 receiver
604 .recv()
605 .map_err(|_| Error::Message("failed to receive GPU map_async status".to_string()))?
606 .map_err(|error| Error::Message(format!("failed to map GPU staging buffer: {error}")))?;
607 let mapped = slice.get_mapped_range();
608 let values = bytemuck::cast_slice::<u8, f32>(&mapped).to_vec();
609 drop(mapped);
610 buffer.unmap();
611 if values.len() != count {
612 return Err(Error::Message(format!(
613 "GPU readback size mismatch: expected {count} floats, got {}",
614 values.len()
615 )));
616 }
617 Ok(values)
618}
619
620fn pack_colors(frame: &crate::frame::ImageFrame) -> Vec<GpuColor> {
621 frame
622 .pixels()
623 .iter()
624 .map(|pixel| GpuColor {
625 value: [pixel.r, pixel.g, pixel.b, 1.0],
626 })
627 .collect()
628}
629
630fn pack_depth_pairs(current: &[f32], history: &[f32]) -> Vec<GpuDepthPair> {
631 current
632 .iter()
633 .zip(history.iter())
634 .map(|(current, history)| GpuDepthPair {
635 value: [*current, *history],
636 })
637 .collect()
638}
639
640fn pack_normal_pairs(
641 current: &[crate::scene::Normal3],
642 history: &[crate::scene::Normal3],
643) -> Vec<GpuNormalPair> {
644 current
645 .iter()
646 .zip(history.iter())
647 .map(|(current, history)| GpuNormalPair {
648 current: [current.x, current.y, current.z, 0.0],
649 history: [history.x, history.y, history.z, 0.0],
650 })
651 .collect()
652}
653
654fn pack_params(width: usize, height: usize, parameters: HostSupervisionParameters) -> GpuParams {
655 GpuParams {
656 size: [width as u32, height as u32, 0, 0],
657 alpha_range: [
658 parameters.alpha_range.min,
659 parameters.alpha_range.max,
660 0.0,
661 0.0,
662 ],
663 residual_threshold: [
664 parameters.thresholds.residual.low,
665 parameters.thresholds.residual.high,
666 0.0,
667 0.0,
668 ],
669 depth_threshold: [
670 parameters.thresholds.depth.low,
671 parameters.thresholds.depth.high,
672 0.0,
673 0.0,
674 ],
675 normal_threshold: [
676 parameters.thresholds.normal.low,
677 parameters.thresholds.normal.high,
678 0.0,
679 0.0,
680 ],
681 neighborhood_threshold: [
682 parameters.thresholds.neighborhood.low,
683 parameters.thresholds.neighborhood.high,
684 0.0,
685 0.0,
686 ],
687 local_contrast_threshold: [
688 parameters.thresholds.local_contrast.low,
689 parameters.thresholds.local_contrast.high,
690 0.0,
691 0.0,
692 ],
693 hazard_curve_threshold: [
694 parameters.thresholds.hazard_curve.low,
695 parameters.thresholds.hazard_curve.high,
696 0.0,
697 0.0,
698 ],
699 weights_a: [
700 parameters.weights.residual,
701 parameters.weights.depth,
702 parameters.weights.normal,
703 parameters.weights.neighborhood,
704 ],
705 weights_b: [
706 parameters.weights.thin,
707 parameters.weights.history_instability,
708 parameters.weights.grammar,
709 0.0,
710 ],
711 history_instability_mix: [
712 parameters.thresholds.history_instability_residual_mix,
713 parameters.thresholds.history_instability_neighborhood_mix,
714 0.0,
715 0.0,
716 ],
717 structural_a: [
718 parameters.structural.disocclusion_like,
719 parameters.structural.unstable_residual,
720 parameters.structural.unstable_neighborhood,
721 0.0,
722 ],
723 structural_b: [
724 parameters.structural.thin_edge,
725 parameters.structural.thin_residual,
726 0.0,
727 0.0,
728 ],
729 }
730}
731
732fn storage_layout_entry(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
733 wgpu::BindGroupLayoutEntry {
734 binding,
735 visibility: wgpu::ShaderStages::COMPUTE,
736 ty: wgpu::BindingType::Buffer {
737 ty: wgpu::BufferBindingType::Storage { read_only },
738 has_dynamic_offset: false,
739 min_binding_size: None,
740 },
741 count: None,
742 }
743}
744
745fn uniform_layout_entry(binding: u32) -> wgpu::BindGroupLayoutEntry {
746 wgpu::BindGroupLayoutEntry {
747 binding,
748 visibility: wgpu::ShaderStages::COMPUTE,
749 ty: wgpu::BindingType::Buffer {
750 ty: wgpu::BufferBindingType::Uniform,
751 has_dynamic_offset: false,
752 min_binding_size: None,
753 },
754 count: None,
755 }
756}
757
758fn storage_binding<'a>(binding: u32, buffer: &'a wgpu::Buffer) -> wgpu::BindGroupEntry<'a> {
759 wgpu::BindGroupEntry {
760 binding,
761 resource: buffer.as_entire_binding(),
762 }
763}
764
765fn uniform_binding<'a>(binding: u32, buffer: &'a wgpu::Buffer) -> wgpu::BindGroupEntry<'a> {
766 wgpu::BindGroupEntry {
767 binding,
768 resource: buffer.as_entire_binding(),
769 }
770}