Skip to main content

compress_tools/
iterator.rs

1use std::{
2    ffi::{CStr, CString},
3    io::{Read, Seek, SeekFrom, Write},
4    slice,
5};
6
7use libc::{c_char, c_int, c_void};
8
9#[cfg(target_os = "windows")]
10use crate::stat;
11#[cfg(not(target_os = "windows"))]
12use libc::stat;
13
14use crate::{
15    error::archive_result, ffi, ffi::UTF8LocaleGuard, libarchive_entry_is_dir, DecodeCallback,
16    Error, Result, READER_BUFFER_SIZE,
17};
18
19struct HeapReadSeekerPipe<R: Read + Seek> {
20    reader: R,
21    buffer: [u8; READER_BUFFER_SIZE],
22}
23
24/// The contents of an archive, yielded in order from the beginning to the end
25/// of the archive.
26///
27/// Each entry, file or directory, will have a
28/// [`ArchiveContents::StartOfEntry`], zero or more
29/// [`ArchiveContents::DataChunk`], and then a corresponding
30/// [`ArchiveContents::EndOfEntry`] to mark that the entry has been read to
31/// completion.
32pub enum ArchiveContents {
33    /// Marks the start of an entry, either a file or a directory.
34    StartOfEntry(String, stat),
35    /// A chunk of uncompressed data from the entry. Entries may have zero or
36    /// more chunks.
37    DataChunk(Vec<u8>),
38    /// Marks the end of the entry that was started by the previous
39    /// StartOfEntry.
40    EndOfEntry,
41    Err(Error),
42}
43
44/// Filter for an archive iterator to skip decompression of unwanted
45/// entries.
46///
47/// Gets called on an encounter of a new archive entry with the filename and
48/// file status information of that entry.
49/// The entry is processed on a return value of `true` and ignored on `false`.
50pub type EntryFilterCallbackFn = dyn Fn(&str, &stat) -> bool;
51
52/// Passphrase used to decrypt encrypted archive entries.
53///
54/// Construct with [`ArchivePassword::new`] — it fails if the supplied string
55/// contains an interior NUL byte, which cannot be passed through to
56/// libarchive's C API.
57pub struct ArchivePassword(CString);
58
59impl ArchivePassword {
60    pub fn new<S: AsRef<str>>(password: S) -> Result<Self> {
61        CString::new(password.as_ref())
62            .map(Self)
63            .map_err(|e| Error::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e)))
64    }
65
66    pub(crate) fn as_ptr(&self) -> *const c_char {
67        self.0.as_ptr()
68    }
69}
70
71/// An iterator over the contents of an archive.
72#[allow(clippy::module_name_repetitions)]
73pub struct ArchiveIterator<R: Read + Seek> {
74    archive_entry: *mut ffi::archive_entry,
75    archive_reader: *mut ffi::archive,
76
77    decode: DecodeCallback,
78    in_file: bool,
79    current_is_dir: bool,
80    closed: bool,
81    error: bool,
82    mtree_format: bool,
83    filter: Option<Box<EntryFilterCallbackFn>>,
84
85    _pipe: Box<HeapReadSeekerPipe<R>>,
86    _utf8_guard: UTF8LocaleGuard,
87}
88
89impl<R: Read + Seek> Iterator for ArchiveIterator<R> {
90    type Item = ArchiveContents;
91
92    fn next(&mut self) -> Option<Self::Item> {
93        debug_assert!(!self.closed);
94
95        if self.error {
96            return None;
97        }
98
99        loop {
100            let next = if self.in_file {
101                unsafe { self.next_data_chunk() }
102            } else {
103                unsafe { self.unsafe_next_header() }
104            };
105
106            match &next {
107                ArchiveContents::StartOfEntry(name, stat) => {
108                    debug_assert!(!self.in_file);
109
110                    if let Some(filter) = &self.filter {
111                        if !filter(name, stat) {
112                            continue;
113                        }
114                    }
115
116                    self.in_file = true;
117                    break Some(next);
118                }
119                ArchiveContents::DataChunk(_) => {
120                    debug_assert!(self.in_file);
121                    break Some(next);
122                }
123                ArchiveContents::EndOfEntry if self.in_file => {
124                    self.in_file = false;
125                    break Some(next);
126                }
127                ArchiveContents::EndOfEntry => break None,
128                ArchiveContents::Err(_) => {
129                    self.error = true;
130                    break Some(next);
131                }
132            }
133        }
134    }
135}
136
137impl<R: Read + Seek> ArchiveIterator<R> {
138    pub fn next_header(&mut self) -> Option<ArchiveContents> {
139        debug_assert!(!self.closed);
140
141        if self.error {
142            return None;
143        }
144
145        let next = unsafe { self.unsafe_next_header() };
146
147        match &next {
148            ArchiveContents::StartOfEntry(name, stat) => {
149                if let Some(filter) = &self.filter {
150                    if !filter(name, stat) {
151                        return None;
152                    }
153                }
154
155                self.in_file = true;
156                Some(next)
157            }
158            ArchiveContents::Err(_) => {
159                self.error = true;
160                Some(next)
161            }
162            _ => None,
163        }
164    }
165}
166
167impl<R: Read + Seek> Drop for ArchiveIterator<R> {
168    fn drop(&mut self) {
169        drop(self.free());
170    }
171}
172
173impl<R: Read + Seek> ArchiveIterator<R> {
174    fn new(
175        mut source: R,
176        decode: DecodeCallback,
177        filter: Option<Box<EntryFilterCallbackFn>>,
178        password: Option<ArchivePassword>,
179        raw_format: bool,
180        mtree_format: bool,
181    ) -> Result<ArchiveIterator<R>>
182    where
183        R: Read + Seek,
184    {
185        let utf8_guard = ffi::UTF8LocaleGuard::new();
186        // libarchive only sniffs the format from offset 0.
187        source.seek(SeekFrom::Start(0))?;
188        crate::zip_preflight::reject_unsupported_zip_methods(&mut source)?;
189        let reader = source;
190        let buffer = [0; READER_BUFFER_SIZE];
191        let mut pipe = Box::new(HeapReadSeekerPipe { reader, buffer });
192
193        unsafe {
194            let archive_entry: *mut ffi::archive_entry = std::ptr::null_mut();
195            let archive_reader = ffi::archive_read_new();
196
197            let res = (|| {
198                if let Some(password) = password {
199                    archive_result(
200                        ffi::archive_read_add_passphrase(archive_reader, password.as_ptr()),
201                        archive_reader,
202                    )?;
203                }
204
205                archive_result(
206                    ffi::archive_read_support_filter_all(archive_reader),
207                    archive_reader,
208                )?;
209
210                if raw_format {
211                    archive_result(
212                        ffi::archive_read_support_format_raw(archive_reader),
213                        archive_reader,
214                    )?;
215                }
216
217                archive_result(
218                    ffi::archive_read_set_seek_callback(
219                        archive_reader,
220                        Some(libarchive_heap_seek_callback::<R>),
221                    ),
222                    archive_reader,
223                )?;
224
225                if archive_reader.is_null() {
226                    return Err(Error::NullArchive);
227                }
228
229                archive_result(
230                    ffi::archive_read_support_format_all(archive_reader),
231                    archive_reader,
232                )?;
233
234                archive_result(
235                    ffi::archive_read_open(
236                        archive_reader,
237                        std::ptr::addr_of_mut!(*pipe) as *mut c_void,
238                        None,
239                        Some(libarchive_heap_seekableread_callback::<R>),
240                        None,
241                    ),
242                    archive_reader,
243                )?;
244
245                Ok(())
246            })();
247
248            let iter = ArchiveIterator {
249                archive_entry,
250                archive_reader,
251
252                decode,
253                in_file: false,
254                current_is_dir: false,
255                closed: false,
256                error: false,
257                mtree_format,
258                filter,
259
260                _pipe: pipe,
261                _utf8_guard: utf8_guard,
262            };
263
264            res?;
265            Ok(iter)
266        }
267    }
268
269    /// Iterate over the contents of an archive, streaming the contents of each
270    /// entry in small chunks.
271    ///
272    /// The [`ArchiveContents::StartOfEntry`] variant carries the entry's
273    /// `stat` struct, so `stat.st_size` gives the uncompressed size reported
274    /// by the archive header without having to consume the data chunks.
275    ///
276    /// ```no_run
277    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
278    /// use compress_tools::*;
279    /// use std::fs::File;
280    ///
281    /// let file = File::open("tree.tar")?;
282    ///
283    /// let mut name = String::default();
284    /// let mut size = 0;
285    /// let decode_utf8 = |bytes: &[u8]| Ok(std::str::from_utf8(bytes)?.to_owned());
286    /// let mut iter = ArchiveIterator::from_read_with_encoding(file, decode_utf8)?;
287    ///
288    /// for content in &mut iter {
289    ///     match content {
290    ///         ArchiveContents::StartOfEntry(s, stat) => {
291    ///             name = s;
292    ///             println!("header reports {} bytes for {}", stat.st_size, name);
293    ///         }
294    ///         ArchiveContents::DataChunk(v) => size += v.len(),
295    ///         ArchiveContents::EndOfEntry => {
296    ///             println!("Entry {} was {} bytes", name, size);
297    ///             size = 0;
298    ///         }
299    ///         ArchiveContents::Err(e) => {
300    ///             Err(e)?;
301    ///         }
302    ///     }
303    /// }
304    ///
305    /// iter.close()?;
306    /// # Ok(())
307    /// # }
308    /// ```
309    pub fn from_read_with_encoding(source: R, decode: DecodeCallback) -> Result<ArchiveIterator<R>>
310    where
311        R: Read + Seek,
312    {
313        Self::new(source, decode, None, None, false, true)
314    }
315
316    /// Iterate over the contents of an archive, streaming the contents of each
317    /// entry in small chunks.
318    ///
319    /// ```no_run
320    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
321    /// use compress_tools::*;
322    /// use std::fs::File;
323    ///
324    /// let file = File::open("tree.tar")?;
325    ///
326    /// let mut name = String::default();
327    /// let mut size = 0;
328    /// let mut iter = ArchiveIterator::from_read(file)?;
329    ///
330    /// for content in &mut iter {
331    ///     match content {
332    ///         ArchiveContents::StartOfEntry(s, _) => name = s,
333    ///         ArchiveContents::DataChunk(v) => size += v.len(),
334    ///         ArchiveContents::EndOfEntry => {
335    ///             println!("Entry {} was {} bytes", name, size);
336    ///             size = 0;
337    ///         }
338    ///         ArchiveContents::Err(e) => {
339    ///             Err(e)?;
340    ///         }
341    ///     }
342    /// }
343    ///
344    /// iter.close()?;
345    /// # Ok(())
346    /// # }
347    /// ```
348    pub fn from_read(source: R) -> Result<ArchiveIterator<R>>
349    where
350        R: Read + Seek,
351    {
352        Self::new(source, crate::decode_utf8, None, None, false, true)
353    }
354
355    /// Close the iterator, freeing up the associated resources.
356    ///
357    /// Resources will be freed on drop if this is not called, but any errors
358    /// during freeing on drop will be lost.
359    pub fn close(mut self) -> Result<()> {
360        self.free()
361    }
362
363    fn free(&mut self) -> Result<()> {
364        if self.closed {
365            return Ok(());
366        }
367
368        self.closed = true;
369        unsafe {
370            archive_result(
371                ffi::archive_read_close(self.archive_reader),
372                self.archive_reader,
373            )?;
374            archive_result(
375                ffi::archive_read_free(self.archive_reader),
376                self.archive_reader,
377            )?;
378        }
379        Ok(())
380    }
381
382    unsafe fn unsafe_next_header(&mut self) -> ArchiveContents {
383        match ffi::archive_read_next_header(self.archive_reader, &mut self.archive_entry) {
384            ffi::ARCHIVE_EOF => ArchiveContents::EndOfEntry,
385            ffi::ARCHIVE_OK | ffi::ARCHIVE_WARN => {
386                if !self.mtree_format {
387                    if let Err(e) = reject_mtree_format(self.archive_reader) {
388                        return ArchiveContents::Err(e);
389                    }
390                }
391                let _utf8_guard = ffi::WindowsUTF8LocaleGuard::new();
392                let cstr = CStr::from_ptr(ffi::archive_entry_pathname(self.archive_entry));
393                let file_name = match (self.decode)(cstr.to_bytes()) {
394                    Ok(f) => f,
395                    Err(e) => return ArchiveContents::Err(e),
396                };
397                let stat = *ffi::archive_entry_stat(self.archive_entry);
398                self.current_is_dir = libarchive_entry_is_dir(self.archive_entry);
399                ArchiveContents::StartOfEntry(file_name, stat)
400            }
401            _ => ArchiveContents::Err(Error::from(self.archive_reader)),
402        }
403    }
404
405    unsafe fn next_data_chunk(&mut self) -> ArchiveContents {
406        if self.current_is_dir {
407            return ArchiveContents::EndOfEntry;
408        }
409
410        let mut buffer = std::ptr::null();
411        let mut offset = 0;
412        let mut size = 0;
413        let mut target = Vec::with_capacity(READER_BUFFER_SIZE);
414
415        match ffi::archive_read_data_block(self.archive_reader, &mut buffer, &mut size, &mut offset)
416        {
417            ffi::ARCHIVE_EOF => ArchiveContents::EndOfEntry,
418            ffi::ARCHIVE_OK | ffi::ARCHIVE_WARN => {
419                if size > 0 {
420                    // fixes: (as buffer is null then) unsafe precondition(s) violated:
421                    // slice::from_raw_parts requires the pointer to be aligned and non-null, and
422                    // the total size of the slice not to exceed `isize::MAX`
423                    let content = slice::from_raw_parts(buffer as *const u8, size);
424                    let write = target.write_all(content);
425                    if let Err(e) = write {
426                        ArchiveContents::Err(e.into())
427                    } else {
428                        ArchiveContents::DataChunk(target)
429                    }
430                } else {
431                    ArchiveContents::DataChunk(target)
432                }
433            }
434            _ => ArchiveContents::Err(Error::from(self.archive_reader)),
435        }
436    }
437}
438
439// Must be called after a successful `archive_read_next_header`, since
440// libarchive only populates the format code once a header has been read.
441unsafe fn reject_mtree_format(archive_reader: *mut ffi::archive) -> Result<()> {
442    if ffi::archive_format(archive_reader) & ffi::ARCHIVE_FORMAT_BASE_MASK
443        == ffi::ARCHIVE_FORMAT_MTREE
444    {
445        return Err(Error::Extraction {
446            code: None,
447            details: "mtree specifications are not treated as archives".to_string(),
448        });
449    }
450    Ok(())
451}
452
453unsafe extern "C" fn libarchive_heap_seek_callback<R: Read + Seek>(
454    _: *mut ffi::archive,
455    client_data: *mut c_void,
456    offset: ffi::la_int64_t,
457    whence: c_int,
458) -> i64 {
459    let pipe = (client_data as *mut HeapReadSeekerPipe<R>)
460        .as_mut()
461        .unwrap();
462    let whence = match whence {
463        0 => SeekFrom::Start(offset as u64),
464        1 => SeekFrom::Current(offset),
465        2 => SeekFrom::End(offset),
466        _ => return -1,
467    };
468
469    match pipe.reader.seek(whence) {
470        Ok(offset) => offset as i64,
471        Err(_) => -1,
472    }
473}
474
475unsafe extern "C" fn libarchive_heap_seekableread_callback<R: Read + Seek>(
476    archive: *mut ffi::archive,
477    client_data: *mut c_void,
478    buffer: *mut *const c_void,
479) -> ffi::la_ssize_t {
480    let pipe = (client_data as *mut HeapReadSeekerPipe<R>)
481        .as_mut()
482        .unwrap();
483
484    *buffer = pipe.buffer.as_ptr() as *const c_void;
485
486    match pipe.reader.read(&mut pipe.buffer) {
487        Ok(size) => size as ffi::la_ssize_t,
488        Err(e) => {
489            let description = CString::new(e.to_string()).unwrap();
490
491            ffi::archive_set_error(archive, e.raw_os_error().unwrap_or(0), description.as_ptr());
492
493            -1
494        }
495    }
496}
497
498#[must_use]
499pub struct ArchiveIteratorBuilder<R>
500where
501    R: Read + Seek,
502{
503    source: R,
504    decoder: DecodeCallback,
505    filter: Option<Box<EntryFilterCallbackFn>>,
506    password: Option<ArchivePassword>,
507    raw_format: bool,
508    mtree_format: bool,
509}
510
511/// A builder to generate an archive iterator over the contents of an
512/// archive, streaming the contents of each entry in small chunks.
513/// The default configuration is identical to `ArchiveIterator::from_read`.
514///
515/// # Example
516///
517/// ```no_run
518/// use compress_tools::{ArchiveContents, ArchiveIteratorBuilder};
519/// use std::path::Path;
520/// use std::ffi::OsStr;
521///
522/// let source = std::fs::File::open("tests/fixtures/tree.tar").expect("Failed to open file");
523/// let decode_utf8 = |bytes: &[u8]| Ok(std::str::from_utf8(bytes)?.to_owned());
524///
525/// for content in ArchiveIteratorBuilder::new(source)
526///     .decoder(decode_utf8)
527///     .filter(|name, stat| Path::new(name).file_name() == Some(OsStr::new("foo")) || stat.st_size == 42)
528///     .build()
529///     .expect("Failed to initialize archive")
530///     {
531///         if let ArchiveContents::StartOfEntry(name, _stat) = content {
532///             println!("{name}");
533///         }
534///     }
535/// ```
536impl<R> ArchiveIteratorBuilder<R>
537where
538    R: Read + Seek,
539{
540    /// Create a new builder for an archive iterator. Default configuration is
541    /// identical to `ArchiveIterator::from_read`.
542    pub fn new(source: R) -> ArchiveIteratorBuilder<R> {
543        ArchiveIteratorBuilder {
544            source,
545            decoder: crate::decode_utf8,
546            filter: None,
547            password: None,
548            raw_format: false,
549            mtree_format: true,
550        }
551    }
552
553    /// Use a custom decoder to decode filenames of archive entries.
554    /// By default an UTF-8 decoder (`decode_utf8`) is used.
555    pub fn decoder(mut self, decoder: DecodeCallback) -> ArchiveIteratorBuilder<R> {
556        self.decoder = decoder;
557        self
558    }
559
560    /// Use a filter to skip unwanted entries and their decompression.
561    /// By default all entries are iterated.
562    pub fn filter<F>(mut self, filter: F) -> ArchiveIteratorBuilder<R>
563    where
564        F: Fn(&str, &stat) -> bool + 'static,
565    {
566        self.filter = Some(Box::new(filter));
567        self
568    }
569
570    /// Set a custom password to decode content of archive entries.
571    pub fn with_password(mut self, password: ArchivePassword) -> ArchiveIteratorBuilder<R> {
572        self.password = Some(password);
573        self
574    }
575
576    /// Enable libarchive's "raw" format handler, which parses any byte
577    /// stream as a single-entry archive with pathname `data`.
578    ///
579    /// Disabled by default so the iterator rejects input that isn't a real
580    /// archive. Enable it only when you intentionally want to iterate over
581    /// arbitrary non-archive streams (e.g. a standalone gzip file).
582    pub fn raw_format(mut self, enable: bool) -> ArchiveIteratorBuilder<R> {
583        self.raw_format = enable;
584        self
585    }
586
587    /// Accept entries from libarchive's "mtree" format handler (default).
588    ///
589    /// libarchive's mtree parser is permissive and will match free-form
590    /// text (a plain gunzip'd text file is enough); pass `false` to
591    /// reject those matches and error out instead.
592    pub fn mtree_format(mut self, enable: bool) -> ArchiveIteratorBuilder<R> {
593        self.mtree_format = enable;
594        self
595    }
596
597    /// Finish the builder and generate the configured `ArchiveIterator`.
598    pub fn build(self) -> Result<ArchiveIterator<R>> {
599        ArchiveIterator::new(
600            self.source,
601            self.decoder,
602            self.filter,
603            self.password,
604            self.raw_format,
605            self.mtree_format,
606        )
607    }
608}