Skip to main content

dsfb_computer_graphics/
gpu.rs

1use std::sync::mpsc;
2use std::time::Instant;
3
4use bytemuck::{Pod, Zeroable};
5use wgpu::util::DeviceExt;
6
7use crate::error::{Error, Result};
8use crate::external::OwnedHostTemporalInputs;
9use crate::frame::ScalarField;
10use crate::parameters::HostSupervisionParameters;
11
12#[derive(Clone, Debug)]
13pub struct GpuKernelResult {
14    pub adapter_name: String,
15    pub backend: String,
16    pub trust: Vec<f32>,
17    pub alpha: Vec<f32>,
18    pub intervention: Vec<f32>,
19    pub total_ms: f64,
20    pub dispatch_ms: f64,
21    pub readback_ms: f64,
22    pub workgroup_size: (u32, u32, u32),
23}
24
25#[derive(Clone, Debug)]
26struct ChunkExecutionResult {
27    trust: Vec<f32>,
28    alpha: Vec<f32>,
29    intervention: Vec<f32>,
30    total_ms: f64,
31    dispatch_ms: f64,
32    readback_ms: f64,
33}
34
35#[repr(C)]
36#[derive(Clone, Copy, Pod, Zeroable)]
37struct GpuParams {
38    size: [u32; 4],
39    alpha_range: [f32; 4],
40    residual_threshold: [f32; 4],
41    depth_threshold: [f32; 4],
42    normal_threshold: [f32; 4],
43    neighborhood_threshold: [f32; 4],
44    local_contrast_threshold: [f32; 4],
45    hazard_curve_threshold: [f32; 4],
46    weights_a: [f32; 4],
47    weights_b: [f32; 4],
48    history_instability_mix: [f32; 4],
49    structural_a: [f32; 4],
50    structural_b: [f32; 4],
51}
52
53#[repr(C)]
54#[derive(Clone, Copy, Pod, Zeroable)]
55struct GpuColor {
56    value: [f32; 4],
57}
58
59#[repr(C)]
60#[derive(Clone, Copy, Pod, Zeroable)]
61struct GpuDepthPair {
62    value: [f32; 2],
63}
64
65#[repr(C)]
66#[derive(Clone, Copy, Pod, Zeroable)]
67struct GpuNormalPair {
68    current: [f32; 4],
69    history: [f32; 4],
70}
71
72#[repr(C)]
73#[derive(Clone, Copy, Pod, Zeroable)]
74struct GpuVec4 {
75    value: [f32; 4],
76}
77
78const SHADER_SOURCE: &str = r#"
79struct Params {
80    size: vec4<u32>,
81    alpha_range: vec4<f32>,
82    residual_threshold: vec4<f32>,
83    depth_threshold: vec4<f32>,
84    normal_threshold: vec4<f32>,
85    neighborhood_threshold: vec4<f32>,
86    local_contrast_threshold: vec4<f32>,
87    hazard_curve_threshold: vec4<f32>,
88    weights_a: vec4<f32>,
89    weights_b: vec4<f32>,
90    history_instability_mix: vec4<f32>,
91    structural_a: vec4<f32>,
92    structural_b: vec4<f32>,
93}
94
95@group(0) @binding(0) var<storage, read> current_color: array<vec4<f32>>;
96@group(0) @binding(1) var<storage, read> reprojected_history: array<vec4<f32>>;
97@group(0) @binding(2) var<storage, read> depth_pairs: array<vec2<f32>>;
98
99struct NormalPair {
100    current: vec4<f32>,
101    history: vec4<f32>,
102}
103
104@group(0) @binding(3) var<storage, read> normal_pairs: array<NormalPair>;
105@group(0) @binding(4) var<uniform> params: Params;
106@group(0) @binding(5) var<storage, read_write> trust_out: array<f32>;
107@group(0) @binding(6) var<storage, read_write> alpha_out: array<f32>;
108@group(0) @binding(7) var<storage, read_write> intervention_out: array<f32>;
109
110fn index_of(x: u32, y: u32) -> u32 {
111    return y * params.size.x + x;
112}
113
114fn luma(color: vec3<f32>) -> f32 {
115    return dot(color, vec3<f32>(0.2126, 0.7152, 0.0722));
116}
117
118fn smoothstep_threshold(low: f32, high: f32, value: f32) -> f32 {
119    let edge_span = max(high - low, 1e-6);
120    let t = clamp((value - low) / edge_span, 0.0, 1.0);
121    return t * t * (3.0 - 2.0 * t);
122}
123
124fn color_at(x: i32, y: i32) -> vec3<f32> {
125    let width = i32(params.size.x);
126    let height = i32(params.size.y);
127    let clamped_x = clamp(x, 0, width - 1);
128    let clamped_y = clamp(y, 0, height - 1);
129    let idx = index_of(u32(clamped_x), u32(clamped_y));
130    return current_color[idx].xyz;
131}
132
133fn local_contrast_gate(x: i32, y: i32) -> f32 {
134    let center = luma(color_at(x, y));
135    var strongest = 0.0;
136    for (var oy: i32 = -1; oy <= 1; oy = oy + 1) {
137        for (var ox: i32 = -1; ox <= 1; ox = ox + 1) {
138            if (ox == 0 && oy == 0) {
139                continue;
140            }
141            strongest = max(strongest, abs(center - luma(color_at(x + ox, y + oy))));
142        }
143    }
144    return smoothstep_threshold(
145        params.local_contrast_threshold.x,
146        params.local_contrast_threshold.y,
147        strongest
148    );
149}
150
151fn neighborhood_gate(x: i32, y: i32, history_luma: f32) -> f32 {
152    var min_luma = 1e9;
153    var max_luma = -1e9;
154    for (var oy: i32 = -1; oy <= 1; oy = oy + 1) {
155        for (var ox: i32 = -1; ox <= 1; ox = ox + 1) {
156            let sample = luma(color_at(x + ox, y + oy));
157            min_luma = min(min_luma, sample);
158            max_luma = max(max_luma, sample);
159        }
160    }
161    var distance = 0.0;
162    if (history_luma < min_luma) {
163        distance = min_luma - history_luma;
164    } else if (history_luma > max_luma) {
165        distance = history_luma - max_luma;
166    }
167    return smoothstep_threshold(
168        params.neighborhood_threshold.x,
169        params.neighborhood_threshold.y,
170        distance
171    );
172}
173
174@compute @workgroup_size(1, 1, 1)
175fn main(
176    @builtin(global_invocation_id) gid: vec3<u32>,
177) {
178    if (gid.x >= params.size.x || gid.y >= params.size.y) {
179        return;
180    }
181    let idx = index_of(gid.x, gid.y);
182    let pixel_x = i32(gid.x);
183    let pixel_y = i32(gid.y);
184    let current = current_color[idx].xyz;
185    let history = reprojected_history[idx].xyz;
186    let residual = (abs(current.x - history.x) + abs(current.y - history.y) + abs(current.z - history.z)) / 3.0;
187    let residual_gate = smoothstep_threshold(params.residual_threshold.x, params.residual_threshold.y, residual);
188    let depth_pair = depth_pairs[idx];
189    let depth_gate = smoothstep_threshold(
190        params.depth_threshold.x,
191        params.depth_threshold.y,
192        abs(depth_pair.x - depth_pair.y)
193    );
194    let normal_pair = normal_pairs[idx];
195    let n0 = normalize(normal_pair.current.xyz);
196    let n1 = normalize(normal_pair.history.xyz);
197    let normal_gate = smoothstep_threshold(
198        params.normal_threshold.x,
199        params.normal_threshold.y,
200        1.0 - clamp(dot(n0, n1), -1.0, 1.0)
201    );
202    let history_luma = luma(history);
203    let neighbor_gate = neighborhood_gate(pixel_x, pixel_y, history_luma);
204    let thin_gate = local_contrast_gate(pixel_x, pixel_y);
205    let history_instability = clamp(
206        params.history_instability_mix.x * residual_gate +
207        params.history_instability_mix.y * neighbor_gate,
208        0.0,
209        1.0
210    );
211    let structural_disagreement = max(depth_gate, normal_gate);
212    var grammar_component = 0.0;
213    if (structural_disagreement >= params.structural_a.x) {
214        grammar_component = 0.88;
215    } else if (residual_gate >= params.structural_a.y && neighbor_gate >= params.structural_a.z) {
216        grammar_component = 0.62;
217    } else if (thin_gate >= params.structural_b.x && residual_gate >= params.structural_b.y) {
218        grammar_component = 0.32;
219    }
220    let hazard_raw =
221        params.weights_a.x * residual_gate +
222        params.weights_a.y * depth_gate +
223        params.weights_a.z * normal_gate +
224        params.weights_a.w * neighbor_gate +
225        params.weights_b.x * thin_gate +
226        params.weights_b.y * history_instability +
227        params.weights_b.z * grammar_component;
228    let hazard = smoothstep_threshold(
229        params.hazard_curve_threshold.x,
230        params.hazard_curve_threshold.y,
231        clamp(hazard_raw, 0.0, 1.0)
232    );
233    trust_out[idx] = 1.0 - hazard;
234    alpha_out[idx] = params.alpha_range.x + (params.alpha_range.y - params.alpha_range.x) * hazard;
235    intervention_out[idx] = hazard;
236}
237"#;
238
239pub fn try_execute_host_minimum_kernel(
240    inputs: &OwnedHostTemporalInputs,
241    parameters: HostSupervisionParameters,
242) -> Result<Option<GpuKernelResult>> {
243    pollster::block_on(try_execute_host_minimum_kernel_async(inputs, parameters))
244}
245
246async fn try_execute_host_minimum_kernel_async(
247    inputs: &OwnedHostTemporalInputs,
248    parameters: HostSupervisionParameters,
249) -> Result<Option<GpuKernelResult>> {
250    let instance = wgpu::Instance::default();
251    let adapter = match instance
252        .request_adapter(&wgpu::RequestAdapterOptions {
253            power_preference: wgpu::PowerPreference::HighPerformance,
254            compatible_surface: None,
255            force_fallback_adapter: false,
256        })
257        .await
258    {
259        Some(adapter) => adapter,
260        None => return Ok(None),
261    };
262
263    let adapter_info = adapter.get_info();
264    let adapter_limits = adapter.limits();
265    let (device, queue) = adapter
266        .request_device(
267            &wgpu::DeviceDescriptor {
268                label: Some("dsfb-computer-graphics-gpu-path"),
269                required_features: wgpu::Features::empty(),
270                required_limits: wgpu::Limits {
271                    max_storage_buffer_binding_size: adapter_limits.max_storage_buffer_binding_size,
272                    max_buffer_size: adapter_limits.max_buffer_size,
273                    ..wgpu::Limits::default()
274                },
275            },
276            None,
277        )
278        .await
279        .map_err(|error| Error::Message(format!("failed to request wgpu device: {error}")))?;
280
281    let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
282        label: Some("dsfb-host-minimum-wgsl"),
283        source: wgpu::ShaderSource::Wgsl(SHADER_SOURCE.into()),
284    });
285    let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
286        label: Some("dsfb-host-minimum-layout"),
287        entries: &[
288            storage_layout_entry(0, true),
289            storage_layout_entry(1, true),
290            storage_layout_entry(2, true),
291            storage_layout_entry(3, true),
292            uniform_layout_entry(4),
293            storage_layout_entry(5, false),
294            storage_layout_entry(6, false),
295            storage_layout_entry(7, false),
296        ],
297    });
298    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
299        label: Some("dsfb-host-minimum-pipeline-layout"),
300        bind_group_layouts: &[&bind_group_layout],
301        push_constant_ranges: &[],
302    });
303    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
304        label: Some("dsfb-host-minimum-pipeline"),
305        layout: Some(&pipeline_layout),
306        module: &shader,
307        entry_point: "main",
308    });
309    let max_binding_size = device.limits().max_storage_buffer_binding_size as usize;
310    let chunk = if requires_tiled_dispatch(inputs, max_binding_size) {
311        execute_host_minimum_tiled(
312            &device,
313            &queue,
314            &pipeline,
315            &bind_group_layout,
316            inputs,
317            parameters,
318            max_binding_size,
319        )?
320    } else {
321        execute_host_minimum_chunk(
322            &device,
323            &queue,
324            &pipeline,
325            &bind_group_layout,
326            inputs,
327            parameters,
328        )?
329    };
330
331    Ok(Some(GpuKernelResult {
332        adapter_name: adapter_info.name,
333        backend: format!("{:?}", adapter_info.backend),
334        trust: chunk.trust,
335        alpha: chunk.alpha,
336        intervention: chunk.intervention,
337        total_ms: chunk.total_ms,
338        dispatch_ms: chunk.dispatch_ms,
339        readback_ms: chunk.readback_ms,
340        workgroup_size: (1, 1, 1),
341    }))
342}
343
344fn requires_tiled_dispatch(inputs: &OwnedHostTemporalInputs, max_binding_size: usize) -> bool {
345    let pixel_count = inputs.width().saturating_mul(inputs.height());
346    let largest_binding_bytes = pixel_count.saturating_mul(std::mem::size_of::<GpuNormalPair>());
347    largest_binding_bytes > max_binding_size
348}
349
350fn execute_host_minimum_tiled(
351    device: &wgpu::Device,
352    queue: &wgpu::Queue,
353    pipeline: &wgpu::ComputePipeline,
354    bind_group_layout: &wgpu::BindGroupLayout,
355    inputs: &OwnedHostTemporalInputs,
356    parameters: HostSupervisionParameters,
357    max_binding_size: usize,
358) -> Result<ChunkExecutionResult> {
359    let width = inputs.width();
360    let height = inputs.height();
361    let bytes_per_row = width
362        .saturating_mul(std::mem::size_of::<GpuNormalPair>())
363        .max(1);
364    let max_rows_with_padding = max_binding_size / bytes_per_row;
365    let stripe_rows = max_rows_with_padding.saturating_sub(2).max(1);
366    if stripe_rows == 0 {
367        return Err(Error::Message(
368            "GPU tiled dispatch could not derive a non-zero stripe height".to_string(),
369        ));
370    }
371
372    let pixel_count = width * height;
373    let mut trust = Vec::with_capacity(pixel_count);
374    let mut alpha = Vec::with_capacity(pixel_count);
375    let mut intervention = Vec::with_capacity(pixel_count);
376    let mut total_ms = 0.0;
377    let mut dispatch_ms = 0.0;
378    let mut readback_ms = 0.0;
379    let mut output_row_start = 0usize;
380
381    while output_row_start < height {
382        let output_rows = stripe_rows.min(height - output_row_start);
383        let pad_top = usize::from(output_row_start > 0);
384        let pad_bottom = usize::from(output_row_start + output_rows < height);
385        let sub_start = output_row_start.saturating_sub(pad_top);
386        let sub_end = (output_row_start + output_rows + pad_bottom).min(height);
387        let sub_inputs = slice_inputs_rows(inputs, sub_start, sub_end);
388        let sub_result = execute_host_minimum_chunk(
389            device,
390            queue,
391            pipeline,
392            bind_group_layout,
393            &sub_inputs,
394            parameters,
395        )?;
396        let row_stride = width;
397        let kept_start = pad_top * row_stride;
398        let kept_len = output_rows * row_stride;
399        let kept_end = kept_start + kept_len;
400        trust.extend_from_slice(&sub_result.trust[kept_start..kept_end]);
401        alpha.extend_from_slice(&sub_result.alpha[kept_start..kept_end]);
402        intervention.extend_from_slice(&sub_result.intervention[kept_start..kept_end]);
403        total_ms += sub_result.total_ms;
404        dispatch_ms += sub_result.dispatch_ms;
405        readback_ms += sub_result.readback_ms;
406        output_row_start += output_rows;
407    }
408
409    Ok(ChunkExecutionResult {
410        trust,
411        alpha,
412        intervention,
413        total_ms,
414        dispatch_ms,
415        readback_ms,
416    })
417}
418
419fn execute_host_minimum_chunk(
420    device: &wgpu::Device,
421    queue: &wgpu::Queue,
422    pipeline: &wgpu::ComputePipeline,
423    bind_group_layout: &wgpu::BindGroupLayout,
424    inputs: &OwnedHostTemporalInputs,
425    parameters: HostSupervisionParameters,
426) -> Result<ChunkExecutionResult> {
427    let pixel_count = inputs.width() * inputs.height();
428    let color_current = pack_colors(&inputs.current_color);
429    let color_history = pack_colors(&inputs.reprojected_history);
430    let depth_pairs = pack_depth_pairs(&inputs.current_depth, &inputs.reprojected_depth);
431    let normal_pairs = pack_normal_pairs(&inputs.current_normals, &inputs.reprojected_normals);
432    let params = pack_params(inputs.width(), inputs.height(), parameters);
433
434    let current_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
435        label: Some("current-color"),
436        contents: bytemuck::cast_slice(&color_current),
437        usage: wgpu::BufferUsages::STORAGE,
438    });
439    let history_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
440        label: Some("reprojected-history"),
441        contents: bytemuck::cast_slice(&color_history),
442        usage: wgpu::BufferUsages::STORAGE,
443    });
444    let depth_pairs_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
445        label: Some("depth-pairs"),
446        contents: bytemuck::cast_slice(&depth_pairs),
447        usage: wgpu::BufferUsages::STORAGE,
448    });
449    let normal_pairs_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
450        label: Some("normal-pairs"),
451        contents: bytemuck::cast_slice(&normal_pairs),
452        usage: wgpu::BufferUsages::STORAGE,
453    });
454    let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
455        label: Some("params"),
456        contents: bytemuck::bytes_of(&params),
457        usage: wgpu::BufferUsages::UNIFORM,
458    });
459
460    let output_size = (pixel_count * std::mem::size_of::<f32>()) as u64;
461    let trust_buffer = device.create_buffer(&wgpu::BufferDescriptor {
462        label: Some("trust-output"),
463        size: output_size,
464        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
465        mapped_at_creation: false,
466    });
467    let alpha_buffer = device.create_buffer(&wgpu::BufferDescriptor {
468        label: Some("alpha-output"),
469        size: output_size,
470        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
471        mapped_at_creation: false,
472    });
473    let intervention_buffer = device.create_buffer(&wgpu::BufferDescriptor {
474        label: Some("intervention-output"),
475        size: output_size,
476        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
477        mapped_at_creation: false,
478    });
479
480    let trust_staging = create_staging_buffer(device, output_size, "trust-staging");
481    let alpha_staging = create_staging_buffer(device, output_size, "alpha-staging");
482    let intervention_staging = create_staging_buffer(device, output_size, "intervention-staging");
483    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
484        label: Some("dsfb-host-minimum-bind-group"),
485        layout: bind_group_layout,
486        entries: &[
487            storage_binding(0, &current_buffer),
488            storage_binding(1, &history_buffer),
489            storage_binding(2, &depth_pairs_buffer),
490            storage_binding(3, &normal_pairs_buffer),
491            uniform_binding(4, &params_buffer),
492            storage_binding(5, &trust_buffer),
493            storage_binding(6, &alpha_buffer),
494            storage_binding(7, &intervention_buffer),
495        ],
496    });
497
498    let total_start = Instant::now();
499    let dispatch_start = Instant::now();
500    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
501        label: Some("dsfb-host-minimum-encoder"),
502    });
503    {
504        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
505            label: Some("dsfb-host-minimum-pass"),
506            timestamp_writes: None,
507        });
508        pass.set_pipeline(pipeline);
509        pass.set_bind_group(0, &bind_group, &[]);
510        let groups_x = inputs.width() as u32;
511        let groups_y = inputs.height() as u32;
512        pass.dispatch_workgroups(groups_x, groups_y, 1);
513    }
514    encoder.copy_buffer_to_buffer(&trust_buffer, 0, &trust_staging, 0, output_size);
515    encoder.copy_buffer_to_buffer(&alpha_buffer, 0, &alpha_staging, 0, output_size);
516    encoder.copy_buffer_to_buffer(
517        &intervention_buffer,
518        0,
519        &intervention_staging,
520        0,
521        output_size,
522    );
523    queue.submit(Some(encoder.finish()));
524    device.poll(wgpu::Maintain::Wait);
525    let dispatch_ms = dispatch_start.elapsed().as_secs_f64() * 1000.0;
526
527    let readback_start = Instant::now();
528    let trust = read_f32_buffer(device, &trust_staging, pixel_count)?;
529    let alpha = read_f32_buffer(device, &alpha_staging, pixel_count)?;
530    let intervention = read_f32_buffer(device, &intervention_staging, pixel_count)?;
531    let readback_ms = readback_start.elapsed().as_secs_f64() * 1000.0;
532
533    Ok(ChunkExecutionResult {
534        trust,
535        alpha,
536        intervention,
537        total_ms: total_start.elapsed().as_secs_f64() * 1000.0,
538        dispatch_ms,
539        readback_ms,
540    })
541}
542
543fn slice_inputs_rows(
544    inputs: &OwnedHostTemporalInputs,
545    row_start: usize,
546    row_end: usize,
547) -> OwnedHostTemporalInputs {
548    let height = row_end.saturating_sub(row_start);
549    let width = inputs.width();
550    OwnedHostTemporalInputs {
551        current_color: slice_frame_rows(&inputs.current_color, row_start, row_end),
552        reprojected_history: slice_frame_rows(&inputs.reprojected_history, row_start, row_end),
553        motion_vectors: slice_rows(&inputs.motion_vectors, width, row_start, row_end),
554        current_depth: slice_rows(&inputs.current_depth, width, row_start, row_end),
555        reprojected_depth: slice_rows(&inputs.reprojected_depth, width, row_start, row_end),
556        current_normals: slice_rows(&inputs.current_normals, width, row_start, row_end),
557        reprojected_normals: slice_rows(&inputs.reprojected_normals, width, row_start, row_end),
558        visibility_hint: inputs
559            .visibility_hint
560            .as_ref()
561            .map(|mask| slice_rows(mask, width, row_start, row_end)),
562        thin_hint: inputs
563            .thin_hint
564            .as_ref()
565            .map(|field| ScalarField::from_values(width, height, slice_rows(field.values(), width, row_start, row_end))),
566    }
567}
568
569fn slice_frame_rows(frame: &crate::frame::ImageFrame, row_start: usize, row_end: usize) -> crate::frame::ImageFrame {
570    let width = frame.width();
571    let height = row_end.saturating_sub(row_start);
572    let mut pixels = Vec::with_capacity(width * height);
573    for y in row_start..row_end {
574        for x in 0..width {
575            pixels.push(frame.get(x, y));
576        }
577    }
578    crate::frame::ImageFrame::from_pixels(width, height, pixels)
579}
580
581fn slice_rows<T: Copy>(values: &[T], width: usize, row_start: usize, row_end: usize) -> Vec<T> {
582    let start = row_start * width;
583    let end = row_end * width;
584    values[start..end].to_vec()
585}
586
587fn create_staging_buffer(device: &wgpu::Device, size: u64, label: &str) -> wgpu::Buffer {
588    device.create_buffer(&wgpu::BufferDescriptor {
589        label: Some(label),
590        size,
591        usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
592        mapped_at_creation: false,
593    })
594}
595
596fn read_f32_buffer(device: &wgpu::Device, buffer: &wgpu::Buffer, count: usize) -> Result<Vec<f32>> {
597    let slice = buffer.slice(..);
598    let (sender, receiver) = mpsc::channel();
599    slice.map_async(wgpu::MapMode::Read, move |result| {
600        let _ = sender.send(result);
601    });
602    device.poll(wgpu::Maintain::Wait);
603    receiver
604        .recv()
605        .map_err(|_| Error::Message("failed to receive GPU map_async status".to_string()))?
606        .map_err(|error| Error::Message(format!("failed to map GPU staging buffer: {error}")))?;
607    let mapped = slice.get_mapped_range();
608    let values = bytemuck::cast_slice::<u8, f32>(&mapped).to_vec();
609    drop(mapped);
610    buffer.unmap();
611    if values.len() != count {
612        return Err(Error::Message(format!(
613            "GPU readback size mismatch: expected {count} floats, got {}",
614            values.len()
615        )));
616    }
617    Ok(values)
618}
619
620fn pack_colors(frame: &crate::frame::ImageFrame) -> Vec<GpuColor> {
621    frame
622        .pixels()
623        .iter()
624        .map(|pixel| GpuColor {
625            value: [pixel.r, pixel.g, pixel.b, 1.0],
626        })
627        .collect()
628}
629
630fn pack_depth_pairs(current: &[f32], history: &[f32]) -> Vec<GpuDepthPair> {
631    current
632        .iter()
633        .zip(history.iter())
634        .map(|(current, history)| GpuDepthPair {
635            value: [*current, *history],
636        })
637        .collect()
638}
639
640fn pack_normal_pairs(
641    current: &[crate::scene::Normal3],
642    history: &[crate::scene::Normal3],
643) -> Vec<GpuNormalPair> {
644    current
645        .iter()
646        .zip(history.iter())
647        .map(|(current, history)| GpuNormalPair {
648            current: [current.x, current.y, current.z, 0.0],
649            history: [history.x, history.y, history.z, 0.0],
650        })
651        .collect()
652}
653
654fn pack_params(width: usize, height: usize, parameters: HostSupervisionParameters) -> GpuParams {
655    GpuParams {
656        size: [width as u32, height as u32, 0, 0],
657        alpha_range: [
658            parameters.alpha_range.min,
659            parameters.alpha_range.max,
660            0.0,
661            0.0,
662        ],
663        residual_threshold: [
664            parameters.thresholds.residual.low,
665            parameters.thresholds.residual.high,
666            0.0,
667            0.0,
668        ],
669        depth_threshold: [
670            parameters.thresholds.depth.low,
671            parameters.thresholds.depth.high,
672            0.0,
673            0.0,
674        ],
675        normal_threshold: [
676            parameters.thresholds.normal.low,
677            parameters.thresholds.normal.high,
678            0.0,
679            0.0,
680        ],
681        neighborhood_threshold: [
682            parameters.thresholds.neighborhood.low,
683            parameters.thresholds.neighborhood.high,
684            0.0,
685            0.0,
686        ],
687        local_contrast_threshold: [
688            parameters.thresholds.local_contrast.low,
689            parameters.thresholds.local_contrast.high,
690            0.0,
691            0.0,
692        ],
693        hazard_curve_threshold: [
694            parameters.thresholds.hazard_curve.low,
695            parameters.thresholds.hazard_curve.high,
696            0.0,
697            0.0,
698        ],
699        weights_a: [
700            parameters.weights.residual,
701            parameters.weights.depth,
702            parameters.weights.normal,
703            parameters.weights.neighborhood,
704        ],
705        weights_b: [
706            parameters.weights.thin,
707            parameters.weights.history_instability,
708            parameters.weights.grammar,
709            0.0,
710        ],
711        history_instability_mix: [
712            parameters.thresholds.history_instability_residual_mix,
713            parameters.thresholds.history_instability_neighborhood_mix,
714            0.0,
715            0.0,
716        ],
717        structural_a: [
718            parameters.structural.disocclusion_like,
719            parameters.structural.unstable_residual,
720            parameters.structural.unstable_neighborhood,
721            0.0,
722        ],
723        structural_b: [
724            parameters.structural.thin_edge,
725            parameters.structural.thin_residual,
726            0.0,
727            0.0,
728        ],
729    }
730}
731
732fn storage_layout_entry(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
733    wgpu::BindGroupLayoutEntry {
734        binding,
735        visibility: wgpu::ShaderStages::COMPUTE,
736        ty: wgpu::BindingType::Buffer {
737            ty: wgpu::BufferBindingType::Storage { read_only },
738            has_dynamic_offset: false,
739            min_binding_size: None,
740        },
741        count: None,
742    }
743}
744
745fn uniform_layout_entry(binding: u32) -> wgpu::BindGroupLayoutEntry {
746    wgpu::BindGroupLayoutEntry {
747        binding,
748        visibility: wgpu::ShaderStages::COMPUTE,
749        ty: wgpu::BindingType::Buffer {
750            ty: wgpu::BufferBindingType::Uniform,
751            has_dynamic_offset: false,
752            min_binding_size: None,
753        },
754        count: None,
755    }
756}
757
758fn storage_binding<'a>(binding: u32, buffer: &'a wgpu::Buffer) -> wgpu::BindGroupEntry<'a> {
759    wgpu::BindGroupEntry {
760        binding,
761        resource: buffer.as_entire_binding(),
762    }
763}
764
765fn uniform_binding<'a>(binding: u32, buffer: &'a wgpu::Buffer) -> wgpu::BindGroupEntry<'a> {
766    wgpu::BindGroupEntry {
767        binding,
768        resource: buffer.as_entire_binding(),
769    }
770}