use crate::compression::{self, should_auto_compress};
use crate::error::{ParxError, Result};
use crate::format::{Compression, Header, Trailer, HEADER_SIZE, MAGIC};
use crate::proto::ParxManifest;
use bytes::Bytes;
use prost::Message;
use std::fs::File;
use std::io::{Read, Seek, SeekFrom};
use std::path::Path;
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug)]
pub struct ParxWriter {
source_uri: String,
source_size: u64,
footer_bytes: Bytes,
compression: Option<Compression>,
page_index_bytes: Bytes,
}
impl ParxWriter {
#[inline]
pub fn new() -> Self {
Self {
source_uri: String::new(),
source_size: 0,
footer_bytes: Bytes::new(),
compression: None,
page_index_bytes: Bytes::new(),
}
}
pub fn from_parquet_bytes(data: &[u8]) -> Result<Self> {
if data.len() < 12 {
return Err(ParxError::FileTooSmall {
size: data.len(),
minimum: 12,
});
}
let mut head_magic = [0u8; 4];
head_magic.copy_from_slice(&data[0..4]);
if &head_magic != b"PAR1" {
return Err(ParxError::InvalidParquetMagic(head_magic));
}
let mut tail_magic = [0u8; 4];
tail_magic.copy_from_slice(&data[data.len() - 4..]);
if &tail_magic != b"PAR1" {
return Err(ParxError::InvalidParquetMagic(tail_magic));
}
let footer_len = u32::from_le_bytes([
data[data.len() - 8],
data[data.len() - 7],
data[data.len() - 6],
data[data.len() - 5],
]) as u64;
let file_size = data.len() as u64;
if footer_len + 12 > file_size {
return Err(ParxError::InvalidParquetFooterLength {
footer_len,
file_size,
});
}
let footer_start = data.len() - 8 - footer_len as usize;
let footer_bytes = &data[footer_start..data.len() - 8];
let mut writer = Self::new();
writer.set_source_size(file_size);
writer.set_footer(footer_bytes);
Ok(writer)
}
pub fn from_parquet_file(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
let (file_size, footer_bytes) = read_parquet_footer_from_file(path)?;
let mut writer = Self::new();
writer.set_source_size(file_size);
writer.set_footer_owned(footer_bytes);
writer.set_source_uri(path.display().to_string());
Ok(writer)
}
#[inline]
pub fn set_source_uri(&mut self, uri: impl Into<String>) {
self.source_uri = uri.into();
}
#[inline]
pub fn set_source_size(&mut self, size: u64) {
self.source_size = size;
}
#[inline]
pub fn set_footer(&mut self, bytes: &[u8]) {
self.footer_bytes = Bytes::copy_from_slice(bytes);
}
#[inline]
pub fn set_footer_owned(&mut self, bytes: impl Into<Bytes>) {
self.footer_bytes = bytes.into();
}
#[inline]
pub fn set_compression(&mut self, compression: Compression) {
self.compression = Some(compression);
}
#[inline]
pub fn clear_compression(&mut self) {
self.compression = None;
}
#[inline]
pub fn auto_compress(&mut self) {
if should_auto_compress(self.footer_bytes.len()) {
self.compression = Some(Compression::Zstd);
}
}
#[inline]
pub const fn compression(&self) -> Option<Compression> {
self.compression
}
#[inline]
pub const fn source_size(&self) -> u64 {
self.source_size
}
#[inline]
pub fn footer_size(&self) -> usize {
self.footer_bytes.len()
}
#[inline]
pub fn set_page_indexes(&mut self, bytes: &[u8]) {
self.page_index_bytes = Bytes::copy_from_slice(bytes);
}
#[inline]
pub fn set_page_indexes_owned(&mut self, bytes: impl Into<Bytes>) {
self.page_index_bytes = bytes.into();
}
#[inline]
pub fn has_page_indexes(&self) -> bool {
!self.page_index_bytes.is_empty()
}
pub fn finish(self) -> Vec<u8> {
let mut header = Header::new();
let source_footer_checksum = crc32c::crc32c(&self.footer_bytes).to_le_bytes().to_vec();
let (footer_payload, footer_uncompressed_size) = match self.compression {
Some(algo) => {
header.set_compression(algo);
let compressed = compression::compress(&self.footer_bytes, algo)
.expect("compression should not fail on valid data");
(compressed, self.footer_bytes.len() as u64)
}
None => (self.footer_bytes.to_vec(), 0),
};
let footer_checksum = crc32c::crc32c(&footer_payload).to_le_bytes().to_vec();
let footer_offset = HEADER_SIZE as u64;
let footer_length = footer_payload.len() as u64;
let page_index_offset = footer_offset + footer_length;
let (page_index_payload, page_index_uncompressed_size) = if self.page_index_bytes.is_empty()
{
(Vec::new(), 0)
} else {
match self.compression {
Some(algo) => {
let compressed = compression::compress(&self.page_index_bytes, algo)
.expect("compression should not fail on valid data");
(compressed, self.page_index_bytes.len() as u64)
}
None => (self.page_index_bytes.to_vec(), 0),
}
};
let page_index_length = page_index_payload.len() as u64;
let page_index_checksum = if page_index_length > 0 {
crc32c::crc32c(&page_index_payload).to_le_bytes().to_vec()
} else {
Vec::new()
};
#[allow(clippy::cast_possible_truncation)]
let created_at_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
let header_bytes = header.to_bytes();
let manifest = ParxManifest {
version: 1,
source_uri: self.source_uri,
source_size: self.source_size,
source_footer_checksum,
footer_offset,
footer_length,
footer_checksum,
created_at_ms,
footer_uncompressed_size,
page_index_offset,
page_index_length,
page_index_checksum,
page_index_uncompressed_size,
};
let manifest_bytes = manifest.encode_to_vec();
let manifest_crc = crc32c::crc32c(&manifest_bytes);
let manifest_len = u32::try_from(manifest_bytes.len()).expect("manifest too large (>4GB)");
let trailer = Trailer::new(manifest_len, manifest_crc, MAGIC);
let trailer_bytes = trailer.to_bytes();
let total_size = header_bytes.len()
+ footer_payload.len()
+ page_index_payload.len()
+ manifest_bytes.len()
+ trailer_bytes.len();
let mut output = Vec::with_capacity(total_size);
output.extend_from_slice(&header_bytes);
output.extend_from_slice(&footer_payload);
output.extend_from_slice(&page_index_payload);
output.extend_from_slice(&manifest_bytes);
output.extend_from_slice(&trailer_bytes);
output
}
}
impl Default for ParxWriter {
fn default() -> Self {
Self::new()
}
}
fn read_parquet_footer_from_file(path: &Path) -> Result<(u64, Bytes)> {
let mut file = File::open(path)?;
let file_size = file.metadata()?.len();
if file_size < 12 {
return Err(ParxError::FileTooSmall {
size: usize::try_from(file_size).unwrap_or(usize::MAX),
minimum: 12,
});
}
let mut head_magic = [0u8; 4];
file.read_exact(&mut head_magic)?;
if &head_magic != b"PAR1" {
return Err(ParxError::InvalidParquetMagic(head_magic));
}
file.seek(SeekFrom::End(-8))?;
let mut footer_trailer = [0u8; 8];
file.read_exact(&mut footer_trailer)?;
let footer_len = u32::from_le_bytes(footer_trailer[..4].try_into().expect("slice len")) as u64;
let tail_magic: [u8; 4] = footer_trailer[4..8].try_into().expect("slice len");
if &tail_magic != b"PAR1" {
return Err(ParxError::InvalidParquetMagic(tail_magic));
}
if footer_len + 12 > file_size {
return Err(ParxError::InvalidParquetFooterLength {
footer_len,
file_size,
});
}
let footer_start = file_size - 8 - footer_len;
file.seek(SeekFrom::Start(footer_start))?;
let mut footer = vec![0u8; footer_len as usize];
file.read_exact(&mut footer)?;
Ok((file_size, Bytes::from(footer)))
}
#[cfg(test)]
mod tests {
use super::*;
fn valid_parquet_bytes() -> Vec<u8> {
let footer = b"abc";
let mut data = Vec::new();
data.extend_from_slice(b"PAR1");
data.extend_from_slice(footer);
data.extend_from_slice(&(footer.len() as u32).to_le_bytes());
data.extend_from_slice(b"PAR1");
data
}
#[test]
fn test_writer_creates_valid_structure() {
let mut writer = ParxWriter::new();
writer.set_source_size(1000);
writer.set_footer(b"test footer");
let bytes = writer.finish();
assert_eq!(&bytes[0..4], b"PARX");
assert_eq!(&bytes[bytes.len() - 4..], b"PARX");
}
#[test]
fn test_set_footer_owned() {
let footer = vec![1, 2, 3, 4, 5];
let mut writer = ParxWriter::new();
writer.set_source_size(100);
writer.set_footer_owned(footer);
let bytes = writer.finish();
assert_eq!(&bytes[0..4], b"PARX");
}
#[test]
fn test_writer_with_compression() {
let footer = b"test footer data that will be compressed".repeat(100);
let mut writer = ParxWriter::new();
writer.set_source_size(1000);
writer.set_footer(&footer);
writer.set_compression(Compression::Zstd);
let bytes = writer.finish();
assert_eq!(&bytes[0..4], b"PARX");
let header = Header::from_bytes(bytes[..HEADER_SIZE].try_into().unwrap());
assert!(header.is_footer_compressed());
assert_eq!(header.compression_algorithm(), Some(Compression::Zstd));
}
#[test]
fn test_auto_compress() {
let mut writer = ParxWriter::new();
writer.set_footer(b"small");
writer.auto_compress();
assert!(writer.compression.is_none());
let mut writer = ParxWriter::new();
writer.set_footer(&vec![0u8; 20_000]);
writer.auto_compress();
assert_eq!(writer.compression, Some(Compression::Zstd));
}
#[test]
fn test_from_parquet_bytes() {
let data = valid_parquet_bytes();
let writer = ParxWriter::from_parquet_bytes(&data).unwrap();
assert_eq!(writer.source_size, data.len() as u64);
assert_eq!(writer.footer_bytes, Bytes::from_static(b"abc"));
}
#[test]
fn test_from_parquet_bytes_invalid_magic() {
let mut data = valid_parquet_bytes();
data[0..4].copy_from_slice(b"XXXX");
let err = ParxWriter::from_parquet_bytes(&data).unwrap_err();
assert!(matches!(
err,
ParxError::InvalidParquetMagic(m) if m == *b"XXXX"
));
}
#[test]
fn test_from_parquet_bytes_invalid_footer_length() {
let mut data = valid_parquet_bytes();
let file_size = data.len() as u64;
let data_len = data.len();
data[data_len - 8..data_len - 4].copy_from_slice(&(100u32).to_le_bytes());
let err = ParxWriter::from_parquet_bytes(&data).unwrap_err();
assert!(matches!(
err,
ParxError::InvalidParquetFooterLength {
footer_len: 100,
file_size: f
} if f == file_size
));
}
#[test]
fn test_from_parquet_file() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("data.parquet");
let data = valid_parquet_bytes();
std::fs::write(&path, data).unwrap();
let writer = ParxWriter::from_parquet_file(&path).unwrap();
assert_eq!(writer.source_uri, path.display().to_string());
assert_eq!(writer.footer_bytes, Bytes::from_static(b"abc"));
}
#[test]
fn test_from_parquet_file_invalid_footer_length() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("broken.parquet");
let mut data = valid_parquet_bytes();
let len = data.len();
data[len - 8..len - 4].copy_from_slice(&(100u32).to_le_bytes());
std::fs::write(&path, data).unwrap();
let err = ParxWriter::from_parquet_file(&path).unwrap_err();
assert!(matches!(
err,
ParxError::InvalidParquetFooterLength {
footer_len: 100,
..
}
));
}
}