libcoreinst/io/
limit.rs

1// Copyright 2019 CoreOS, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::io::{self, Read, Write};
16
17pub struct LimitReader<R: Read> {
18    source: R,
19    length: u64,
20    remaining: u64,
21    conflict: String,
22}
23
24impl<R: Read> LimitReader<R> {
25    pub fn new(source: R, length: u64, conflict: String) -> Self {
26        Self {
27            source,
28            length,
29            remaining: length,
30            conflict,
31        }
32    }
33}
34
35impl<R: Read> Read for LimitReader<R> {
36    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
37        if buf.is_empty() {
38            return Ok(0);
39        }
40        let allowed = self.remaining.min(buf.len() as u64);
41        if allowed == 0 {
42            // reached the limit; only error if we're not at EOF
43            return match self.source.read(&mut buf[..1]) {
44                Ok(0) => Ok(0),
45                Ok(_) => Err(io::Error::other(format!(
46                    "collision with {} at offset {}",
47                    self.conflict, self.length
48                ))),
49                Err(e) => Err(e),
50            };
51        }
52        let count = self.source.read(&mut buf[..allowed as usize])?;
53        self.remaining = self
54            .remaining
55            .checked_sub(count as u64)
56            .expect("read more bytes than allowed");
57        Ok(count)
58    }
59}
60
61pub struct LimitWriter<W: Write> {
62    sink: W,
63    length: u64,
64    remaining: u64,
65    conflict: String,
66}
67
68impl<W: Write> LimitWriter<W> {
69    pub fn new(sink: W, length: u64, conflict: String) -> Self {
70        Self {
71            sink,
72            length,
73            remaining: length,
74            conflict,
75        }
76    }
77}
78
79impl<W: Write> Write for LimitWriter<W> {
80    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
81        if buf.is_empty() {
82            return Ok(0);
83        }
84        let allowed = self.remaining.min(buf.len() as u64);
85        if allowed == 0 {
86            return Err(io::Error::other(format!(
87                "collision with {} at offset {}",
88                self.conflict, self.length
89            )));
90        }
91        let count = self.sink.write(&buf[..allowed as usize])?;
92        self.remaining = self
93            .remaining
94            .checked_sub(count as u64)
95            .expect("wrote more bytes than allowed");
96        Ok(count)
97    }
98
99    fn flush(&mut self) -> io::Result<()> {
100        self.sink.flush()
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use std::io::Cursor;
108
109    #[test]
110    fn limit_reader_test() {
111        // build input data
112        let data: Vec<u8> = (0..100).collect();
113
114        // limit larger than file
115        let mut file = Cursor::new(data.clone());
116        let mut lim = LimitReader::new(&mut file, 150, "foo".into());
117        let mut buf = [0u8; 60];
118        assert_eq!(lim.read(&mut buf).unwrap(), 60);
119        assert_eq!(buf[..], data[0..60]);
120        assert_eq!(lim.read(&mut buf).unwrap(), 40);
121        assert_eq!(buf[..40], data[60..100]);
122        assert_eq!(lim.read(&mut buf).unwrap(), 0);
123
124        // limit exactly equal to file
125        let mut file = Cursor::new(data.clone());
126        let mut lim = LimitReader::new(&mut file, 100, "foo".into());
127        let mut buf = [0u8; 60];
128        assert_eq!(lim.read(&mut buf).unwrap(), 60);
129        assert_eq!(buf[..], data[0..60]);
130        assert_eq!(lim.read(&mut buf).unwrap(), 40);
131        assert_eq!(buf[..40], data[60..100]);
132        assert_eq!(lim.read(&mut buf).unwrap(), 0);
133
134        // buffer smaller than limit
135        let mut file = Cursor::new(data.clone());
136        let mut lim = LimitReader::new(&mut file, 90, "foo".into());
137        let mut buf = [0u8; 60];
138        assert_eq!(lim.read(&mut buf).unwrap(), 60);
139        assert_eq!(buf[..], data[0..60]);
140        assert_eq!(lim.read(&mut buf).unwrap(), 30);
141        assert_eq!(buf[..30], data[60..90]);
142        assert_eq!(
143            lim.read(&mut buf).unwrap_err().to_string(),
144            "collision with foo at offset 90"
145        );
146
147        // buffer exactly equal to limit
148        let mut file = Cursor::new(data.clone());
149        let mut lim = LimitReader::new(&mut file, 60, "foo".into());
150        let mut buf = [0u8; 60];
151        assert_eq!(lim.read(&mut buf).unwrap(), 60);
152        assert_eq!(buf[..], data[0..60]);
153        assert_eq!(
154            lim.read(&mut buf).unwrap_err().to_string(),
155            "collision with foo at offset 60"
156        );
157
158        // buffer larger than limit
159        let mut file = Cursor::new(data.clone());
160        let mut lim = LimitReader::new(&mut file, 50, "foo".into());
161        let mut buf = [0u8; 60];
162        assert_eq!(lim.read(&mut buf).unwrap(), 50);
163        assert_eq!(buf[..50], data[0..50]);
164        assert_eq!(
165            lim.read(&mut buf).unwrap_err().to_string(),
166            "collision with foo at offset 50"
167        );
168    }
169
170    #[test]
171    fn limit_writer_test() {
172        let data: Vec<u8> = (0..100).collect();
173
174        // limit larger than data
175        let mut outbuf: Vec<u8> = Vec::new();
176        let mut lim = LimitWriter::new(&mut outbuf, 150, "foo".into());
177        lim.write_all(&data).unwrap();
178        lim.flush().unwrap();
179        assert_eq!(data, outbuf);
180
181        // limit exactly equal to data
182        let mut outbuf: Vec<u8> = Vec::new();
183        let mut lim = LimitWriter::new(&mut outbuf, 100, "foo".into());
184        lim.write_all(&data).unwrap();
185        lim.flush().unwrap();
186        assert_eq!(data, outbuf);
187
188        // limit smaller than data
189        let mut outbuf: Vec<u8> = Vec::new();
190        let mut lim = LimitWriter::new(&mut outbuf, 90, "foo".into());
191        assert_eq!(
192            lim.write_all(&data).unwrap_err().to_string(),
193            "collision with foo at offset 90"
194        );
195
196        // directly test writing in multiple chunks
197        let mut outbuf: Vec<u8> = Vec::new();
198        let mut lim = LimitWriter::new(&mut outbuf, 90, "foo".into());
199        assert_eq!(lim.write(&data[0..60]).unwrap(), 60);
200        assert_eq!(lim.write(&data[60..100]).unwrap(), 30); // short write
201        assert_eq!(
202            lim.write(&data[90..100]).unwrap_err().to_string(),
203            "collision with foo at offset 90"
204        );
205        assert_eq!(lim.write(&data[0..0]).unwrap(), 0);
206        assert_eq!(&data[0..90], &outbuf);
207    }
208}