libcoreinst/io/
tee.rs

1// Copyright 2019 CoreOS, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Wrappers for splitting I/O streams
16
17use std::io::{self, Read, Write};
18
19/// Reader wrapper that copies data to a writer as a side effect
20pub struct TeeReader<R: Read, W: Write> {
21    source: R,
22    dest: W,
23}
24
25impl<R: Read, W: Write> TeeReader<R, W> {
26    pub fn new(source: R, dest: W) -> Self {
27        Self { source, dest }
28    }
29
30    pub fn into_inner(self) -> (R, W) {
31        (self.source, self.dest)
32    }
33}
34
35impl<R: Read, W: Write> Read for TeeReader<R, W> {
36    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
37        if buf.is_empty() {
38            return Ok(0);
39        }
40        let count = self.source.read(buf)?;
41        self.dest.write_all(&buf[..count])?;
42        Ok(count)
43    }
44}
45
46#[cfg(test)]
47mod tests {
48    use super::*;
49
50    /// Do some I/O of different sizes, reach EOF, and check that both
51    /// copies of the output are correct
52    #[test]
53    fn tee_reader() {
54        const COUNT: usize = 100;
55        let src: Vec<u8> = (0..COUNT as u8).collect();
56        let mut buf = [0; 2 * COUNT];
57        let mut off = 0;
58        let mut tee = TeeReader::new(&*src, Vec::new());
59        for i in 2.. {
60            off += tee.read(&mut buf[off..off + i]).unwrap();
61            assert!(off <= COUNT);
62            if off == COUNT {
63                break;
64            }
65        }
66        assert_eq!(src, buf[..COUNT]);
67        let (_, dest) = tee.into_inner();
68        assert_eq!(src, dest);
69    }
70}