async_deflate_zip/writer/
entry_writer.rs1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use crate::error::ZipError;
6use crate::header;
7
8use crate::deflate_encoder::DeflateEncoder;
9use tokio::io::{AsyncWrite, AsyncWriteExt};
10
11use super::helpers::CountWriter;
12use super::stored_entry::StoredEntry;
13use super::zip_writer::ZipWriter;
14
15pin_project_lite::pin_project! {
16 pub struct EntryWriter<'a, W>
27 where
28 W: AsyncWrite,
29 W: Unpin,
30 {
31 pub(crate) zip: &'a mut ZipWriter<W>,
32 #[pin]
33 pub(crate) deflate_encoder: Option<DeflateEncoder<CountWriter<W>>>,
34 #[pin]
35 pub(crate) passthrough: Option<CountWriter<W>>,
36 pub(crate) is_stored: bool,
37 pub(crate) crc_hasher: crc32fast::Hasher,
38 pub(crate) uncompressed_size: u64,
39 pub(crate) local_header_offset: u64,
40 pub(crate) name: String,
41 pub(crate) mtime: Option<std::time::SystemTime>,
42 pub(crate) unix_permissions: Option<u32>,
43 }
44
45 impl<W> PinnedDrop for EntryWriter<'_, W>
46 where
47 W: AsyncWrite,
48 W: Unpin,
49 {
50 fn drop(this: Pin<&mut Self>) {
51 let this = this.project();
52 if this.deflate_encoder.is_some() || this.passthrough.is_some() {
53 this.zip.poisoned = true;
55 }
56 }
57 }
58}
59
60impl<W: AsyncWrite + Unpin> EntryWriter<'_, W> {
61 pub fn set_mtime(&mut self, mtime: std::time::SystemTime) -> &mut Self {
66 self.mtime = Some(mtime);
67 self
68 }
69
70 pub fn set_permissions(&mut self, mode: u32) -> &mut Self {
76 self.unix_permissions = Some(mode & 0o7777);
77 self
78 }
79
80 pub async fn close(mut self) -> Result<(), ZipError> {
94 let (compressed_size, mut inner) = if self.is_stored {
95 let cw = self
96 .passthrough
97 .take()
98 .ok_or_else(|| ZipError::InvalidState("entry already closed".to_string()))?;
99 (cw.count, cw.inner)
100 } else {
101 let mut encoder = self
102 .deflate_encoder
103 .take()
104 .ok_or_else(|| ZipError::InvalidState("entry already closed".to_string()))?;
105 encoder.shutdown().await?;
106
107 let count_writer: CountWriter<W> = encoder.into_inner();
109 let compressed_size = count_writer.count;
110 (compressed_size, count_writer.inner)
111 };
112
113 let crc32 = self.crc_hasher.clone().finalize();
114
115 let dd = header::DataDescriptor {
116 crc32,
117 compressed_size,
118 uncompressed_size: self.uncompressed_size,
119 zip64: compressed_size > header::U32_MAX
122 || self.uncompressed_size > header::U32_MAX
123 || self.local_header_offset > header::U32_MAX,
124 };
125 let dd_bytes = dd.serialize();
126 inner.write_all(&dd_bytes).await.map_err(|e| {
127 self.zip.poisoned = true;
128 ZipError::Io(e)
129 })?;
130
131 self.zip.pos += compressed_size + dd_bytes.len() as u64;
133
134 let (mtime_msdos, unix_mtime) = header::mtime_to_ms_dos_and_unix(self.mtime);
135
136 self.zip.entries.push(StoredEntry {
137 name: self.name.clone(),
138 crc32,
139 compressed_size,
140 uncompressed_size: self.uncompressed_size,
141 local_header_offset: self.local_header_offset,
142 is_directory: false,
143 is_symlink: false,
144 is_stored: self.is_stored,
145 mtime: mtime_msdos,
146 unix_mtime,
147 unix_permissions: self.unix_permissions,
148 });
149
150 self.zip.inner = Some(inner);
152 Ok(())
153 }
154}
155
156impl<W: AsyncWrite + Unpin> AsyncWrite for EntryWriter<'_, W> {
157 fn poll_write(
158 self: Pin<&mut Self>,
159 cx: &mut Context<'_>,
160 buf: &[u8],
161 ) -> Poll<io::Result<usize>> {
162 let this = self.project();
163 let result = if *this.is_stored {
164 match this.passthrough.as_pin_mut() {
165 Some(w) => w.poll_write(cx, buf),
166 None => {
167 this.zip.poisoned = true;
168 return Poll::Ready(Err(ZipError::Poisoned(
169 "write after entry closed".to_string(),
170 )
171 .into()));
172 }
173 }
174 } else {
175 match this.deflate_encoder.as_pin_mut() {
176 Some(e) => e.poll_write(cx, buf),
177 None => {
178 this.zip.poisoned = true;
179 return Poll::Ready(Err(ZipError::Poisoned(
180 "write after entry closed".to_string(),
181 )
182 .into()));
183 }
184 }
185 };
186 match result {
187 Poll::Ready(Ok(n)) => {
188 this.crc_hasher.update(&buf[..n]);
189 *this.uncompressed_size += n as u64;
190 Poll::Ready(Ok(n))
191 }
192 other => other,
193 }
194 }
195
196 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
197 let this = self.project();
198 if *this.is_stored {
199 match this.passthrough.as_pin_mut() {
200 Some(w) => w.poll_flush(cx),
201 None => {
202 this.zip.poisoned = true;
203 Poll::Ready(Err(ZipError::Poisoned(
204 "flush after entry closed".to_string(),
205 )
206 .into()))
207 }
208 }
209 } else {
210 match this.deflate_encoder.as_pin_mut() {
211 Some(e) => e.poll_flush(cx),
212 None => {
213 this.zip.poisoned = true;
214 Poll::Ready(Err(ZipError::Poisoned(
215 "flush after entry closed".to_string(),
216 )
217 .into()))
218 }
219 }
220 }
221 }
222
223 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
224 let this = self.project();
225 if *this.is_stored {
226 match this.passthrough.as_pin_mut() {
227 Some(w) => w.poll_shutdown(cx),
228 None => {
229 this.zip.poisoned = true;
230 Poll::Ready(Err(ZipError::Poisoned(
231 "shutdown after entry closed".to_string(),
232 )
233 .into()))
234 }
235 }
236 } else {
237 match this.deflate_encoder.as_pin_mut() {
238 Some(e) => e.poll_shutdown(cx),
239 None => {
240 this.zip.poisoned = true;
241 Poll::Ready(Err(ZipError::Poisoned(
242 "shutdown after entry closed".to_string(),
243 )
244 .into()))
245 }
246 }
247 }
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::super::*;
254 use crate::header;
255 use tokio::io::AsyncWriteExt;
256
257 #[tokio::test]
258 async fn test_entry_mtime_epoch() {
259 let mut buf = Vec::new();
260 let mut zip = ZipWriter::new(&mut buf);
261 let mut entry = zip.append_file("epoch.txt").await.unwrap();
262 entry.set_mtime(std::time::SystemTime::UNIX_EPOCH);
263 entry.write_all(b"test").await.unwrap();
264 entry.close().await.unwrap();
265 zip.finalize().await.unwrap();
266
267 let pos = buf.windows(4).position(|w| w == b"PK\x01\x02").unwrap();
268 let cd = &buf[pos..];
269
270 let time = u16::from_le_bytes(cd[12..14].try_into().unwrap());
271 let date = u16::from_le_bytes(cd[14..16].try_into().unwrap());
272 let local_offset = time::UtcOffset::current_local_offset().unwrap_or(time::UtcOffset::UTC);
273 let local_epoch =
274 time::OffsetDateTime::from(std::time::SystemTime::UNIX_EPOCH).to_offset(local_offset);
275 let expected_time = (local_epoch.hour() as u16) << 11
276 | (local_epoch.minute() as u16) << 5
277 | (local_epoch.second() as u16 / 2);
278 assert_eq!(time, expected_time, "expected local time for epoch");
279 assert_eq!(date, (1 << 5) | 1, "expected MS-DOS date for 1980-01-01");
280 }
281
282 #[tokio::test]
283 async fn test_entry_permissions() {
284 let mut buf = Vec::new();
285 let mut zip = ZipWriter::new(&mut buf);
286 let mut entry = zip.append_file("perm_test.txt").await.unwrap();
287 entry.set_permissions(0o644);
288 entry.write_all(b"test").await.unwrap();
289 entry.close().await.unwrap();
290 zip.finalize().await.unwrap();
291
292 let pos = buf.windows(4).position(|w| w == b"PK\x01\x02").unwrap();
293 let cd = &buf[pos..];
294 let efa = u32::from_le_bytes(cd[38..42].try_into().unwrap());
295 assert_eq!(efa, ((0o644 | 0o100000) as u32) << 16);
296 let vmb = u16::from_le_bytes(cd[4..6].try_into().unwrap());
297 assert!(vmb >> 8 == 3, "expected Unix host OS");
298 }
299
300 #[tokio::test]
301 async fn test_entry_setuid_permissions() {
302 let mut buf = Vec::new();
303 let mut zip = ZipWriter::new(&mut buf);
304 let mut entry = zip.append_file("setuid_test.txt").await.unwrap();
305 entry.set_permissions(0o4755);
306 entry.write_all(b"test").await.unwrap();
307 entry.close().await.unwrap();
308 zip.finalize().await.unwrap();
309
310 let pos = buf.windows(4).position(|w| w == b"PK\x01\x02").unwrap();
311 let cd = &buf[pos..];
312 let efa = u32::from_le_bytes(cd[38..42].try_into().unwrap());
313 assert_eq!(efa, ((0o4755 | 0o100000) as u32) << 16);
314 }
315
316 #[tokio::test]
317 async fn test_entry_mtime_and_permissions() {
318 let mut buf = Vec::new();
319 let mut zip = ZipWriter::new(&mut buf);
320 let mut entry = zip.append_file("both.txt").await.unwrap();
321 entry.set_mtime(std::time::SystemTime::UNIX_EPOCH);
322 entry.set_permissions(0o755);
323 entry.write_all(b"test").await.unwrap();
324 entry.close().await.unwrap();
325 zip.finalize().await.unwrap();
326
327 let pos = buf.windows(4).position(|w| w == b"PK\x01\x02").unwrap();
328 let cd = &buf[pos..];
329 let time = u16::from_le_bytes(cd[12..14].try_into().unwrap());
330 let local_offset = time::UtcOffset::current_local_offset().unwrap_or(time::UtcOffset::UTC);
331 let local_epoch =
332 time::OffsetDateTime::from(std::time::SystemTime::UNIX_EPOCH).to_offset(local_offset);
333 let expected_time = (local_epoch.hour() as u16) << 11
334 | (local_epoch.minute() as u16) << 5
335 | (local_epoch.second() as u16 / 2);
336 assert_eq!(time, expected_time);
337 let efa = u32::from_le_bytes(cd[38..42].try_into().unwrap());
338 assert_eq!(efa, ((0o755 | 0o100000) as u32) << 16);
339 let vmb = u16::from_le_bytes(cd[4..6].try_into().unwrap());
340 assert!(
341 vmb >> 8 == 3,
342 "expected version_made_by upper byte = 3 (Unix), got {}",
343 vmb >> 8
344 );
345 }
346
347 #[tokio::test]
348 async fn test_entry_mtime_appears_in_cd_extra() {
349 let mut buf = Vec::new();
350 let mut zip = ZipWriter::new(&mut buf);
351 let mut entry = zip.append_file("mtime_test.txt").await.unwrap();
352 entry.set_mtime(std::time::SystemTime::UNIX_EPOCH);
353 entry.write_all(b"hello").await.unwrap();
354 entry.close().await.unwrap();
355 zip.finalize().await.unwrap();
356
357 let pos = buf.windows(4).position(|w| w == b"PK\x01\x02").unwrap();
358 let cd = &buf[pos..];
359 let name_len = u16::from_le_bytes(cd[28..30].try_into().unwrap()) as usize;
360 let extra_len = u16::from_le_bytes(cd[30..32].try_into().unwrap()) as usize;
361
362 let extra_start = 46 + name_len;
363 let extra = &cd[extra_start..extra_start + extra_len];
364 let has_ts_extra = extra.windows(2).any(|w| w == b"UT");
365 assert!(
366 has_ts_extra,
367 "CD entry extra should contain extended timestamp (0x5455/UT) when mtime is set"
368 );
369 assert!(
370 extra_len >= 4,
371 "extra_len should be >= 4 when mtime is set, got {extra_len}"
372 );
373 let vmb = u16::from_le_bytes(cd[4..6].try_into().unwrap());
374 assert_eq!(vmb >> 8, 3, "expected Unix host OS when mtime is set");
375 }
376
377 #[tokio::test]
378 async fn test_entry_default_no_metadata() {
379 let mut buf = Vec::new();
380 let mut zip = ZipWriter::new(&mut buf);
381 let mut entry = zip.append_file("default.txt").await.unwrap();
382 entry.write_all(b"test").await.unwrap();
383 entry.close().await.unwrap();
384 zip.finalize().await.unwrap();
385
386 let pos = buf.windows(4).position(|w| w == b"PK\x01\x02").unwrap();
387 let cd = &buf[pos..];
388 let efa = u32::from_le_bytes(cd[38..42].try_into().unwrap());
389 assert_eq!(efa, 0);
390 let vmb = u16::from_le_bytes(cd[4..6].try_into().unwrap());
391 assert_eq!(vmb, header::VERSION_DEFLATE);
392 }
393
394 #[tokio::test]
395 async fn test_entry_drop_poisons_zip_writer() {
396 let mut buf = Vec::new();
397 let mut zip = ZipWriter::new(&mut buf);
398
399 drop(zip.append_file("lost.txt").await.unwrap());
400
401 let result = zip.append_file("another.txt").await;
402 assert!(result.is_err(), "expected Err, got Ok");
403 let err = result.err().unwrap();
404 assert!(
405 err.to_string().contains("archive corrupted"),
406 "expected 'archive corrupted', got: {err}"
407 );
408 }
409
410 #[tokio::test]
411 async fn test_entry_drop_poison_affects_finalize() {
412 let mut buf = Vec::new();
413 let mut zip = ZipWriter::new(&mut buf);
414
415 drop(zip.append_file("lost.txt").await.unwrap());
416
417 let err = zip.finalize().await.unwrap_err();
418 assert!(
419 err.to_string().contains("archive corrupted"),
420 "expected 'archive corrupted', got: {err}"
421 );
422 }
423}