use crate::deflate_encoder::DeflateEncoder;
use crate::error::ZipError;
use crate::header;
use flate2::Compression;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use super::directory_writer::DirectoryWriter;
use super::entry_writer::EntryWriter;
use super::helpers::CountWriter;
use super::stored_entry::StoredEntry;
pub struct ZipWriter<W: AsyncWrite + Unpin> {
pub(crate) inner: Option<W>,
pub(crate) entries: Vec<StoredEntry>,
level: Compression,
pub(crate) pos: u64,
pub(crate) poisoned: bool,
}
impl<W: AsyncWrite + Unpin> ZipWriter<W> {
pub fn new(inner: W) -> Self {
Self {
inner: Some(inner),
entries: Vec::new(),
level: Compression::default(),
pos: 0,
poisoned: false,
}
}
pub fn with_level(mut self, level: Compression) -> Self {
self.level = level;
self
}
pub async fn append_file<'a>(&'a mut self, name: &str) -> Result<EntryWriter<'a, W>, ZipError> {
let mut inner = self.inner.take().ok_or_else(|| {
if self.poisoned {
ZipError::Poisoned("previous entry was dropped without calling close()".to_string())
} else {
ZipError::InvalidState("entry writer already active".to_string())
}
})?;
let is_stored = self.level.level() == 0;
let method = if is_stored {
header::METHOD_STORED
} else {
header::METHOD_DEFLATE
};
let needs_zip64 = self.pos > header::U32_MAX;
let lfh = header::LocalFileHeader::new(name, method, needs_zip64);
let lfh_bytes = lfh.serialize()?;
inner.write_all(&lfh_bytes).await?;
let offset = self.pos;
self.pos += lfh_bytes.len() as u64;
let (deflate_encoder, passthrough) = if is_stored {
(None, Some(CountWriter { inner, count: 0 }))
} else {
(
Some(DeflateEncoder::new(
CountWriter { inner, count: 0 },
self.level,
)),
None,
)
};
Ok(EntryWriter {
zip: self,
deflate_encoder,
passthrough,
is_stored,
crc_hasher: crc32fast::Hasher::new(),
uncompressed_size: 0,
local_header_offset: offset,
name: name.to_string(),
mtime: None,
unix_permissions: None,
})
}
pub async fn append_directory<'a>(
&'a mut self,
name: &str,
) -> Result<DirectoryWriter<'a, W>, ZipError> {
let mut inner = self.inner.take().ok_or_else(|| {
if self.poisoned {
ZipError::Poisoned("previous entry was dropped without calling close()".to_string())
} else {
ZipError::InvalidState("entry writer already active".to_string())
}
})?;
let needs_zip64 = self.pos > header::U32_MAX;
let lfh = header::LocalFileHeader::new(name, header::METHOD_STORED, needs_zip64);
let lfh_bytes = lfh.serialize()?;
inner.write_all(&lfh_bytes).await?;
let offset = self.pos;
self.pos += lfh_bytes.len() as u64;
Ok(DirectoryWriter {
zip: self,
writer: Some(inner),
name: name.to_string(),
local_header_offset: offset,
mtime: None,
unix_permissions: None,
})
}
pub async fn append_symlink(&mut self, name: &str, target: &str) -> Result<(), ZipError> {
let mut inner = self.inner.take().ok_or_else(|| {
if self.poisoned {
ZipError::Poisoned("previous entry was dropped without calling close()".to_string())
} else {
ZipError::InvalidState("entry writer already active".to_string())
}
})?;
let needs_zip64 = self.pos > header::U32_MAX;
let lfh = header::LocalFileHeader::new(name, header::METHOD_STORED, needs_zip64);
let lfh_bytes = lfh.serialize()?;
inner.write_all(&lfh_bytes).await?;
let offset = self.pos;
self.pos += lfh_bytes.len() as u64;
let target_bytes = target.as_bytes();
inner.write_all(target_bytes).await?;
self.pos += target_bytes.len() as u64;
let mut hasher = crc32fast::Hasher::new();
hasher.update(target_bytes);
let crc32 = hasher.finalize();
let data_size = target_bytes.len() as u64;
let dd = header::DataDescriptor {
crc32,
compressed_size: data_size,
uncompressed_size: data_size,
zip64: data_size > header::U32_MAX || offset > header::U32_MAX,
};
let dd_bytes = dd.serialize();
inner.write_all(&dd_bytes).await.map_err(|e| {
self.poisoned = true;
ZipError::Io(e)
})?;
self.pos += dd_bytes.len() as u64;
self.entries.push(StoredEntry {
name: name.to_string(),
crc32,
compressed_size: data_size,
uncompressed_size: data_size,
local_header_offset: offset,
is_directory: false,
is_symlink: true,
is_stored: false,
mtime: None,
unix_mtime: None,
unix_permissions: None,
});
self.inner = Some(inner);
Ok(())
}
pub async fn finalize(mut self) -> Result<(), ZipError> {
let mut inner = self.inner.take().ok_or_else(|| {
if self.poisoned {
ZipError::Poisoned("previous entry was dropped without calling close()".to_string())
} else {
ZipError::InvalidState("entry writer still active".to_string())
}
})?;
let cd_offset = self.pos;
for entry in &self.entries {
let cd_entry = entry.to_central_dir_entry();
let data = cd_entry.serialize()?;
inner.write_all(&data).await?;
self.pos += data.len() as u64;
}
let cd_size = self.pos - cd_offset;
let total_entries = self.entries.len() as u64;
let needs_zip64 =
total_entries > 0xFFFF || cd_size > header::U32_MAX || cd_offset > header::U32_MAX;
if needs_zip64 {
let eocdr64 = header::Zip64Eocdr {
total_entries,
cd_size,
cd_offset,
};
let data = eocdr64.serialize();
let eocdr64_offset = self.pos;
inner.write_all(&data).await?;
self.pos += data.len() as u64;
let locator = header::Zip64EocdrLocator { eocdr64_offset };
inner.write_all(&locator.serialize()).await?;
self.pos += 20;
}
let eocdr = header::Eocdr {
total_entries,
cd_size,
cd_offset,
};
inner.write_all(&eocdr.serialize()).await?;
inner.shutdown().await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::writer::test_utils::lookup_entry;
use flate2::Compression;
use tokio::io::AsyncWriteExt;
#[tokio::test]
async fn test_zip_write_single_file() {
let mut buf = Vec::new();
let mut zip = ZipWriter::new(&mut buf);
let mut entry = zip.append_file("hello.txt").await.unwrap();
entry.write_all(b"Hello, World!").await.unwrap();
entry.close().await.unwrap();
zip.finalize().await.unwrap();
assert!(buf.len() > 30);
assert!(buf.windows(4).any(|w| w == b"PK\x03\x04"));
assert!(buf.windows(4).any(|w| w == b"PK\x01\x02"));
assert!(buf.windows(4).any(|w| w == b"PK\x05\x06"));
}
#[tokio::test]
async fn test_zip_write_multiple_files() {
let mut buf = Vec::new();
let mut zip = ZipWriter::new(&mut buf);
let mut entry = zip.append_file("a.txt").await.unwrap();
entry.write_all(b"aaa").await.unwrap();
entry.close().await.unwrap();
let mut entry = zip.append_file("b.txt").await.unwrap();
entry.write_all(b"bbb").await.unwrap();
entry.close().await.unwrap();
zip.finalize().await.unwrap();
let cd_count = buf.windows(4).filter(|w| w == b"PK\x01\x02").count();
assert_eq!(cd_count, 2);
}
#[tokio::test]
async fn test_zip_compression_ratio() {
let mut buf = Vec::new();
let mut zip = ZipWriter::new(&mut buf).with_level(Compression::best());
let data = vec![b'A'; 1024];
let mut entry = zip.append_file("repeated.txt").await.unwrap();
entry.write_all(&data).await.unwrap();
entry.close().await.unwrap();
zip.finalize().await.unwrap();
let entry = lookup_entry(&buf, 0);
assert!(
entry.compressed_size < entry.uncompressed_size,
"compressed {} >= uncompressed {}",
entry.compressed_size,
entry.uncompressed_size
);
}
#[tokio::test]
async fn test_symlink_entry() {
let mut buf = Vec::new();
let mut zip = ZipWriter::new(&mut buf);
zip.append_symlink("link.txt", "target.txt").await.unwrap();
zip.finalize().await.unwrap();
let pos = buf.windows(4).position(|w| w == b"PK\x01\x02").unwrap();
let cd = &buf[pos..];
let vmb = u16::from_le_bytes(cd[4..6].try_into().unwrap());
assert_eq!(vmb >> 8, 3, "expected Unix host OS for symlink");
let version_needed = u16::from_le_bytes(cd[6..8].try_into().unwrap());
assert_eq!(version_needed, 10, "expected VERSION_STORED for symlink");
let method = u16::from_le_bytes(cd[10..12].try_into().unwrap());
assert_eq!(method, 0, "expected METHOD_STORED for symlink");
let efa = u32::from_le_bytes(cd[38..42].try_into().unwrap());
assert!(
(efa >> 16) & 0o170000 == 0o120000,
"expected S_IFLNK in external_file_attributes, got {:06o}",
efa >> 16
);
let lfh_pos = buf.windows(4).position(|w| w == b"PK\x03\x04").unwrap();
let lfh = &buf[lfh_pos..];
let lfh_name_len = u16::from_le_bytes(lfh[26..28].try_into().unwrap()) as usize;
let lfh_extra_len = u16::from_le_bytes(lfh[28..30].try_into().unwrap()) as usize;
let lfh_total = 30 + lfh_name_len + lfh_extra_len;
let data = &buf[lfh_pos + lfh_total..lfh_pos + lfh_total + 10];
assert_eq!(data, b"target.txt", "symlink target mismatch");
}
#[tokio::test]
async fn test_zip64_finalize_many_entries() {
let num_entries: u16 = 0xFFFF;
let mut buf = Vec::new();
let mut zip = ZipWriter::new(&mut buf).with_level(Compression::none());
for i in 0..=num_entries {
let name = format!("f{i}");
let mut entry = zip.append_file(&name).await.unwrap();
entry.write_all(b"x").await.unwrap();
entry.close().await.unwrap();
}
zip.finalize().await.unwrap();
let eocdr_pos = buf.windows(4).rposition(|w| w == b"PK\x05\x06").unwrap();
let eocdr_end = &buf[eocdr_pos..];
assert_eq!(
u16::from_le_bytes(eocdr_end[8..10].try_into().unwrap()),
0xFFFF,
"EOCDR total_entries should be sentinel 0xFFFF for ZIP64"
);
let locator_pos = buf.windows(4).rposition(|w| w == b"PK\x06\x07").unwrap();
assert_eq!(&buf[locator_pos..locator_pos + 4], b"PK\x06\x07");
let z64_pos = buf.windows(4).rposition(|w| w == b"PK\x06\x06").unwrap();
assert_eq!(&buf[z64_pos..z64_pos + 4], b"PK\x06\x06");
assert!(
z64_pos < locator_pos && locator_pos < eocdr_pos,
"expected Zip64Eocdr < Zip64EocdrLocator < Eocdr, got {z64_pos} < {locator_pos} < {eocdr_pos}"
);
let cd_count = buf.windows(4).filter(|w| w == b"PK\x01\x02").count();
assert_eq!(cd_count, num_entries as usize + 1);
assert_eq!(
&buf[33..37],
b"PK\x07\x08",
"first entry should have DD signature"
);
assert_eq!(
&buf[49..53],
b"PK\x03\x04",
"next LFH at offset 49 confirms 16-byte DD (non-ZIP64) for small-entry ZIP64 archive"
);
}
#[tokio::test]
async fn test_stored_entry_level_zero() {
let mut buf = Vec::new();
let mut zip = ZipWriter::new(&mut buf).with_level(Compression::none());
let data = b"Hello, stored entry!";
let mut entry = zip.append_file("stored.txt").await.unwrap();
entry.write_all(data).await.unwrap();
entry.close().await.unwrap();
zip.finalize().await.unwrap();
let pos = buf.windows(4).position(|w| w == b"PK\x01\x02").unwrap();
let cd = &buf[pos..];
let method = u16::from_le_bytes(cd[10..12].try_into().unwrap());
assert_eq!(method, 0, "expected METHOD_STORED for level=0 entry");
let version_needed = u16::from_le_bytes(cd[6..8].try_into().unwrap());
assert_eq!(
version_needed, 10,
"expected VERSION_STORED for level=0 entry"
);
let compressed_size = u32::from_le_bytes(cd[20..24].try_into().unwrap()) as u64;
let uncompressed_size = u32::from_le_bytes(cd[24..28].try_into().unwrap()) as u64;
assert_eq!(
compressed_size, uncompressed_size,
"stored entry should have equal compressed and uncompressed sizes"
);
assert_eq!(compressed_size, data.len() as u64);
let lfh_pos = buf.windows(4).position(|w| w == b"PK\x03\x04").unwrap();
let lfh_method = u16::from_le_bytes(buf[lfh_pos + 8..lfh_pos + 10].try_into().unwrap());
assert_eq!(lfh_method, 0, "LFH method should be STORED for level=0");
}
}