async_deflate_zip/writer/
zip_writer.rs1use crate::deflate_encoder::DeflateEncoder;
2use crate::error::ZipError;
3use crate::header;
4
5use flate2::Compression;
6use tokio::io::{AsyncWrite, AsyncWriteExt};
7
8use super::directory_writer::DirectoryWriter;
9use super::entry_writer::EntryWriter;
10use super::helpers::CountWriter;
11use super::stored_entry::StoredEntry;
12
13pub struct ZipWriter<W: AsyncWrite + Unpin> {
38 pub(crate) inner: Option<W>,
39 pub(crate) entries: Vec<StoredEntry>,
40 level: Compression,
41 pub(crate) pos: u64,
42 pub(crate) poisoned: bool,
43}
44
45impl<W: AsyncWrite + Unpin> ZipWriter<W> {
46 pub fn new(inner: W) -> Self {
51 Self {
52 inner: Some(inner),
53 entries: Vec::new(),
54 level: Compression::default(),
55 pos: 0,
56 poisoned: false,
57 }
58 }
59
60 pub fn with_level(mut self, level: Compression) -> Self {
74 self.level = level;
75 self
76 }
77
78 pub async fn append_file<'a>(&'a mut self, name: &str) -> Result<EntryWriter<'a, W>, ZipError> {
105 let mut inner = self.inner.take().ok_or_else(|| {
106 if self.poisoned {
107 ZipError::Poisoned("previous entry was dropped without calling close()".to_string())
108 } else {
109 ZipError::InvalidState("entry writer already active".to_string())
110 }
111 })?;
112
113 let is_stored = self.level.level() == 0;
114 let method = if is_stored {
115 header::METHOD_STORED
116 } else {
117 header::METHOD_DEFLATE
118 };
119
120 let needs_zip64 = self.pos > header::U32_MAX;
121 let lfh = header::LocalFileHeader::new(name, method, needs_zip64);
122 let lfh_bytes = lfh.serialize()?;
123 inner.write_all(&lfh_bytes).await?;
124 let offset = self.pos;
125 self.pos += lfh_bytes.len() as u64;
126
127 let (deflate_encoder, passthrough) = if is_stored {
128 (None, Some(CountWriter { inner, count: 0 }))
129 } else {
130 (
131 Some(DeflateEncoder::new(
132 CountWriter { inner, count: 0 },
133 self.level,
134 )),
135 None,
136 )
137 };
138
139 Ok(EntryWriter {
140 zip: self,
141 deflate_encoder,
142 passthrough,
143 is_stored,
144 crc_hasher: crc32fast::Hasher::new(),
145 uncompressed_size: 0,
146 local_header_offset: offset,
147 name: name.to_string(),
148 mtime: None,
149 unix_permissions: None,
150 })
151 }
152
153 pub async fn append_directory<'a>(
178 &'a mut self,
179 name: &str,
180 ) -> Result<DirectoryWriter<'a, W>, ZipError> {
181 let mut inner = self.inner.take().ok_or_else(|| {
182 if self.poisoned {
183 ZipError::Poisoned("previous entry was dropped without calling close()".to_string())
184 } else {
185 ZipError::InvalidState("entry writer already active".to_string())
186 }
187 })?;
188 let needs_zip64 = self.pos > header::U32_MAX;
189 let lfh = header::LocalFileHeader::new(name, header::METHOD_STORED, needs_zip64);
190 let lfh_bytes = lfh.serialize()?;
191 inner.write_all(&lfh_bytes).await?;
192 let offset = self.pos;
193 self.pos += lfh_bytes.len() as u64;
194
195 Ok(DirectoryWriter {
196 zip: self,
197 writer: Some(inner),
198 name: name.to_string(),
199 local_header_offset: offset,
200 mtime: None,
201 unix_permissions: None,
202 })
203 }
204
205 pub async fn append_symlink(&mut self, name: &str, target: &str) -> Result<(), ZipError> {
232 let mut inner = self.inner.take().ok_or_else(|| {
233 if self.poisoned {
234 ZipError::Poisoned("previous entry was dropped without calling close()".to_string())
235 } else {
236 ZipError::InvalidState("entry writer already active".to_string())
237 }
238 })?;
239 let needs_zip64 = self.pos > header::U32_MAX;
240 let lfh = header::LocalFileHeader::new(name, header::METHOD_STORED, needs_zip64);
241 let lfh_bytes = lfh.serialize()?;
242 inner.write_all(&lfh_bytes).await?;
243 let offset = self.pos;
244 self.pos += lfh_bytes.len() as u64;
245
246 let target_bytes = target.as_bytes();
248 inner.write_all(target_bytes).await?;
249 self.pos += target_bytes.len() as u64;
250
251 let mut hasher = crc32fast::Hasher::new();
253 hasher.update(target_bytes);
254 let crc32 = hasher.finalize();
255 let data_size = target_bytes.len() as u64;
256
257 let dd = header::DataDescriptor {
258 crc32,
259 compressed_size: data_size,
260 uncompressed_size: data_size,
261 zip64: data_size > header::U32_MAX || offset > header::U32_MAX,
262 };
263 let dd_bytes = dd.serialize();
264 inner.write_all(&dd_bytes).await.map_err(|e| {
265 self.poisoned = true;
266 ZipError::Io(e)
267 })?;
268 self.pos += dd_bytes.len() as u64;
269
270 self.entries.push(StoredEntry {
271 name: name.to_string(),
272 crc32,
273 compressed_size: data_size,
274 uncompressed_size: data_size,
275 local_header_offset: offset,
276 is_directory: false,
277 is_symlink: true,
278 is_stored: false,
279 mtime: None,
280 unix_mtime: None,
281 unix_permissions: None,
282 });
283 self.inner = Some(inner);
284 Ok(())
285 }
286
287 pub async fn finalize(mut self) -> Result<(), ZipError> {
302 let mut inner = self.inner.take().ok_or_else(|| {
303 if self.poisoned {
304 ZipError::Poisoned("previous entry was dropped without calling close()".to_string())
305 } else {
306 ZipError::InvalidState("entry writer still active".to_string())
307 }
308 })?;
309 let cd_offset = self.pos;
310
311 for entry in &self.entries {
312 let cd_entry = entry.to_central_dir_entry();
313 let data = cd_entry.serialize()?;
314 inner.write_all(&data).await?;
315 self.pos += data.len() as u64;
316 }
317
318 let cd_size = self.pos - cd_offset;
319 let total_entries = self.entries.len() as u64;
320 let needs_zip64 =
321 total_entries > 0xFFFF || cd_size > header::U32_MAX || cd_offset > header::U32_MAX;
322
323 if needs_zip64 {
324 let eocdr64 = header::Zip64Eocdr {
325 total_entries,
326 cd_size,
327 cd_offset,
328 };
329 let data = eocdr64.serialize();
330 let eocdr64_offset = self.pos;
331 inner.write_all(&data).await?;
332 self.pos += data.len() as u64;
333
334 let locator = header::Zip64EocdrLocator { eocdr64_offset };
335 inner.write_all(&locator.serialize()).await?;
336 self.pos += 20;
337 }
338
339 let eocdr = header::Eocdr {
340 total_entries,
341 cd_size,
342 cd_offset,
343 };
344 inner.write_all(&eocdr.serialize()).await?;
345 inner.shutdown().await?;
346 Ok(())
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353 use crate::writer::test_utils::lookup_entry;
354 use flate2::Compression;
355 use tokio::io::AsyncWriteExt;
356
357 #[tokio::test]
358 async fn test_zip_write_single_file() {
359 let mut buf = Vec::new();
360 let mut zip = ZipWriter::new(&mut buf);
361 let mut entry = zip.append_file("hello.txt").await.unwrap();
362 entry.write_all(b"Hello, World!").await.unwrap();
363 entry.close().await.unwrap();
364 zip.finalize().await.unwrap();
365
366 assert!(buf.len() > 30);
367 assert!(buf.windows(4).any(|w| w == b"PK\x03\x04"));
368 assert!(buf.windows(4).any(|w| w == b"PK\x01\x02"));
369 assert!(buf.windows(4).any(|w| w == b"PK\x05\x06"));
370 }
371
372 #[tokio::test]
373 async fn test_zip_write_multiple_files() {
374 let mut buf = Vec::new();
375 let mut zip = ZipWriter::new(&mut buf);
376
377 let mut entry = zip.append_file("a.txt").await.unwrap();
378 entry.write_all(b"aaa").await.unwrap();
379 entry.close().await.unwrap();
380
381 let mut entry = zip.append_file("b.txt").await.unwrap();
382 entry.write_all(b"bbb").await.unwrap();
383 entry.close().await.unwrap();
384
385 zip.finalize().await.unwrap();
386 let cd_count = buf.windows(4).filter(|w| w == b"PK\x01\x02").count();
387 assert_eq!(cd_count, 2);
388 }
389
390 #[tokio::test]
391 async fn test_zip_compression_ratio() {
392 let mut buf = Vec::new();
393 let mut zip = ZipWriter::new(&mut buf).with_level(Compression::best());
394
395 let data = vec![b'A'; 1024];
396 let mut entry = zip.append_file("repeated.txt").await.unwrap();
397 entry.write_all(&data).await.unwrap();
398 entry.close().await.unwrap();
399 zip.finalize().await.unwrap();
400
401 let entry = lookup_entry(&buf, 0);
402 assert!(
403 entry.compressed_size < entry.uncompressed_size,
404 "compressed {} >= uncompressed {}",
405 entry.compressed_size,
406 entry.uncompressed_size
407 );
408 }
409
410 #[tokio::test]
411 async fn test_symlink_entry() {
412 let mut buf = Vec::new();
413 let mut zip = ZipWriter::new(&mut buf);
414 zip.append_symlink("link.txt", "target.txt").await.unwrap();
415 zip.finalize().await.unwrap();
416
417 let pos = buf.windows(4).position(|w| w == b"PK\x01\x02").unwrap();
418 let cd = &buf[pos..];
419
420 let vmb = u16::from_le_bytes(cd[4..6].try_into().unwrap());
421 assert_eq!(vmb >> 8, 3, "expected Unix host OS for symlink");
422
423 let version_needed = u16::from_le_bytes(cd[6..8].try_into().unwrap());
424 assert_eq!(version_needed, 10, "expected VERSION_STORED for symlink");
425
426 let method = u16::from_le_bytes(cd[10..12].try_into().unwrap());
427 assert_eq!(method, 0, "expected METHOD_STORED for symlink");
428
429 let efa = u32::from_le_bytes(cd[38..42].try_into().unwrap());
430 assert!(
431 (efa >> 16) & 0o170000 == 0o120000,
432 "expected S_IFLNK in external_file_attributes, got {:06o}",
433 efa >> 16
434 );
435
436 let lfh_pos = buf.windows(4).position(|w| w == b"PK\x03\x04").unwrap();
437 let lfh = &buf[lfh_pos..];
438 let lfh_name_len = u16::from_le_bytes(lfh[26..28].try_into().unwrap()) as usize;
439 let lfh_extra_len = u16::from_le_bytes(lfh[28..30].try_into().unwrap()) as usize;
440 let lfh_total = 30 + lfh_name_len + lfh_extra_len;
441 let data = &buf[lfh_pos + lfh_total..lfh_pos + lfh_total + 10];
442 assert_eq!(data, b"target.txt", "symlink target mismatch");
443 }
444
445 #[tokio::test]
446 async fn test_zip64_finalize_many_entries() {
447 let num_entries: u16 = 0xFFFF;
448 let mut buf = Vec::new();
449 let mut zip = ZipWriter::new(&mut buf).with_level(Compression::none());
450
451 for i in 0..=num_entries {
452 let name = format!("f{i}");
453 let mut entry = zip.append_file(&name).await.unwrap();
454 entry.write_all(b"x").await.unwrap();
455 entry.close().await.unwrap();
456 }
457
458 zip.finalize().await.unwrap();
459
460 let eocdr_pos = buf.windows(4).rposition(|w| w == b"PK\x05\x06").unwrap();
461 let eocdr_end = &buf[eocdr_pos..];
462 assert_eq!(
463 u16::from_le_bytes(eocdr_end[8..10].try_into().unwrap()),
464 0xFFFF,
465 "EOCDR total_entries should be sentinel 0xFFFF for ZIP64"
466 );
467
468 let locator_pos = buf.windows(4).rposition(|w| w == b"PK\x06\x07").unwrap();
469 assert_eq!(&buf[locator_pos..locator_pos + 4], b"PK\x06\x07");
470
471 let z64_pos = buf.windows(4).rposition(|w| w == b"PK\x06\x06").unwrap();
472 assert_eq!(&buf[z64_pos..z64_pos + 4], b"PK\x06\x06");
473
474 assert!(
475 z64_pos < locator_pos && locator_pos < eocdr_pos,
476 "expected Zip64Eocdr < Zip64EocdrLocator < Eocdr, got {z64_pos} < {locator_pos} < {eocdr_pos}"
477 );
478
479 let cd_count = buf.windows(4).filter(|w| w == b"PK\x01\x02").count();
480 assert_eq!(cd_count, num_entries as usize + 1);
481
482 assert_eq!(
483 &buf[33..37],
484 b"PK\x07\x08",
485 "first entry should have DD signature"
486 );
487 assert_eq!(
488 &buf[49..53],
489 b"PK\x03\x04",
490 "next LFH at offset 49 confirms 16-byte DD (non-ZIP64) for small-entry ZIP64 archive"
491 );
492 }
493
494 #[tokio::test]
495 async fn test_stored_entry_level_zero() {
496 let mut buf = Vec::new();
497 let mut zip = ZipWriter::new(&mut buf).with_level(Compression::none());
498
499 let data = b"Hello, stored entry!";
500 let mut entry = zip.append_file("stored.txt").await.unwrap();
501 entry.write_all(data).await.unwrap();
502 entry.close().await.unwrap();
503 zip.finalize().await.unwrap();
504
505 let pos = buf.windows(4).position(|w| w == b"PK\x01\x02").unwrap();
506 let cd = &buf[pos..];
507 let method = u16::from_le_bytes(cd[10..12].try_into().unwrap());
508 assert_eq!(method, 0, "expected METHOD_STORED for level=0 entry");
509 let version_needed = u16::from_le_bytes(cd[6..8].try_into().unwrap());
510 assert_eq!(
511 version_needed, 10,
512 "expected VERSION_STORED for level=0 entry"
513 );
514
515 let compressed_size = u32::from_le_bytes(cd[20..24].try_into().unwrap()) as u64;
516 let uncompressed_size = u32::from_le_bytes(cd[24..28].try_into().unwrap()) as u64;
517 assert_eq!(
518 compressed_size, uncompressed_size,
519 "stored entry should have equal compressed and uncompressed sizes"
520 );
521 assert_eq!(compressed_size, data.len() as u64);
522
523 let lfh_pos = buf.windows(4).position(|w| w == b"PK\x03\x04").unwrap();
524 let lfh_method = u16::from_le_bytes(buf[lfh_pos + 8..lfh_pos + 10].try_into().unwrap());
525 assert_eq!(lfh_method, 0, "LFH method should be STORED for level=0");
526 }
527}