Skip to main content

crush_gpu/backend/
wgpu_backend.rs

1//! wgpu compute shader backend (Vulkan/Metal/DX12)
2//!
3//! Provides GPU decompression via wgpu's cross-platform compute shader API.
4//! Requires Vulkan 1.2 / Metal 2 / DX12 + 2 GB VRAM minimum.
5
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::time::Duration;
8
9use crush_core::error::{CrushError, PluginError, Result};
10
11use super::{CompressedTile, ComputeBackend, GpuInfo, GpuVendor, MIN_VRAM_BYTES};
12
13/// WGSL compute shader source — LZ77 (v1 compat, embedded at compile time).
14const DECOMPRESS_SHADER: &str = include_str!("../shader/decompress.wgsl");
15
16/// WGSL compute shader source — `GDeflate` (v2, embedded at compile time).
17const GDEFLATE_SHADER: &str = include_str!("../shader/gdeflate_decompress.wgsl");
18
19/// Maximum time to wait for GPU work to complete before treating it as a hang.
20/// 5 seconds is generous for a single tile (64KB) decompression dispatch.
21/// On Windows, TDR typically resets the GPU after ~2s, so this catches hangs
22/// that survive TDR as well.
23const GPU_POLL_TIMEOUT: Duration = Duration::from_secs(5);
24
25/// wgpu-backed GPU compute backend.
26///
27/// Holds two compute pipelines: one for LZ77 (v1) and one for `GDeflate` (v2).
28pub struct WgpuBackend {
29    info: GpuInfo,
30    device: wgpu::Device,
31    queue: wgpu::Queue,
32    // LZ77 pipeline (v1)
33    pipeline: wgpu::ComputePipeline,
34    bind_group_layout: wgpu::BindGroupLayout,
35    // GDeflate pipeline (v2)
36    gdeflate_pipeline: wgpu::ComputePipeline,
37    gdeflate_bgl: wgpu::BindGroupLayout,
38}
39
40/// Uniform struct matching the `TileMeta` layout in the LZ77 WGSL shader.
41/// 8 × u32 = 32 bytes.
42#[repr(C)]
43#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
44struct TileMeta {
45    compressed_offset: u32,
46    compressed_size: u32,
47    uncompressed_size: u32,
48    sub_stream_count: u32,
49    output_offset: u32,
50    tile_index: u32,
51    _pad0: u32,
52    _pad1: u32,
53}
54
55/// Metadata struct matching `GDeflateMeta` in the `GDeflate` WGSL shader.
56/// 4 × u32 = 16 bytes.
57#[repr(C)]
58#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
59struct GDeflateMeta {
60    payload_size: u32,
61    uncompressed_size: u32,
62    _pad0: u32,
63    _pad1: u32,
64}
65
66/// GPU buffers needed for a single tile dispatch.
67struct TileBuffers {
68    meta: wgpu::Buffer,
69    compressed: wgpu::Buffer,
70    output: wgpu::Buffer,
71    lengths: wgpu::Buffer,
72    out_staging: wgpu::Buffer,
73    len_staging: wgpu::Buffer,
74}
75
76/// Allocate all GPU buffers for a single tile dispatch.
77///
78/// Uses `catch_unwind` to prevent wgpu internal panics (e.g. on OOM)
79/// from crashing the entire process.
80fn create_tile_buffers(
81    device: &wgpu::Device,
82    comp_data: &[u8],
83    out_aligned: u64,
84    len_size: u64,
85) -> Result<TileBuffers> {
86    std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
87        let buf = |label, size, usage| {
88            device.create_buffer(&wgpu::BufferDescriptor {
89                label: Some(label),
90                size,
91                usage,
92                mapped_at_creation: false,
93            })
94        };
95        let storage_dst = wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST;
96        let storage_src = wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC;
97        let map_read = wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ;
98
99        TileBuffers {
100            meta: buf(
101                "tile_meta",
102                std::mem::size_of::<TileMeta>() as u64,
103                storage_dst,
104            ),
105            compressed: buf("compressed_data", comp_data.len() as u64, storage_dst),
106            output: buf("decompressed_data", out_aligned, storage_src),
107            lengths: buf("sub_stream_lengths", len_size, storage_src),
108            out_staging: buf("out_staging", out_aligned, map_read),
109            len_staging: buf("len_staging", len_size, map_read),
110        }
111    }))
112    .map_err(|e| {
113        let msg = e
114            .downcast_ref::<String>()
115            .map(String::as_str)
116            .or_else(|| e.downcast_ref::<&str>().copied())
117            .unwrap_or("unknown GPU buffer allocation panic");
118        CrushError::from(PluginError::OperationFailed(format!(
119            "GPU buffer allocation failed: {msg}"
120        )))
121    })
122}
123
124/// Parse sub-stream length u32 values from raw bytes.
125fn parse_ss_lengths(len_bytes: &[u8], n: u32) -> Vec<u32> {
126    let mut ss_lengths = Vec::with_capacity(n as usize);
127    for i in 0..n as usize {
128        let off = i * 4;
129        if off + 4 <= len_bytes.len() {
130            ss_lengths.push(u32::from_le_bytes([
131                len_bytes[off],
132                len_bytes[off + 1],
133                len_bytes[off + 2],
134                len_bytes[off + 3],
135            ]));
136        } else {
137            ss_lengths.push(0);
138        }
139    }
140    ss_lengths
141}
142
143/// Create the bind group layout with 4 storage buffer bindings for the LZ77 shader.
144fn create_bind_group_layout(device: &wgpu::Device) -> wgpu::BindGroupLayout {
145    let storage_entry = |binding: u32, read_only: bool| wgpu::BindGroupLayoutEntry {
146        binding,
147        visibility: wgpu::ShaderStages::COMPUTE,
148        ty: wgpu::BindingType::Buffer {
149            ty: wgpu::BufferBindingType::Storage { read_only },
150            has_dynamic_offset: false,
151            min_binding_size: None,
152        },
153        count: None,
154    };
155
156    device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
157        label: Some("decompress_bgl"),
158        entries: &[
159            storage_entry(0, true),  // tile_meta
160            storage_entry(1, true),  // compressed_data
161            storage_entry(2, false), // decompressed_data
162            storage_entry(3, false), // sub_stream_lengths
163        ],
164    })
165}
166
167/// Create the bind group layout with 3 storage buffer bindings for the `GDeflate` shader.
168fn create_gdeflate_bgl(device: &wgpu::Device) -> wgpu::BindGroupLayout {
169    let storage_entry = |binding: u32, read_only: bool| wgpu::BindGroupLayoutEntry {
170        binding,
171        visibility: wgpu::ShaderStages::COMPUTE,
172        ty: wgpu::BindingType::Buffer {
173            ty: wgpu::BufferBindingType::Storage { read_only },
174            has_dynamic_offset: false,
175            min_binding_size: None,
176        },
177        count: None,
178    };
179
180    device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
181        label: Some("gdeflate_bgl"),
182        entries: &[
183            storage_entry(0, true),  // meta (GDeflateMeta)
184            storage_entry(1, true),  // compressed (GDeflate payload)
185            storage_entry(2, false), // output (decompressed bytes)
186        ],
187    })
188}
189
190impl WgpuBackend {
191    /// Attempt to create a wgpu backend by discovering a suitable GPU adapter.
192    ///
193    /// Returns `None` if no compatible GPU is found.
194    ///
195    /// # Errors
196    ///
197    /// Returns an error if wgpu initialisation fails unexpectedly.
198    pub fn try_new() -> Result<Option<Self>> {
199        let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
200            backends: wgpu::Backends::VULKAN | wgpu::Backends::METAL | wgpu::Backends::DX12,
201            ..Default::default()
202        });
203
204        let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
205            power_preference: wgpu::PowerPreference::HighPerformance,
206            compatible_surface: None,
207            force_fallback_adapter: false,
208        }));
209
210        let Ok(adapter) = adapter else {
211            return Ok(None);
212        };
213
214        let adapter_info = adapter.get_info();
215
216        // Reject software/CPU adapters — they won't provide GPU acceleration.
217        if adapter_info.device_type == wgpu::DeviceType::Cpu {
218            return Ok(None);
219        }
220
221        // Use max_buffer_size as a rough VRAM proxy. Note: this is the
222        // driver-reported maximum single-buffer size, not total VRAM.
223        // On discrete GPUs it's typically 2+ GB. We use a conservative
224        // check and rely on catch_unwind + CPU fallback for OOM safety.
225        let limits = adapter.limits();
226        let estimated_vram = limits.max_buffer_size;
227        if estimated_vram < MIN_VRAM_BYTES {
228            return Ok(None);
229        }
230
231        let vendor = match adapter_info.vendor {
232            0x10DE => GpuVendor::Nvidia,
233            0x1002 => GpuVendor::Amd,
234            0x8086 => GpuVendor::Intel,
235            _ if adapter_info.driver.to_lowercase().contains("apple")
236                || adapter_info.name.to_lowercase().contains("apple") =>
237            {
238                GpuVendor::Apple
239            }
240            _ => GpuVendor::Other,
241        };
242
243        let (device, queue) = pollster::block_on(adapter.request_device(&wgpu::DeviceDescriptor {
244            label: Some("crush-gpu"),
245            required_features: wgpu::Features::empty(),
246            required_limits: wgpu::Limits::default(),
247            memory_hints: wgpu::MemoryHints::Performance,
248            trace: wgpu::Trace::Off,
249            experimental_features: wgpu::ExperimentalFeatures::default(),
250        }))
251        .map_err(|e| PluginError::OperationFailed(format!("wgpu device request failed: {e}")))?;
252
253        // --- LZ77 pipeline (v1) ---
254        let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
255            label: Some("decompress.wgsl"),
256            source: wgpu::ShaderSource::Wgsl(DECOMPRESS_SHADER.into()),
257        });
258
259        let bind_group_layout = create_bind_group_layout(&device);
260
261        let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
262            label: Some("decompress_pipeline_layout"),
263            bind_group_layouts: &[&bind_group_layout],
264            immediate_size: 0,
265        });
266
267        let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
268            label: Some("decompress_pipeline"),
269            layout: Some(&pipeline_layout),
270            module: &shader_module,
271            entry_point: Some("main"),
272            compilation_options: wgpu::PipelineCompilationOptions::default(),
273            cache: None,
274        });
275
276        // --- GDeflate pipeline (v2) ---
277        let gdeflate_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
278            label: Some("gdeflate_decompress.wgsl"),
279            source: wgpu::ShaderSource::Wgsl(GDEFLATE_SHADER.into()),
280        });
281
282        let gdeflate_bgl = create_gdeflate_bgl(&device);
283
284        let gdeflate_pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
285            label: Some("gdeflate_pipeline_layout"),
286            bind_group_layouts: &[&gdeflate_bgl],
287            immediate_size: 0,
288        });
289
290        let gdeflate_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
291            label: Some("gdeflate_pipeline"),
292            layout: Some(&gdeflate_pl),
293            module: &gdeflate_module,
294            entry_point: Some("main"),
295            compilation_options: wgpu::PipelineCompilationOptions::default(),
296            cache: None,
297        });
298
299        let info = GpuInfo {
300            name: adapter_info.name.clone(),
301            vendor,
302            vram_bytes: estimated_vram,
303            api_backend: format!("{:?}", adapter_info.backend),
304        };
305
306        Ok(Some(Self {
307            info,
308            device,
309            queue,
310            pipeline,
311            bind_group_layout,
312            gdeflate_pipeline,
313            gdeflate_bgl,
314        }))
315    }
316
317    /// Map two staging buffers, poll the device, and return their contents.
318    fn readback_buffers(
319        &self,
320        out_staging: &wgpu::Buffer,
321        lengths_staging: &wgpu::Buffer,
322    ) -> Result<(Vec<u8>, Vec<u8>)> {
323        let out_slice = out_staging.slice(..);
324        let lengths_slice = lengths_staging.slice(..);
325
326        let (out_tx, out_rx) = std::sync::mpsc::channel();
327        out_slice.map_async(wgpu::MapMode::Read, move |result| {
328            let _ = out_tx.send(result);
329        });
330        let (len_tx, len_rx) = std::sync::mpsc::channel();
331        lengths_slice.map_async(wgpu::MapMode::Read, move |result| {
332            let _ = len_tx.send(result);
333        });
334
335        self.device
336            .poll(wgpu::PollType::Wait {
337                submission_index: None,
338                timeout: Some(GPU_POLL_TIMEOUT),
339            })
340            .map_err(|e| {
341                PluginError::OperationFailed(format!(
342                    "GPU poll failed (timeout or device lost): {e}"
343                ))
344            })?;
345
346        out_rx
347            .recv()
348            .map_err(|e| PluginError::OperationFailed(format!("GPU readback channel error: {e}")))?
349            .map_err(|e| PluginError::OperationFailed(format!("GPU output map failed: {e}")))?;
350        len_rx
351            .recv()
352            .map_err(|e| PluginError::OperationFailed(format!("GPU readback channel error: {e}")))?
353            .map_err(|e| PluginError::OperationFailed(format!("GPU lengths map failed: {e}")))?;
354
355        let out_bytes = out_slice.get_mapped_range().to_vec();
356        let len_bytes = lengths_slice.get_mapped_range().to_vec();
357        Ok((out_bytes, len_bytes))
358    }
359
360    /// Decompress a single tile on the GPU and return the raw sub-stream outputs.
361    fn dispatch_tile(&self, tile: &CompressedTile, tile_index: u32) -> Result<(Vec<u8>, Vec<u32>)> {
362        let n = u32::from(tile.sub_stream_count);
363        if n == 0 {
364            return Err(CrushError::InvalidFormat(
365                "tile has zero sub-stream count".to_owned(),
366            ));
367        }
368
369        // Guard against crafted archives with absurd uncompressed_size that
370        // would cause u32 overflow in `n * max_per_ss`.
371        let max_tile_size: u32 = crate::format::DEFAULT_TILE_SIZE;
372        if tile.uncompressed_size > max_tile_size.saturating_mul(2) {
373            return Err(CrushError::InvalidFormat(format!(
374                "tile uncompressed_size {} exceeds maximum {}",
375                tile.uncompressed_size,
376                max_tile_size * 2,
377            )));
378        }
379
380        let max_per_ss = tile.uncompressed_size.div_ceil(n);
381        let output_buf_size = n.checked_mul(max_per_ss).ok_or_else(|| {
382            CrushError::InvalidFormat(format!("output buffer size overflow: {n} * {max_per_ss}"))
383        })?;
384
385        let mut comp_data = tile.data.clone();
386        while !comp_data.len().is_multiple_of(4) {
387            comp_data.push(0);
388        }
389
390        let meta = TileMeta {
391            compressed_offset: 0,
392            compressed_size: u32::try_from(tile.data.len())
393                .map_err(|e| PluginError::OperationFailed(e.to_string()))?,
394            uncompressed_size: tile.uncompressed_size,
395            sub_stream_count: n,
396            output_offset: 0,
397            tile_index,
398            _pad0: 0,
399            _pad1: 0,
400        };
401
402        let out_aligned = u64::from(output_buf_size.div_ceil(4) * 4).max(4);
403        let len_size = (u64::from(n) * 4).max(4);
404        let bufs = create_tile_buffers(&self.device, &comp_data, out_aligned, len_size)?;
405
406        self.queue
407            .write_buffer(&bufs.meta, 0, bytemuck::bytes_of(&meta));
408        self.queue.write_buffer(&bufs.compressed, 0, &comp_data);
409
410        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
411            label: Some("decompress_bg"),
412            layout: &self.bind_group_layout,
413            entries: &[
414                wgpu::BindGroupEntry {
415                    binding: 0,
416                    resource: bufs.meta.as_entire_binding(),
417                },
418                wgpu::BindGroupEntry {
419                    binding: 1,
420                    resource: bufs.compressed.as_entire_binding(),
421                },
422                wgpu::BindGroupEntry {
423                    binding: 2,
424                    resource: bufs.output.as_entire_binding(),
425                },
426                wgpu::BindGroupEntry {
427                    binding: 3,
428                    resource: bufs.lengths.as_entire_binding(),
429                },
430            ],
431        });
432
433        let mut encoder = self
434            .device
435            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
436                label: Some("decompress_encoder"),
437            });
438        {
439            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
440                label: Some("decompress_pass"),
441                timestamp_writes: None,
442            });
443            pass.set_pipeline(&self.pipeline);
444            pass.set_bind_group(0, &bind_group, &[]);
445            pass.dispatch_workgroups(1, 1, 1);
446        }
447        encoder.copy_buffer_to_buffer(&bufs.output, 0, &bufs.out_staging, 0, out_aligned);
448        encoder.copy_buffer_to_buffer(&bufs.lengths, 0, &bufs.len_staging, 0, len_size);
449        self.queue.submit(std::iter::once(encoder.finish()));
450
451        let (out_bytes, len_bytes) = self.readback_buffers(&bufs.out_staging, &bufs.len_staging)?;
452
453        Ok((out_bytes, parse_ss_lengths(&len_bytes, n)))
454    }
455}
456
457/// GPU buffers for a single `GDeflate` tile dispatch.
458struct GDeflateBuffers {
459    meta: wgpu::Buffer,
460    compressed: wgpu::Buffer,
461    output: wgpu::Buffer,
462    out_staging: wgpu::Buffer,
463}
464
465/// Allocate GPU buffers for a `GDeflate` tile dispatch.
466fn create_gdeflate_buffers(
467    device: &wgpu::Device,
468    comp_data: &[u8],
469    out_aligned: u64,
470) -> Result<GDeflateBuffers> {
471    std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
472        let buf = |label, size, usage| {
473            device.create_buffer(&wgpu::BufferDescriptor {
474                label: Some(label),
475                size,
476                usage,
477                mapped_at_creation: false,
478            })
479        };
480        let storage_dst = wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST;
481        let storage_src = wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC;
482        let map_read = wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ;
483
484        GDeflateBuffers {
485            meta: buf(
486                "gdeflate_meta",
487                std::mem::size_of::<GDeflateMeta>() as u64,
488                storage_dst,
489            ),
490            compressed: buf("gdeflate_compressed", comp_data.len() as u64, storage_dst),
491            output: buf("gdeflate_output", out_aligned, storage_src),
492            out_staging: buf("gdeflate_staging", out_aligned, map_read),
493        }
494    }))
495    .map_err(|e| {
496        let msg = e
497            .downcast_ref::<String>()
498            .map(String::as_str)
499            .or_else(|| e.downcast_ref::<&str>().copied())
500            .unwrap_or("unknown GPU buffer allocation panic");
501        CrushError::from(PluginError::OperationFailed(format!(
502            "GDeflate GPU buffer allocation failed: {msg}"
503        )))
504    })
505}
506
507impl ComputeBackend for WgpuBackend {
508    #[allow(clippy::unnecessary_literal_bound)]
509    fn name(&self) -> &str {
510        "wgpu"
511    }
512
513    fn gpu_info(&self) -> &GpuInfo {
514        &self.info
515    }
516
517    fn decompress_tiles(
518        &self,
519        tiles: &[CompressedTile],
520        cancel: &AtomicBool,
521    ) -> Result<Vec<Vec<u8>>> {
522        // Wrap the entire GPU dispatch loop in catch_unwind so that panics
523        // inside wgpu (e.g. "device is lost", driver crashes, TDR) are
524        // converted to errors instead of crashing the process.
525        std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
526            self.decompress_tiles_inner(tiles, cancel)
527        }))
528        .unwrap_or_else(|e| {
529            let msg = e
530                .downcast_ref::<String>()
531                .map(String::as_str)
532                .or_else(|| e.downcast_ref::<&str>().copied())
533                .unwrap_or("unknown GPU panic");
534            Err(CrushError::from(PluginError::OperationFailed(format!(
535                "GPU panic caught (falling back to CPU): {msg}"
536            ))))
537        })
538    }
539
540    fn decompress_tiles_gdeflate(
541        &self,
542        tiles: &[CompressedTile],
543        cancel: &AtomicBool,
544    ) -> Result<Vec<Vec<u8>>> {
545        std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
546            self.decompress_tiles_gdeflate_inner(tiles, cancel)
547        }))
548        .unwrap_or_else(|e| {
549            let msg = e
550                .downcast_ref::<String>()
551                .map(String::as_str)
552                .or_else(|| e.downcast_ref::<&str>().copied())
553                .unwrap_or("unknown GPU panic");
554            Err(CrushError::from(PluginError::OperationFailed(format!(
555                "GDeflate GPU panic caught (falling back to CPU): {msg}"
556            ))))
557        })
558    }
559
560    fn release(&self) {
561        // wgpu resources are dropped automatically via RAII.
562    }
563}
564
565/// Validated and padded tile data ready for GPU upload.
566struct PreparedTile {
567    padded_data: Vec<u8>,
568    meta: GDeflateMeta,
569    out_aligned: u64,
570}
571
572/// Validate a tile and prepare its padded data + metadata for GPU dispatch.
573fn prepare_gdeflate_tile(tile: &CompressedTile) -> Result<PreparedTile> {
574    let max_tile_size: u32 = crate::format::DEFAULT_TILE_SIZE;
575    if tile.uncompressed_size > max_tile_size.saturating_mul(2) {
576        return Err(CrushError::InvalidFormat(format!(
577            "tile uncompressed_size {} exceeds maximum {}",
578            tile.uncompressed_size,
579            max_tile_size * 2,
580        )));
581    }
582
583    let mut padded_data = tile.data.clone();
584    while !padded_data.len().is_multiple_of(4) {
585        padded_data.push(0);
586    }
587
588    let meta = GDeflateMeta {
589        payload_size: u32::try_from(padded_data.len())
590            .map_err(|e| PluginError::OperationFailed(e.to_string()))?,
591        uncompressed_size: tile.uncompressed_size,
592        _pad0: 0,
593        _pad1: 0,
594    };
595
596    let out_aligned = u64::from(tile.uncompressed_size.div_ceil(4) * 4).max(4);
597
598    Ok(PreparedTile {
599        padded_data,
600        meta,
601        out_aligned,
602    })
603}
604
605/// Create a `GDeflate` bind group for a single tile's buffers.
606fn create_gdeflate_bind_group(
607    device: &wgpu::Device,
608    layout: &wgpu::BindGroupLayout,
609    bufs: &GDeflateBuffers,
610) -> wgpu::BindGroup {
611    device.create_bind_group(&wgpu::BindGroupDescriptor {
612        label: Some("gdeflate_bg"),
613        layout,
614        entries: &[
615            wgpu::BindGroupEntry {
616                binding: 0,
617                resource: bufs.meta.as_entire_binding(),
618            },
619            wgpu::BindGroupEntry {
620                binding: 1,
621                resource: bufs.compressed.as_entire_binding(),
622            },
623            wgpu::BindGroupEntry {
624                binding: 2,
625                resource: bufs.output.as_entire_binding(),
626            },
627        ],
628    })
629}
630
631impl WgpuBackend {
632    /// Dispatch a batch of `GDeflate` tiles in a single GPU submission.
633    ///
634    /// All tiles share one `CommandEncoder`, one `ComputePass`, one `queue.submit()`,
635    /// and one `device.poll()`. Each tile gets its own buffers and bind group since
636    /// buffer sizes vary per tile. This eliminates per-tile host-GPU synchronization
637    /// overhead.
638    fn dispatch_batch_gdeflate(&self, tiles: &[CompressedTile]) -> Result<Vec<Vec<u8>>> {
639        // Prepare all tiles (validate, pad, build metadata).
640        let prepared: Vec<PreparedTile> = tiles
641            .iter()
642            .map(prepare_gdeflate_tile)
643            .collect::<Result<Vec<_>>>()?;
644
645        // Allocate all GPU buffers upfront.
646        let all_bufs: Vec<GDeflateBuffers> = prepared
647            .iter()
648            .map(|p| create_gdeflate_buffers(&self.device, &p.padded_data, p.out_aligned))
649            .collect::<Result<Vec<_>>>()?;
650
651        // Upload all metadata and compressed data.
652        for (p, bufs) in prepared.iter().zip(all_bufs.iter()) {
653            self.queue
654                .write_buffer(&bufs.meta, 0, bytemuck::bytes_of(&p.meta));
655            self.queue.write_buffer(&bufs.compressed, 0, &p.padded_data);
656        }
657
658        // Create all bind groups.
659        let bind_groups: Vec<wgpu::BindGroup> = all_bufs
660            .iter()
661            .map(|bufs| create_gdeflate_bind_group(&self.device, &self.gdeflate_bgl, bufs))
662            .collect();
663
664        // One encoder, one compute pass, multiple dispatches.
665        let mut encoder = self
666            .device
667            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
668                label: Some("gdeflate_batch_encoder"),
669            });
670        {
671            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
672                label: Some("gdeflate_batch_pass"),
673                timestamp_writes: None,
674            });
675            pass.set_pipeline(&self.gdeflate_pipeline);
676            for bg in &bind_groups {
677                pass.set_bind_group(0, bg, &[]);
678                pass.dispatch_workgroups(1, 1, 1);
679            }
680        }
681
682        // Copy all output buffers to staging buffers.
683        for (p, bufs) in prepared.iter().zip(all_bufs.iter()) {
684            encoder.copy_buffer_to_buffer(&bufs.output, 0, &bufs.out_staging, 0, p.out_aligned);
685        }
686
687        // Single submit for all tiles.
688        self.queue.submit(std::iter::once(encoder.finish()));
689
690        // Map all staging buffers, single poll, collect results.
691        self.readback_batch(&all_bufs, tiles)
692    }
693
694    /// Map all staging buffers, poll once, and collect decompressed results.
695    fn readback_batch(
696        &self,
697        all_bufs: &[GDeflateBuffers],
698        tiles: &[CompressedTile],
699    ) -> Result<Vec<Vec<u8>>> {
700        let receivers: Vec<_> = all_bufs
701            .iter()
702            .map(|bufs| {
703                let slice = bufs.out_staging.slice(..);
704                let (tx, rx) = std::sync::mpsc::channel();
705                slice.map_async(wgpu::MapMode::Read, move |result| {
706                    let _ = tx.send(result);
707                });
708                rx
709            })
710            .collect();
711
712        self.device
713            .poll(wgpu::PollType::Wait {
714                submission_index: None,
715                timeout: Some(GPU_POLL_TIMEOUT),
716            })
717            .map_err(|e| {
718                PluginError::OperationFailed(format!(
719                    "GDeflate GPU poll failed (timeout or device lost): {e}"
720                ))
721            })?;
722
723        let mut results = Vec::with_capacity(tiles.len());
724        for (i, rx) in receivers.into_iter().enumerate() {
725            rx.recv()
726                .map_err(|e| {
727                    PluginError::OperationFailed(format!(
728                        "GDeflate GPU readback channel error: {e}"
729                    ))
730                })?
731                .map_err(|e| {
732                    PluginError::OperationFailed(format!("GDeflate GPU output map failed: {e}"))
733                })?;
734
735            let slice = all_bufs[i].out_staging.slice(..);
736            let out_bytes = slice.get_mapped_range().to_vec();
737            let size = tiles[i].uncompressed_size as usize;
738            results.push(out_bytes[..size.min(out_bytes.len())].to_vec());
739        }
740
741        Ok(results)
742    }
743
744    /// Inner dispatch loop for `GDeflate` tiles — batched for throughput.
745    ///
746    /// Processes tiles in chunks of `super::MAX_TILES_PER_BATCH`, checking for
747    /// cancellation between batches. Each batch is dispatched as a single
748    /// GPU submission to minimize host-GPU synchronization overhead.
749    fn decompress_tiles_gdeflate_inner(
750        &self,
751        tiles: &[CompressedTile],
752        cancel: &AtomicBool,
753    ) -> Result<Vec<Vec<u8>>> {
754        let mut results = Vec::with_capacity(tiles.len());
755        for batch in tiles.chunks(super::MAX_TILES_PER_BATCH) {
756            if cancel.load(Ordering::Relaxed) {
757                return Err(CrushError::Cancelled);
758            }
759            let batch_results = self.dispatch_batch_gdeflate(batch)?;
760            results.extend(batch_results);
761        }
762        Ok(results)
763    }
764
765    /// Inner dispatch loop, separated so `decompress_tiles` can wrap it in
766    /// `catch_unwind` to prevent wgpu panics from crashing the process.
767    fn decompress_tiles_inner(
768        &self,
769        tiles: &[CompressedTile],
770        cancel: &AtomicBool,
771    ) -> Result<Vec<Vec<u8>>> {
772        let mut results = Vec::with_capacity(tiles.len());
773
774        for (i, tile) in tiles.iter().enumerate() {
775            if cancel.load(Ordering::Relaxed) {
776                return Err(CrushError::Cancelled);
777            }
778
779            let tile_index =
780                u32::try_from(i).map_err(|e| PluginError::OperationFailed(e.to_string()))?;
781
782            let (raw_output, ss_lengths) = self.dispatch_tile(tile, tile_index)?;
783
784            let decompressed = super::deinterleave(
785                &raw_output,
786                &ss_lengths,
787                u32::from(tile.sub_stream_count),
788                tile.uncompressed_size,
789            );
790
791            results.push(decompressed);
792        }
793
794        Ok(results)
795    }
796}