use core::{marker::PhantomData, ptr::NonNull};
use crate::{
parse_code, ptr_mut, ptr_void, CompressionLevel, InBuffer, OutBuffer,
SafeResult, WriteBuf, SEEKABLE_FRAMEINDEX_TOOLARGE,
};
#[derive(Debug, PartialEq)]
pub struct FrameIndexTooLargeError;
impl core::fmt::Display for FrameIndexTooLargeError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("Frame index too large")
}
}
pub struct SeekableCStream(NonNull<zstd_sys::ZSTD_seekable_CStream>);
unsafe impl Send for SeekableCStream {}
unsafe impl Sync for SeekableCStream {}
impl Default for SeekableCStream {
fn default() -> Self {
SeekableCStream::create()
}
}
impl SeekableCStream {
pub fn try_create() -> Option<Self> {
Some(SeekableCStream(NonNull::new(unsafe {
zstd_sys::ZSTD_seekable_createCStream()
})?))
}
pub fn create() -> Self {
Self::try_create()
.expect("zstd returned null pointer when creating new seekable compression stream")
}
pub fn init(
&mut self,
compression_level: CompressionLevel,
checksum_flag: bool,
max_frame_size: u32,
) -> SafeResult {
let code = unsafe {
zstd_sys::ZSTD_seekable_initCStream(
self.0.as_ptr(),
compression_level,
checksum_flag as i32,
max_frame_size,
)
};
parse_code(code)
}
pub fn compress_stream<C: WriteBuf + ?Sized>(
&mut self,
output: &mut OutBuffer<'_, C>,
input: &mut InBuffer<'_>,
) -> SafeResult {
let mut output = output.wrap();
let mut input = input.wrap();
let code = unsafe {
zstd_sys::ZSTD_seekable_compressStream(
self.0.as_ptr(),
ptr_mut(&mut output),
ptr_mut(&mut input),
)
};
parse_code(code)
}
pub fn end_frame<C: WriteBuf + ?Sized>(
&mut self,
output: &mut OutBuffer<'_, C>,
) -> SafeResult {
let mut output = output.wrap();
let code = unsafe {
zstd_sys::ZSTD_seekable_endFrame(
self.0.as_ptr(),
ptr_mut(&mut output),
)
};
parse_code(code)
}
pub fn end_stream<C: WriteBuf + ?Sized>(
&mut self,
output: &mut OutBuffer<'_, C>,
) -> SafeResult {
let mut output = output.wrap();
let code = unsafe {
zstd_sys::ZSTD_seekable_endStream(
self.0.as_ptr(),
ptr_mut(&mut output),
)
};
parse_code(code)
}
}
impl Drop for SeekableCStream {
fn drop(&mut self) {
unsafe {
zstd_sys::ZSTD_seekable_freeCStream(self.0.as_ptr());
}
}
}
pub struct FrameLog(NonNull<zstd_sys::ZSTD_frameLog>);
unsafe impl Send for FrameLog {}
unsafe impl Sync for FrameLog {}
impl FrameLog {
pub fn try_create(checksum_flag: bool) -> Option<Self> {
Some(FrameLog(
NonNull::new(unsafe {
zstd_sys::ZSTD_seekable_createFrameLog(checksum_flag as i32)
})?,
))
}
pub fn create(checksum_flag: bool) -> Self {
Self::try_create(checksum_flag)
.expect("Zstd returned null pointer when creating new frame log")
}
pub fn log_frame(
&mut self,
compressed_size: u32,
decompressed_size: u32,
checksum: Option<u32>,
) -> SafeResult {
let code = unsafe {
zstd_sys::ZSTD_seekable_logFrame(
self.0.as_ptr(),
compressed_size,
decompressed_size,
checksum.unwrap_or_default(),
)
};
parse_code(code)
}
pub fn write_seek_table<C: WriteBuf + ?Sized>(
&mut self,
output: &mut OutBuffer<'_, C>,
) -> SafeResult {
let mut output = output.wrap();
let code = unsafe {
zstd_sys::ZSTD_seekable_writeSeekTable(
self.0.as_ptr(),
ptr_mut(&mut output),
)
};
parse_code(code)
}
}
impl Drop for FrameLog {
fn drop(&mut self) {
unsafe {
zstd_sys::ZSTD_seekable_freeFrameLog(self.0.as_ptr());
}
}
}
pub struct Seekable<'a>(NonNull<zstd_sys::ZSTD_seekable>, PhantomData<&'a ()>);
unsafe impl Send for Seekable<'_> {}
unsafe impl Sync for Seekable<'_> {}
impl Default for Seekable<'_> {
fn default() -> Self {
Seekable::create()
}
}
impl<'a> Seekable<'a> {
pub fn try_create() -> Option<Self> {
Some(Seekable(
NonNull::new(unsafe { zstd_sys::ZSTD_seekable_create() })?,
PhantomData,
))
}
pub fn create() -> Self {
Self::try_create()
.expect("Zstd returned null pointer when creating new seekable")
}
pub fn init_buff(&mut self, src: &'a [u8]) -> SafeResult {
let code = unsafe {
zstd_sys::ZSTD_seekable_initBuff(
self.0.as_ptr(),
ptr_void(src),
src.len(),
)
};
parse_code(code)
}
pub fn decompress<C: WriteBuf + ?Sized>(
&mut self,
dst: &mut C,
offset: u64,
) -> SafeResult {
unsafe {
dst.write_from(|buffer, capacity| {
parse_code(zstd_sys::ZSTD_seekable_decompress(
self.0.as_ptr(),
buffer,
capacity,
offset,
))
})
}
}
pub fn decompress_frame<C: WriteBuf + ?Sized>(
&mut self,
dst: &mut C,
frame_index: u32,
) -> SafeResult {
unsafe {
dst.write_from(|buffer, capacity| {
parse_code(zstd_sys::ZSTD_seekable_decompressFrame(
self.0.as_ptr(),
buffer,
capacity,
frame_index,
))
})
}
}
pub fn num_frames(&self) -> u32 {
unsafe { zstd_sys::ZSTD_seekable_getNumFrames(self.0.as_ptr()) }
}
pub fn frame_compressed_offset(
&self,
frame_index: u32,
) -> Result<u64, FrameIndexTooLargeError> {
let offset = unsafe {
zstd_sys::ZSTD_seekable_getFrameCompressedOffset(
self.0.as_ptr(),
frame_index,
)
};
if offset == SEEKABLE_FRAMEINDEX_TOOLARGE {
return Err(FrameIndexTooLargeError);
}
Ok(offset)
}
pub fn frame_decompressed_offset(
&self,
frame_index: u32,
) -> Result<u64, FrameIndexTooLargeError> {
let offset = unsafe {
zstd_sys::ZSTD_seekable_getFrameDecompressedOffset(
self.0.as_ptr(),
frame_index,
)
};
if offset == SEEKABLE_FRAMEINDEX_TOOLARGE {
return Err(FrameIndexTooLargeError);
}
Ok(offset)
}
pub fn frame_compressed_size(&self, frame_index: u32) -> SafeResult {
let code = unsafe {
zstd_sys::ZSTD_seekable_getFrameCompressedSize(
self.0.as_ptr(),
frame_index,
)
};
parse_code(code)
}
pub fn frame_decompressed_size(&self, frame_index: u32) -> SafeResult {
let code = unsafe {
zstd_sys::ZSTD_seekable_getFrameDecompressedSize(
self.0.as_ptr(),
frame_index,
)
};
parse_code(code)
}
pub fn offset_to_frame_index(&self, offset: u64) -> u32 {
unsafe {
zstd_sys::ZSTD_seekable_offsetToFrameIndex(self.0.as_ptr(), offset)
}
}
}
impl<'a> Drop for Seekable<'a> {
fn drop(&mut self) {
unsafe {
zstd_sys::ZSTD_seekable_free(self.0.as_ptr());
}
}
}
#[cfg(feature = "std")]
pub struct AdvancedSeekable<'a, F> {
inner: Seekable<'a>,
src: *mut F,
}
#[cfg(feature = "std")]
unsafe impl<F> Send for AdvancedSeekable<'_, F> where F: Send {}
#[cfg(feature = "std")]
unsafe impl<F> Sync for AdvancedSeekable<'_, F> where F: Sync {}
#[cfg(feature = "std")]
impl<'a, F> core::ops::Deref for AdvancedSeekable<'a, F> {
type Target = Seekable<'a>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
#[cfg(feature = "std")]
impl<'a, F> core::ops::DerefMut for AdvancedSeekable<'a, F> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
#[cfg(feature = "std")]
impl<'a, F> Drop for AdvancedSeekable<'a, F> {
fn drop(&mut self) {
use std::boxed::Box;
let _: Box<F> = unsafe { Box::from_raw(self.src) };
}
}
impl<'a> Seekable<'a> {
#[cfg(feature = "std")]
#[cfg_attr(feature = "doc-cfg", doc(cfg(feature = "std")))]
pub fn init_advanced<F>(
self,
src: std::boxed::Box<F>,
) -> Result<AdvancedSeekable<'a, F>, crate::ErrorCode>
where
F: std::io::Read + std::io::Seek,
{
let opaque = std::boxed::Box::into_raw(src) as *mut F;
let custom_file = zstd_sys::ZSTD_seekable_customFile {
opaque: opaque as *mut core::ffi::c_void,
read: Some(advanced_read::<F>),
seek: Some(advanced_seek::<F>),
};
let code = unsafe {
zstd_sys::ZSTD_seekable_initAdvanced(self.0.as_ptr(), custom_file)
};
if crate::is_error(code) {
return Err(code);
}
Ok(AdvancedSeekable {
inner: self,
src: opaque,
})
}
}
#[cfg(feature = "std")]
unsafe extern "C" fn advanced_seek<S: std::io::Seek>(
opaque: *mut core::ffi::c_void,
offset: ::core::ffi::c_longlong,
origin: ::core::ffi::c_int,
) -> ::core::ffi::c_int {
use core::convert::TryFrom;
use std::io::SeekFrom;
const SEEK_SET: i32 = 0;
const SEEK_CUR: i32 = 1;
const SEEK_END: i32 = 2;
let seeker: &mut S = std::mem::transmute(opaque);
let pos = match origin {
SEEK_SET => {
let Ok(offset) = u64::try_from(offset) else {
return -1;
};
SeekFrom::Start(offset)
}
SEEK_CUR => SeekFrom::Current(offset),
SEEK_END => SeekFrom::End(offset),
_ => return -1,
};
if seeker.seek(pos).is_err() {
return -1;
}
0
}
#[cfg(feature = "std")]
unsafe extern "C" fn advanced_read<R: std::io::Read>(
opaque: *mut core::ffi::c_void,
buffer: *mut core::ffi::c_void,
n: usize,
) -> ::core::ffi::c_int {
let reader: &mut R = std::mem::transmute(opaque);
let mut buf = std::slice::from_raw_parts_mut(buffer as *mut u8, n);
if reader.read_exact(&mut buf).is_err() {
return -1;
}
0
}
#[derive(Debug, PartialEq)]
pub struct SeekTableCreateError;
impl core::fmt::Display for SeekTableCreateError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("Zstd returned null pointer when creating new seektable from seekable")
}
}
pub struct SeekTable(NonNull<zstd_sys::ZSTD_seekTable>);
unsafe impl Send for SeekTable {}
unsafe impl Sync for SeekTable {}
impl SeekTable {
pub fn try_from_seekable<'a>(
value: &Seekable<'a>,
) -> Result<Self, SeekTableCreateError> {
let ptr = unsafe {
zstd_sys::ZSTD_seekTable_create_fromSeekable(value.0.as_ptr())
};
let ptr = NonNull::new(ptr).ok_or(SeekTableCreateError)?;
Ok(Self(ptr))
}
pub fn num_frames(&self) -> u32 {
unsafe { zstd_sys::ZSTD_seekTable_getNumFrames(self.0.as_ptr()) }
}
pub fn frame_compressed_offset(
&self,
frame_index: u32,
) -> Result<u64, FrameIndexTooLargeError> {
let offset = unsafe {
zstd_sys::ZSTD_seekTable_getFrameCompressedOffset(
self.0.as_ptr(),
frame_index,
)
};
if offset == SEEKABLE_FRAMEINDEX_TOOLARGE {
return Err(FrameIndexTooLargeError);
}
Ok(offset)
}
pub fn frame_decompressed_offset(
&self,
frame_index: u32,
) -> Result<u64, FrameIndexTooLargeError> {
let offset = unsafe {
zstd_sys::ZSTD_seekTable_getFrameDecompressedOffset(
self.0.as_ptr(),
frame_index,
)
};
if offset == SEEKABLE_FRAMEINDEX_TOOLARGE {
return Err(FrameIndexTooLargeError);
}
Ok(offset)
}
pub fn frame_compressed_size(&self, frame_index: u32) -> SafeResult {
let code = unsafe {
zstd_sys::ZSTD_seekTable_getFrameCompressedSize(
self.0.as_ptr(),
frame_index,
)
};
parse_code(code)
}
pub fn frame_decompressed_size(&self, frame_index: u32) -> SafeResult {
let code = unsafe {
zstd_sys::ZSTD_seekTable_getFrameDecompressedSize(
self.0.as_ptr(),
frame_index,
)
};
parse_code(code)
}
pub fn offset_to_frame_index(&self, offset: u64) -> u32 {
unsafe {
zstd_sys::ZSTD_seekTable_offsetToFrameIndex(
self.0.as_ptr(),
offset,
)
}
}
}
impl Drop for SeekTable {
fn drop(&mut self) {
unsafe {
zstd_sys::ZSTD_seekTable_free(self.0.as_ptr());
}
}
}