crush_gpu/backend/mod.rs
1//! GPU compute backend trait and discovery
2//!
3//! Defines the [`ComputeBackend`] trait implemented by each GPU vendor
4//! backend (wgpu, CUDA) and the types needed for backend auto-selection.
5
6#[cfg(feature = "cuda")]
7pub mod cuda;
8pub mod wgpu_backend;
9
10use std::sync::atomic::AtomicBool;
11use std::sync::{Arc, OnceLock};
12
13use crush_core::error::Result;
14use tracing::{debug, info, trace, warn};
15
16// ---------------------------------------------------------------------------
17// GpuVendor
18// ---------------------------------------------------------------------------
19
20/// Known GPU hardware vendors.
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
22pub enum GpuVendor {
23 Nvidia,
24 Amd,
25 Intel,
26 Apple,
27 Other,
28}
29
30impl std::fmt::Display for GpuVendor {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 Self::Nvidia => write!(f, "NVIDIA"),
34 Self::Amd => write!(f, "AMD"),
35 Self::Intel => write!(f, "Intel"),
36 Self::Apple => write!(f, "Apple"),
37 Self::Other => write!(f, "Other"),
38 }
39 }
40}
41
42// ---------------------------------------------------------------------------
43// GpuInfo
44// ---------------------------------------------------------------------------
45
46/// Runtime information about a discovered GPU.
47#[derive(Debug, Clone)]
48pub struct GpuInfo {
49 /// Human-readable adapter name (e.g. "NVIDIA `GeForce` RTX 4090").
50 pub name: String,
51 /// Hardware vendor.
52 pub vendor: GpuVendor,
53 /// Estimated VRAM in bytes.
54 pub vram_bytes: u64,
55 /// Graphics API backend in use (e.g. "Vulkan", "Metal", "CUDA").
56 pub api_backend: String,
57}
58
59// ---------------------------------------------------------------------------
60// CompressedTile
61// ---------------------------------------------------------------------------
62
63/// A single compressed tile ready for GPU decompression dispatch.
64#[derive(Debug, Clone)]
65pub struct CompressedTile {
66 /// Compressed payload bytes (excluding `TileHeader`).
67 pub data: Vec<u8>,
68 /// Expected uncompressed size.
69 pub uncompressed_size: u32,
70 /// Sub-stream count within this tile.
71 pub sub_stream_count: u8,
72 /// CRC32 of the uncompressed data (0 if checksums disabled).
73 pub checksum: u32,
74}
75
76// ---------------------------------------------------------------------------
77// ComputeBackend trait
78// ---------------------------------------------------------------------------
79
80/// Abstraction over GPU compute backends (wgpu, CUDA).
81///
82/// All methods that can fail return [`crush_core::error::Result`] so the
83/// engine can decide whether to fall back to CPU.
84pub trait ComputeBackend: Send + Sync {
85 /// Backend display name (e.g. "wgpu-Vulkan", "CUDA").
86 fn name(&self) -> &str;
87
88 /// Information about the GPU selected by this backend.
89 fn gpu_info(&self) -> &GpuInfo;
90
91 /// Decompress a batch of compressed tiles on the GPU.
92 ///
93 /// Returns one `Vec<u8>` per input tile in the same order.
94 ///
95 /// # Cancellation
96 ///
97 /// Implementations **should** check `cancel` between tile batches
98 /// and return `CrushError::Cancelled` when set.
99 ///
100 /// # Errors
101 ///
102 /// May return any GPU error variant wrapped in a `CrushError`.
103 fn decompress_tiles(
104 &self,
105 tiles: &[CompressedTile],
106 cancel: &AtomicBool,
107 ) -> Result<Vec<Vec<u8>>>;
108
109 /// Decompress a batch of `GDeflate`-encoded tiles on the GPU.
110 ///
111 /// Returns one `Vec<u8>` per input tile in the same order.
112 /// The output is already in the correct byte order (no de-interleaving).
113 ///
114 /// # Cancellation
115 ///
116 /// Implementations **should** check `cancel` between tile dispatches
117 /// and return `CrushError::Cancelled` when set.
118 ///
119 /// # Errors
120 ///
121 /// May return any GPU error variant wrapped in a `CrushError`.
122 fn decompress_tiles_gdeflate(
123 &self,
124 tiles: &[CompressedTile],
125 cancel: &AtomicBool,
126 ) -> Result<Vec<Vec<u8>>>;
127
128 /// Release GPU resources held by this backend.
129 fn release(&self);
130}
131
132/// Minimum VRAM requirement in bytes (2 GB).
133pub const MIN_VRAM_BYTES: u64 = 2 * 1024 * 1024 * 1024;
134
135/// GPU memory budget for decompression dispatch in bytes (256 MB).
136pub const GPU_MEMORY_BUDGET: u64 = 256 * 1024 * 1024;
137
138/// Maximum number of tiles to batch into a single GPU submission.
139/// 512 tiles × ~200KB GPU buffers ≈ 100MB, well within `GPU_MEMORY_BUDGET` (256MB).
140pub const MAX_TILES_PER_BATCH: usize = 512;
141
142// ---------------------------------------------------------------------------
143// Shared helpers
144// ---------------------------------------------------------------------------
145
146/// De-interleave sub-stream outputs back to the original tile byte order.
147///
148/// The LZ77 GPU kernel decompresses each sub-stream independently into a
149/// separate region of the output buffer. This function reconstructs the
150/// original byte order by round-robin reading from each sub-stream:
151/// byte `i` of the original tile came from sub-stream `i % n`, position
152/// `i / n` within that sub-stream.
153#[must_use]
154pub fn deinterleave(
155 raw_output: &[u8],
156 ss_lengths: &[u32],
157 sub_stream_count: u32,
158 uncompressed_size: u32,
159) -> Vec<u8> {
160 let n = sub_stream_count as usize;
161 let max_per_ss = (uncompressed_size as usize).div_ceil(n);
162
163 // Extract each sub-stream's decoded bytes.
164 let sub_streams: Vec<&[u8]> = (0..n)
165 .map(|i| {
166 let start = i * max_per_ss;
167 let len = ss_lengths[i] as usize;
168 let end = (start + len).min(raw_output.len());
169 let actual_start = start.min(raw_output.len());
170 &raw_output[actual_start..end]
171 })
172 .collect();
173
174 // De-interleave: byte i of the original tile came from sub-stream i%n,
175 // position i/n within that sub-stream.
176 let mut output = Vec::with_capacity(uncompressed_size as usize);
177 let max_len = sub_streams.iter().map(|s| s.len()).max().unwrap_or(0);
178 for j in 0..max_len {
179 for ss in &sub_streams {
180 if j < ss.len() {
181 output.push(ss[j]);
182 }
183 if output.len() == uncompressed_size as usize {
184 return output;
185 }
186 }
187 }
188
189 output
190}
191
192// ---------------------------------------------------------------------------
193// Backend auto-discovery (cached)
194// ---------------------------------------------------------------------------
195
196/// Cached GPU backend singleton.
197///
198/// GPU device creation is expensive (50-500 ms) and rapid creation/destruction
199/// destabilizes Windows DX12 drivers, causing `DXGI_ERROR_DEVICE_REMOVED` and
200/// `device.lose()` in wgpu. By caching the backend for the process lifetime
201/// we avoid these issues and match how games and other GPU applications work.
202static CACHED_BACKEND: OnceLock<Option<Arc<dyn ComputeBackend>>> = OnceLock::new();
203
204/// Attempt to create a CUDA backend.
205///
206/// Wrapped in `catch_unwind` because cudarc's nvrtc loading panics if the
207/// NVIDIA Runtime Compiler library is not installed (e.g. `nvrtc.dll` on
208/// Windows, `libnvrtc.so` on Linux).
209///
210/// Returns `Ok(backend)` on success, `Err(message)` explaining why CUDA
211/// is unavailable on failure.
212#[cfg(feature = "cuda")]
213fn try_cuda() -> std::result::Result<Arc<dyn ComputeBackend>, String> {
214 debug!("Probing CUDA backend...");
215
216 // Temporarily silence the default panic hook so the user doesn't see
217 // Rust's "thread panicked at ..." noise when nvrtc is missing.
218 let prev_hook = std::panic::take_hook();
219 std::panic::set_hook(Box::new(|_| {}));
220
221 let result = std::panic::catch_unwind(cuda::CudaBackend::try_new);
222
223 // Restore the original panic hook.
224 std::panic::set_hook(prev_hook);
225
226 match result {
227 Ok(Ok(Some(backend))) => {
228 let gi = backend.gpu_info();
229 info!(
230 gpu = %gi.name,
231 vram_mb = gi.vram_bytes / 1024 / 1024,
232 "CUDA backend ready: {}",
233 gi.name
234 );
235 Ok(Arc::new(backend) as Arc<dyn ComputeBackend>)
236 }
237 Ok(Ok(None)) => {
238 debug!("CUDA probe: no compatible NVIDIA GPU found");
239 Err("no compatible NVIDIA GPU found".to_owned())
240 }
241 Ok(Err(e)) => {
242 debug!("CUDA probe failed: {e}");
243 Err(format!("{e}"))
244 }
245 Err(panic_info) => {
246 let msg = panic_info
247 .downcast_ref::<String>()
248 .map(String::as_str)
249 .or_else(|| panic_info.downcast_ref::<&str>().copied())
250 .unwrap_or("unknown panic");
251 warn!("CUDA probe panicked (nvrtc likely missing): {msg}");
252 Err(format!(
253 "CUDA runtime compiler (nvrtc) not found. \
254 Install the CUDA Toolkit to enable CUDA decompression. \
255 Detail: {msg}"
256 ))
257 }
258 }
259}
260
261/// Attempt to create a wgpu backend, returning `None` on failure.
262fn try_wgpu() -> Option<Arc<dyn ComputeBackend>> {
263 debug!("Probing wgpu backend...");
264 match wgpu_backend::WgpuBackend::try_new() {
265 Ok(Some(backend)) => {
266 let gi = backend.gpu_info();
267 info!(
268 gpu = %gi.name,
269 api = %gi.api_backend,
270 vram_mb = gi.vram_bytes / 1024 / 1024,
271 "wgpu backend ready: {} ({})",
272 gi.name,
273 gi.api_backend
274 );
275 Some(Arc::new(backend) as Arc<dyn ComputeBackend>)
276 }
277 Ok(None) => {
278 debug!("wgpu probe: no compatible GPU found");
279 None
280 }
281 Err(e) => {
282 warn!("wgpu backend init failed: {e}");
283 None
284 }
285 }
286}
287
288/// Discover the best available GPU backend, caching the result.
289///
290/// The backend is created once and reused for the process lifetime.
291/// If the GPU device becomes lost during use, the engine's `catch_unwind`
292/// safety net converts the error and falls back to CPU decompression.
293///
294/// Backend selection is controlled by [`crate::get_config()`]`.backend`:
295/// - `Auto`: try CUDA first (if feature enabled), then wgpu
296/// - `Cuda`: only try CUDA
297/// - `Wgpu`: only try wgpu
298///
299/// Returns `Ok(None)` if no compatible GPU is found.
300///
301/// # Errors
302///
303/// This function always returns `Ok`. GPU initialization errors are
304/// handled internally and result in `Ok(None)` (no GPU available).
305pub fn discover_gpu() -> Result<Option<Arc<dyn ComputeBackend>>> {
306 use crush_core::error::{CrushError, PluginError};
307
308 use crate::BackendPreference;
309
310 // We cannot return errors from inside `get_or_init`, so we do a two-step:
311 // first check if the cached value already exists, then handle the init
312 // with potential errors.
313 if let Some(cached) = CACHED_BACKEND.get() {
314 trace!("Using cached GPU backend");
315 return Ok(cached.clone());
316 }
317
318 let pref = crate::get_config().backend;
319 info!(preference = ?pref, "Discovering GPU backend (preference: {pref:?})");
320
321 let backend: Option<Arc<dyn ComputeBackend>> = match pref {
322 BackendPreference::Auto => {
323 #[cfg(feature = "cuda")]
324 {
325 match try_cuda() {
326 Ok(backend) => Some(backend),
327 Err(msg) => {
328 info!("CUDA unavailable ({msg}), trying wgpu");
329 try_wgpu()
330 }
331 }
332 }
333 #[cfg(not(feature = "cuda"))]
334 {
335 debug!("CUDA feature not compiled in, trying wgpu only");
336 try_wgpu()
337 }
338 }
339 BackendPreference::Cuda => {
340 #[cfg(feature = "cuda")]
341 {
342 match try_cuda() {
343 Ok(backend) => Some(backend),
344 Err(msg) => {
345 return Err(CrushError::from(PluginError::OperationFailed(format!(
346 "CUDA backend requested but unavailable: {msg}"
347 ))));
348 }
349 }
350 }
351 #[cfg(not(feature = "cuda"))]
352 {
353 return Err(CrushError::from(PluginError::OperationFailed(
354 "CUDA backend not included in this build. \
355 Reinstall with: cargo install crush-cli --features cuda"
356 .to_owned(),
357 )));
358 }
359 }
360 BackendPreference::Wgpu => try_wgpu(),
361 };
362
363 if let Some(ref b) = backend {
364 info!(backend = b.name(), "GPU backend selected: {}", b.name());
365 } else {
366 info!("No GPU backend available, will use CPU fallback");
367 }
368
369 // Cache the result (race-safe: if another thread beat us, use their value).
370 let _ = CACHED_BACKEND.set(backend.clone());
371 Ok(CACHED_BACKEND.get().cloned().flatten())
372}