Skip to main content

compress_tools/
async_support.rs

1// Copyright (C) 2019-2021 O.S. Systems Sofware LTDA
2//
3// SPDX-License-Identifier: MIT OR Apache-2.0
4
5//! Generic async support with which you can use you own thread pool by
6//! implementing the [`BlockingExecutor`] trait.
7
8use crate::{
9    ArchiveContents, ArchiveIteratorBuilder, ArchivePassword, DecodeCallback, Ownership, Result,
10    READER_BUFFER_SIZE,
11};
12use async_trait::async_trait;
13use futures_channel::mpsc::{channel, Receiver, Sender};
14use futures_core::{FusedStream, Stream};
15use futures_executor::block_on;
16use futures_io::{AsyncRead, AsyncSeek, AsyncWrite};
17use futures_util::{
18    io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt},
19    join,
20    sink::SinkExt,
21    stream::StreamExt,
22};
23use std::{
24    future::Future,
25    io::{ErrorKind, Read, Seek, SeekFrom, Write},
26    path::Path,
27    pin::Pin,
28    task::{Context, Poll},
29};
30
31#[async_trait]
32pub trait BlockingExecutor {
33    /// Execute the provided function on a thread where blocking is acceptable
34    /// (in some kind of thread pool).
35    async fn execute_blocking<T, F>(f: F) -> Result<T>
36    where
37        T: Send + 'static,
38        F: FnOnce() -> T + Send + 'static;
39}
40
41// ----------------------------------------------------------------------------
42// Stream-only reader wrapper (used by `uncompress_data`, which never seeks)
43// ----------------------------------------------------------------------------
44
45struct AsyncReadWrapper {
46    rx: Receiver<Vec<u8>>,
47}
48
49impl Read for AsyncReadWrapper {
50    fn read(&mut self, mut buf: &mut [u8]) -> std::io::Result<usize> {
51        if self.rx.is_terminated() {
52            return Ok(0);
53        }
54        assert_eq!(buf.len(), READER_BUFFER_SIZE);
55        Ok(match block_on(self.rx.next()) {
56            Some(data) => {
57                buf.write_all(&data)?;
58                data.len()
59            }
60            None => 0,
61        })
62    }
63}
64
65fn make_async_read_wrapper_and_worker<R>(
66    mut read: R,
67) -> (AsyncReadWrapper, impl Future<Output = Result<()>>)
68where
69    R: AsyncRead + Unpin,
70{
71    let (mut tx, rx) = channel(0);
72    (AsyncReadWrapper { rx }, async move {
73        loop {
74            let mut data = vec![0; READER_BUFFER_SIZE];
75            let read = read.read(&mut data).await?;
76            data.truncate(read);
77            if read == 0 || tx.send(data).await.is_err() {
78                break;
79            }
80        }
81        Ok(())
82    })
83}
84
85// ----------------------------------------------------------------------------
86// Seekable read bridge (used by list / extract / iterator paths)
87// ----------------------------------------------------------------------------
88//
89// libarchive's seekable formats (ZIP, 7z, …) issue `seek()` calls through the
90// synchronous callback it registers with us. When the caller supplies an
91// `AsyncRead + AsyncSeek` source, we cannot call `.await` from inside that
92// C callback, so we stand up a request/response channel pair: the sync side
93// sends a `BridgeReq` describing the desired operation and blocks on the
94// matching `BridgeRes`, while an async worker future awaits the operation on
95// the underlying source.
96
97enum BridgeReq {
98    Read(usize),
99    Seek(SeekFrom),
100}
101
102enum BridgeRes {
103    Read(std::io::Result<Vec<u8>>),
104    Seek(std::io::Result<u64>),
105}
106
107pub(crate) struct SeekableAsyncReadWrapper {
108    req_tx: Sender<BridgeReq>,
109    res_rx: Receiver<BridgeRes>,
110}
111
112impl Read for SeekableAsyncReadWrapper {
113    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
114        if block_on(self.req_tx.send(BridgeReq::Read(buf.len()))).is_err() {
115            return Ok(0);
116        }
117        match block_on(self.res_rx.next()) {
118            Some(BridgeRes::Read(Ok(data))) => {
119                let n = data.len().min(buf.len());
120                buf[..n].copy_from_slice(&data[..n]);
121                Ok(n)
122            }
123            Some(BridgeRes::Read(Err(e))) => Err(e),
124            Some(BridgeRes::Seek(_)) | None => Ok(0),
125        }
126    }
127}
128
129impl Seek for SeekableAsyncReadWrapper {
130    fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
131        if block_on(self.req_tx.send(BridgeReq::Seek(pos))).is_err() {
132            return Err(std::io::Error::new(
133                ErrorKind::BrokenPipe,
134                "async seek bridge closed",
135            ));
136        }
137        match block_on(self.res_rx.next()) {
138            Some(BridgeRes::Seek(r)) => r,
139            Some(BridgeRes::Read(_)) | None => Err(std::io::Error::new(
140                ErrorKind::BrokenPipe,
141                "async seek bridge closed",
142            )),
143        }
144    }
145}
146
147fn make_seekable_read_wrapper_and_worker<R>(
148    mut read: R,
149) -> (SeekableAsyncReadWrapper, impl Future<Output = Result<()>>)
150where
151    R: AsyncRead + AsyncSeek + Unpin,
152{
153    let (req_tx, mut req_rx) = channel::<BridgeReq>(0);
154    let (mut res_tx, res_rx) = channel::<BridgeRes>(0);
155    let worker = async move {
156        while let Some(req) = req_rx.next().await {
157            let res = match req {
158                BridgeReq::Read(n) => {
159                    let mut buf = vec![0u8; n];
160                    match read.read(&mut buf).await {
161                        Ok(size) => {
162                            buf.truncate(size);
163                            BridgeRes::Read(Ok(buf))
164                        }
165                        Err(e) => BridgeRes::Read(Err(e)),
166                    }
167                }
168                BridgeReq::Seek(pos) => BridgeRes::Seek(read.seek(pos).await),
169            };
170            if res_tx.send(res).await.is_err() {
171                break;
172            }
173        }
174        Ok(())
175    };
176    (SeekableAsyncReadWrapper { req_tx, res_rx }, worker)
177}
178
179// ----------------------------------------------------------------------------
180// Write bridge (unchanged)
181// ----------------------------------------------------------------------------
182
183pub(crate) struct AsyncWriteWrapper {
184    tx: Sender<Vec<u8>>,
185}
186
187impl Write for AsyncWriteWrapper {
188    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
189        match block_on(self.tx.send(buf.to_owned())) {
190            Ok(()) => Ok(buf.len()),
191            Err(err) => Err(std::io::Error::new(ErrorKind::Other, err)),
192        }
193    }
194
195    fn flush(&mut self) -> std::io::Result<()> {
196        block_on(self.tx.send(vec![])).map_err(|err| std::io::Error::new(ErrorKind::Other, err))
197    }
198}
199
200fn make_async_write_wrapper_and_worker<W>(
201    mut write: W,
202) -> (AsyncWriteWrapper, impl Future<Output = Result<()>>)
203where
204    W: AsyncWrite + Unpin,
205{
206    let (tx, mut rx) = channel(0);
207    (AsyncWriteWrapper { tx }, async move {
208        while let Some(v) = rx.next().await {
209            if v.is_empty() {
210                write.flush().await?;
211            } else {
212                write.write_all(&v).await?;
213            }
214        }
215        Ok(())
216    })
217}
218
219// ----------------------------------------------------------------------------
220// High-level wrappers
221// ----------------------------------------------------------------------------
222
223async fn wrap_async_read_and_write<B, R, W, F, T>(_: B, read: R, write: W, f: F) -> Result<T>
224where
225    B: BlockingExecutor,
226    R: AsyncRead + Unpin,
227    W: AsyncWrite + Unpin,
228    F: FnOnce(AsyncReadWrapper, AsyncWriteWrapper) -> T + Send + 'static,
229    T: Send + 'static,
230{
231    let (async_read_wrapper, async_read_wrapper_worker) = make_async_read_wrapper_and_worker(read);
232    let (async_write_wrapper, async_write_wrapper_worker) =
233        make_async_write_wrapper_and_worker(write);
234    let g = B::execute_blocking(move || f(async_read_wrapper, async_write_wrapper));
235    let join = join!(async_read_wrapper_worker, async_write_wrapper_worker, g);
236    join.0?;
237    join.1?;
238    join.2
239}
240
241async fn wrap_async_seek_read<B, R, F, T>(_: B, read: R, f: F) -> Result<T>
242where
243    B: BlockingExecutor,
244    R: AsyncRead + AsyncSeek + Unpin,
245    F: FnOnce(SeekableAsyncReadWrapper) -> T + Send + 'static,
246    T: Send + 'static,
247{
248    let (seekable_wrapper, seekable_worker) = make_seekable_read_wrapper_and_worker(read);
249    let g = B::execute_blocking(move || f(seekable_wrapper));
250    let join = join!(seekable_worker, g);
251    join.0?;
252    join.1
253}
254
255async fn wrap_async_seek_read_and_write<B, R, W, F, T>(_: B, read: R, write: W, f: F) -> Result<T>
256where
257    B: BlockingExecutor,
258    R: AsyncRead + AsyncSeek + Unpin,
259    W: AsyncWrite + Unpin,
260    F: FnOnce(SeekableAsyncReadWrapper, AsyncWriteWrapper) -> T + Send + 'static,
261    T: Send + 'static,
262{
263    let (seekable_wrapper, seekable_worker) = make_seekable_read_wrapper_and_worker(read);
264    let (async_write_wrapper, async_write_wrapper_worker) =
265        make_async_write_wrapper_and_worker(write);
266    let g = B::execute_blocking(move || f(seekable_wrapper, async_write_wrapper));
267    let join = join!(seekable_worker, async_write_wrapper_worker, g);
268    join.0?;
269    join.1?;
270    join.2
271}
272
273// ----------------------------------------------------------------------------
274// Public async entry points
275// ----------------------------------------------------------------------------
276
277/// Async version of
278/// [`list_archive_files_with_encoding`](crate::
279/// list_archive_files_with_encoding).
280pub async fn list_archive_files_with_encoding<B, R>(
281    blocking_executor: B,
282    source: R,
283    decode: DecodeCallback,
284) -> Result<Vec<String>>
285where
286    B: BlockingExecutor,
287    R: AsyncRead + AsyncSeek + Unpin,
288{
289    wrap_async_seek_read(blocking_executor, source, move |source| {
290        crate::list_archive_files_with_encoding(source, decode)
291    })
292    .await?
293}
294
295/// Async version of [`list_archive_files`](crate::list_archive_files).
296pub async fn list_archive_files<B, R>(blocking_executor: B, source: R) -> Result<Vec<String>>
297where
298    B: BlockingExecutor,
299    R: AsyncRead + AsyncSeek + Unpin,
300{
301    wrap_async_seek_read(blocking_executor, source, crate::list_archive_files).await?
302}
303
304/// Async version of
305/// [`list_archive_entries_with_encoding`](crate::
306/// list_archive_entries_with_encoding).
307pub async fn list_archive_entries_with_encoding<B, R>(
308    blocking_executor: B,
309    source: R,
310    decode: DecodeCallback,
311) -> Result<Vec<crate::ArchiveEntryInfo>>
312where
313    B: BlockingExecutor,
314    R: AsyncRead + AsyncSeek + Unpin,
315{
316    wrap_async_seek_read(blocking_executor, source, move |source| {
317        crate::list_archive_entries_with_encoding(source, decode)
318    })
319    .await?
320}
321
322/// Async version of [`list_archive_entries`](crate::list_archive_entries).
323pub async fn list_archive_entries<B, R>(
324    blocking_executor: B,
325    source: R,
326) -> Result<Vec<crate::ArchiveEntryInfo>>
327where
328    B: BlockingExecutor,
329    R: AsyncRead + AsyncSeek + Unpin,
330{
331    wrap_async_seek_read(blocking_executor, source, crate::list_archive_entries).await?
332}
333
334/// Async version of [`uncompress_data`](crate::uncompress_data).
335pub async fn uncompress_data<B, R, W>(blocking_executor: B, source: R, target: W) -> Result<usize>
336where
337    B: BlockingExecutor,
338    R: AsyncRead + Unpin,
339    W: AsyncWrite + Unpin,
340{
341    wrap_async_read_and_write(blocking_executor, source, target, |source, target| {
342        crate::uncompress_data(source, target)
343    })
344    .await?
345}
346
347/// Async version of
348/// [`uncompress_archive_with_encoding`](crate::
349/// uncompress_archive_with_encoding).
350pub async fn uncompress_archive_with_encoding<B, R>(
351    blocking_executor: B,
352    source: R,
353    dest: &Path,
354    ownership: Ownership,
355    decode: DecodeCallback,
356) -> Result<()>
357where
358    B: BlockingExecutor,
359    R: AsyncRead + AsyncSeek + Unpin,
360{
361    let dest = dest.to_owned();
362    wrap_async_seek_read(blocking_executor, source, move |source| {
363        crate::uncompress_archive_with_encoding(source, &dest, ownership, decode)
364    })
365    .await?
366}
367
368/// Async version of [`uncompress_archive`](crate::uncompress_archive).
369pub async fn uncompress_archive<B, R>(
370    blocking_executor: B,
371    source: R,
372    dest: &Path,
373    ownership: Ownership,
374) -> Result<()>
375where
376    B: BlockingExecutor,
377    R: AsyncRead + AsyncSeek + Unpin,
378{
379    let dest = dest.to_owned();
380    wrap_async_seek_read(blocking_executor, source, move |source| {
381        crate::uncompress_archive(source, &dest, ownership)
382    })
383    .await?
384}
385
386/// Async version of
387/// [`uncompress_archive_file_with_encoding`](crate::
388/// uncompress_archive_file_with_encoding).
389pub async fn uncompress_archive_file_with_encoding<B, R, W>(
390    blocking_executor: B,
391    source: R,
392    target: W,
393    path: &str,
394    decode: DecodeCallback,
395) -> Result<usize>
396where
397    B: BlockingExecutor,
398    R: AsyncRead + AsyncSeek + Unpin,
399    W: AsyncWrite + Unpin,
400{
401    let path = path.to_owned();
402    wrap_async_seek_read_and_write(blocking_executor, source, target, move |source, target| {
403        crate::uncompress_archive_file_with_encoding(source, target, &path, decode)
404    })
405    .await?
406}
407
408/// Async version of
409/// [`uncompress_archive_file`](crate::uncompress_archive_file).
410pub async fn uncompress_archive_file<B, R, W>(
411    blocking_executor: B,
412    source: R,
413    target: W,
414    path: &str,
415) -> Result<usize>
416where
417    B: BlockingExecutor,
418    R: AsyncRead + AsyncSeek + Unpin,
419    W: AsyncWrite + Unpin,
420{
421    let path = path.to_owned();
422    wrap_async_seek_read_and_write(blocking_executor, source, target, move |source, target| {
423        crate::uncompress_archive_file(source, target, &path)
424    })
425    .await?
426}
427
428// ----------------------------------------------------------------------------
429// Async archive iterator
430// ----------------------------------------------------------------------------
431
432/// A filter callback for the async archive iterator.
433///
434/// Differs from the synchronous [`crate::EntryFilterCallbackFn`] only in that
435/// it must be `Send + Sync` so that the filter can cross into the blocking
436/// worker driving the sync iterator.
437pub type AsyncEntryFilterCallbackFn = dyn Fn(&str, &crate::stat) -> bool + Send + Sync;
438
439/// Asynchronous streaming iterator over the contents of an archive.
440///
441/// Yields [`ArchiveContents`] items in the same order and shape as the
442/// synchronous [`ArchiveIterator`]. The sync iterator and its libarchive
443/// state live on a dedicated blocking worker; entries are forwarded through
444/// a bounded channel and surfaced through this [`Stream`] impl.
445///
446/// Polling this stream also drives the bridge worker future that services
447/// the sync side's `read`/`seek` requests and the blocking pump's
448/// `JoinHandle` — so progress only happens while the consumer polls.
449/// Dropping the iterator closes the entry channel; the pump notices on its
450/// next send and exits.
451pub struct AsyncArchiveIterator {
452    rx: Receiver<ArchiveContents>,
453    worker: Option<Pin<Box<dyn Future<Output = Result<()>> + Send>>>,
454    pump: Option<Pin<Box<dyn Future<Output = Result<()>> + Send>>>,
455}
456
457impl Stream for AsyncArchiveIterator {
458    type Item = ArchiveContents;
459
460    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
461        let this = &mut *self;
462        if let Some(worker) = this.worker.as_mut() {
463            if let Poll::Ready(res) = worker.as_mut().poll(cx) {
464                this.worker = None;
465                if let Err(e) = res {
466                    return Poll::Ready(Some(ArchiveContents::Err(e)));
467                }
468            }
469        }
470        if let Some(pump) = this.pump.as_mut() {
471            if let Poll::Ready(res) = pump.as_mut().poll(cx) {
472                this.pump = None;
473                if let Err(e) = res {
474                    return Poll::Ready(Some(ArchiveContents::Err(e)));
475                }
476            }
477        }
478        Pin::new(&mut this.rx).poll_next(cx)
479    }
480}
481
482pub(crate) fn new_async_archive_iterator<B, R>(
483    source: R,
484    decode: DecodeCallback,
485    filter: Option<Box<AsyncEntryFilterCallbackFn>>,
486    password: Option<ArchivePassword>,
487) -> AsyncArchiveIterator
488where
489    B: BlockingExecutor + 'static,
490    R: AsyncRead + AsyncSeek + Unpin + Send + 'static,
491{
492    let (mut entry_tx, entry_rx) = channel::<ArchiveContents>(1);
493    let (seekable_wrapper, seekable_worker) = make_seekable_read_wrapper_and_worker(source);
494
495    let pump_fut = async move {
496        let r: Result<()> = B::execute_blocking(move || -> Result<()> {
497            let mut builder = ArchiveIteratorBuilder::new(seekable_wrapper).decoder(decode);
498            if let Some(filter) = filter {
499                builder = builder.filter(move |name, stat| filter(name, stat));
500            }
501            if let Some(password) = password {
502                builder = builder.with_password(password);
503            }
504            let mut iter = builder.build()?;
505            for content in iter.by_ref() {
506                if block_on(entry_tx.send(content)).is_err() {
507                    // Consumer dropped the receiver; stop forwarding and
508                    // close the iterator so libarchive state is released
509                    // promptly.
510                    break;
511                }
512            }
513            iter.close()
514        })
515        .await?;
516        r
517    };
518
519    AsyncArchiveIterator {
520        rx: entry_rx,
521        worker: Some(Box::pin(seekable_worker)),
522        pump: Some(Box::pin(pump_fut)),
523    }
524}