Skip to main content

oxdoc_core/
vfs.rs

1use std::io::{Read, Seek};
2
3use zip::ZipArchive;
4use zip::read::ZipFile;
5use zip::result::ZipError;
6
7use crate::{OxdocError, Result};
8
9const DEFAULT_MAX_PART_UNCOMPRESSED_SIZE: u64 = 64 * 1024 * 1024;
10const DEFAULT_MAX_PART_COMPRESSION_RATIO: u64 = 200;
11const DEFAULT_MIN_RATIO_CHECK_SIZE: u64 = 4 * 1024 * 1024;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub struct OoxmlLimits {
15    pub max_part_uncompressed_size: u64,
16    pub max_part_compression_ratio: u64,
17    pub min_ratio_check_size: u64,
18}
19
20impl Default for OoxmlLimits {
21    fn default() -> Self {
22        Self {
23            max_part_uncompressed_size: DEFAULT_MAX_PART_UNCOMPRESSED_SIZE,
24            max_part_compression_ratio: DEFAULT_MAX_PART_COMPRESSION_RATIO,
25            min_ratio_check_size: DEFAULT_MIN_RATIO_CHECK_SIZE,
26        }
27    }
28}
29
30pub struct OoxmlPackage<R: Read + Seek> {
31    archive: ZipArchive<R>,
32    limits: OoxmlLimits,
33}
34
35impl<R: Read + Seek> OoxmlPackage<R> {
36    pub fn new(reader: R) -> Result<Self> {
37        Self::with_limits(reader, OoxmlLimits::default())
38    }
39
40    pub fn with_limits(reader: R, limits: OoxmlLimits) -> Result<Self> {
41        Ok(Self {
42            archive: ZipArchive::new(reader)?,
43            limits,
44        })
45    }
46
47    pub fn with_entry<T>(
48        &mut self,
49        path: &str,
50        read_entry: impl FnOnce(&mut dyn Read) -> Result<T>,
51    ) -> Result<T> {
52        self.with_entry_limits(path, self.limits, read_entry)
53    }
54
55    pub fn with_entry_limits<T>(
56        &mut self,
57        path: &str,
58        limits: OoxmlLimits,
59        read_entry: impl FnOnce(&mut dyn Read) -> Result<T>,
60    ) -> Result<T> {
61        match self.archive.by_name(path) {
62            Ok(mut entry) => {
63                validate_entry(path, &entry, limits)?;
64                let mut entry =
65                    LimitedEntryReader::new(&mut entry, limits.max_part_uncompressed_size);
66                let result = read_entry(&mut entry);
67                if entry.exceeded_limit() {
68                    Err(OxdocError::PartTooLarge {
69                        path: path.to_owned(),
70                        size: entry.observed_size(),
71                        limit: limits.max_part_uncompressed_size,
72                    })
73                } else {
74                    result
75                }
76            }
77            Err(err) => Err(map_zip_entry_error(path, err)),
78        }
79    }
80
81    pub fn read_to_string(&mut self, path: &str) -> Result<String> {
82        self.with_entry(path, |entry| {
83            let mut content = String::new();
84            entry.read_to_string(&mut content)?;
85            Ok(content)
86        })
87    }
88
89    pub fn contains(&mut self, path: &str) -> bool {
90        self.archive.by_name(path).is_ok()
91    }
92
93    pub fn contains_any(&mut self, paths: &[&str]) -> bool {
94        paths.iter().any(|path| self.contains(path))
95    }
96
97    pub fn part_names(&mut self) -> Vec<String> {
98        (0..self.archive.len())
99            .filter_map(|index| {
100                self.archive.by_index(index).ok().and_then(|entry| {
101                    entry
102                        .enclosed_name()
103                        .and_then(|path| path.to_str().map(str::to_owned))
104                })
105            })
106            .collect()
107    }
108}
109
110struct LimitedEntryReader<'a, R: Read + ?Sized> {
111    inner: &'a mut R,
112    limit: u64,
113    read: u64,
114    exceeded: bool,
115    observed_size: u64,
116}
117
118impl<'a, R: Read + ?Sized> LimitedEntryReader<'a, R> {
119    fn new(inner: &'a mut R, limit: u64) -> Self {
120        Self {
121            inner,
122            limit,
123            read: 0,
124            exceeded: false,
125            observed_size: 0,
126        }
127    }
128
129    fn exceeded_limit(&self) -> bool {
130        self.exceeded
131    }
132
133    fn observed_size(&self) -> u64 {
134        self.observed_size.max(self.read)
135    }
136}
137
138impl<R: Read + ?Sized> Read for LimitedEntryReader<'_, R> {
139    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
140        if buf.is_empty() {
141            return Ok(0);
142        }
143
144        if self.read >= self.limit {
145            let mut probe = [0u8; 1];
146            let bytes = self.inner.read(&mut probe)?;
147            if bytes == 0 {
148                return Ok(0);
149            }
150            self.exceeded = true;
151            self.observed_size = self.read.saturating_add(bytes as u64);
152            return Err(std::io::Error::new(
153                std::io::ErrorKind::InvalidData,
154                "OOXML part exceeded its configured uncompressed size limit",
155            ));
156        }
157
158        let remaining = self.limit - self.read;
159        let allowed = (buf.len() as u64).min(remaining) as usize;
160        let bytes = self.inner.read(&mut buf[..allowed])?;
161        self.read = self.read.saturating_add(bytes as u64);
162        self.observed_size = self.read;
163        Ok(bytes)
164    }
165}
166
167fn validate_entry<R: Read + ?Sized>(
168    path: &str,
169    entry: &ZipFile<'_, R>,
170    limits: OoxmlLimits,
171) -> Result<()> {
172    if entry.encrypted() {
173        return Err(OxdocError::UnsupportedEncryptedPart(path.to_owned()));
174    }
175
176    if entry.is_dir() {
177        return Err(OxdocError::SuspiciousZipEntry {
178            path: path.to_owned(),
179            reason: "required OOXML part resolves to a directory".to_owned(),
180        });
181    }
182
183    if entry.enclosed_name().is_none() {
184        return Err(OxdocError::SuspiciousZipEntry {
185            path: path.to_owned(),
186            reason: "ZIP entry name is not enclosed within the package".to_owned(),
187        });
188    }
189
190    let size = entry.size();
191    if size > limits.max_part_uncompressed_size {
192        return Err(OxdocError::PartTooLarge {
193            path: path.to_owned(),
194            size,
195            limit: limits.max_part_uncompressed_size,
196        });
197    }
198
199    let compressed_size = entry.compressed_size();
200    if size >= limits.min_ratio_check_size
201        && (compressed_size == 0
202            || size > compressed_size.saturating_mul(limits.max_part_compression_ratio))
203    {
204        return Err(OxdocError::SuspiciousZipEntry {
205            path: path.to_owned(),
206            reason: format!(
207                "uncompressed size {size} bytes is too large for compressed size {compressed_size} bytes"
208            ),
209        });
210    }
211
212    Ok(())
213}
214
215fn map_zip_entry_error(path: &str, err: ZipError) -> OxdocError {
216    match err {
217        ZipError::FileNotFound => OxdocError::MissingPart(path.to_owned()),
218        ZipError::UnsupportedArchive(reason) if reason == ZipError::PASSWORD_REQUIRED => {
219            OxdocError::UnsupportedEncryptedPart(path.to_owned())
220        }
221        err => OxdocError::CorruptedZip(err),
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use std::io::Cursor;
229    use zip::write::{SimpleFileOptions, ZipWriter};
230
231    fn create_zip(content: &[(&str, &[u8])]) -> Cursor<Vec<u8>> {
232        let mut zip = ZipWriter::new(Cursor::new(Vec::new()));
233        for (name, data) in content {
234            zip.start_file(*name, SimpleFileOptions::default()).unwrap();
235            use std::io::Write;
236            zip.write_all(data).unwrap();
237        }
238        zip.finish().unwrap()
239    }
240
241    #[test]
242    fn test_vfs_contains_any_and_read_to_string() {
243        let cursor = create_zip(&[("file1.txt", b"hello"), ("file2.txt", b"world")]);
244        let mut pkg = OoxmlPackage::new(cursor).unwrap();
245
246        assert!(pkg.contains_any(&["nonexistent.txt", "file2.txt"]));
247        assert!(!pkg.contains_any(&["nonexistent.txt"]));
248
249        assert_eq!(pkg.read_to_string("file1.txt").unwrap(), "hello");
250    }
251
252    #[test]
253    fn test_vfs_limited_reader_empty_buf() {
254        let cursor = create_zip(&[("file1.txt", b"hello")]);
255        let mut pkg = OoxmlPackage::new(cursor).unwrap();
256
257        pkg.with_entry("file1.txt", |reader| {
258            let mut buf = [];
259            assert_eq!(reader.read(&mut buf).unwrap(), 0);
260            Ok(())
261        })
262        .unwrap();
263    }
264
265    #[test]
266    fn test_map_zip_entry_error() {
267        assert!(matches!(
268            map_zip_entry_error("test", ZipError::FileNotFound),
269            OxdocError::MissingPart(_)
270        ));
271        assert!(matches!(
272            map_zip_entry_error(
273                "test",
274                ZipError::UnsupportedArchive(ZipError::PASSWORD_REQUIRED)
275            ),
276            OxdocError::UnsupportedEncryptedPart(_)
277        ));
278        assert!(matches!(
279            map_zip_entry_error("test", ZipError::InvalidArchive("bad".into())),
280            OxdocError::CorruptedZip(_)
281        ));
282    }
283
284    #[test]
285    fn test_vfs_contains_and_limits() {
286        let cursor = create_zip(&[("file1.txt", b"hello world")]);
287        let mut pkg = OoxmlPackage::new(cursor).unwrap();
288
289        assert!(pkg.contains("file1.txt"));
290        assert!(!pkg.contains("file2.txt"));
291
292        // Exceed limit
293        let limits = OoxmlLimits {
294            max_part_uncompressed_size: 5,
295            ..Default::default()
296        };
297        let err = pkg
298            .with_entry_limits("file1.txt", limits, |reader| {
299                let mut s = String::new();
300                reader.read_to_string(&mut s)?;
301                Ok(())
302            })
303            .unwrap_err();
304        assert!(matches!(err, OxdocError::PartTooLarge { .. }));
305    }
306}