1use core::cmp;
2use io;
3use io::prelude::*;
4
5const BUF_SIZE: usize = 512;
6
7pub struct BufStream<T: Read+Write+Seek> {
13 inner: T,
14 buf: [u8; BUF_SIZE],
15 len: usize,
16 pos: usize,
17 write: bool,
18}
19
20impl<T: Read+Write+Seek> BufStream<T> {
21 pub fn new(inner: T) -> Self {
23 BufStream {
24 inner,
25 buf: [0; BUF_SIZE],
26 pos: 0,
27 len: 0,
28 write: false,
29 }
30 }
31
32 fn flush_buf(&mut self) -> io::Result<()> {
33 if self.write {
34 self.inner.write_all(&self.buf[..self.pos])?;
35 self.pos = 0;
36 }
37 Ok(())
38 }
39
40 fn make_reader(&mut self) -> io::Result<()> {
41 if self.write {
42 self.flush_buf()?;
43 self.write = false;
44 self.len = 0;
45 self.pos = 0;
46 }
47 Ok(())
48 }
49
50 fn make_writter(&mut self) -> io::Result<()> {
51 if !self.write {
52 self.inner.seek(io::SeekFrom::Current(-(self.len as i64 - self.pos as i64)))?;
53 self.write = true;
54 self.len = 0;
55 self.pos = 0;
56 }
57 Ok(())
58 }
59
60 fn fill_buf(&mut self) -> io::Result<&[u8]> {
61 self.make_reader()?;
62 if self.pos >= self.len {
63 debug_assert!(self.pos == self.len);
64 self.len = self.inner.read(&mut self.buf)?;
65 self.pos = 0;
66 }
67 Ok(&self.buf[self.pos..self.len])
68 }
69
70 fn consume(&mut self, amt: usize) {
71 self.pos = cmp::min(self.pos + amt, self.len);
72 }
73}
74
75#[cfg(any(feature = "std", feature = "core_io/collections"))]
76impl<T: Read+Write+Seek> BufRead for BufStream<T> {
77 fn fill_buf(&mut self) -> io::Result<&[u8]> {
78 BufStream::fill_buf(self)
79 }
80
81 fn consume(&mut self, amt: usize) {
82 BufStream::consume(self, amt)
83 }
84}
85
86impl<T: Read+Write+Seek> Read for BufStream<T> {
87 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
88 self.make_reader()?;
90 if self.pos == self.len && buf.len() >= BUF_SIZE {
92 return self.inner.read(buf);
93 }
94 let nread = {
95 let mut rem = self.fill_buf()?;
96 rem.read(buf)?
97 };
98 self.consume(nread);
99 Ok(nread)
100 }
101}
102
103impl<T: Read+Write+Seek> Write for BufStream<T> {
104 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
105 self.make_writter()?;
107 if self.pos + buf.len() > BUF_SIZE {
108 self.flush_buf()?;
109 if buf.len() >= BUF_SIZE {
110 return self.inner.write(buf);
111 }
112 }
113 let written = (&mut self.buf[self.pos..]).write(buf)?;
114 self.pos += written;
115 Ok(written)
116 }
117
118 fn flush(&mut self) -> io::Result<()> {
119 self.flush_buf()?;
120 self.inner.flush()
121 }
122}
123
124impl<T: Read+Write+Seek> Seek for BufStream<T> {
125 fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
126 self.flush_buf()?;
127 let new_pos = match pos {
128 io::SeekFrom::Current(x) => io::SeekFrom::Current(x - (self.len as i64 - self.pos as i64)),
129 _ => pos,
130 };
131 self.pos = 0;
132 self.len = 0;
133 self.inner.seek(new_pos)
134 }
135}
136
137impl<T: Read+Write+Seek> Drop for BufStream<T> {
138 fn drop(&mut self) {
139 if let Err(err) = self.flush() {
140 error!("flush failed {}", err);
141 }
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 #[test]
149 fn it_works() {
150 let buf = "Test data".to_string().into_bytes();
151 let cur = io::Cursor::new(buf);
152 let mut buf_stream = BufStream::new(cur);
153
154 let mut data = String::new();
155 buf_stream.read_to_string(&mut data).unwrap();
156 assert_eq!(data, "Test data");
157
158 buf_stream.seek(io::SeekFrom::Start(5)).unwrap();
159 let mut data = String::new();
160 buf_stream.read_to_string(&mut data).unwrap();
161 assert_eq!(data, "data");
162
163 buf_stream.write_all("\nHello".as_bytes()).unwrap();
164 buf_stream.seek(io::SeekFrom::Start(0)).unwrap();
165 let mut data = String::new();
166 buf_stream.read_line(&mut data).unwrap();
167 assert_eq!(data, "Test data\n");
168 data.clear();
169 buf_stream.read_line(&mut data).unwrap();
170 assert_eq!(data, "Hello");
171 }
172}