Skip to main content

fscommon/
buf_stream.rs

1use core::cmp;
2use io;
3use io::prelude::*;
4
5const BUF_SIZE: usize = 512;
6
7/// The `BufStream` struct adds buffering to underlying file or device.
8///
9/// It is basically composition of `BufReader` and `BufWritter`.
10/// Buffer size is fixed to 512 to avoid dynamic allocation.
11/// `BufStream` automatically flushes itself when being dropped.
12pub struct BufStream<T: Read+Write+Seek>  {
13    inner: T,
14    buf: [u8; BUF_SIZE],
15    len: usize,
16    pos: usize,
17    write: bool,
18}
19
20impl<T: Read+Write+Seek> BufStream<T> {
21    /// Creates a new `BufStream` object for a given inner stream.
22    pub fn new(inner: T) -> Self {
23        BufStream {
24            inner,
25            buf: [0; BUF_SIZE],
26            pos: 0,
27            len: 0,
28            write: false,
29        }
30    }
31
32    fn flush_buf(&mut self) -> io::Result<()> {
33        if self.write {
34            self.inner.write_all(&self.buf[..self.pos])?;
35            self.pos = 0;
36        }
37        Ok(())
38    }
39
40    fn make_reader(&mut self) -> io::Result<()> {
41        if self.write {
42            self.flush_buf()?;
43            self.write = false;
44            self.len = 0;
45            self.pos = 0;
46        }
47        Ok(())
48    }
49
50    fn make_writter(&mut self) -> io::Result<()> {
51        if !self.write {
52            self.inner.seek(io::SeekFrom::Current(-(self.len as i64 - self.pos as i64)))?;
53            self.write = true;
54            self.len = 0;
55            self.pos = 0;
56        }
57        Ok(())
58    }
59
60    fn fill_buf(&mut self) -> io::Result<&[u8]> {
61        self.make_reader()?;
62        if self.pos >= self.len {
63            debug_assert!(self.pos == self.len);
64            self.len = self.inner.read(&mut self.buf)?;
65            self.pos = 0;
66        }
67        Ok(&self.buf[self.pos..self.len])
68    }
69
70    fn consume(&mut self, amt: usize) {
71        self.pos = cmp::min(self.pos + amt, self.len);
72    }
73}
74
75#[cfg(any(feature = "std", feature = "core_io/collections"))]
76impl<T: Read+Write+Seek> BufRead for BufStream<T> {
77    fn fill_buf(&mut self) -> io::Result<&[u8]> {
78        BufStream::fill_buf(self)
79    }
80
81    fn consume(&mut self, amt: usize) {
82        BufStream::consume(self, amt)
83    }
84}
85
86impl<T: Read+Write+Seek> Read for BufStream<T> {
87    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
88        // Make sure we are in read mode
89        self.make_reader()?;
90        // Check if this read is bigger than buffer size
91        if self.pos == self.len && buf.len() >= BUF_SIZE {
92            return self.inner.read(buf);
93        }
94        let nread = {
95            let mut rem = self.fill_buf()?;
96            rem.read(buf)?
97        };
98        self.consume(nread);
99        Ok(nread)
100    }
101}
102
103impl<T: Read+Write+Seek> Write for BufStream<T> {
104    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
105        // Make sure we are in write mode
106        self.make_writter()?;
107        if self.pos + buf.len() > BUF_SIZE {
108            self.flush_buf()?;
109            if buf.len() >= BUF_SIZE {
110                return self.inner.write(buf);
111            }
112        }
113        let written = (&mut self.buf[self.pos..]).write(buf)?;
114        self.pos += written;
115        Ok(written)
116    }
117
118    fn flush(&mut self) -> io::Result<()> {
119        self.flush_buf()?;
120        self.inner.flush()
121    }
122}
123
124impl<T: Read+Write+Seek> Seek for BufStream<T> {
125    fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
126        self.flush_buf()?;
127        let new_pos = match pos {
128            io::SeekFrom::Current(x) => io::SeekFrom::Current(x - (self.len as i64 - self.pos as i64)),
129            _ => pos,
130        };
131        self.pos = 0;
132        self.len = 0;
133        self.inner.seek(new_pos)
134    }
135}
136
137impl<T: Read+Write+Seek> Drop for BufStream<T> {
138    fn drop(&mut self) {
139        if let Err(err) = self.flush() {
140            error!("flush failed {}", err);
141        }
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    #[test]
149    fn it_works() {
150        let buf = "Test data".to_string().into_bytes();
151        let cur = io::Cursor::new(buf);
152        let mut buf_stream = BufStream::new(cur);
153
154        let mut data = String::new();
155        buf_stream.read_to_string(&mut data).unwrap();
156        assert_eq!(data, "Test data");
157
158        buf_stream.seek(io::SeekFrom::Start(5)).unwrap();
159        let mut data = String::new();
160        buf_stream.read_to_string(&mut data).unwrap();
161        assert_eq!(data, "data");
162
163        buf_stream.write_all("\nHello".as_bytes()).unwrap();
164        buf_stream.seek(io::SeekFrom::Start(0)).unwrap();
165        let mut data = String::new();
166        buf_stream.read_line(&mut data).unwrap();
167        assert_eq!(data, "Test data\n");
168        data.clear();
169        buf_stream.read_line(&mut data).unwrap();
170        assert_eq!(data, "Hello");
171    }
172}