use alloc::vec;
use alloc::vec::Vec;
use zstd_safe::zstd_sys::ZSTD_ErrorCode;
use crate::{
SEEK_TABLE_INTEGRITY_SIZE, SEEKABLE_MAGIC_NUMBER, SEEKABLE_MAX_FRAMES, SKIPPABLE_HEADER_SIZE,
error::{Error, Result},
seekable::{OffsetFrom, Seekable},
};
macro_rules! read_le32 {
($buf:expr, $offset:expr) => {
($buf[$offset] as u32)
| (($buf[$offset + 1] as u32) << 8)
| (($buf[$offset + 2] as u32) << 16)
| (($buf[$offset + 3] as u32) << 24)
};
}
macro_rules! write_le32 {
($buf:expr, $buf_pos:expr, $write_pos:expr, $value:expr, $offset:expr) => {
if $write_pos < $offset + 4 {
let len = usize::min($buf.len() - $buf_pos, $offset + 4 - $write_pos);
let val_offset = $write_pos - $offset;
$buf[$buf_pos..$buf_pos + len]
.copy_from_slice(&$value.to_le_bytes()[val_offset..val_offset + len]);
$buf_pos += len;
$write_pos += len;
if $buf_pos == $buf.len() {
return $buf_pos;
}
}
};
}
macro_rules! write_frame {
($buf:expr, $buf_pos:expr, $self:expr, $offset:expr) => {
write_le32!(
$buf,
$buf_pos,
$self.write_pos,
$self.frames[$self.frame_index].c_size,
$offset
);
write_le32!(
$buf,
$buf_pos,
$self.write_pos,
$self.frames[$self.frame_index].d_size,
$offset + 4
);
$self.frame_index += 1;
};
}
macro_rules! write_integrity {
($buf:expr, $buf_pos:expr, $self:expr, $num_frames:expr, $offset:expr) => {
write_le32!($buf, $buf_pos, $self.write_pos, $num_frames, $offset);
if $self.write_pos < $offset + 5 {
$buf[$buf_pos] = 0;
$buf_pos += 1;
$self.write_pos += 1;
}
write_le32!(
$buf,
$buf_pos,
$self.write_pos,
SEEKABLE_MAGIC_NUMBER,
$offset + 5
);
};
}
const SIZE_PER_FRAME: usize = 8;
const SKIPPABLE_MAGIC_NUMBER: u32 = zstd_safe::zstd_sys::ZSTD_MAGIC_SKIPPABLE_START | 0xE;
struct Frame {
c_size: u32,
d_size: u32,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct Entry {
c_offset: u64,
d_offset: u64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct Entries(Vec<Entry>);
impl Entries {
fn with_num_frames(num_frames: usize) -> Self {
let num_frames = num_frames.max(1);
let cap = core::mem::size_of::<Entry>() * num_frames;
Self(Vec::with_capacity(cap))
}
fn into_frames(self) -> Vec<Frame> {
self.0
.windows(2)
.map(|w| Frame {
c_size: (w[1].c_offset - w[0].c_offset) as u32,
d_size: (w[1].d_offset - w[0].d_offset) as u32,
})
.collect()
}
}
impl core::ops::Index<u32> for Entries {
type Output = Entry;
fn index(&self, index: u32) -> &Self::Output {
let idx = usize::try_from(index).expect("Frame index can be transformed to uisze");
&self.0[idx]
}
}
#[derive(Debug)]
struct Parser {
num_frames: usize,
size_per_frame: usize,
seek_table_size: usize,
entries: Entries,
c_offset: u64,
d_offset: u64,
}
impl Parser {
fn from_bytes(buf: &[u8]) -> Result<Self> {
if read_le32!(buf, 5) != SEEKABLE_MAGIC_NUMBER {
return Err(Error::zstd(ZSTD_ErrorCode::ZSTD_error_prefix_unknown));
}
if ((buf[4] >> 2) & 0x1f) > 0 {
return Err(Error::zstd(ZSTD_ErrorCode::ZSTD_error_corruption_detected));
}
let with_checksum = (buf[4] & (1 << 7)) > 0;
let num_frames = read_le32!(buf, 0);
if num_frames > SEEKABLE_MAX_FRAMES {
return Err(Error::frame_index_too_large());
}
let num_frames = usize::try_from(num_frames).expect("Number of frames never exceeds usize");
let size_per_frame: usize = if with_checksum { 12 } else { 8 };
let seek_table_size =
num_frames * size_per_frame + SKIPPABLE_HEADER_SIZE + SEEK_TABLE_INTEGRITY_SIZE;
Ok(Self {
num_frames,
size_per_frame,
seek_table_size,
entries: Entries::with_num_frames(num_frames),
c_offset: 0,
d_offset: 0,
})
}
fn verify_skippable_header(&self, buf: &[u8]) -> Result<()> {
if read_le32!(buf, 0) != SKIPPABLE_MAGIC_NUMBER {
return Err(Error::zstd(ZSTD_ErrorCode::ZSTD_error_prefix_unknown));
}
let size = usize::try_from(read_le32!(buf, 4)).expect("frame size fits in usize");
if size + SKIPPABLE_HEADER_SIZE != self.seek_table_size {
return Err(Error::zstd(ZSTD_ErrorCode::ZSTD_error_corruption_detected));
}
Ok(())
}
fn parse_entries(&mut self, buf: &[u8]) -> usize {
let mut pos: usize = 0;
while self.entries.0.len() < self.num_frames {
if pos + self.size_per_frame > buf.len() {
return pos;
}
self.log_entry();
self.c_offset += read_le32!(buf, pos) as u64;
self.d_offset += read_le32!(buf, pos + 4) as u64;
pos += self.size_per_frame;
}
self.log_entry();
pos
}
fn log_entry(&mut self) {
self.entries.0.push(Entry {
c_offset: self.c_offset,
d_offset: self.d_offset,
});
}
fn verify(&self) -> Result<()> {
if self.entries.0.len() == self.num_frames + 1 {
Ok(())
} else {
Err(Error::zstd(ZSTD_ErrorCode::ZSTD_error_corruption_detected))
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub enum Format {
Head,
#[default]
Foot,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SeekTable {
entries: Entries,
}
impl Default for SeekTable {
fn default() -> Self {
Self::new()
}
}
impl From<Parser> for SeekTable {
fn from(value: Parser) -> Self {
SeekTable {
entries: value.entries,
}
}
}
impl SeekTable {
pub fn new() -> Self {
let entries = Entries(vec![Entry {
c_offset: 0,
d_offset: 0,
}]);
Self { entries }
}
pub fn from_seekable(src: &mut impl Seekable) -> Result<Self> {
Self::from_seekable_format(src, Format::Foot)
}
pub fn from_seekable_format(src: &mut impl Seekable, format: Format) -> Result<Self> {
let integrity = src.seek_table_integrity(format)?;
let mut parser = Parser::from_bytes(&integrity)?;
match format {
Format::Head => src.set_offset(OffsetFrom::Start(0))?,
Format::Foot => src.set_offset(OffsetFrom::End(-(parser.seek_table_size as i64)))?,
};
let len = 8192.min(parser.seek_table_size);
let mut buf = vec![0u8; len];
let mut read = 0;
while read < SKIPPABLE_HEADER_SIZE {
let n = src.read(&mut buf)?;
if n == 0 {
return Err(Error::zstd(ZSTD_ErrorCode::ZSTD_error_corruption_detected));
}
read += n;
}
parser.verify_skippable_header(&buf[..SKIPPABLE_HEADER_SIZE])?;
let mut buf_start = SKIPPABLE_HEADER_SIZE;
if matches!(format, Format::Head) {
buf_start += SEEK_TABLE_INTEGRITY_SIZE;
}
let mut remaining =
parser.seek_table_size - SKIPPABLE_HEADER_SIZE - SEEK_TABLE_INTEGRITY_SIZE;
let mut buf_end = read;
loop {
let n = parser.parse_entries(&buf[buf_start..buf_end]);
remaining -= n;
if remaining == 0 {
break;
}
let r = buf_start + n..buf_end;
let offset = r.len();
buf.copy_within(r, 0);
let n = src.read(&mut buf[offset..])?;
if remaining > 0 && n == 0 {
return Err(Error::zstd(ZSTD_ErrorCode::ZSTD_error_corruption_detected));
}
buf_start = 0;
buf_end = offset + n;
}
parser.verify()?;
Ok(parser.into())
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
pub fn from_reader(mut reader: impl std::io::Read) -> Result<Self> {
let mut buf = [0u8; SKIPPABLE_HEADER_SIZE + SEEK_TABLE_INTEGRITY_SIZE];
reader.read_exact(&mut buf)?;
let mut parser = Parser::from_bytes(&buf[SKIPPABLE_HEADER_SIZE..])?;
parser.verify_skippable_header(&buf)?;
let mut remaining =
parser.seek_table_size - SKIPPABLE_HEADER_SIZE - SEEK_TABLE_INTEGRITY_SIZE;
let mut buf = vec![0u8; 8192.min(remaining)];
let buf_len = buf.len();
let mut offset = 0;
loop {
if remaining > 0 && reader.read(&mut buf[offset..buf_len.min(remaining)])? == 0 {
return Err(Error::zstd(ZSTD_ErrorCode::ZSTD_error_corruption_detected));
}
let n = parser.parse_entries(&buf);
remaining -= n;
if remaining == 0 {
break;
}
offset = buf_len - n;
buf.copy_within(n.., 0);
}
parser.verify()?;
Ok(parser.into())
}
pub fn log_frame(&mut self, c_size: u32, d_size: u32) -> Result<()> {
if self.num_frames() >= SEEKABLE_MAX_FRAMES {
return Err(Error::frame_index_too_large());
}
let last = &self.entries[self.num_frames()];
self.entries.0.push(Entry {
c_offset: last.c_offset + c_size as u64,
d_offset: last.d_offset + d_size as u64,
});
Ok(())
}
pub fn num_frames(&self) -> u32 {
(self.entries.0.len() - 1) as u32
}
pub fn frame_index_comp(&self, offset: u64) -> u32 {
self.frame_index_at(offset, |i| self.entries[i].c_offset)
}
pub fn frame_index_decomp(&self, offset: u64) -> u32 {
self.frame_index_at(offset, |i| self.entries[i].d_offset)
}
pub fn frame_start_comp(&self, index: u32) -> Result<u64> {
if index >= self.num_frames() {
return Err(Error::frame_index_too_large());
}
Ok(self.entries[index].c_offset)
}
pub fn frame_start_decomp(&self, index: u32) -> Result<u64> {
if index >= self.num_frames() {
return Err(Error::frame_index_too_large());
}
Ok(self.entries[index].d_offset)
}
pub fn frame_end_comp(&self, index: u32) -> Result<u64> {
if index >= self.num_frames() {
return Err(Error::frame_index_too_large());
}
Ok(self.entries[index + 1].c_offset)
}
pub fn frame_end_decomp(&self, index: u32) -> Result<u64> {
if index >= self.num_frames() {
return Err(Error::frame_index_too_large());
}
Ok(self.entries[index + 1].d_offset)
}
pub fn frame_size_comp(&self, index: u32) -> Result<u64> {
if index >= self.num_frames() {
return Err(Error::frame_index_too_large());
}
let size = self.entries[index + 1].c_offset - self.entries[index].c_offset;
Ok(size)
}
pub fn frame_size_decomp(&self, index: u32) -> Result<u64> {
if index >= self.num_frames() {
return Err(Error::frame_index_too_large());
}
let size = self.entries[index + 1].d_offset - self.entries[index].d_offset;
Ok(size)
}
#[allow(clippy::missing_panics_doc)]
pub fn max_frame_size_comp(&self) -> u64 {
(0..self.num_frames())
.map(|i| {
self.frame_size_comp(i)
.expect("Frame index is never out of range")
})
.max()
.unwrap_or(0)
}
#[allow(clippy::missing_panics_doc)]
pub fn max_frame_size_decomp(&self) -> u64 {
(0..self.num_frames())
.map(|i| {
self.frame_size_decomp(i)
.expect("Frame index is never out of range")
})
.max()
.unwrap_or(0)
}
#[allow(clippy::missing_panics_doc)]
pub fn size_comp(&self) -> u64 {
self.entries
.0
.last()
.expect("Seek table entries are never empty")
.c_offset
}
#[allow(clippy::missing_panics_doc)]
pub fn size_decomp(&self) -> u64 {
self.entries
.0
.last()
.expect("Seek table entries are never empty")
.d_offset
}
pub fn into_serializer(self) -> Serializer {
self.into_format_serializer(Format::Foot)
}
pub fn into_format_serializer(self, format: Format) -> Serializer {
Serializer {
frames: self.entries.into_frames(),
frame_index: 0,
write_pos: 0,
format,
}
}
fn frame_index_at(&self, offset: u64, offset_at: impl Fn(u32) -> u64) -> u32 {
if offset >= offset_at(self.num_frames()) {
return self.num_frames() - 1;
}
let mut low = 0;
let mut high = self.num_frames();
while low + 1 < high {
let mid = low.midpoint(high);
if offset_at(mid) <= offset {
low = mid;
} else {
high = mid;
}
}
low
}
}
pub struct Serializer {
frames: Vec<Frame>,
frame_index: usize,
write_pos: usize,
format: Format,
}
impl Serializer {
pub fn write_into(&mut self, buf: &mut [u8]) -> usize {
let mut buf_pos = 0;
write_le32!(buf, buf_pos, self.write_pos, SKIPPABLE_MAGIC_NUMBER, 0);
write_le32!(buf, buf_pos, self.write_pos, self.frame_size(), 4);
if matches!(self.format, Format::Head) {
write_integrity!(
buf,
buf_pos,
self,
self.frames.len() as u32,
SKIPPABLE_HEADER_SIZE
);
}
while self.frame_index < self.frames.len() {
let offset = SKIPPABLE_HEADER_SIZE + SIZE_PER_FRAME * self.frame_index;
match self.format {
Format::Head => {
write_frame!(buf, buf_pos, self, offset + SEEK_TABLE_INTEGRITY_SIZE);
}
Format::Foot => {
write_frame!(buf, buf_pos, self, offset);
}
}
}
if matches!(self.format, Format::Foot) {
let offset = SKIPPABLE_HEADER_SIZE + SIZE_PER_FRAME * self.frames.len();
write_integrity!(buf, buf_pos, self, self.frames.len() as u32, offset);
}
buf_pos
}
pub fn reset(&mut self) {
self.write_pos = 0;
self.frame_index = 0;
}
pub fn encoded_len(&self) -> usize {
SKIPPABLE_HEADER_SIZE + SEEK_TABLE_INTEGRITY_SIZE + self.frames.len() * SIZE_PER_FRAME
}
fn frame_size(&self) -> u32 {
(self.encoded_len() - SKIPPABLE_HEADER_SIZE) as u32
}
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
impl std::io::Read for Serializer {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
Ok(self.write_into(buf))
}
}
#[cfg(test)]
mod tests {
use crate::BytesWrapper;
use super::*;
use proptest::prelude::*;
use zstd_safe::OutBuffer;
fn seek_table(num_frames: u32) -> SeekTable {
let mut st = SeekTable::new();
let mut c_size = 3;
let mut d_size = 6;
for _ in 0..num_frames {
st.log_frame(c_size, d_size).unwrap();
c_size += 1;
d_size += 1;
}
st
}
#[test]
fn frame_functions() {
const NUM_FRAMES: u32 = 1234;
let mut st = SeekTable::new();
for i in 1..=NUM_FRAMES {
st.log_frame(i * 7, i * 13).unwrap();
}
assert_eq!(st.num_frames(), NUM_FRAMES);
let mut c_offset = 0;
let mut d_offset = 0;
for i in 1..=NUM_FRAMES {
let j = i - 1;
let c_size = i as u64 * 7;
let d_size = i as u64 * 13;
assert_eq!(st.frame_index_comp(c_offset), j);
assert_eq!(st.frame_index_decomp(d_offset), j);
assert_eq!(st.frame_start_comp(j).unwrap(), c_offset);
assert_eq!(st.frame_start_decomp(j).unwrap(), d_offset);
assert_eq!(st.frame_end_comp(j).unwrap(), c_offset + c_size);
assert_eq!(st.frame_end_decomp(j).unwrap(), d_offset + d_size);
assert_eq!(st.frame_size_comp(j).unwrap(), c_size);
assert_eq!(st.frame_size_decomp(j).unwrap(), d_size);
c_offset += c_size;
d_offset += d_size;
}
assert_eq!(st.max_frame_size_comp(), NUM_FRAMES as u64 * 7);
assert_eq!(st.max_frame_size_decomp(), NUM_FRAMES as u64 * 13);
}
fn test_serialize(format: Format, num_frames: u32, buf_len: usize) {
let mut ser = seek_table(num_frames)
.clone()
.into_format_serializer(format);
let mut buf = vec![0; ser.encoded_len()];
let n = ser.write_into(&mut buf);
assert_eq!(n, buf.len());
let n = ser.write_into(&mut buf);
assert_eq!(n, 0);
ser.reset();
let mut buf = vec![0; buf_len];
let mut pos = 0;
while pos < ser.encoded_len() {
let n = ser.write_into(&mut buf);
pos += n;
}
assert_eq!(pos, ser.encoded_len());
}
fn test_serde_cycle(format: Format, num_frames: u32) {
let st = seek_table(num_frames);
let mut ser = st.clone().into_format_serializer(format);
let mut buf = vec![0; ser.encoded_len()];
let n = ser.write_into(&mut buf);
assert_eq!(n, ser.encoded_len());
let mut wrapper = BytesWrapper::new(&buf);
let from_seekable = SeekTable::from_seekable_format(&mut wrapper, format).unwrap();
assert_eq!(from_seekable, st);
}
fn test_serialize_compatible_with_zstd_seekable(num_frames: u32) {
let st = seek_table(num_frames);
let mut ser = st.clone().into_serializer();
let mut buf = vec![0; ser.encoded_len()];
let n = ser.write_into(&mut buf);
assert_eq!(n, ser.encoded_len());
let mut seekable = zstd_safe::seekable::Seekable::create();
seekable.init_buff(&buf).unwrap();
assert_eq!(st.num_frames(), seekable.num_frames());
for i in 0..st.num_frames() {
assert_eq!(
st.frame_start_comp(i).unwrap(),
seekable.frame_compressed_offset(i).unwrap()
);
assert_eq!(
st.frame_start_decomp(i).unwrap(),
seekable.frame_decompressed_offset(i).unwrap()
);
assert_eq!(
st.frame_size_comp(i).unwrap(),
seekable.frame_compressed_size(i).unwrap() as u64
);
assert_eq!(
st.frame_size_decomp(i).unwrap(),
seekable.frame_decompressed_size(i).unwrap() as u64
);
}
}
fn test_deserialize_compatible_with_zstd_seekable(num_frames: u32) {
let mut fl = zstd_safe::seekable::FrameLog::create(true);
for i in 1..=num_frames {
fl.log_frame(i * 7, i * 13, Some(i)).unwrap();
}
let cap = SKIPPABLE_HEADER_SIZE + (num_frames * 12) as usize + SEEK_TABLE_INTEGRITY_SIZE;
let mut buf = vec![0; cap];
let mut out_buf = OutBuffer::around(&mut buf);
let n = fl.write_seek_table(&mut out_buf).unwrap();
assert_eq!(n, 0);
let mut wrapper = BytesWrapper::new(&buf);
let st = SeekTable::from_seekable(&mut wrapper).unwrap();
assert_eq!(st.num_frames(), num_frames);
for i in 1..=num_frames {
let c_size = i as u64 * 7;
let d_size = i as u64 * 13;
assert_eq!(st.frame_size_comp(i - 1).unwrap(), c_size);
assert_eq!(st.frame_size_decomp(i - 1).unwrap(), d_size);
}
}
#[cfg(feature = "std")]
fn test_serde_cycle_std(format: Format, num_frames: u32) {
let st = seek_table(num_frames);
let mut ser = st.clone().into_format_serializer(format);
let mut buf = std::io::Cursor::new(Vec::with_capacity(ser.encoded_len()));
let n = std::io::copy(&mut ser, &mut buf).unwrap();
assert_eq!(n, ser.encoded_len() as u64);
let mut wrapper = BytesWrapper::new(buf.get_ref());
let from_bytes = SeekTable::from_seekable_format(&mut wrapper, format).unwrap();
assert_eq!(from_bytes, st);
}
#[cfg(feature = "std")]
fn test_serde_cycle_buf(format: Format, num_frames: u32) {
let st = seek_table(num_frames);
let mut ser = st.clone().into_format_serializer(format);
let mut buf = std::io::Cursor::new(Vec::with_capacity(ser.encoded_len()));
let n = std::io::copy(&mut ser, &mut buf).unwrap();
assert_eq!(n, ser.encoded_len() as u64);
let mut buf = std::io::BufReader::new(buf);
let from_bytes = SeekTable::from_seekable_format(&mut buf, format).unwrap();
assert_eq!(from_bytes, st);
}
#[cfg(feature = "std")]
proptest! {
#[test]
fn serde_cycle_std(num_frames in 0..4096u32) {
test_serde_cycle_std(Format::Head, num_frames);
test_serde_cycle_std(Format::Foot, num_frames);
test_serde_cycle_buf(Format::Head, num_frames);
test_serde_cycle_buf(Format::Foot, num_frames);
}
}
proptest! {
#[test]
fn serialize(num_frames in 0..4096u32, buf_len in 1..64usize) {
test_serialize(Format::Head, num_frames, buf_len);
test_serialize(Format::Foot, num_frames, buf_len);
}
#[test]
fn serde_cycle(num_frames in 0..4096u32) {
test_serde_cycle(Format::Head, num_frames);
test_serde_cycle(Format::Foot, num_frames);
}
#[test]
fn serialize_compatible_with_zstd_seekable(num_frames in 0..4096u32) {
test_serialize_compatible_with_zstd_seekable(num_frames);
}
#[test]
fn deserialize_compatible_with_zstd_seekable(num_frames in 0..4096u32) {
test_deserialize_compatible_with_zstd_seekable(num_frames);
}
}
}