Skip to main content

crush_gpu/
engine.rs

1//! Compression and decompression orchestration
2//!
3//! Implements tile-based compression using 64KB tiles with `GDeflate` encoding.
4//! Compression runs on the CPU ([`crate::gdeflate::gdeflate_compress_tile`]).
5//! Decompression can run on GPU (via [`crate::backend::ComputeBackend`]) or CPU fallback.
6
7use std::sync::atomic::{AtomicBool, Ordering};
8
9use crc32fast::Hasher;
10use crush_core::error::{CrushError, PluginError, Result};
11use tracing::{debug, info, warn};
12
13use crate::format::{
14    padding_to_alignment, GpuFileFooter, GpuFileHeader, TileFlags, TileHeader, TileIndexEntry,
15    TileIndexHeader, DEFAULT_SUB_STREAM_COUNT, DEFAULT_TILE_SIZE,
16};
17use crate::gdeflate;
18
19// ============================================================================
20// EngineConfig
21// ============================================================================
22
23/// Runtime configuration for the GPU compression engine.
24#[derive(Debug, Clone)]
25pub struct EngineConfig {
26    /// Tile size in bytes (default 65536 = 64 KB).
27    pub tile_size: u32,
28    /// Number of parallel sub-streams per tile (default 32).
29    pub sub_stream_count: u8,
30    /// Whether to store per-tile CRC32 checksums.
31    pub enable_checksums: bool,
32    /// If `true`, never attempt GPU decompression.
33    pub force_cpu: bool,
34}
35
36impl Default for EngineConfig {
37    fn default() -> Self {
38        Self {
39            tile_size: DEFAULT_TILE_SIZE,
40            sub_stream_count: DEFAULT_SUB_STREAM_COUNT,
41            enable_checksums: true,
42            force_cpu: false,
43        }
44    }
45}
46
47// ============================================================================
48// Public API
49// ============================================================================
50
51/// Compress `input` into a GPU-format archive.
52///
53/// # Errors
54///
55/// * [`CrushError::Cancelled`] if `cancel` is set during processing.
56/// * [`PluginError::OperationFailed`] on internal encoding failures.
57pub fn compress(input: &[u8], config: &EngineConfig, cancel: &AtomicBool) -> Result<Vec<u8>> {
58    if input.is_empty() {
59        return write_empty_archive(config);
60    }
61
62    let tile_size = config.tile_size as usize;
63    let tiles: Vec<&[u8]> = input.chunks(tile_size).collect();
64    let tile_count = tiles.len();
65
66    // Pre-allocate output buffer (header + estimated body).
67    let mut output = Vec::with_capacity(GpuFileHeader::SIZE + input.len());
68
69    // ── File header ──────────────────────────────────────────────────
70    let header = GpuFileHeader::new(
71        u64::try_from(tile_count).map_err(|e| PluginError::OperationFailed(e.to_string()))?,
72        u64::try_from(input.len()).map_err(|e| PluginError::OperationFailed(e.to_string()))?,
73    );
74    output.extend_from_slice(&header.to_bytes());
75
76    // Pad file header to 128-byte alignment so tiles start aligned.
77    let hdr_pad = padding_to_alignment(output.len());
78    output.resize(output.len() + hdr_pad, 0);
79
80    // ── Compress each tile ───────────────────────────────────────────
81    let mut index_entries: Vec<TileIndexEntry> = Vec::with_capacity(tile_count);
82
83    for (i, tile_data) in tiles.iter().enumerate() {
84        if cancel.load(Ordering::Relaxed) {
85            return Err(CrushError::Cancelled);
86        }
87
88        let is_last = i + 1 == tile_count;
89
90        let tile_offset =
91            u64::try_from(output.len()).map_err(|e| PluginError::OperationFailed(e.to_string()))?;
92
93        let checksum = if config.enable_checksums {
94            let mut h = Hasher::new();
95            h.update(tile_data);
96            h.finalize()
97        } else {
98            0
99        };
100
101        // Compress tile with GDeflate (handles 32-way sub-stream interleaving internally).
102        let compressed_payload = gdeflate::gdeflate_compress_tile(tile_data)?;
103
104        let compressed_size = u32::try_from(compressed_payload.len())
105            .map_err(|e| PluginError::OperationFailed(e.to_string()))?;
106        let uncompressed_size = u32::try_from(tile_data.len())
107            .map_err(|e| PluginError::OperationFailed(e.to_string()))?;
108
109        // Build tile flags.
110        let mut flags = TileFlags::default();
111        if is_last {
112            flags = flags.with_last_tile();
113        }
114        // If "compression" expanded the data, store raw.
115        let (final_payload, final_compressed_size, final_flags) =
116            if compressed_payload.len() >= tile_data.len() {
117                // Stored mode – write raw tile
118                let stored_size = u32::try_from(tile_data.len())
119                    .map_err(|e| PluginError::OperationFailed(e.to_string()))?;
120                (tile_data.to_vec(), stored_size, flags.with_stored())
121            } else {
122                (compressed_payload, compressed_size, flags)
123            };
124
125        let tile_header = TileHeader {
126            version: 2,
127            flags: final_flags,
128            sub_stream_count: config.sub_stream_count,
129            compressed_size: final_compressed_size,
130            uncompressed_size,
131            checksum,
132            sub_stream_offsets_size: 0, // GDeflate handles sub-streams internally
133        };
134
135        output.extend_from_slice(&tile_header.to_bytes());
136        output.extend_from_slice(&final_payload);
137
138        // Pad to 128-byte alignment.
139        let written = TileHeader::SIZE + final_payload.len();
140        let pad = padding_to_alignment(written);
141        output.resize(output.len() + pad, 0);
142
143        let entry = TileIndexEntry {
144            tile_offset,
145            compressed_size: final_compressed_size,
146            uncompressed_size,
147            checksum,
148            flags: u32::from(final_flags.0),
149        };
150        index_entries.push(entry);
151    }
152
153    // ── Tile index ───────────────────────────────────────────────────
154    let index_offset =
155        u64::try_from(output.len()).map_err(|e| PluginError::OperationFailed(e.to_string()))?;
156
157    let index_header = TileIndexHeader {
158        entry_count: u32::try_from(index_entries.len())
159            .map_err(|e| PluginError::OperationFailed(e.to_string()))?,
160        index_flags: 0,
161    };
162    output.extend_from_slice(&index_header.to_bytes());
163
164    for entry in &index_entries {
165        output.extend_from_slice(&entry.to_bytes());
166    }
167
168    let idx_off_usize =
169        usize::try_from(index_offset).map_err(|e| PluginError::OperationFailed(e.to_string()))?;
170    let index_size = u32::try_from(output.len() - idx_off_usize)
171        .map_err(|e| PluginError::OperationFailed(e.to_string()))?;
172
173    // ── Footer ───────────────────────────────────────────────────────
174    let footer = GpuFileFooter::new(index_offset, index_size);
175    output.extend_from_slice(&footer.to_bytes());
176
177    Ok(output)
178}
179
180/// Decompress a GPU-format archive back to the original bytes.
181///
182/// # Errors
183///
184/// * [`CrushError::InvalidFormat`] if the archive header/footer is invalid.
185/// * [`CrushError::Cancelled`] if `cancel` is set during processing.
186pub fn decompress(input: &[u8], config: &EngineConfig, cancel: &AtomicBool) -> Result<Vec<u8>> {
187    if input.len() < GpuFileHeader::SIZE + GpuFileFooter::SIZE {
188        return try_decompress_empty_archive(input);
189    }
190
191    // ── Footer ───────────────────────────────────────────────────────
192    let footer_start = input.len() - GpuFileFooter::SIZE;
193    let footer_bytes: &[u8; GpuFileFooter::SIZE] = input[footer_start..]
194        .try_into()
195        .map_err(|_| CrushError::InvalidFormat("footer truncated".to_owned()))?;
196    let footer = GpuFileFooter::from_bytes(footer_bytes)?;
197
198    // ── File header ──────────────────────────────────────────────────
199    let hdr_bytes: &[u8; GpuFileHeader::SIZE] = input[..GpuFileHeader::SIZE]
200        .try_into()
201        .map_err(|_| CrushError::InvalidFormat("header truncated".to_owned()))?;
202    let header = GpuFileHeader::from_bytes(hdr_bytes)?;
203
204    if header.tile_count == 0 {
205        return Ok(Vec::new());
206    }
207
208    let entries = read_tile_index(input, &footer, footer_start)?;
209
210    // ── Try GPU decompression first ──────────────────────────────────
211    if config.force_cpu {
212        debug!("force_cpu=true, skipping GPU discovery");
213    } else {
214        debug!(
215            tile_count = entries.len(),
216            force_cpu = config.force_cpu,
217            "Attempting GPU decompression for {} tiles",
218            entries.len()
219        );
220        match crate::backend::discover_gpu() {
221            Ok(Some(backend)) => {
222                info!(
223                    backend = backend.name(),
224                    tiles = entries.len(),
225                    "Using GPU backend '{}' for {} tiles",
226                    backend.name(),
227                    entries.len()
228                );
229                match decompress_tiles_gpu(input, &header, &entries, config, cancel, &*backend) {
230                    Ok(output) => {
231                        info!(
232                            output_bytes = output.len(),
233                            "GPU decompression succeeded ({} bytes)",
234                            output.len()
235                        );
236                        return Ok(output);
237                    }
238                    Err(e) => {
239                        warn!("GPU decompression failed, falling back to CPU: {e}");
240                    }
241                }
242            }
243            Ok(None) => {
244                info!("No GPU backend available, using CPU");
245            }
246            Err(e) => {
247                // GPU discovery returned an explicit error (e.g. user requested
248                // --gpu-backend cuda but CUDA is unavailable). Propagate it.
249                return Err(e);
250            }
251        }
252    }
253
254    info!(
255        tiles = entries.len(),
256        "Starting CPU decompression for {} tiles",
257        entries.len()
258    );
259    decompress_tiles_cpu(input, &header, &entries, config, cancel)
260}
261
262/// Attempt to parse an undersized archive as an empty archive.
263fn try_decompress_empty_archive(input: &[u8]) -> Result<Vec<u8>> {
264    if input.len() == GpuFileHeader::SIZE + TileIndexHeader::SIZE + GpuFileFooter::SIZE {
265        let hdr_bytes: &[u8; GpuFileHeader::SIZE] = input[..GpuFileHeader::SIZE]
266            .try_into()
267            .map_err(|_| CrushError::InvalidFormat("header truncated".to_owned()))?;
268        let header = GpuFileHeader::from_bytes(hdr_bytes)?;
269        if header.tile_count == 0 {
270            return Ok(Vec::new());
271        }
272    }
273    Err(CrushError::InvalidFormat(
274        "archive too small for GPU format".to_owned(),
275    ))
276}
277
278/// Read the tile index from the archive.
279fn read_tile_index(
280    input: &[u8],
281    footer: &GpuFileFooter,
282    footer_start: usize,
283) -> Result<Vec<TileIndexEntry>> {
284    let idx_off = usize::try_from(footer.index_offset)
285        .map_err(|_| CrushError::InvalidFormat("index offset too large for platform".to_owned()))?;
286    if idx_off + TileIndexHeader::SIZE > footer_start {
287        return Err(CrushError::IndexCorrupted(
288            "tile index header beyond archive".to_owned(),
289        ));
290    }
291    let idx_hdr_bytes: &[u8; TileIndexHeader::SIZE] = input
292        [idx_off..idx_off + TileIndexHeader::SIZE]
293        .try_into()
294        .map_err(|_| CrushError::IndexCorrupted("index header truncated".to_owned()))?;
295    let idx_hdr = TileIndexHeader::from_bytes(idx_hdr_bytes);
296
297    let entry_count = idx_hdr.entry_count as usize;
298    let entries_start = idx_off + TileIndexHeader::SIZE;
299    let mut entries = Vec::with_capacity(entry_count);
300    for i in 0..entry_count {
301        let e_off = entries_start + i * TileIndexEntry::SIZE;
302        if e_off + TileIndexEntry::SIZE > footer_start {
303            return Err(CrushError::IndexCorrupted(
304                "tile index entry beyond archive".to_owned(),
305            ));
306        }
307        let e_bytes: &[u8; TileIndexEntry::SIZE] = input[e_off..e_off + TileIndexEntry::SIZE]
308            .try_into()
309            .map_err(|_| CrushError::IndexCorrupted("entry truncated".to_owned()))?;
310        entries.push(TileIndexEntry::from_bytes(e_bytes));
311    }
312    Ok(entries)
313}
314
315/// CPU-fallback decompression of all tiles.
316fn decompress_tiles_cpu(
317    input: &[u8],
318    header: &GpuFileHeader,
319    entries: &[TileIndexEntry],
320    config: &EngineConfig,
321    cancel: &AtomicBool,
322) -> Result<Vec<u8>> {
323    let uncompressed_total = usize::try_from(header.uncompressed_size).map_err(|_| {
324        CrushError::InvalidFormat("uncompressed size too large for platform".to_owned())
325    })?;
326    let mut output = Vec::with_capacity(uncompressed_total);
327
328    for (i, entry) in entries.iter().enumerate() {
329        if cancel.load(Ordering::Relaxed) {
330            return Err(CrushError::Cancelled);
331        }
332
333        let tile_data = read_and_decompress_tile(input, entry, i, config.sub_stream_count)?;
334
335        if config.enable_checksums && entry.checksum != 0 {
336            let mut h = Hasher::new();
337            h.update(&tile_data);
338            let actual = h.finalize();
339            if actual != entry.checksum {
340                return Err(CrushError::ChecksumMismatch {
341                    block_index: u64::try_from(i)
342                        .map_err(|e| PluginError::OperationFailed(e.to_string()))?,
343                    expected: entry.checksum,
344                    actual,
345                });
346            }
347        }
348
349        output.extend_from_slice(&tile_data);
350    }
351
352    Ok(output)
353}
354
355/// GPU decompression of all tiles via a [`ComputeBackend`].
356fn decompress_tiles_gpu(
357    input: &[u8],
358    header: &GpuFileHeader,
359    entries: &[TileIndexEntry],
360    config: &EngineConfig,
361    cancel: &AtomicBool,
362    backend: &dyn crate::backend::ComputeBackend,
363) -> Result<Vec<u8>> {
364    // Build CompressedTile structs from the archive entries.
365    debug!(
366        tile_count = entries.len(),
367        backend = backend.name(),
368        "Building compressed tile structs for GPU dispatch"
369    );
370    let build_start = std::time::Instant::now();
371    let mut tiles = Vec::with_capacity(entries.len());
372    for (i, entry) in entries.iter().enumerate() {
373        if cancel.load(Ordering::Relaxed) {
374            return Err(CrushError::Cancelled);
375        }
376        let tile_off = usize::try_from(entry.tile_offset).map_err(|_| {
377            CrushError::InvalidFormat("tile offset too large for platform".to_owned())
378        })?;
379        if tile_off + TileHeader::SIZE > input.len() {
380            return Err(CrushError::InvalidFormat(format!(
381                "tile {i} header at offset {tile_off} is beyond archive"
382            )));
383        }
384        let th_bytes: &[u8; TileHeader::SIZE] = input[tile_off..tile_off + TileHeader::SIZE]
385            .try_into()
386            .map_err(|_| CrushError::InvalidFormat("tile header truncated".to_owned()))?;
387        let tile_hdr = TileHeader::from_bytes(th_bytes);
388
389        let payload_start = tile_off + TileHeader::SIZE;
390        let payload_end = payload_start + tile_hdr.compressed_size as usize;
391        if payload_end > input.len() {
392            return Err(CrushError::InvalidFormat(format!(
393                "tile {i} payload extends beyond archive"
394            )));
395        }
396
397        tiles.push(crate::backend::CompressedTile {
398            data: input[payload_start..payload_end].to_vec(),
399            uncompressed_size: tile_hdr.uncompressed_size,
400            sub_stream_count: config.sub_stream_count,
401            checksum: entry.checksum,
402        });
403    }
404    let build_elapsed = build_start.elapsed();
405    debug!(
406        tile_count = tiles.len(),
407        elapsed_secs = build_elapsed.as_secs_f64(),
408        "Built {} tile structs in {:.1}ms",
409        tiles.len(),
410        build_elapsed.as_secs_f64() * 1000.0
411    );
412
413    // Dispatch to GPU backend (GDeflate path).
414    debug!(
415        "Dispatching {} tiles to {} backend...",
416        tiles.len(),
417        backend.name()
418    );
419    let dispatch_start = std::time::Instant::now();
420    let decompressed = backend.decompress_tiles_gdeflate(&tiles, cancel)?;
421    let dispatch_elapsed = dispatch_start.elapsed();
422    debug!(
423        elapsed_secs = dispatch_elapsed.as_secs_f64(),
424        "GPU dispatch completed in {:.1}ms",
425        dispatch_elapsed.as_secs_f64() * 1000.0
426    );
427
428    // Validate checksums and assemble output.
429    let uncompressed_total = usize::try_from(header.uncompressed_size).map_err(|_| {
430        CrushError::InvalidFormat("uncompressed size too large for platform".to_owned())
431    })?;
432    let mut output = Vec::with_capacity(uncompressed_total);
433    for (i, (tile_data, entry)) in decompressed.iter().zip(entries.iter()).enumerate() {
434        if config.enable_checksums && entry.checksum != 0 {
435            let mut h = Hasher::new();
436            h.update(tile_data);
437            let actual = h.finalize();
438            if actual != entry.checksum {
439                return Err(CrushError::ChecksumMismatch {
440                    block_index: u64::try_from(i)
441                        .map_err(|e| PluginError::OperationFailed(e.to_string()))?,
442                    expected: entry.checksum,
443                    actual,
444                });
445            }
446        }
447        output.extend_from_slice(tile_data);
448    }
449
450    Ok(output)
451}
452
453/// Read a single tile from the archive and decompress it.
454fn read_and_decompress_tile(
455    input: &[u8],
456    entry: &TileIndexEntry,
457    tile_idx: usize,
458    sub_stream_count: u8,
459) -> Result<Vec<u8>> {
460    let tile_off = usize::try_from(entry.tile_offset)
461        .map_err(|_| CrushError::InvalidFormat("tile offset too large for platform".to_owned()))?;
462    if tile_off + TileHeader::SIZE > input.len() {
463        return Err(CrushError::InvalidFormat(format!(
464            "tile {tile_idx} header at offset {tile_off} is beyond archive"
465        )));
466    }
467    let th_bytes: &[u8; TileHeader::SIZE] = input[tile_off..tile_off + TileHeader::SIZE]
468        .try_into()
469        .map_err(|_| CrushError::InvalidFormat("tile header truncated".to_owned()))?;
470    let tile_hdr = TileHeader::from_bytes(th_bytes);
471
472    let payload_start = tile_off + TileHeader::SIZE;
473    let payload_end = payload_start + tile_hdr.compressed_size as usize;
474    if payload_end > input.len() {
475        return Err(CrushError::InvalidFormat(format!(
476            "tile {tile_idx} payload extends beyond archive"
477        )));
478    }
479    let payload = &input[payload_start..payload_end];
480
481    let _ = sub_stream_count; // GDeflate handles sub-streams internally
482    if tile_hdr.flags.stored() {
483        Ok(payload.to_vec())
484    } else {
485        gdeflate::gdeflate_decompress_tile(payload, tile_hdr.uncompressed_size as usize)
486    }
487}
488
489// (Tile compression and decompression are handled by the gdeflate module.)
490
491// ============================================================================
492// Internal: empty archive
493// ============================================================================
494
495fn write_empty_archive(config: &EngineConfig) -> Result<Vec<u8>> {
496    let header = GpuFileHeader::new(0, 0);
497    let mut output =
498        Vec::with_capacity(GpuFileHeader::SIZE + TileIndexHeader::SIZE + GpuFileFooter::SIZE);
499    output.extend_from_slice(&header.to_bytes());
500
501    let index_offset =
502        u64::try_from(output.len()).map_err(|e| PluginError::OperationFailed(e.to_string()))?;
503    let idx_hdr = TileIndexHeader {
504        entry_count: 0,
505        index_flags: 0,
506    };
507    output.extend_from_slice(&idx_hdr.to_bytes());
508    let index_size = u32::try_from(TileIndexHeader::SIZE)
509        .map_err(|e| PluginError::OperationFailed(e.to_string()))?;
510    let footer = GpuFileFooter::new(index_offset, index_size);
511    output.extend_from_slice(&footer.to_bytes());
512
513    let _ = config; // config used for consistency with non-empty path
514    Ok(output)
515}
516
517// ============================================================================
518// Public: random access API (US4)
519// ============================================================================
520
521/// Loaded tile index for O(1) tile lookup.
522#[derive(Debug, Clone)]
523pub struct TileIndex {
524    /// Header from the archive.
525    pub header: GpuFileHeader,
526    /// One entry per tile, in order.
527    pub entries: Vec<TileIndexEntry>,
528}
529
530impl TileIndex {
531    /// Number of tiles in the archive.
532    #[must_use]
533    pub fn tile_count(&self) -> usize {
534        self.entries.len()
535    }
536
537    /// Get an entry by tile index.  Returns `None` if out of bounds.
538    #[must_use]
539    pub fn get(&self, index: usize) -> Option<&TileIndexEntry> {
540        self.entries.get(index)
541    }
542}
543
544/// Load the tile index from a GPU archive without decompressing any tiles.
545///
546/// This enables O(1) tile lookup for random access decompression.
547///
548/// # Errors
549///
550/// * [`CrushError::InvalidFormat`] if the archive is malformed.
551pub fn load_tile_index(archive: &[u8]) -> Result<TileIndex> {
552    let min_size = GpuFileHeader::SIZE + TileIndexHeader::SIZE + GpuFileFooter::SIZE;
553    if archive.len() < min_size {
554        return Err(CrushError::InvalidFormat(
555            "archive too small for tile index".to_owned(),
556        ));
557    }
558
559    let footer_start = archive.len() - GpuFileFooter::SIZE;
560    let footer_bytes: &[u8; GpuFileFooter::SIZE] = archive[footer_start..]
561        .try_into()
562        .map_err(|_| CrushError::InvalidFormat("footer truncated".to_owned()))?;
563    let footer = GpuFileFooter::from_bytes(footer_bytes)?;
564
565    let hdr_bytes: &[u8; GpuFileHeader::SIZE] = archive[..GpuFileHeader::SIZE]
566        .try_into()
567        .map_err(|_| CrushError::InvalidFormat("header truncated".to_owned()))?;
568    let header = GpuFileHeader::from_bytes(hdr_bytes)?;
569
570    let entries = read_tile_index(archive, &footer, footer_start)?;
571
572    Ok(TileIndex { header, entries })
573}
574
575/// Decompress a single tile by index from a GPU archive.
576///
577/// Only reads the target tile's header + payload — no other tiles are touched.
578///
579/// # Errors
580///
581/// * [`CrushError::InvalidFormat`] if the tile index is invalid.
582/// * [`CrushError::ChecksumMismatch`] if CRC validation fails.
583pub fn decompress_tile_by_index(
584    archive: &[u8],
585    tile_index: &TileIndex,
586    index: usize,
587    config: &EngineConfig,
588) -> Result<Vec<u8>> {
589    let entry = tile_index.get(index).ok_or_else(|| {
590        CrushError::InvalidFormat(format!(
591            "tile index {index} out of range ({})",
592            tile_index.tile_count()
593        ))
594    })?;
595
596    let tile_data = read_and_decompress_tile(archive, entry, index, config.sub_stream_count)?;
597
598    if config.enable_checksums && entry.checksum != 0 {
599        let mut h = Hasher::new();
600        h.update(&tile_data);
601        let actual = h.finalize();
602        if actual != entry.checksum {
603            return Err(CrushError::ChecksumMismatch {
604                block_index: u64::try_from(index)
605                    .map_err(|e| PluginError::OperationFailed(e.to_string()))?,
606                expected: entry.checksum,
607                actual,
608            });
609        }
610    }
611
612    Ok(tile_data)
613}