walker_common/compression/
limit.rs

1use std::io::{Error, ErrorKind, Write};
2
3/// A writer, limiting the output. Failing if more data is written.
4pub struct LimitWriter<W>
5where
6    W: Write,
7{
8    writer: W,
9    limit: usize,
10    current: usize,
11}
12
13impl<W> LimitWriter<W>
14where
15    W: Write,
16{
17    /// Create a new writer, providing the limit.
18    pub fn new(writer: W, limit: usize) -> Self {
19        Self {
20            writer,
21            limit,
22            current: 0,
23        }
24    }
25
26    /// Close writer, return the inner writer.
27    ///
28    /// Note: Closing the writer will not flush it before.
29    pub fn close(self) -> W {
30        self.writer
31    }
32}
33
34impl<W> Write for LimitWriter<W>
35where
36    W: Write,
37{
38    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
39        // check what is remaining
40        let remaining = self.limit.saturating_sub(self.current);
41        // if noting is left ...
42        if remaining == 0 {
43            // ... return an error
44            return Err(Error::new(ErrorKind::WriteZero, "write limit exceeded"));
45        }
46
47        // write out remaining bytes, maxing out at limit
48        let to_write = remaining.min(buf.len());
49        let bytes_written = self.writer.write(&buf[..to_write])?;
50        self.current += bytes_written;
51
52        Ok(bytes_written)
53    }
54
55    fn flush(&mut self) -> std::io::Result<()> {
56        self.writer.flush()
57    }
58}
59
60#[cfg(test)]
61mod test {
62    use crate::compression::LimitWriter;
63    use std::io::{Cursor, Write};
64
65    fn perform_write(data: &[u8], limit: usize) -> Result<Vec<u8>, std::io::Error> {
66        let mut out = LimitWriter::new(vec![], limit);
67        std::io::copy(&mut Cursor::new(data), &mut out)?;
68        out.flush()?;
69
70        Ok(out.close())
71    }
72
73    #[test]
74    fn write_ok() {
75        assert!(matches!(
76            perform_write(b"0123456789", 100).as_deref(),
77            Ok(b"0123456789")
78        ));
79        assert!(matches!(perform_write(b"", 100).as_deref(), Ok(b"")));
80        assert!(matches!(
81            perform_write(b"0123456789", 10).as_deref(),
82            Ok(b"0123456789")
83        ));
84        assert!(matches!(
85            perform_write(b"012345678", 10).as_deref(),
86            Ok(b"012345678")
87        ));
88    }
89
90    #[test]
91    fn write_err() {
92        assert!(perform_write(b"01234567890", 10).is_err(),);
93        assert!(perform_write(b"012345678901", 10).is_err(),);
94    }
95}