use std::ffi::{CStr, c_void};
use std::ptr::null_mut;
use std::sync::Mutex;
use super::nvcomp_sys::cuda::*;
use super::nvcomp_sys::nvcomp::*;
use super::{Algo, BitcompDataType, Codec, Error, Result};
const FRAME_MAGIC: [u8; 4] = *b"FCG1";
const HEADER_FIXED_BYTES: usize = 4 + 1 + 3 + 8 + 4 + 4;
const DEFAULT_CHUNK_SIZE: usize = 64 * 1024;
pub struct NvcompCodec {
algo: Algo,
chunk_size: usize,
stream: cudaStream_t,
inner: Mutex<NvcompCodecInner>,
}
unsafe impl Send for NvcompCodec {}
unsafe impl Sync for NvcompCodec {}
impl std::fmt::Debug for NvcompCodec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NvcompCodec")
.field("algo", &self.algo)
.field("chunk_size", &self.chunk_size)
.finish()
}
}
impl NvcompCodec {
pub fn new(algo: Algo) -> Result<Self> {
Self::with_chunk_size(algo, DEFAULT_CHUNK_SIZE)
}
pub fn with_chunk_size(algo: Algo, chunk_size: usize) -> Result<Self> {
match algo {
Algo::Snappy | Algo::Lz4 | Algo::Zstd | Algo::GDeflate | Algo::Bitcomp { .. } => {}
other => return Err(Error::UnsupportedAlgo(other)),
}
if chunk_size == 0 || chunk_size > (1 << 24) {
return Err(Error::Compress(format!(
"nvcomp chunk_size must be in (0, 16 MiB]; got {chunk_size}"
)));
}
let mut stream: cudaStream_t = null_mut();
let rc = unsafe { cudaStreamCreate(&mut stream) };
check_cuda(rc, "cudaStreamCreate")?;
Ok(Self {
algo,
chunk_size,
stream,
inner: Mutex::new(NvcompCodecInner::default()),
})
}
pub fn is_available() -> bool {
let mut count = 0i32;
let rc = unsafe { cudaGetDeviceCount(&mut count) };
rc == CUDA_SUCCESS && count > 0
}
pub fn cuda_stream(&self) -> cudaStream_t {
self.stream
}
}
impl Drop for NvcompCodec {
fn drop(&mut self) {
if !self.stream.is_null() {
unsafe { cudaStreamDestroy(self.stream) };
}
}
}
impl Codec for NvcompCodec {
fn algo(&self) -> Algo {
self.algo
}
fn compress(&self, input: &[u8], output: &mut Vec<u8>) -> Result<()> {
let mut inner = self.inner.lock().expect("nvcomp codec inner poisoned");
compress_chunked(
self.algo,
self.chunk_size,
self.stream,
&mut inner,
input,
output,
)
}
fn decompress(&self, input: &[u8], output: &mut Vec<u8>) -> Result<()> {
let mut inner = self.inner.lock().expect("nvcomp codec inner poisoned");
decompress_chunked(self.stream, &mut inner, input, output)
}
fn max_compressed_len(&self, uncompressed_len: usize) -> usize {
let num_chunks = uncompressed_len.div_ceil(self.chunk_size).max(1);
let max_per_chunk = match self.algo {
Algo::Snappy => 32 + self.chunk_size + self.chunk_size / 6,
Algo::Lz4 => self.chunk_size + self.chunk_size / 255 + 16,
Algo::Zstd => self.chunk_size + self.chunk_size / 200 + 64,
Algo::GDeflate => self.chunk_size + self.chunk_size / 200 + 64,
Algo::Bitcomp { .. } => self.chunk_size + self.chunk_size / 64 + 64,
_ => self.chunk_size,
};
HEADER_FIXED_BYTES + 4 * num_chunks + max_per_chunk * num_chunks
}
}
#[derive(Default)]
struct NvcompCodecInner {
d_uncomp: *mut c_void,
d_uncomp_cap: usize,
d_comp: *mut c_void,
d_comp_cap: usize,
d_temp: *mut c_void,
d_temp_cap: usize,
d_uncomp_ptrs: *mut c_void,
d_uncomp_sizes: *mut c_void,
d_comp_ptrs: *mut c_void,
d_comp_sizes: *mut c_void,
d_uncomp_buf_sizes: *mut c_void,
d_uncomp_actual_sizes: *mut c_void,
d_statuses: *mut c_void,
chunks_cap: usize,
h_pinned_input: *mut c_void,
h_pinned_input_cap: usize,
h_pinned_output: *mut c_void,
h_pinned_output_cap: usize,
h_pinned_meta: *mut c_void,
h_pinned_meta_cap: usize,
h_uncomp_ptrs: Vec<*const c_void>,
h_uncomp_sizes: Vec<usize>,
h_comp_ptrs: Vec<*mut c_void>,
h_comp_sizes: Vec<usize>,
h_uncomp_buf_sizes: Vec<usize>,
h_statuses: Vec<nvcompStatus_t>,
}
impl Drop for NvcompCodecInner {
fn drop(&mut self) {
unsafe {
for p in [
&mut self.d_uncomp,
&mut self.d_comp,
&mut self.d_temp,
&mut self.d_uncomp_ptrs,
&mut self.d_uncomp_sizes,
&mut self.d_comp_ptrs,
&mut self.d_comp_sizes,
&mut self.d_uncomp_buf_sizes,
&mut self.d_uncomp_actual_sizes,
&mut self.d_statuses,
] {
if !p.is_null() {
cudaFree(*p);
*p = null_mut();
}
}
for p in [
&mut self.h_pinned_input,
&mut self.h_pinned_output,
&mut self.h_pinned_meta,
] {
if !p.is_null() {
cudaFreeHost(*p);
*p = null_mut();
}
}
}
}
}
impl NvcompCodecInner {
fn ensure_d_buf(&mut self, kind: BufKind, needed: usize) -> Result<()> {
if needed == 0 {
return Ok(());
}
let (slot, cap) = match kind {
BufKind::Uncomp => (&mut self.d_uncomp, &mut self.d_uncomp_cap),
BufKind::Comp => (&mut self.d_comp, &mut self.d_comp_cap),
BufKind::Temp => (&mut self.d_temp, &mut self.d_temp_cap),
};
if *cap >= needed {
return Ok(());
}
if !slot.is_null() {
unsafe { cudaFree(*slot) };
*slot = null_mut();
*cap = 0;
}
let alloc_size = needed.div_ceil(1 << 20).max(1) << 20;
check_cuda(unsafe { cudaMalloc(slot, alloc_size) }, "cudaMalloc(buf)")?;
*cap = alloc_size;
Ok(())
}
fn ensure_metadata(&mut self, chunks: usize) -> Result<()> {
if self.chunks_cap >= chunks {
return Ok(());
}
unsafe {
for p in [
&mut self.d_uncomp_ptrs,
&mut self.d_uncomp_sizes,
&mut self.d_comp_ptrs,
&mut self.d_comp_sizes,
&mut self.d_uncomp_buf_sizes,
&mut self.d_uncomp_actual_sizes,
&mut self.d_statuses,
] {
if !p.is_null() {
cudaFree(*p);
*p = null_mut();
}
}
}
let target = chunks.next_power_of_two().max(64);
let ptr_bytes = target * std::mem::size_of::<*const c_void>();
let size_bytes = target * std::mem::size_of::<usize>();
let status_bytes = target * std::mem::size_of::<nvcompStatus_t>();
check_cuda(
unsafe { cudaMalloc(&mut self.d_uncomp_ptrs, ptr_bytes) },
"cudaMalloc(uncomp_ptrs)",
)?;
check_cuda(
unsafe { cudaMalloc(&mut self.d_uncomp_sizes, size_bytes) },
"cudaMalloc(uncomp_sizes)",
)?;
check_cuda(
unsafe { cudaMalloc(&mut self.d_comp_ptrs, ptr_bytes) },
"cudaMalloc(comp_ptrs)",
)?;
check_cuda(
unsafe { cudaMalloc(&mut self.d_comp_sizes, size_bytes) },
"cudaMalloc(comp_sizes)",
)?;
check_cuda(
unsafe { cudaMalloc(&mut self.d_uncomp_buf_sizes, size_bytes) },
"cudaMalloc(uncomp_buf_sizes)",
)?;
check_cuda(
unsafe { cudaMalloc(&mut self.d_uncomp_actual_sizes, size_bytes) },
"cudaMalloc(uncomp_actual_sizes)",
)?;
check_cuda(
unsafe { cudaMalloc(&mut self.d_statuses, status_bytes) },
"cudaMalloc(statuses)",
)?;
self.chunks_cap = target;
self.h_uncomp_ptrs.resize(target, std::ptr::null());
self.h_uncomp_sizes.resize(target, 0);
self.h_comp_ptrs.resize(target, std::ptr::null_mut());
self.h_comp_sizes.resize(target, 0);
self.h_uncomp_buf_sizes.resize(target, 0);
self.h_statuses.resize(target, 0);
Ok(())
}
fn ensure_pinned(&mut self, kind: PinnedKind, needed: usize) -> Result<()> {
if needed == 0 {
return Ok(());
}
let (slot, cap) = match kind {
PinnedKind::Input => (&mut self.h_pinned_input, &mut self.h_pinned_input_cap),
PinnedKind::Output => (&mut self.h_pinned_output, &mut self.h_pinned_output_cap),
PinnedKind::Meta => (&mut self.h_pinned_meta, &mut self.h_pinned_meta_cap),
};
if *cap >= needed {
return Ok(());
}
if !slot.is_null() {
unsafe { cudaFreeHost(*slot) };
*slot = null_mut();
*cap = 0;
}
let alloc_size = needed.div_ceil(1 << 20).max(1) << 20;
check_cuda(
unsafe { cudaHostAlloc(slot, alloc_size, cudaHostAllocDefault) },
"cudaHostAlloc",
)?;
*cap = alloc_size;
Ok(())
}
}
#[derive(Clone, Copy)]
enum BufKind {
Uncomp,
Comp,
Temp,
}
#[derive(Clone, Copy)]
enum PinnedKind {
Input,
Output,
Meta,
}
fn check_cuda(rc: cudaError_t, what: &'static str) -> Result<()> {
if rc == CUDA_SUCCESS {
return Ok(());
}
let msg = unsafe {
let s = cudaGetErrorString(rc);
if s.is_null() {
"unknown".to_string()
} else {
CStr::from_ptr(s).to_string_lossy().into_owned()
}
};
Err(Error::Compress(format!(
"CUDA error in {what}: code={rc} ({msg})"
)))
}
fn check_nvcomp(status: nvcompStatus_t, what: &'static str) -> Result<()> {
if status == nvcompSuccess {
Ok(())
} else {
Err(Error::Compress(format!(
"nvCOMP error in {what}: code={status} ({})",
status_str(status)
)))
}
}
fn compress_get_max_output_chunk_size(algo: Algo, max_chunk: usize) -> Result<usize> {
let mut out = 0usize;
let status = unsafe {
match algo {
Algo::Snappy => nvcompBatchedSnappyCompressGetMaxOutputChunkSize(
max_chunk,
Default::default(),
&mut out,
),
Algo::Lz4 => nvcompBatchedLZ4CompressGetMaxOutputChunkSize(
max_chunk,
Default::default(),
&mut out,
),
Algo::Zstd => nvcompBatchedZstdCompressGetMaxOutputChunkSize(
max_chunk,
Default::default(),
&mut out,
),
Algo::GDeflate => nvcompBatchedGdeflateCompressGetMaxOutputChunkSize(
max_chunk,
Default::default(),
&mut out,
),
Algo::Bitcomp { data_type } => nvcompBatchedBitcompCompressGetMaxOutputChunkSize(
max_chunk,
bitcomp_format_opts(data_type),
&mut out,
),
_ => return Err(Error::UnsupportedAlgo(algo)),
}
};
check_nvcomp(status, "GetMaxOutputChunkSize")?;
Ok(out)
}
#[allow(clippy::too_many_arguments)]
fn compress_get_temp_size(
algo: Algo,
d_uncomp_ptrs: *const *const c_void,
d_uncomp_sizes: *const usize,
num_chunks: usize,
max_chunk: usize,
total_uncomp: usize,
stream: cudaStream_t,
) -> Result<usize> {
let mut out = 0usize;
let status = unsafe {
match algo {
Algo::Snappy => nvcompBatchedSnappyCompressGetTempSizeAsync(
num_chunks,
max_chunk,
Default::default(),
&mut out,
total_uncomp,
),
Algo::Lz4 => nvcompBatchedLZ4CompressGetTempSizeAsync(
num_chunks,
max_chunk,
Default::default(),
&mut out,
total_uncomp,
),
Algo::Zstd => nvcompBatchedZstdCompressGetTempSizeAsync(
num_chunks,
max_chunk,
Default::default(),
&mut out,
total_uncomp,
),
Algo::GDeflate => nvcompBatchedGdeflateCompressGetTempSizeAsync(
num_chunks,
max_chunk,
Default::default(),
&mut out,
total_uncomp,
),
Algo::Bitcomp { data_type } => nvcompBatchedBitcompCompressGetTempSizeSync(
d_uncomp_ptrs,
d_uncomp_sizes,
num_chunks,
max_chunk,
bitcomp_format_opts(data_type),
&mut out,
total_uncomp,
stream,
),
_ => return Err(Error::UnsupportedAlgo(algo)),
}
};
check_nvcomp(status, "CompressGetTempSize")?;
Ok(out)
}
#[allow(clippy::too_many_arguments)]
fn dispatch_compress(
algo: Algo,
d_uncomp_ptrs: *const *const c_void,
d_uncomp_sizes: *const usize,
max_chunk: usize,
num_chunks: usize,
d_temp: *mut c_void,
temp_bytes: usize,
d_comp_ptrs: *const *mut c_void,
d_comp_sizes: *mut usize,
d_statuses: *mut nvcompStatus_t,
stream: cudaStream_t,
) -> Result<()> {
let status = unsafe {
match algo {
Algo::Snappy => nvcompBatchedSnappyCompressAsync(
d_uncomp_ptrs,
d_uncomp_sizes,
max_chunk,
num_chunks,
d_temp,
temp_bytes,
d_comp_ptrs,
d_comp_sizes,
Default::default(),
d_statuses,
stream,
),
Algo::Lz4 => nvcompBatchedLZ4CompressAsync(
d_uncomp_ptrs,
d_uncomp_sizes,
max_chunk,
num_chunks,
d_temp,
temp_bytes,
d_comp_ptrs,
d_comp_sizes,
Default::default(),
d_statuses,
stream,
),
Algo::Zstd => nvcompBatchedZstdCompressAsync(
d_uncomp_ptrs,
d_uncomp_sizes,
max_chunk,
num_chunks,
d_temp,
temp_bytes,
d_comp_ptrs,
d_comp_sizes,
Default::default(),
d_statuses,
stream,
),
Algo::GDeflate => nvcompBatchedGdeflateCompressAsync(
d_uncomp_ptrs,
d_uncomp_sizes,
max_chunk,
num_chunks,
d_temp,
temp_bytes,
d_comp_ptrs,
d_comp_sizes,
Default::default(),
d_statuses,
stream,
),
Algo::Bitcomp { data_type } => nvcompBatchedBitcompCompressAsync(
d_uncomp_ptrs,
d_uncomp_sizes,
max_chunk,
num_chunks,
d_temp,
temp_bytes,
d_comp_ptrs,
d_comp_sizes,
bitcomp_format_opts(data_type),
d_statuses,
stream,
),
_ => return Err(Error::UnsupportedAlgo(algo)),
}
};
check_nvcomp(status, "CompressAsync")
}
fn decompress_get_temp_size(
algo: Algo,
num_chunks: usize,
chunk_size: usize,
total_uncomp: usize,
) -> Result<usize> {
let mut out = 0usize;
let status = unsafe {
match algo {
Algo::Snappy => nvcompBatchedSnappyDecompressGetTempSizeAsync(
num_chunks,
chunk_size,
Default::default(),
&mut out,
total_uncomp,
),
Algo::Lz4 => nvcompBatchedLZ4DecompressGetTempSizeAsync(
num_chunks,
chunk_size,
Default::default(),
&mut out,
total_uncomp,
),
Algo::Zstd => nvcompBatchedZstdDecompressGetTempSizeAsync(
num_chunks,
chunk_size,
Default::default(),
&mut out,
total_uncomp,
),
Algo::GDeflate => nvcompBatchedGdeflateDecompressGetTempSizeAsync(
num_chunks,
chunk_size,
Default::default(),
&mut out,
total_uncomp,
),
Algo::Bitcomp { .. } => nvcompBatchedBitcompDecompressGetTempSizeAsync(
num_chunks,
chunk_size,
Default::default(),
&mut out,
total_uncomp,
),
_ => return Err(Error::UnsupportedAlgo(algo)),
}
};
check_nvcomp(status, "DecompressGetTempSize")?;
Ok(out)
}
#[allow(clippy::too_many_arguments)]
fn dispatch_decompress(
algo: Algo,
d_comp_ptrs: *const *const c_void,
d_comp_sizes: *const usize,
d_uncomp_buf_sizes: *const usize,
d_uncomp_actual_sizes: *mut usize,
num_chunks: usize,
d_temp: *mut c_void,
temp_bytes: usize,
d_uncomp_ptrs: *const *mut c_void,
d_statuses: *mut nvcompStatus_t,
stream: cudaStream_t,
) -> Result<()> {
let status = unsafe {
match algo {
Algo::Snappy => nvcompBatchedSnappyDecompressAsync(
d_comp_ptrs,
d_comp_sizes,
d_uncomp_buf_sizes,
d_uncomp_actual_sizes,
num_chunks,
d_temp,
temp_bytes,
d_uncomp_ptrs,
Default::default(),
d_statuses,
stream,
),
Algo::Lz4 => nvcompBatchedLZ4DecompressAsync(
d_comp_ptrs,
d_comp_sizes,
d_uncomp_buf_sizes,
d_uncomp_actual_sizes,
num_chunks,
d_temp,
temp_bytes,
d_uncomp_ptrs,
Default::default(),
d_statuses,
stream,
),
Algo::Zstd => nvcompBatchedZstdDecompressAsync(
d_comp_ptrs,
d_comp_sizes,
d_uncomp_buf_sizes,
d_uncomp_actual_sizes,
num_chunks,
d_temp,
temp_bytes,
d_uncomp_ptrs,
Default::default(),
d_statuses,
stream,
),
Algo::GDeflate => nvcompBatchedGdeflateDecompressAsync(
d_comp_ptrs,
d_comp_sizes,
d_uncomp_buf_sizes,
d_uncomp_actual_sizes,
num_chunks,
d_temp,
temp_bytes,
d_uncomp_ptrs,
Default::default(),
d_statuses,
stream,
),
Algo::Bitcomp { .. } => nvcompBatchedBitcompDecompressAsync(
d_comp_ptrs,
d_comp_sizes,
d_uncomp_buf_sizes,
d_uncomp_actual_sizes,
num_chunks,
d_temp,
temp_bytes,
d_uncomp_ptrs,
Default::default(),
d_statuses,
stream,
),
_ => return Err(Error::UnsupportedAlgo(algo)),
}
};
check_nvcomp(status, "DecompressAsync")
}
fn compress_chunked(
algo: Algo,
chunk_size: usize,
stream: cudaStream_t,
inner: &mut NvcompCodecInner,
input: &[u8],
output: &mut Vec<u8>,
) -> Result<()> {
if input.is_empty() {
write_header(algo, chunk_size, 0, &[], output);
return Ok(());
}
let num_chunks = input.len().div_ceil(chunk_size);
let max_chunk_bytes = chunk_size;
let raw_max = compress_get_max_output_chunk_size(algo, max_chunk_bytes)?;
let max_comp_chunk_bytes = raw_max.div_ceil(256) * 256;
let comp_buf_bytes = max_comp_chunk_bytes * num_chunks;
inner.ensure_d_buf(BufKind::Uncomp, input.len())?;
inner.ensure_d_buf(BufKind::Comp, comp_buf_bytes)?;
inner.ensure_metadata(num_chunks)?;
inner.ensure_pinned(PinnedKind::Input, input.len())?;
let meta_bytes_each = num_chunks
* std::mem::size_of::<usize>()
.max(std::mem::size_of::<*const c_void>())
.max(std::mem::size_of::<nvcompStatus_t>());
let meta_total = meta_bytes_each * 4 + num_chunks * std::mem::size_of::<nvcompStatus_t>();
inner.ensure_pinned(PinnedKind::Meta, meta_total)?;
for i in 0..num_chunks {
let off = i * chunk_size;
let end = (off + chunk_size).min(input.len());
inner.h_uncomp_ptrs[i] = unsafe { (inner.d_uncomp as *const u8).add(off) as *const c_void };
inner.h_uncomp_sizes[i] = end - off;
inner.h_comp_ptrs[i] =
unsafe { (inner.d_comp as *mut u8).add(i * max_comp_chunk_bytes) as *mut c_void };
}
let ptr_bytes = num_chunks * std::mem::size_of::<*const c_void>();
let size_bytes = num_chunks * std::mem::size_of::<usize>();
let status_bytes = num_chunks * std::mem::size_of::<nvcompStatus_t>();
let meta_base = inner.h_pinned_meta as *mut u8;
unsafe {
std::ptr::copy_nonoverlapping(
inner.h_uncomp_ptrs.as_ptr() as *const u8,
meta_base,
ptr_bytes,
);
std::ptr::copy_nonoverlapping(
inner.h_uncomp_sizes.as_ptr() as *const u8,
meta_base.add(ptr_bytes),
size_bytes,
);
std::ptr::copy_nonoverlapping(
inner.h_comp_ptrs.as_ptr() as *const u8,
meta_base.add(ptr_bytes + size_bytes),
ptr_bytes,
);
}
check_cuda(
unsafe {
cudaMemcpyAsync(
inner.d_uncomp,
input.as_ptr() as *const c_void,
input.len(),
cudaMemcpyKind::cudaMemcpyHostToDevice,
stream,
)
},
"cudaMemcpyAsync(input H2D)",
)?;
check_cuda(
unsafe {
cudaMemcpyAsync(
inner.d_uncomp_ptrs,
inner.h_pinned_meta,
ptr_bytes,
cudaMemcpyKind::cudaMemcpyHostToDevice,
stream,
)
},
"cudaMemcpyAsync(uncomp_ptrs H2D)",
)?;
check_cuda(
unsafe {
cudaMemcpyAsync(
inner.d_uncomp_sizes,
(meta_base as *const u8).add(ptr_bytes) as *const c_void,
size_bytes,
cudaMemcpyKind::cudaMemcpyHostToDevice,
stream,
)
},
"cudaMemcpyAsync(uncomp_sizes H2D)",
)?;
check_cuda(
unsafe {
cudaMemcpyAsync(
inner.d_comp_ptrs,
(meta_base as *const u8).add(ptr_bytes + size_bytes) as *const c_void,
ptr_bytes,
cudaMemcpyKind::cudaMemcpyHostToDevice,
stream,
)
},
"cudaMemcpyAsync(comp_ptrs H2D)",
)?;
let temp_bytes = compress_get_temp_size(
algo,
inner.d_uncomp_ptrs as *const *const c_void,
inner.d_uncomp_sizes as *const usize,
num_chunks,
max_chunk_bytes,
input.len(),
stream,
)?;
inner.ensure_d_buf(BufKind::Temp, temp_bytes)?;
dispatch_compress(
algo,
inner.d_uncomp_ptrs as *const *const c_void,
inner.d_uncomp_sizes as *const usize,
max_chunk_bytes,
num_chunks,
inner.d_temp,
temp_bytes,
inner.d_comp_ptrs as *const *mut c_void,
inner.d_comp_sizes as *mut usize,
inner.d_statuses as *mut nvcompStatus_t,
stream,
)?;
let bulk_d2h_bytes = num_chunks * max_comp_chunk_bytes;
inner.ensure_pinned(PinnedKind::Output, bulk_d2h_bytes.max(1))?;
let pinned_post = inner.h_pinned_meta as *mut u8;
let pinned_out = inner.h_pinned_output as *mut u8;
check_cuda(
unsafe {
cudaMemcpyAsync(
pinned_post as *mut c_void,
inner.d_comp_sizes,
size_bytes,
cudaMemcpyKind::cudaMemcpyDeviceToHost,
stream,
)
},
"cudaMemcpyAsync(comp_sizes D2H)",
)?;
check_cuda(
unsafe {
cudaMemcpyAsync(
pinned_post.add(size_bytes) as *mut c_void,
inner.d_statuses,
status_bytes,
cudaMemcpyKind::cudaMemcpyDeviceToHost,
stream,
)
},
"cudaMemcpyAsync(statuses D2H)",
)?;
check_cuda(
unsafe {
cudaMemcpyAsync(
pinned_out as *mut c_void,
inner.d_comp as *const c_void,
bulk_d2h_bytes,
cudaMemcpyKind::cudaMemcpyDeviceToHost,
stream,
)
},
"cudaMemcpyAsync(bulk d_comp D2H)",
)?;
check_cuda(
unsafe { cudaStreamSynchronize(stream) },
"cudaStreamSynchronize(compress)",
)?;
unsafe {
std::ptr::copy_nonoverlapping(
pinned_post as *const usize,
inner.h_comp_sizes.as_mut_ptr(),
num_chunks,
);
std::ptr::copy_nonoverlapping(
pinned_post.add(size_bytes) as *const nvcompStatus_t,
inner.h_statuses.as_mut_ptr(),
num_chunks,
);
}
for (i, st) in inner.h_statuses[..num_chunks].iter().enumerate() {
if *st != nvcompSuccess {
return Err(Error::Compress(format!(
"nvcomp per-chunk failure at chunk {i}: status={st} ({})",
status_str(*st)
)));
}
}
let total_comp: usize = inner.h_comp_sizes[..num_chunks].iter().sum();
write_header(
algo,
chunk_size,
input.len(),
&inner.h_comp_sizes[..num_chunks],
output,
);
let start = output.len();
output.resize(start + total_comp, 0);
let dst_base = output[start..].as_mut_ptr();
let mut cursor = 0usize;
for i in 0..num_chunks {
let sz = inner.h_comp_sizes[i];
unsafe {
std::ptr::copy_nonoverlapping(
pinned_out.add(i * max_comp_chunk_bytes),
dst_base.add(cursor),
sz,
);
}
cursor += sz;
}
Ok(())
}
fn write_header(
algo: Algo,
chunk_size: usize,
orig_size: usize,
chunk_sizes: &[usize],
output: &mut Vec<u8>,
) {
output.extend_from_slice(&FRAME_MAGIC);
output.push(algo_tag(algo));
let reserved = match algo {
Algo::Bitcomp { data_type } => [bitcomp_data_type_tag(data_type), 0, 0],
_ => [0u8; 3],
};
output.extend_from_slice(&reserved);
output.extend_from_slice(&(orig_size as u64).to_le_bytes());
output.extend_from_slice(&(chunk_size as u32).to_le_bytes());
output.extend_from_slice(&(chunk_sizes.len() as u32).to_le_bytes());
for sz in chunk_sizes {
output.extend_from_slice(&(*sz as u32).to_le_bytes());
}
}
fn algo_tag(algo: Algo) -> u8 {
match algo {
Algo::Snappy => 1,
Algo::Lz4 => 2,
Algo::Zstd => 3,
Algo::Bitcomp { .. } => 4,
Algo::GDeflate => 5,
_ => 0xff,
}
}
fn bitcomp_data_type_tag(dt: BitcompDataType) -> u8 {
match dt {
BitcompDataType::Char => 0,
BitcompDataType::Uint8 => 1,
BitcompDataType::Uint16 => 2,
BitcompDataType::Uint32 => 3,
BitcompDataType::Uint64 => 4,
BitcompDataType::Int8 => 5,
BitcompDataType::Int16 => 6,
BitcompDataType::Int32 => 7,
BitcompDataType::Int64 => 8,
BitcompDataType::Float32 => 9,
BitcompDataType::Float64 => 10,
BitcompDataType::BFloat16 => 11,
}
}
fn bitcomp_data_type_from_tag(tag: u8) -> Result<BitcompDataType> {
match tag {
0 => Ok(BitcompDataType::Char),
1 => Ok(BitcompDataType::Uint8),
2 => Ok(BitcompDataType::Uint16),
3 => Ok(BitcompDataType::Uint32),
4 => Ok(BitcompDataType::Uint64),
5 => Ok(BitcompDataType::Int8),
6 => Ok(BitcompDataType::Int16),
7 => Ok(BitcompDataType::Int32),
8 => Ok(BitcompDataType::Int64),
9 => Ok(BitcompDataType::Float32),
10 => Ok(BitcompDataType::Float64),
11 => Ok(BitcompDataType::BFloat16),
_ => Err(Error::Decompress(format!(
"unknown bitcomp data-type tag: {tag}"
))),
}
}
fn bitcomp_to_nvcomp_type(dt: BitcompDataType) -> nvcompType_t {
match dt {
BitcompDataType::Char => NVCOMP_TYPE_CHAR,
BitcompDataType::Uint8 => NVCOMP_TYPE_UCHAR,
BitcompDataType::Uint16 => NVCOMP_TYPE_USHORT,
BitcompDataType::Uint32 => NVCOMP_TYPE_UINT,
BitcompDataType::Uint64 => NVCOMP_TYPE_ULONGLONG,
BitcompDataType::Int8 => NVCOMP_TYPE_CHAR,
BitcompDataType::Int16 => NVCOMP_TYPE_SHORT,
BitcompDataType::Int32 => NVCOMP_TYPE_INT,
BitcompDataType::Int64 => NVCOMP_TYPE_LONGLONG,
BitcompDataType::Float32 => NVCOMP_TYPE_FLOAT,
BitcompDataType::Float64 => NVCOMP_TYPE_DOUBLE,
BitcompDataType::BFloat16 => NVCOMP_TYPE_BFLOAT16,
}
}
fn bitcomp_format_opts(dt: BitcompDataType) -> nvcompBatchedBitcompFormatOpts {
nvcompBatchedBitcompFormatOpts {
algorithm_type: NVCOMP_BITCOMP_FORMAT_DEFAULT as std::ffi::c_int,
data_type: bitcomp_to_nvcomp_type(dt),
reserved: [0; 56],
}
}
fn algo_from_header(tag: u8, reserved: [u8; 3]) -> Result<Algo> {
match tag {
1 => Ok(Algo::Snappy),
2 => Ok(Algo::Lz4),
3 => Ok(Algo::Zstd),
4 => {
let dt = bitcomp_data_type_from_tag(reserved[0])?;
Ok(Algo::Bitcomp { data_type: dt })
}
5 => Ok(Algo::GDeflate),
_ => Err(Error::Decompress(format!("unknown algo tag: {tag}"))),
}
}
fn decompress_chunked(
stream: cudaStream_t,
inner: &mut NvcompCodecInner,
input: &[u8],
output: &mut Vec<u8>,
) -> Result<()> {
if input.len() < HEADER_FIXED_BYTES {
return Err(Error::Decompress(format!(
"nvcomp frame too short: {} bytes",
input.len()
)));
}
if input[0..4] != FRAME_MAGIC {
return Err(Error::Decompress("missing FCG1 magic".into()));
}
let reserved = [input[5], input[6], input[7]];
let algo = algo_from_header(input[4], reserved)?;
let orig_size = u64::from_le_bytes(input[8..16].try_into().unwrap()) as usize;
let chunk_size = u32::from_le_bytes(input[16..20].try_into().unwrap()) as usize;
let num_chunks = u32::from_le_bytes(input[20..24].try_into().unwrap()) as usize;
if num_chunks == 0 {
return Ok(());
}
let sizes_off = HEADER_FIXED_BYTES;
let payload_off = sizes_off + 4 * num_chunks;
if input.len() < payload_off {
return Err(Error::Decompress(format!(
"nvcomp frame truncated: need {payload_off} bytes for sizes table, got {}",
input.len()
)));
}
inner.ensure_metadata(num_chunks)?;
for i in 0..num_chunks {
let s = sizes_off + 4 * i;
inner.h_comp_sizes[i] = u32::from_le_bytes(input[s..s + 4].try_into().unwrap()) as usize;
}
let total_comp: usize = inner.h_comp_sizes[..num_chunks].iter().sum();
if input.len() < payload_off + total_comp {
return Err(Error::Decompress(format!(
"nvcomp frame truncated: need {} bytes of payload, got {}",
total_comp,
input.len() - payload_off
)));
}
let payload = &input[payload_off..payload_off + total_comp];
let needs_strided_layout = matches!(algo, Algo::Bitcomp { .. });
let (comp_buf_bytes, stride) = if needs_strided_layout {
let raw_max = compress_get_max_output_chunk_size(algo, chunk_size)?;
let stride = raw_max.div_ceil(256) * 256;
(stride * num_chunks, stride)
} else {
(total_comp, 0)
};
inner.ensure_d_buf(BufKind::Comp, comp_buf_bytes.max(1))?;
inner.ensure_d_buf(BufKind::Uncomp, orig_size.max(1))?;
let ptr_bytes = num_chunks * std::mem::size_of::<*const c_void>();
let size_bytes = num_chunks * std::mem::size_of::<usize>();
let status_bytes = num_chunks * std::mem::size_of::<nvcompStatus_t>();
let meta_total = ptr_bytes * 2 + size_bytes * 3 + status_bytes;
inner.ensure_pinned(PinnedKind::Meta, meta_total)?;
if needs_strided_layout {
inner.ensure_pinned(PinnedKind::Input, comp_buf_bytes)?;
let pinned_in = inner.h_pinned_input as *mut u8;
let mut payload_cursor = 0usize;
for i in 0..num_chunks {
let sz = inner.h_comp_sizes[i];
unsafe {
std::ptr::copy_nonoverlapping(
payload.as_ptr().add(payload_cursor),
pinned_in.add(i * stride),
sz,
);
}
payload_cursor += sz;
}
for i in 0..num_chunks {
inner.h_comp_ptrs[i] =
unsafe { (inner.d_comp as *mut u8).add(i * stride) as *mut c_void };
let off = i * chunk_size;
inner.h_uncomp_ptrs[i] =
unsafe { (inner.d_uncomp as *const u8).add(off) as *const c_void };
let end = (off + chunk_size).min(orig_size);
inner.h_uncomp_buf_sizes[i] = end - off;
}
} else {
let mut comp_cursor = 0usize;
for i in 0..num_chunks {
inner.h_comp_ptrs[i] =
unsafe { (inner.d_comp as *mut u8).add(comp_cursor) as *mut c_void };
comp_cursor += inner.h_comp_sizes[i];
let off = i * chunk_size;
inner.h_uncomp_ptrs[i] =
unsafe { (inner.d_uncomp as *const u8).add(off) as *const c_void };
let end = (off + chunk_size).min(orig_size);
inner.h_uncomp_buf_sizes[i] = end - off;
}
}
let meta_base = inner.h_pinned_meta as *mut u8;
let mut moff = 0usize;
unsafe {
std::ptr::copy_nonoverlapping(
inner.h_comp_ptrs.as_ptr() as *const u8,
meta_base.add(moff),
ptr_bytes,
);
moff += ptr_bytes;
std::ptr::copy_nonoverlapping(
inner.h_comp_sizes.as_ptr() as *const u8,
meta_base.add(moff),
size_bytes,
);
moff += size_bytes;
std::ptr::copy_nonoverlapping(
inner.h_uncomp_ptrs.as_ptr() as *const u8,
meta_base.add(moff),
ptr_bytes,
);
moff += ptr_bytes;
std::ptr::copy_nonoverlapping(
inner.h_uncomp_buf_sizes.as_ptr() as *const u8,
meta_base.add(moff),
size_bytes,
);
}
let (h2d_src, h2d_bytes) = if needs_strided_layout {
(inner.h_pinned_input as *const c_void, comp_buf_bytes)
} else {
(payload.as_ptr() as *const c_void, total_comp)
};
check_cuda(
unsafe {
cudaMemcpyAsync(
inner.d_comp,
h2d_src,
h2d_bytes,
cudaMemcpyKind::cudaMemcpyHostToDevice,
stream,
)
},
"cudaMemcpyAsync(payload H2D)",
)?;
let mut moff = 0usize;
for (dst, n) in [
(inner.d_comp_ptrs, ptr_bytes),
(inner.d_comp_sizes, size_bytes),
(inner.d_uncomp_ptrs, ptr_bytes),
(inner.d_uncomp_buf_sizes, size_bytes),
] {
check_cuda(
unsafe {
cudaMemcpyAsync(
dst,
meta_base.add(moff) as *const c_void,
n,
cudaMemcpyKind::cudaMemcpyHostToDevice,
stream,
)
},
"cudaMemcpyAsync(meta H2D)",
)?;
moff += n;
}
let temp_bytes = decompress_get_temp_size(algo, num_chunks, chunk_size, orig_size)?;
inner.ensure_d_buf(BufKind::Temp, temp_bytes)?;
dispatch_decompress(
algo,
inner.d_comp_ptrs as *const *const c_void,
inner.d_comp_sizes as *const usize,
inner.d_uncomp_buf_sizes as *const usize,
inner.d_uncomp_actual_sizes as *mut usize,
num_chunks,
inner.d_temp,
temp_bytes,
inner.d_uncomp_ptrs as *const *mut c_void,
inner.d_statuses as *mut nvcompStatus_t,
stream,
)?;
let start = output.len();
output.resize(start + orig_size, 0);
check_cuda(
unsafe {
cudaMemcpyAsync(
output[start..].as_mut_ptr() as *mut c_void,
inner.d_uncomp,
orig_size,
cudaMemcpyKind::cudaMemcpyDeviceToHost,
stream,
)
},
"cudaMemcpyAsync(uncomp D2H)",
)?;
check_cuda(
unsafe { cudaStreamSynchronize(stream) },
"cudaStreamSynchronize(decompress)",
)?;
Ok(())
}