1use std::io::{Read, Seek};
2
3use zip::ZipArchive;
4use zip::read::ZipFile;
5use zip::result::ZipError;
6
7use crate::{OxdocError, Result};
8
9const DEFAULT_MAX_PART_UNCOMPRESSED_SIZE: u64 = 64 * 1024 * 1024;
10const DEFAULT_MAX_PART_COMPRESSION_RATIO: u64 = 200;
11const DEFAULT_MIN_RATIO_CHECK_SIZE: u64 = 4 * 1024 * 1024;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub struct OoxmlLimits {
15 pub max_part_uncompressed_size: u64,
16 pub max_part_compression_ratio: u64,
17 pub min_ratio_check_size: u64,
18}
19
20impl Default for OoxmlLimits {
21 fn default() -> Self {
22 Self {
23 max_part_uncompressed_size: DEFAULT_MAX_PART_UNCOMPRESSED_SIZE,
24 max_part_compression_ratio: DEFAULT_MAX_PART_COMPRESSION_RATIO,
25 min_ratio_check_size: DEFAULT_MIN_RATIO_CHECK_SIZE,
26 }
27 }
28}
29
30pub struct OoxmlPackage<R: Read + Seek> {
31 archive: ZipArchive<R>,
32 limits: OoxmlLimits,
33}
34
35impl<R: Read + Seek> OoxmlPackage<R> {
36 pub fn new(reader: R) -> Result<Self> {
37 Self::with_limits(reader, OoxmlLimits::default())
38 }
39
40 pub fn with_limits(reader: R, limits: OoxmlLimits) -> Result<Self> {
41 Ok(Self {
42 archive: ZipArchive::new(reader)?,
43 limits,
44 })
45 }
46
47 pub fn with_entry<T>(
48 &mut self,
49 path: &str,
50 read_entry: impl FnOnce(&mut dyn Read) -> Result<T>,
51 ) -> Result<T> {
52 self.with_entry_limits(path, self.limits, read_entry)
53 }
54
55 pub fn with_entry_limits<T>(
56 &mut self,
57 path: &str,
58 limits: OoxmlLimits,
59 read_entry: impl FnOnce(&mut dyn Read) -> Result<T>,
60 ) -> Result<T> {
61 match self.archive.by_name(path) {
62 Ok(mut entry) => {
63 validate_entry(path, &entry, limits)?;
64 let mut entry =
65 LimitedEntryReader::new(&mut entry, limits.max_part_uncompressed_size);
66 let result = read_entry(&mut entry);
67 if entry.exceeded_limit() {
68 Err(OxdocError::PartTooLarge {
69 path: path.to_owned(),
70 size: entry.observed_size(),
71 limit: limits.max_part_uncompressed_size,
72 })
73 } else {
74 result
75 }
76 }
77 Err(err) => Err(map_zip_entry_error(path, err)),
78 }
79 }
80
81 pub fn read_to_string(&mut self, path: &str) -> Result<String> {
82 self.with_entry(path, |entry| {
83 let mut content = String::new();
84 entry.read_to_string(&mut content)?;
85 Ok(content)
86 })
87 }
88
89 pub fn contains(&mut self, path: &str) -> bool {
90 self.archive.by_name(path).is_ok()
91 }
92
93 pub fn contains_any(&mut self, paths: &[&str]) -> bool {
94 paths.iter().any(|path| self.contains(path))
95 }
96
97 pub fn part_names(&mut self) -> Vec<String> {
98 (0..self.archive.len())
99 .filter_map(|index| {
100 self.archive.by_index(index).ok().and_then(|entry| {
101 entry
102 .enclosed_name()
103 .and_then(|path| path.to_str().map(str::to_owned))
104 })
105 })
106 .collect()
107 }
108}
109
110struct LimitedEntryReader<'a, R: Read + ?Sized> {
111 inner: &'a mut R,
112 limit: u64,
113 read: u64,
114 exceeded: bool,
115 observed_size: u64,
116}
117
118impl<'a, R: Read + ?Sized> LimitedEntryReader<'a, R> {
119 fn new(inner: &'a mut R, limit: u64) -> Self {
120 Self {
121 inner,
122 limit,
123 read: 0,
124 exceeded: false,
125 observed_size: 0,
126 }
127 }
128
129 fn exceeded_limit(&self) -> bool {
130 self.exceeded
131 }
132
133 fn observed_size(&self) -> u64 {
134 self.observed_size.max(self.read)
135 }
136}
137
138impl<R: Read + ?Sized> Read for LimitedEntryReader<'_, R> {
139 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
140 if buf.is_empty() {
141 return Ok(0);
142 }
143
144 if self.read >= self.limit {
145 let mut probe = [0u8; 1];
146 let bytes = self.inner.read(&mut probe)?;
147 if bytes == 0 {
148 return Ok(0);
149 }
150 self.exceeded = true;
151 self.observed_size = self.read.saturating_add(bytes as u64);
152 return Err(std::io::Error::new(
153 std::io::ErrorKind::InvalidData,
154 "OOXML part exceeded its configured uncompressed size limit",
155 ));
156 }
157
158 let remaining = self.limit - self.read;
159 let allowed = (buf.len() as u64).min(remaining) as usize;
160 let bytes = self.inner.read(&mut buf[..allowed])?;
161 self.read = self.read.saturating_add(bytes as u64);
162 self.observed_size = self.read;
163 Ok(bytes)
164 }
165}
166
167fn validate_entry<R: Read + ?Sized>(
168 path: &str,
169 entry: &ZipFile<'_, R>,
170 limits: OoxmlLimits,
171) -> Result<()> {
172 if entry.encrypted() {
173 return Err(OxdocError::UnsupportedEncryptedPart(path.to_owned()));
174 }
175
176 if entry.is_dir() {
177 return Err(OxdocError::SuspiciousZipEntry {
178 path: path.to_owned(),
179 reason: "required OOXML part resolves to a directory".to_owned(),
180 });
181 }
182
183 if entry.enclosed_name().is_none() {
184 return Err(OxdocError::SuspiciousZipEntry {
185 path: path.to_owned(),
186 reason: "ZIP entry name is not enclosed within the package".to_owned(),
187 });
188 }
189
190 let size = entry.size();
191 if size > limits.max_part_uncompressed_size {
192 return Err(OxdocError::PartTooLarge {
193 path: path.to_owned(),
194 size,
195 limit: limits.max_part_uncompressed_size,
196 });
197 }
198
199 let compressed_size = entry.compressed_size();
200 if size >= limits.min_ratio_check_size
201 && (compressed_size == 0
202 || size > compressed_size.saturating_mul(limits.max_part_compression_ratio))
203 {
204 return Err(OxdocError::SuspiciousZipEntry {
205 path: path.to_owned(),
206 reason: format!(
207 "uncompressed size {size} bytes is too large for compressed size {compressed_size} bytes"
208 ),
209 });
210 }
211
212 Ok(())
213}
214
215fn map_zip_entry_error(path: &str, err: ZipError) -> OxdocError {
216 match err {
217 ZipError::FileNotFound => OxdocError::MissingPart(path.to_owned()),
218 ZipError::UnsupportedArchive(reason) if reason == ZipError::PASSWORD_REQUIRED => {
219 OxdocError::UnsupportedEncryptedPart(path.to_owned())
220 }
221 err => OxdocError::CorruptedZip(err),
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use std::io::Cursor;
229 use zip::write::{SimpleFileOptions, ZipWriter};
230
231 fn create_zip(content: &[(&str, &[u8])]) -> Cursor<Vec<u8>> {
232 let mut zip = ZipWriter::new(Cursor::new(Vec::new()));
233 for (name, data) in content {
234 zip.start_file(*name, SimpleFileOptions::default()).unwrap();
235 use std::io::Write;
236 zip.write_all(data).unwrap();
237 }
238 zip.finish().unwrap()
239 }
240
241 #[test]
242 fn test_vfs_contains_any_and_read_to_string() {
243 let cursor = create_zip(&[("file1.txt", b"hello"), ("file2.txt", b"world")]);
244 let mut pkg = OoxmlPackage::new(cursor).unwrap();
245
246 assert!(pkg.contains_any(&["nonexistent.txt", "file2.txt"]));
247 assert!(!pkg.contains_any(&["nonexistent.txt"]));
248
249 assert_eq!(pkg.read_to_string("file1.txt").unwrap(), "hello");
250 }
251
252 #[test]
253 fn test_vfs_limited_reader_empty_buf() {
254 let cursor = create_zip(&[("file1.txt", b"hello")]);
255 let mut pkg = OoxmlPackage::new(cursor).unwrap();
256
257 pkg.with_entry("file1.txt", |reader| {
258 let mut buf = [];
259 assert_eq!(reader.read(&mut buf).unwrap(), 0);
260 Ok(())
261 })
262 .unwrap();
263 }
264
265 #[test]
266 fn test_map_zip_entry_error() {
267 assert!(matches!(
268 map_zip_entry_error("test", ZipError::FileNotFound),
269 OxdocError::MissingPart(_)
270 ));
271 assert!(matches!(
272 map_zip_entry_error(
273 "test",
274 ZipError::UnsupportedArchive(ZipError::PASSWORD_REQUIRED)
275 ),
276 OxdocError::UnsupportedEncryptedPart(_)
277 ));
278 assert!(matches!(
279 map_zip_entry_error("test", ZipError::InvalidArchive("bad".into())),
280 OxdocError::CorruptedZip(_)
281 ));
282 }
283
284 #[test]
285 fn test_vfs_contains_and_limits() {
286 let cursor = create_zip(&[("file1.txt", b"hello world")]);
287 let mut pkg = OoxmlPackage::new(cursor).unwrap();
288
289 assert!(pkg.contains("file1.txt"));
290 assert!(!pkg.contains("file2.txt"));
291
292 let limits = OoxmlLimits {
294 max_part_uncompressed_size: 5,
295 ..Default::default()
296 };
297 let err = pkg
298 .with_entry_limits("file1.txt", limits, |reader| {
299 let mut s = String::new();
300 reader.read_to_string(&mut s)?;
301 Ok(())
302 })
303 .unwrap_err();
304 assert!(matches!(err, OxdocError::PartTooLarge { .. }));
305 }
306}