Skip to main content

async_deflate_zip/writer/
entry_writer.rs

1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use crate::error::ZipError;
6use crate::header;
7
8use crate::deflate_encoder::DeflateEncoder;
9use tokio::io::{AsyncWrite, AsyncWriteExt};
10
11use super::helpers::CountWriter;
12use super::stored_entry::StoredEntry;
13use super::zip_writer::ZipWriter;
14
15pin_project_lite::pin_project! {
16    /// A streaming writer for a single file entry in a ZIP archive.
17    ///
18    /// Obtained from [`ZipWriter::append_file`]. Data written through this
19    /// writer is compressed with DEFLATE and streamed to the underlying output.
20    ///
21    /// # Important
22    ///
23    /// The [`close`](EntryWriter::close) method **must** be called after all
24    /// data is written to finalize the deflate frame and write the Data
25    /// Descriptor. Dropping without closing will lose the entry.
26    pub struct EntryWriter<'a, W>
27    where
28        W: AsyncWrite,
29        W: Unpin,
30    {
31        pub(crate) zip: &'a mut ZipWriter<W>,
32        #[pin]
33        pub(crate) deflate_encoder: Option<DeflateEncoder<CountWriter<W>>>,
34        #[pin]
35        pub(crate) passthrough: Option<CountWriter<W>>,
36        pub(crate) is_stored: bool,
37        pub(crate) crc_hasher: crc32fast::Hasher,
38        pub(crate) uncompressed_size: u64,
39        pub(crate) local_header_offset: u64,
40        pub(crate) name: String,
41        pub(crate) mtime: Option<std::time::SystemTime>,
42        pub(crate) unix_permissions: Option<u32>,
43    }
44
45    impl<W> PinnedDrop for EntryWriter<'_, W>
46    where
47        W: AsyncWrite,
48        W: Unpin,
49    {
50        fn drop(this: Pin<&mut Self>) {
51            let this = this.project();
52            if this.deflate_encoder.is_some() || this.passthrough.is_some() {
53                // close() was never called — mark the ZipWriter as poisoned
54                this.zip.poisoned = true;
55            }
56        }
57    }
58}
59
60impl<W: AsyncWrite + Unpin> EntryWriter<'_, W> {
61    /// Set the modification time for this entry.
62    ///
63    /// The timestamp is stored in the Central Directory using MS-DOS format
64    /// and as an extended timestamp extra field (0x5455).
65    pub fn set_mtime(&mut self, mtime: std::time::SystemTime) -> &mut Self {
66        self.mtime = Some(mtime);
67        self
68    }
69
70    /// Set Unix file permissions for this entry.
71    ///
72    /// Provide permission bits including setuid/setgid/sticky (e.g., `0o4755`, `0o2755`).
73    /// The crate automatically adds the file type bit (`S_IFREG` for files,
74    /// `S_IFDIR` for directories).
75    pub fn set_permissions(&mut self, mode: u32) -> &mut Self {
76        self.unix_permissions = Some(mode & 0o7777);
77        self
78    }
79
80    /// Finalize the deflate frame, compute the CRC-32 checksum, and write
81    /// the Data Descriptor.
82    ///
83    /// This consumes the `EntryWriter`, flushes the deflate encoder, extracts
84    /// the compressed size, computes the CRC-32 of the uncompressed data, and
85    /// writes the trailing CRC-32 and sizes (Data Descriptor) after the compressed
86    /// data. The inner writer is returned to the parent [`ZipWriter`].
87    ///
88    /// # Errors
89    ///
90    /// Returns [`ZipError`] if `close` is called more than once, if the deflate
91    /// encoder fails to shut down (I/O error), or if writing the Data
92    /// Descriptor fails (I/O error).
93    pub async fn close(mut self) -> Result<(), ZipError> {
94        let (compressed_size, mut inner) = if self.is_stored {
95            let cw = self
96                .passthrough
97                .take()
98                .ok_or_else(|| ZipError::InvalidState("entry already closed".to_string()))?;
99            (cw.count, cw.inner)
100        } else {
101            let mut encoder = self
102                .deflate_encoder
103                .take()
104                .ok_or_else(|| ZipError::InvalidState("entry already closed".to_string()))?;
105            encoder.shutdown().await?;
106
107            // Extract the inner writer from the encoder stack
108            let count_writer: CountWriter<W> = encoder.into_inner();
109            let compressed_size = count_writer.count;
110            (compressed_size, count_writer.inner)
111        };
112
113        let crc32 = self.crc_hasher.clone().finalize();
114
115        let dd = header::DataDescriptor {
116            crc32,
117            compressed_size,
118            uncompressed_size: self.uncompressed_size,
119            // Use ZIP64 DD when any entry field exceeds 32 bits, consistent with
120            // CentralDirEntry::serialize() which also checks local_header_offset.
121            zip64: compressed_size > header::U32_MAX
122                || self.uncompressed_size > header::U32_MAX
123                || self.local_header_offset > header::U32_MAX,
124        };
125        let dd_bytes = dd.serialize();
126        inner.write_all(&dd_bytes).await.map_err(|e| {
127            self.zip.poisoned = true;
128            ZipError::Io(e)
129        })?;
130
131        // Update position tracker: compressed data + data descriptor
132        self.zip.pos += compressed_size + dd_bytes.len() as u64;
133
134        let (mtime_msdos, unix_mtime) = header::mtime_to_ms_dos_and_unix(self.mtime);
135
136        self.zip.entries.push(StoredEntry {
137            name: self.name.clone(),
138            crc32,
139            compressed_size,
140            uncompressed_size: self.uncompressed_size,
141            local_header_offset: self.local_header_offset,
142            is_directory: false,
143            is_symlink: false,
144            is_stored: self.is_stored,
145            mtime: mtime_msdos,
146            unix_mtime,
147            unix_permissions: self.unix_permissions,
148        });
149
150        // Return the inner writer to ZipWriter
151        self.zip.inner = Some(inner);
152        Ok(())
153    }
154}
155
156impl<W: AsyncWrite + Unpin> AsyncWrite for EntryWriter<'_, W> {
157    fn poll_write(
158        self: Pin<&mut Self>,
159        cx: &mut Context<'_>,
160        buf: &[u8],
161    ) -> Poll<io::Result<usize>> {
162        let this = self.project();
163        let result = if *this.is_stored {
164            match this.passthrough.as_pin_mut() {
165                Some(w) => w.poll_write(cx, buf),
166                None => {
167                    this.zip.poisoned = true;
168                    return Poll::Ready(Err(ZipError::Poisoned(
169                        "write after entry closed".to_string(),
170                    )
171                    .into()));
172                }
173            }
174        } else {
175            match this.deflate_encoder.as_pin_mut() {
176                Some(e) => e.poll_write(cx, buf),
177                None => {
178                    this.zip.poisoned = true;
179                    return Poll::Ready(Err(ZipError::Poisoned(
180                        "write after entry closed".to_string(),
181                    )
182                    .into()));
183                }
184            }
185        };
186        match result {
187            Poll::Ready(Ok(n)) => {
188                this.crc_hasher.update(&buf[..n]);
189                *this.uncompressed_size += n as u64;
190                Poll::Ready(Ok(n))
191            }
192            other => other,
193        }
194    }
195
196    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
197        let this = self.project();
198        if *this.is_stored {
199            match this.passthrough.as_pin_mut() {
200                Some(w) => w.poll_flush(cx),
201                None => {
202                    this.zip.poisoned = true;
203                    Poll::Ready(Err(ZipError::Poisoned(
204                        "flush after entry closed".to_string(),
205                    )
206                    .into()))
207                }
208            }
209        } else {
210            match this.deflate_encoder.as_pin_mut() {
211                Some(e) => e.poll_flush(cx),
212                None => {
213                    this.zip.poisoned = true;
214                    Poll::Ready(Err(ZipError::Poisoned(
215                        "flush after entry closed".to_string(),
216                    )
217                    .into()))
218                }
219            }
220        }
221    }
222
223    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
224        let this = self.project();
225        if *this.is_stored {
226            match this.passthrough.as_pin_mut() {
227                Some(w) => w.poll_shutdown(cx),
228                None => {
229                    this.zip.poisoned = true;
230                    Poll::Ready(Err(ZipError::Poisoned(
231                        "shutdown after entry closed".to_string(),
232                    )
233                    .into()))
234                }
235            }
236        } else {
237            match this.deflate_encoder.as_pin_mut() {
238                Some(e) => e.poll_shutdown(cx),
239                None => {
240                    this.zip.poisoned = true;
241                    Poll::Ready(Err(ZipError::Poisoned(
242                        "shutdown after entry closed".to_string(),
243                    )
244                    .into()))
245                }
246            }
247        }
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::super::*;
254    use crate::header;
255    use tokio::io::AsyncWriteExt;
256
257    #[tokio::test]
258    async fn test_entry_mtime_epoch() {
259        let mut buf = Vec::new();
260        let mut zip = ZipWriter::new(&mut buf);
261        let mut entry = zip.append_file("epoch.txt").await.unwrap();
262        entry.set_mtime(std::time::SystemTime::UNIX_EPOCH);
263        entry.write_all(b"test").await.unwrap();
264        entry.close().await.unwrap();
265        zip.finalize().await.unwrap();
266
267        let pos = buf.windows(4).position(|w| w == b"PK\x01\x02").unwrap();
268        let cd = &buf[pos..];
269
270        let time = u16::from_le_bytes(cd[12..14].try_into().unwrap());
271        let date = u16::from_le_bytes(cd[14..16].try_into().unwrap());
272        let local_offset = time::UtcOffset::current_local_offset().unwrap_or(time::UtcOffset::UTC);
273        let local_epoch =
274            time::OffsetDateTime::from(std::time::SystemTime::UNIX_EPOCH).to_offset(local_offset);
275        let expected_time = (local_epoch.hour() as u16) << 11
276            | (local_epoch.minute() as u16) << 5
277            | (local_epoch.second() as u16 / 2);
278        assert_eq!(time, expected_time, "expected local time for epoch");
279        assert_eq!(date, (1 << 5) | 1, "expected MS-DOS date for 1980-01-01");
280    }
281
282    #[tokio::test]
283    async fn test_entry_permissions() {
284        let mut buf = Vec::new();
285        let mut zip = ZipWriter::new(&mut buf);
286        let mut entry = zip.append_file("perm_test.txt").await.unwrap();
287        entry.set_permissions(0o644);
288        entry.write_all(b"test").await.unwrap();
289        entry.close().await.unwrap();
290        zip.finalize().await.unwrap();
291
292        let pos = buf.windows(4).position(|w| w == b"PK\x01\x02").unwrap();
293        let cd = &buf[pos..];
294        let efa = u32::from_le_bytes(cd[38..42].try_into().unwrap());
295        assert_eq!(efa, ((0o644 | 0o100000) as u32) << 16);
296        let vmb = u16::from_le_bytes(cd[4..6].try_into().unwrap());
297        assert!(vmb >> 8 == 3, "expected Unix host OS");
298    }
299
300    #[tokio::test]
301    async fn test_entry_setuid_permissions() {
302        let mut buf = Vec::new();
303        let mut zip = ZipWriter::new(&mut buf);
304        let mut entry = zip.append_file("setuid_test.txt").await.unwrap();
305        entry.set_permissions(0o4755);
306        entry.write_all(b"test").await.unwrap();
307        entry.close().await.unwrap();
308        zip.finalize().await.unwrap();
309
310        let pos = buf.windows(4).position(|w| w == b"PK\x01\x02").unwrap();
311        let cd = &buf[pos..];
312        let efa = u32::from_le_bytes(cd[38..42].try_into().unwrap());
313        assert_eq!(efa, ((0o4755 | 0o100000) as u32) << 16);
314    }
315
316    #[tokio::test]
317    async fn test_entry_mtime_and_permissions() {
318        let mut buf = Vec::new();
319        let mut zip = ZipWriter::new(&mut buf);
320        let mut entry = zip.append_file("both.txt").await.unwrap();
321        entry.set_mtime(std::time::SystemTime::UNIX_EPOCH);
322        entry.set_permissions(0o755);
323        entry.write_all(b"test").await.unwrap();
324        entry.close().await.unwrap();
325        zip.finalize().await.unwrap();
326
327        let pos = buf.windows(4).position(|w| w == b"PK\x01\x02").unwrap();
328        let cd = &buf[pos..];
329        let time = u16::from_le_bytes(cd[12..14].try_into().unwrap());
330        let local_offset = time::UtcOffset::current_local_offset().unwrap_or(time::UtcOffset::UTC);
331        let local_epoch =
332            time::OffsetDateTime::from(std::time::SystemTime::UNIX_EPOCH).to_offset(local_offset);
333        let expected_time = (local_epoch.hour() as u16) << 11
334            | (local_epoch.minute() as u16) << 5
335            | (local_epoch.second() as u16 / 2);
336        assert_eq!(time, expected_time);
337        let efa = u32::from_le_bytes(cd[38..42].try_into().unwrap());
338        assert_eq!(efa, ((0o755 | 0o100000) as u32) << 16);
339        let vmb = u16::from_le_bytes(cd[4..6].try_into().unwrap());
340        assert!(
341            vmb >> 8 == 3,
342            "expected version_made_by upper byte = 3 (Unix), got {}",
343            vmb >> 8
344        );
345    }
346
347    #[tokio::test]
348    async fn test_entry_mtime_appears_in_cd_extra() {
349        let mut buf = Vec::new();
350        let mut zip = ZipWriter::new(&mut buf);
351        let mut entry = zip.append_file("mtime_test.txt").await.unwrap();
352        entry.set_mtime(std::time::SystemTime::UNIX_EPOCH);
353        entry.write_all(b"hello").await.unwrap();
354        entry.close().await.unwrap();
355        zip.finalize().await.unwrap();
356
357        let pos = buf.windows(4).position(|w| w == b"PK\x01\x02").unwrap();
358        let cd = &buf[pos..];
359        let name_len = u16::from_le_bytes(cd[28..30].try_into().unwrap()) as usize;
360        let extra_len = u16::from_le_bytes(cd[30..32].try_into().unwrap()) as usize;
361
362        let extra_start = 46 + name_len;
363        let extra = &cd[extra_start..extra_start + extra_len];
364        let has_ts_extra = extra.windows(2).any(|w| w == b"UT");
365        assert!(
366            has_ts_extra,
367            "CD entry extra should contain extended timestamp (0x5455/UT) when mtime is set"
368        );
369        assert!(
370            extra_len >= 4,
371            "extra_len should be >= 4 when mtime is set, got {extra_len}"
372        );
373        let vmb = u16::from_le_bytes(cd[4..6].try_into().unwrap());
374        assert_eq!(vmb >> 8, 3, "expected Unix host OS when mtime is set");
375    }
376
377    #[tokio::test]
378    async fn test_entry_default_no_metadata() {
379        let mut buf = Vec::new();
380        let mut zip = ZipWriter::new(&mut buf);
381        let mut entry = zip.append_file("default.txt").await.unwrap();
382        entry.write_all(b"test").await.unwrap();
383        entry.close().await.unwrap();
384        zip.finalize().await.unwrap();
385
386        let pos = buf.windows(4).position(|w| w == b"PK\x01\x02").unwrap();
387        let cd = &buf[pos..];
388        let efa = u32::from_le_bytes(cd[38..42].try_into().unwrap());
389        assert_eq!(efa, 0);
390        let vmb = u16::from_le_bytes(cd[4..6].try_into().unwrap());
391        assert_eq!(vmb, header::VERSION_DEFLATE);
392    }
393
394    #[tokio::test]
395    async fn test_entry_drop_poisons_zip_writer() {
396        let mut buf = Vec::new();
397        let mut zip = ZipWriter::new(&mut buf);
398
399        drop(zip.append_file("lost.txt").await.unwrap());
400
401        let result = zip.append_file("another.txt").await;
402        assert!(result.is_err(), "expected Err, got Ok");
403        let err = result.err().unwrap();
404        assert!(
405            err.to_string().contains("archive corrupted"),
406            "expected 'archive corrupted', got: {err}"
407        );
408    }
409
410    #[tokio::test]
411    async fn test_entry_drop_poison_affects_finalize() {
412        let mut buf = Vec::new();
413        let mut zip = ZipWriter::new(&mut buf);
414
415        drop(zip.append_file("lost.txt").await.unwrap());
416
417        let err = zip.finalize().await.unwrap_err();
418        assert!(
419            err.to_string().contains("archive corrupted"),
420            "expected 'archive corrupted', got: {err}"
421        );
422    }
423}