Skip to main content

crush_gpu/backend/
mod.rs

1//! GPU compute backend trait and discovery
2//!
3//! Defines the [`ComputeBackend`] trait implemented by each GPU vendor
4//! backend (wgpu, CUDA) and the types needed for backend auto-selection.
5
6#[cfg(feature = "cuda")]
7pub mod cuda;
8pub mod wgpu_backend;
9
10use std::sync::atomic::AtomicBool;
11use std::sync::{Arc, OnceLock};
12
13use crush_core::error::Result;
14use tracing::{debug, info, trace, warn};
15
16// ---------------------------------------------------------------------------
17// GpuVendor
18// ---------------------------------------------------------------------------
19
20/// Known GPU hardware vendors.
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
22pub enum GpuVendor {
23    Nvidia,
24    Amd,
25    Intel,
26    Apple,
27    Other,
28}
29
30impl std::fmt::Display for GpuVendor {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            Self::Nvidia => write!(f, "NVIDIA"),
34            Self::Amd => write!(f, "AMD"),
35            Self::Intel => write!(f, "Intel"),
36            Self::Apple => write!(f, "Apple"),
37            Self::Other => write!(f, "Other"),
38        }
39    }
40}
41
42// ---------------------------------------------------------------------------
43// GpuInfo
44// ---------------------------------------------------------------------------
45
46/// Runtime information about a discovered GPU.
47#[derive(Debug, Clone)]
48pub struct GpuInfo {
49    /// Human-readable adapter name (e.g. "NVIDIA `GeForce` RTX 4090").
50    pub name: String,
51    /// Hardware vendor.
52    pub vendor: GpuVendor,
53    /// Estimated VRAM in bytes.
54    pub vram_bytes: u64,
55    /// Graphics API backend in use (e.g. "Vulkan", "Metal", "CUDA").
56    pub api_backend: String,
57}
58
59// ---------------------------------------------------------------------------
60// CompressedTile
61// ---------------------------------------------------------------------------
62
63/// A single compressed tile ready for GPU decompression dispatch.
64#[derive(Debug, Clone)]
65pub struct CompressedTile {
66    /// Compressed payload bytes (excluding `TileHeader`).
67    pub data: Vec<u8>,
68    /// Expected uncompressed size.
69    pub uncompressed_size: u32,
70    /// Sub-stream count within this tile.
71    pub sub_stream_count: u8,
72    /// CRC32 of the uncompressed data (0 if checksums disabled).
73    pub checksum: u32,
74}
75
76// ---------------------------------------------------------------------------
77// ComputeBackend trait
78// ---------------------------------------------------------------------------
79
80/// Abstraction over GPU compute backends (wgpu, CUDA).
81///
82/// All methods that can fail return [`crush_core::error::Result`] so the
83/// engine can decide whether to fall back to CPU.
84pub trait ComputeBackend: Send + Sync {
85    /// Backend display name (e.g. "wgpu-Vulkan", "CUDA").
86    fn name(&self) -> &str;
87
88    /// Information about the GPU selected by this backend.
89    fn gpu_info(&self) -> &GpuInfo;
90
91    /// Decompress a batch of compressed tiles on the GPU.
92    ///
93    /// Returns one `Vec<u8>` per input tile in the same order.
94    ///
95    /// # Cancellation
96    ///
97    /// Implementations **should** check `cancel` between tile batches
98    /// and return `CrushError::Cancelled` when set.
99    ///
100    /// # Errors
101    ///
102    /// May return any GPU error variant wrapped in a `CrushError`.
103    fn decompress_tiles(
104        &self,
105        tiles: &[CompressedTile],
106        cancel: &AtomicBool,
107    ) -> Result<Vec<Vec<u8>>>;
108
109    /// Decompress a batch of `GDeflate`-encoded tiles on the GPU.
110    ///
111    /// Returns one `Vec<u8>` per input tile in the same order.
112    /// The output is already in the correct byte order (no de-interleaving).
113    ///
114    /// # Cancellation
115    ///
116    /// Implementations **should** check `cancel` between tile dispatches
117    /// and return `CrushError::Cancelled` when set.
118    ///
119    /// # Errors
120    ///
121    /// May return any GPU error variant wrapped in a `CrushError`.
122    fn decompress_tiles_gdeflate(
123        &self,
124        tiles: &[CompressedTile],
125        cancel: &AtomicBool,
126    ) -> Result<Vec<Vec<u8>>>;
127
128    /// Release GPU resources held by this backend.
129    fn release(&self);
130}
131
132/// Minimum VRAM requirement in bytes (2 GB).
133pub const MIN_VRAM_BYTES: u64 = 2 * 1024 * 1024 * 1024;
134
135/// GPU memory budget for decompression dispatch in bytes (256 MB).
136pub const GPU_MEMORY_BUDGET: u64 = 256 * 1024 * 1024;
137
138/// Maximum number of tiles to batch into a single GPU submission.
139/// 512 tiles × ~200KB GPU buffers ≈ 100MB, well within `GPU_MEMORY_BUDGET` (256MB).
140pub const MAX_TILES_PER_BATCH: usize = 512;
141
142// ---------------------------------------------------------------------------
143// Shared helpers
144// ---------------------------------------------------------------------------
145
146/// De-interleave sub-stream outputs back to the original tile byte order.
147///
148/// The LZ77 GPU kernel decompresses each sub-stream independently into a
149/// separate region of the output buffer. This function reconstructs the
150/// original byte order by round-robin reading from each sub-stream:
151/// byte `i` of the original tile came from sub-stream `i % n`, position
152/// `i / n` within that sub-stream.
153#[must_use]
154pub fn deinterleave(
155    raw_output: &[u8],
156    ss_lengths: &[u32],
157    sub_stream_count: u32,
158    uncompressed_size: u32,
159) -> Vec<u8> {
160    let n = sub_stream_count as usize;
161    let max_per_ss = (uncompressed_size as usize).div_ceil(n);
162
163    // Extract each sub-stream's decoded bytes.
164    let sub_streams: Vec<&[u8]> = (0..n)
165        .map(|i| {
166            let start = i * max_per_ss;
167            let len = ss_lengths[i] as usize;
168            let end = (start + len).min(raw_output.len());
169            let actual_start = start.min(raw_output.len());
170            &raw_output[actual_start..end]
171        })
172        .collect();
173
174    // De-interleave: byte i of the original tile came from sub-stream i%n,
175    // position i/n within that sub-stream.
176    let mut output = Vec::with_capacity(uncompressed_size as usize);
177    let max_len = sub_streams.iter().map(|s| s.len()).max().unwrap_or(0);
178    for j in 0..max_len {
179        for ss in &sub_streams {
180            if j < ss.len() {
181                output.push(ss[j]);
182            }
183            if output.len() == uncompressed_size as usize {
184                return output;
185            }
186        }
187    }
188
189    output
190}
191
192// ---------------------------------------------------------------------------
193// Backend auto-discovery (cached)
194// ---------------------------------------------------------------------------
195
196/// Cached GPU backend singleton.
197///
198/// GPU device creation is expensive (50-500 ms) and rapid creation/destruction
199/// destabilizes Windows DX12 drivers, causing `DXGI_ERROR_DEVICE_REMOVED` and
200/// `device.lose()` in wgpu. By caching the backend for the process lifetime
201/// we avoid these issues and match how games and other GPU applications work.
202static CACHED_BACKEND: OnceLock<Option<Arc<dyn ComputeBackend>>> = OnceLock::new();
203
204/// Attempt to create a CUDA backend.
205///
206/// Wrapped in `catch_unwind` because cudarc's nvrtc loading panics if the
207/// NVIDIA Runtime Compiler library is not installed (e.g. `nvrtc.dll` on
208/// Windows, `libnvrtc.so` on Linux).
209///
210/// Returns `Ok(backend)` on success, `Err(message)` explaining why CUDA
211/// is unavailable on failure.
212#[cfg(feature = "cuda")]
213fn try_cuda() -> std::result::Result<Arc<dyn ComputeBackend>, String> {
214    debug!("Probing CUDA backend...");
215
216    // Temporarily silence the default panic hook so the user doesn't see
217    // Rust's "thread panicked at ..." noise when nvrtc is missing.
218    let prev_hook = std::panic::take_hook();
219    std::panic::set_hook(Box::new(|_| {}));
220
221    let result = std::panic::catch_unwind(cuda::CudaBackend::try_new);
222
223    // Restore the original panic hook.
224    std::panic::set_hook(prev_hook);
225
226    match result {
227        Ok(Ok(Some(backend))) => {
228            let gi = backend.gpu_info();
229            info!(
230                gpu = %gi.name,
231                vram_mb = gi.vram_bytes / 1024 / 1024,
232                "CUDA backend ready: {}",
233                gi.name
234            );
235            Ok(Arc::new(backend) as Arc<dyn ComputeBackend>)
236        }
237        Ok(Ok(None)) => {
238            debug!("CUDA probe: no compatible NVIDIA GPU found");
239            Err("no compatible NVIDIA GPU found".to_owned())
240        }
241        Ok(Err(e)) => {
242            debug!("CUDA probe failed: {e}");
243            Err(format!("{e}"))
244        }
245        Err(panic_info) => {
246            let msg = panic_info
247                .downcast_ref::<String>()
248                .map(String::as_str)
249                .or_else(|| panic_info.downcast_ref::<&str>().copied())
250                .unwrap_or("unknown panic");
251            warn!("CUDA probe panicked (nvrtc likely missing): {msg}");
252            Err(format!(
253                "CUDA runtime compiler (nvrtc) not found. \
254                 Install the CUDA Toolkit to enable CUDA decompression. \
255                 Detail: {msg}"
256            ))
257        }
258    }
259}
260
261/// Attempt to create a wgpu backend, returning `None` on failure.
262fn try_wgpu() -> Option<Arc<dyn ComputeBackend>> {
263    debug!("Probing wgpu backend...");
264    match wgpu_backend::WgpuBackend::try_new() {
265        Ok(Some(backend)) => {
266            let gi = backend.gpu_info();
267            info!(
268                gpu = %gi.name,
269                api = %gi.api_backend,
270                vram_mb = gi.vram_bytes / 1024 / 1024,
271                "wgpu backend ready: {} ({})",
272                gi.name,
273                gi.api_backend
274            );
275            Some(Arc::new(backend) as Arc<dyn ComputeBackend>)
276        }
277        Ok(None) => {
278            debug!("wgpu probe: no compatible GPU found");
279            None
280        }
281        Err(e) => {
282            warn!("wgpu backend init failed: {e}");
283            None
284        }
285    }
286}
287
288/// Discover the best available GPU backend, caching the result.
289///
290/// The backend is created once and reused for the process lifetime.
291/// If the GPU device becomes lost during use, the engine's `catch_unwind`
292/// safety net converts the error and falls back to CPU decompression.
293///
294/// Backend selection is controlled by [`crate::get_config()`]`.backend`:
295/// - `Auto`: try CUDA first (if feature enabled), then wgpu
296/// - `Cuda`: only try CUDA
297/// - `Wgpu`: only try wgpu
298///
299/// Returns `Ok(None)` if no compatible GPU is found.
300///
301/// # Errors
302///
303/// This function always returns `Ok`. GPU initialization errors are
304/// handled internally and result in `Ok(None)` (no GPU available).
305pub fn discover_gpu() -> Result<Option<Arc<dyn ComputeBackend>>> {
306    use crush_core::error::{CrushError, PluginError};
307
308    use crate::BackendPreference;
309
310    // We cannot return errors from inside `get_or_init`, so we do a two-step:
311    // first check if the cached value already exists, then handle the init
312    // with potential errors.
313    if let Some(cached) = CACHED_BACKEND.get() {
314        trace!("Using cached GPU backend");
315        return Ok(cached.clone());
316    }
317
318    let pref = crate::get_config().backend;
319    info!(preference = ?pref, "Discovering GPU backend (preference: {pref:?})");
320
321    let backend: Option<Arc<dyn ComputeBackend>> = match pref {
322        BackendPreference::Auto => {
323            #[cfg(feature = "cuda")]
324            {
325                match try_cuda() {
326                    Ok(backend) => Some(backend),
327                    Err(msg) => {
328                        info!("CUDA unavailable ({msg}), trying wgpu");
329                        try_wgpu()
330                    }
331                }
332            }
333            #[cfg(not(feature = "cuda"))]
334            {
335                debug!("CUDA feature not compiled in, trying wgpu only");
336                try_wgpu()
337            }
338        }
339        BackendPreference::Cuda => {
340            #[cfg(feature = "cuda")]
341            {
342                match try_cuda() {
343                    Ok(backend) => Some(backend),
344                    Err(msg) => {
345                        return Err(CrushError::from(PluginError::OperationFailed(format!(
346                            "CUDA backend requested but unavailable: {msg}"
347                        ))));
348                    }
349                }
350            }
351            #[cfg(not(feature = "cuda"))]
352            {
353                return Err(CrushError::from(PluginError::OperationFailed(
354                    "CUDA backend not included in this build. \
355                     Reinstall with: cargo install crush-cli --features cuda"
356                        .to_owned(),
357                )));
358            }
359        }
360        BackendPreference::Wgpu => try_wgpu(),
361    };
362
363    if let Some(ref b) = backend {
364        info!(backend = b.name(), "GPU backend selected: {}", b.name());
365    } else {
366        info!("No GPU backend available, will use CPU fallback");
367    }
368
369    // Cache the result (race-safe: if another thread beat us, use their value).
370    let _ = CACHED_BACKEND.set(backend.clone());
371    Ok(CACHED_BACKEND.get().cloned().flatten())
372}