1#![warn(clippy::pedantic)]
51
52use std::io::{Read, Seek, SeekFrom, Write};
53use std::ops::{Bound, RangeBounds};
54
55#[derive(Debug, Default, Clone, PartialEq, Eq)]
59pub struct IoWindow<T: Seek> {
60 inner: T,
61 start: u64,
62 end: Option<u64>,
63}
64
65impl<T: Seek> IoWindow<T> {
66 pub fn new(mut inner: T, range: impl RangeBounds<u64>) -> std::io::Result<IoWindow<T>> {
84 let start = match range.start_bound() {
85 Bound::Included(pos) => *pos,
86 Bound::Excluded(pos) => pos.checked_add(1).ok_or(BadRange)?,
87 Bound::Unbounded => 0,
88 };
89 let end = match range.end_bound() {
90 Bound::Included(pos) => Some(pos.checked_add(1).ok_or(BadRange)?),
91 Bound::Excluded(pos) => Some(*pos),
92 Bound::Unbounded => None,
93 };
94 if inner.stream_position()? < start {
95 inner.seek(SeekFrom::Start(start))?;
96 }
97 Ok(IoWindow { inner, start, end })
98 }
99
100 pub fn into_inner(self) -> T {
102 self.inner
103 }
104
105 pub fn get_ref(&self) -> &T {
107 &self.inner
108 }
109
110 pub fn get_mut(&mut self) -> &mut T {
117 &mut self.inner
118 }
119
120 fn reduce_buf_len(&mut self, len: usize) -> std::io::Result<usize> {
123 Ok(if let Some(end) = self.end {
124 if let Some(remaining) = end.checked_sub(self.inner.stream_position()?) {
125 match usize::try_from(remaining) {
129 Ok(remaining) => len.min(remaining),
130 Err(_) => len,
131 }
132 } else {
133 0
136 }
137 } else {
138 len
139 })
140 }
141}
142
143impl<T: Read + Seek> Read for IoWindow<T> {
144 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
145 let len = self.reduce_buf_len(buf.len())?;
146 self.inner.read(&mut buf[..len])
147 }
148}
149
150impl<T: Write + Seek> Write for IoWindow<T> {
151 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
152 let len = self.reduce_buf_len(buf.len())?;
153 self.inner.write(&buf[..len])
154 }
155
156 fn flush(&mut self) -> std::io::Result<()> {
157 self.inner.flush()
158 }
159}
160
161impl<T: Seek> Seek for IoWindow<T> {
162 fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
163 let adjusted = match pos {
164 SeekFrom::Start(pos) => SeekFrom::Start(self.start.checked_add(pos).ok_or(BadSeek)?),
165 SeekFrom::End(pos) => {
166 if let Some(end) = self.end {
167 SeekFrom::Start(checked_add_signed(end, pos).ok_or(BadSeek)?)
168 } else {
169 SeekFrom::End(pos)
170 }
171 }
172 SeekFrom::Current(0) => SeekFrom::Current(0),
173 SeekFrom::Current(pos) => SeekFrom::Start(
174 checked_add_signed(self.inner.stream_position()?, pos).ok_or(BadSeek)?,
175 ),
176 };
177 if let SeekFrom::Start(start) = adjusted {
178 if start < self.start {
179 Err(BadSeek)?;
180 }
181 }
182 Ok(self.inner.seek(adjusted)? - self.start)
183 }
184}
185
186#[inline]
187fn checked_add_signed(lhs: u64, rhs: i64) -> Option<u64> {
188 if rhs.is_negative() {
189 lhs.checked_sub(rhs.unsigned_abs())
190 } else {
191 lhs.checked_add(rhs.unsigned_abs())
192 }
193}
194
195macro_rules! err_shortcut {
196 ($ident:ident, $str:literal) => {
197 struct $ident;
198
199 impl From<$ident> for std::io::Error {
200 #[inline]
201 fn from(_: $ident) -> std::io::Error {
202 std::io::Error::new(std::io::ErrorKind::InvalidInput, $str)
203 }
204 }
205 };
206}
207err_shortcut!(BadRange, "overflowing range bound");
208err_shortcut!(
209 BadSeek,
210 "invalid seek to a negative or overflowing position"
211);
212
213#[cfg(test)]
214mod tests {
215 use std::io::{Cursor, ErrorKind, Read, Seek, SeekFrom, Write};
216
217 use super::IoWindow;
218
219 #[test]
220 fn range() -> std::io::Result<()> {
221 let v = Cursor::new(vec![0; 512]);
222 let mut window = IoWindow::new(v, 128..256)?;
223 assert_eq!(window.stream_position()?, 0);
224 assert_eq!(window.get_mut().stream_position()?, 128);
225
226 macro_rules! t {
227 ($seekfrom:expr, $windowpos:expr, $innerpos:expr) => {{
228 assert_eq!(window.seek($seekfrom)?, $windowpos);
229 assert_eq!(window.stream_position()?, $windowpos);
230 assert_eq!(window.inner.stream_position()?, $innerpos);
231 }};
232 }
233
234 t!(SeekFrom::Start(0), 0, 128);
236 t!(SeekFrom::Start(32), 32, 160);
237 t!(SeekFrom::End(0), 128, 256);
238 t!(SeekFrom::End(-32), 96, 224);
239 t!(SeekFrom::Current(-32), 64, 192);
240
241 let mut buf = [0; 16];
243 t!(SeekFrom::Start(0), 0, 128);
244 assert_eq!(window.write(b"meow meow meow")?, 14);
245 t!(SeekFrom::Current(0), 14, 142);
246 assert!(window
247 .inner
248 .get_ref()
249 .iter()
250 .eq([0; 128].iter().chain(b"meow meow meow").chain(&[0; 370])));
251 t!(SeekFrom::Current(-4), 10, 138);
252 assert_eq!(window.read(&mut buf[..4])?, 4);
253 assert_eq!(&buf[..4], b"meow");
254
255 t!(SeekFrom::End(-4), 124, 252);
257 assert_eq!(window.write(b"meow meow meow")?, 4);
258 t!(SeekFrom::Current(0), 128, 256);
259 assert_eq!(&window.inner.get_ref()[256..], &[0; 256]);
260 t!(SeekFrom::End(-8), 120, 248);
261 assert_eq!(window.read(&mut buf[..])?, 8);
262 t!(SeekFrom::Current(0), 128, 256);
263 assert_eq!(&buf[..8], b"\0\0\0\0meow");
264
265 assert!(window.seek(SeekFrom::Current(-160)).is_err());
268 t!(SeekFrom::Current(0), 128, 256);
269 assert!(window.seek(SeekFrom::End(-160)).is_err());
270 t!(SeekFrom::Current(0), 128, 256);
271
272 t!(SeekFrom::End(64), 192, 320);
275 assert_eq!(window.read(&mut [0; 64])?, 0);
276 t!(SeekFrom::Current(0), 192, 320);
277 assert_eq!(window.write(&[0; 64])?, 0);
278 t!(SeekFrom::Current(0), 192, 320);
279
280 let v = window.into_inner().into_inner();
281 assert_eq!(v.len(), 512);
282 assert_eq!(v.capacity(), 512);
283
284 Ok(())
285 }
286
287 #[test]
288 fn range_unbounded() -> std::io::Result<()> {
289 let v = Cursor::new(Vec::new());
290 let mut window = IoWindow::new(v, 128..)?;
291
292 assert_eq!(window.inner.stream_position()?, 128);
293 assert_eq!(window.inner.get_ref().len(), 0);
294 window.write_all(b"meow")?;
295 assert_eq!(window.inner.get_ref().len(), 132);
296
297 window.seek(SeekFrom::Start(0))?;
298 let mut buf = [0; 8];
299 assert_eq!(window.read(&mut buf[..])?, 4);
300 assert_eq!(&buf[..4], b"meow");
301
302 Ok(())
303 }
304
305 #[test]
306 fn zero_range() -> std::io::Result<()> {
307 let mut window = IoWindow::new(Cursor::new(Vec::new()), 0..0)?;
308 assert_eq!(window.write(&[0; 4])?, 0);
309 assert_eq!(window.read(&mut [0; 4])?, 0);
310 Ok(())
311 }
312
313 #[test]
314 fn wrapped() -> std::io::Result<()> {
315 let inner = IoWindow::new(Cursor::new([0; 512]), 128..256)?;
316 let mut window = IoWindow::new(inner, 32..64)?;
317 assert_eq!(window.write(&[42; 128])?, 32);
318 assert_eq!(window.stream_position()?, 32);
319 let mut inner = window.into_inner();
320 assert_eq!(inner.stream_position()?, 64);
321 let mut cursor = inner.into_inner();
322 assert_eq!(cursor.stream_position()?, 192);
323 assert!(cursor
324 .get_ref()
325 .iter()
326 .eq([0; 160].iter().chain(&[42; 32]).chain(&[0; 320])));
327 Ok(())
328 }
329
330 #[test]
331 fn copy() -> std::io::Result<()> {
332 let from = b"meow meow meow meow";
333 let mut to = IoWindow::new(Cursor::new([0; 32]), 0..24)?;
334 std::io::copy(&mut &from[..], &mut to)?;
335
336 let mut to = IoWindow::new(Cursor::new([0; 32]), 0..8)?;
337 assert_eq!(
338 std::io::copy(&mut &from[..], &mut to).unwrap_err().kind(),
339 ErrorKind::WriteZero
340 );
341
342 Ok(())
343 }
344}