oxdoc-core 1.1.0

Core OOXML parsing library for oxdoc
Documentation
use std::io::{Read, Seek};

use zip::ZipArchive;
use zip::read::ZipFile;
use zip::result::ZipError;

use crate::{OxdocError, Result};

const DEFAULT_MAX_PART_UNCOMPRESSED_SIZE: u64 = 64 * 1024 * 1024;
const DEFAULT_MAX_PART_COMPRESSION_RATIO: u64 = 200;
const DEFAULT_MIN_RATIO_CHECK_SIZE: u64 = 4 * 1024 * 1024;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct OoxmlLimits {
    pub max_part_uncompressed_size: u64,
    pub max_part_compression_ratio: u64,
    pub min_ratio_check_size: u64,
}

impl Default for OoxmlLimits {
    fn default() -> Self {
        Self {
            max_part_uncompressed_size: DEFAULT_MAX_PART_UNCOMPRESSED_SIZE,
            max_part_compression_ratio: DEFAULT_MAX_PART_COMPRESSION_RATIO,
            min_ratio_check_size: DEFAULT_MIN_RATIO_CHECK_SIZE,
        }
    }
}

pub struct OoxmlPackage<R: Read + Seek> {
    archive: ZipArchive<R>,
    limits: OoxmlLimits,
}

impl<R: Read + Seek> OoxmlPackage<R> {
    pub fn new(reader: R) -> Result<Self> {
        Self::with_limits(reader, OoxmlLimits::default())
    }

    pub fn with_limits(reader: R, limits: OoxmlLimits) -> Result<Self> {
        Ok(Self {
            archive: ZipArchive::new(reader)?,
            limits,
        })
    }

    pub fn with_entry<T>(
        &mut self,
        path: &str,
        read_entry: impl FnOnce(&mut dyn Read) -> Result<T>,
    ) -> Result<T> {
        self.with_entry_limits(path, self.limits, read_entry)
    }

    pub fn with_entry_limits<T>(
        &mut self,
        path: &str,
        limits: OoxmlLimits,
        read_entry: impl FnOnce(&mut dyn Read) -> Result<T>,
    ) -> Result<T> {
        match self.archive.by_name(path) {
            Ok(mut entry) => {
                validate_entry(path, &entry, limits)?;
                let mut entry =
                    LimitedEntryReader::new(&mut entry, limits.max_part_uncompressed_size);
                let result = read_entry(&mut entry);
                if entry.exceeded_limit() {
                    Err(OxdocError::PartTooLarge {
                        path: path.to_owned(),
                        size: entry.observed_size(),
                        limit: limits.max_part_uncompressed_size,
                    })
                } else {
                    result
                }
            }
            Err(err) => Err(map_zip_entry_error(path, err)),
        }
    }

    pub fn read_to_string(&mut self, path: &str) -> Result<String> {
        self.with_entry(path, |entry| {
            let mut content = String::new();
            entry.read_to_string(&mut content)?;
            Ok(content)
        })
    }

    pub fn contains(&mut self, path: &str) -> bool {
        self.archive.by_name(path).is_ok()
    }

    pub fn contains_any(&mut self, paths: &[&str]) -> bool {
        paths.iter().any(|path| self.contains(path))
    }

    pub fn part_names(&mut self) -> Vec<String> {
        (0..self.archive.len())
            .filter_map(|index| {
                self.archive.by_index(index).ok().and_then(|entry| {
                    entry
                        .enclosed_name()
                        .and_then(|path| path.to_str().map(str::to_owned))
                })
            })
            .collect()
    }
}

struct LimitedEntryReader<'a, R: Read + ?Sized> {
    inner: &'a mut R,
    limit: u64,
    read: u64,
    exceeded: bool,
    observed_size: u64,
}

impl<'a, R: Read + ?Sized> LimitedEntryReader<'a, R> {
    fn new(inner: &'a mut R, limit: u64) -> Self {
        Self {
            inner,
            limit,
            read: 0,
            exceeded: false,
            observed_size: 0,
        }
    }

    fn exceeded_limit(&self) -> bool {
        self.exceeded
    }

    fn observed_size(&self) -> u64 {
        self.observed_size.max(self.read)
    }
}

impl<R: Read + ?Sized> Read for LimitedEntryReader<'_, R> {
    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
        if buf.is_empty() {
            return Ok(0);
        }

        if self.read >= self.limit {
            let mut probe = [0u8; 1];
            let bytes = self.inner.read(&mut probe)?;
            if bytes == 0 {
                return Ok(0);
            }
            self.exceeded = true;
            self.observed_size = self.read.saturating_add(bytes as u64);
            return Err(std::io::Error::new(
                std::io::ErrorKind::InvalidData,
                "OOXML part exceeded its configured uncompressed size limit",
            ));
        }

        let remaining = self.limit - self.read;
        let allowed = (buf.len() as u64).min(remaining) as usize;
        let bytes = self.inner.read(&mut buf[..allowed])?;
        self.read = self.read.saturating_add(bytes as u64);
        self.observed_size = self.read;
        Ok(bytes)
    }
}

fn validate_entry<R: Read + ?Sized>(
    path: &str,
    entry: &ZipFile<'_, R>,
    limits: OoxmlLimits,
) -> Result<()> {
    if entry.encrypted() {
        return Err(OxdocError::UnsupportedEncryptedPart(path.to_owned()));
    }

    if entry.is_dir() {
        return Err(OxdocError::SuspiciousZipEntry {
            path: path.to_owned(),
            reason: "required OOXML part resolves to a directory".to_owned(),
        });
    }

    if entry.enclosed_name().is_none() {
        return Err(OxdocError::SuspiciousZipEntry {
            path: path.to_owned(),
            reason: "ZIP entry name is not enclosed within the package".to_owned(),
        });
    }

    let size = entry.size();
    if size > limits.max_part_uncompressed_size {
        return Err(OxdocError::PartTooLarge {
            path: path.to_owned(),
            size,
            limit: limits.max_part_uncompressed_size,
        });
    }

    let compressed_size = entry.compressed_size();
    if size >= limits.min_ratio_check_size
        && (compressed_size == 0
            || size > compressed_size.saturating_mul(limits.max_part_compression_ratio))
    {
        return Err(OxdocError::SuspiciousZipEntry {
            path: path.to_owned(),
            reason: format!(
                "uncompressed size {size} bytes is too large for compressed size {compressed_size} bytes"
            ),
        });
    }

    Ok(())
}

fn map_zip_entry_error(path: &str, err: ZipError) -> OxdocError {
    match err {
        ZipError::FileNotFound => OxdocError::MissingPart(path.to_owned()),
        ZipError::UnsupportedArchive(reason) if reason == ZipError::PASSWORD_REQUIRED => {
            OxdocError::UnsupportedEncryptedPart(path.to_owned())
        }
        err => OxdocError::CorruptedZip(err),
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::io::Cursor;
    use zip::write::{SimpleFileOptions, ZipWriter};

    fn create_zip(content: &[(&str, &[u8])]) -> Cursor<Vec<u8>> {
        let mut zip = ZipWriter::new(Cursor::new(Vec::new()));
        for (name, data) in content {
            zip.start_file(*name, SimpleFileOptions::default()).unwrap();
            use std::io::Write;
            zip.write_all(data).unwrap();
        }
        zip.finish().unwrap()
    }

    #[test]
    fn test_vfs_contains_any_and_read_to_string() {
        let cursor = create_zip(&[("file1.txt", b"hello"), ("file2.txt", b"world")]);
        let mut pkg = OoxmlPackage::new(cursor).unwrap();

        assert!(pkg.contains_any(&["nonexistent.txt", "file2.txt"]));
        assert!(!pkg.contains_any(&["nonexistent.txt"]));

        assert_eq!(pkg.read_to_string("file1.txt").unwrap(), "hello");
    }

    #[test]
    fn test_vfs_limited_reader_empty_buf() {
        let cursor = create_zip(&[("file1.txt", b"hello")]);
        let mut pkg = OoxmlPackage::new(cursor).unwrap();

        pkg.with_entry("file1.txt", |reader| {
            let mut buf = [];
            assert_eq!(reader.read(&mut buf).unwrap(), 0);
            Ok(())
        })
        .unwrap();
    }

    #[test]
    fn test_map_zip_entry_error() {
        assert!(matches!(
            map_zip_entry_error("test", ZipError::FileNotFound),
            OxdocError::MissingPart(_)
        ));
        assert!(matches!(
            map_zip_entry_error(
                "test",
                ZipError::UnsupportedArchive(ZipError::PASSWORD_REQUIRED)
            ),
            OxdocError::UnsupportedEncryptedPart(_)
        ));
        assert!(matches!(
            map_zip_entry_error("test", ZipError::InvalidArchive("bad".into())),
            OxdocError::CorruptedZip(_)
        ));
    }

    #[test]
    fn test_vfs_contains_and_limits() {
        let cursor = create_zip(&[("file1.txt", b"hello world")]);
        let mut pkg = OoxmlPackage::new(cursor).unwrap();

        assert!(pkg.contains("file1.txt"));
        assert!(!pkg.contains("file2.txt"));

        // Exceed limit
        let limits = OoxmlLimits {
            max_part_uncompressed_size: 5,
            ..Default::default()
        };
        let err = pkg
            .with_entry_limits("file1.txt", limits, |reader| {
                let mut s = String::new();
                reader.read_to_string(&mut s)?;
                Ok(())
            })
            .unwrap_err();
        assert!(matches!(err, OxdocError::PartTooLarge { .. }));
    }
}