bytecodec/
io.rs

1//! I/O (i.e., `Read` and `Write` traits) related module.
2use crate::{ByteCount, Decode, Encode, Eos, Error, ErrorKind, Result};
3#[cfg(feature = "tokio-async")]
4use pin_project::pin_project;
5use std::cmp;
6use std::io::{self, Read, Write};
7
8/// An extension of `Decode` trait to aid decodings involving I/O.
9pub trait IoDecodeExt: Decode {
10    /// Consumes bytes from the given read buffer and proceeds the decoding process.
11    fn decode_from_read_buf<B>(&mut self, buf: &mut ReadBuf<B>) -> Result<()>
12    where
13        B: AsRef<[u8]>,
14    {
15        let eos = Eos::new(buf.stream_state.is_eos());
16        let size = track!(self.decode(&buf.inner.as_ref()[buf.head..buf.tail], eos))?;
17        buf.head += size;
18        if buf.head == buf.tail {
19            buf.head = 0;
20            buf.tail = 0;
21        }
22        Ok(())
23    }
24
25    /// Decodes an item from the given reader.
26    ///
27    /// This method reads only minimal bytes required to decode an item.
28    ///
29    /// Note that this is a blocking method.
30    fn decode_exact<R: Read>(&mut self, mut reader: R) -> Result<Self::Item> {
31        let mut buf = [0; 1024];
32        loop {
33            let mut size = match self.requiring_bytes() {
34                ByteCount::Finite(n) => cmp::min(n, buf.len() as u64) as usize,
35                ByteCount::Infinite => buf.len(),
36                ByteCount::Unknown => 1,
37            };
38            let eos = if size != 0 {
39                size = track!(reader.read(&mut buf[..size]).map_err(Error::from))?;
40                Eos::new(size == 0)
41            } else {
42                Eos::new(false)
43            };
44
45            let consumed = track!(self.decode(&buf[..size], eos))?;
46            track_assert_eq!(consumed, size, ErrorKind::InconsistentState; self.is_idle(), eos);
47            if self.is_idle() {
48                let item = track!(self.finish_decoding())?;
49                return Ok(item);
50            }
51        }
52    }
53}
54impl<T: Decode> IoDecodeExt for T {}
55
56/// An extension of `Encode` trait to aid encodings involving I/O.
57pub trait IoEncodeExt: Encode {
58    /// Encodes the items remaining in the encoder and
59    /// writes the encoded bytes to the given write buffer.
60    fn encode_to_write_buf<B>(&mut self, buf: &mut WriteBuf<B>) -> Result<()>
61    where
62        B: AsMut<[u8]>,
63    {
64        let eos = Eos::new(buf.stream_state.is_eos());
65        let size = track!(self.encode(&mut buf.inner.as_mut()[buf.tail..], eos))?;
66        buf.tail += size;
67        Ok(())
68    }
69
70    /// Encodes the items remaining in the encoder and
71    /// writes the encoded bytes to the given write buffer.
72    /// If the write buffer is full and the writing cannot be performed,
73    /// the given WriteBuf will memorize cx's `Waker`.
74    /// This `Waker`'s `wake` will later be called when the `WriteBuf` regains its free space.
75    #[cfg(feature = "tokio-async")]
76    fn encode_to_write_buf_async<B>(
77        &mut self,
78        buf: &mut WriteBuf<B>,
79        cx: &mut std::task::Context,
80    ) -> Result<()>
81    where
82        B: AsMut<[u8]>,
83    {
84        let eos = Eos::new(buf.stream_state.is_eos());
85        let size = track!(self.encode(&mut buf.inner.as_mut()[buf.tail..], eos))?;
86        buf.tail += size;
87        buf.waker = Some(cx.waker().clone());
88        Ok(())
89    }
90
91    /// Encodes all of the items remaining in the encoder and
92    /// writes the encoded bytes to the given writer.
93    ///
94    /// Note that this is a blocking method.
95    fn encode_all<W: Write>(&mut self, mut writer: W) -> Result<()> {
96        let mut buf = [0; 1024];
97        while !self.is_idle() {
98            let size = track!(self.encode(&mut buf[..], Eos::new(false)))?;
99            track!(writer.write_all(&buf[..size]).map_err(Error::from))?;
100            if !self.is_idle() {
101                track_assert_ne!(size, 0, ErrorKind::Other);
102            }
103        }
104        Ok(())
105    }
106}
107impl<T: Encode> IoEncodeExt for T {}
108
109/// State of I/O streams.
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
111#[allow(missing_docs)]
112pub enum StreamState {
113    Normal,
114    Eos,
115    WouldBlock,
116    Error,
117}
118impl StreamState {
119    /// Returns `true` if the state is `Normal`, otherwise `false`.
120    pub fn is_normal(self) -> bool {
121        self == StreamState::Normal
122    }
123
124    /// Returns `true` if the state is `Error`, otherwise `false`.
125    pub fn is_error(self) -> bool {
126        self == StreamState::Error
127    }
128
129    /// Returns `true` if the state is `Eos`, otherwise `false`.
130    pub fn is_eos(self) -> bool {
131        self == StreamState::Eos
132    }
133
134    /// Returns `true` if the state is `WouldBlock`, otherwise `false`.
135    pub fn would_block(self) -> bool {
136        self == StreamState::WouldBlock
137    }
138}
139
140/// Read buffer.
141#[derive(Debug)]
142pub struct ReadBuf<B> {
143    pub(crate) inner: B,
144    pub(crate) head: usize,
145    pub(crate) tail: usize,
146    pub(crate) stream_state: StreamState,
147}
148impl<B: AsRef<[u8]> + AsMut<[u8]>> ReadBuf<B> {
149    /// Makes a new `ReadBuf` instance.
150    pub fn new(inner: B) -> Self {
151        ReadBuf {
152            inner,
153            head: 0,
154            tail: 0,
155            stream_state: StreamState::Normal,
156        }
157    }
158
159    /// Returns the number of filled bytes in the buffer.
160    pub fn len(&self) -> usize {
161        self.tail - self.head
162    }
163
164    /// Returns the free space of the buffer.
165    ///
166    /// Invariant: `self.len() + self.room() <= self.capacity()`
167    pub fn room(&self) -> usize {
168        self.inner.as_ref().len() - self.tail
169    }
170
171    /// Returns the capacity of the buffer.
172    pub fn capacity(&self) -> usize {
173        self.inner.as_ref().len()
174    }
175
176    /// Returns `true` if the buffer is empty, otherwise `false`.
177    pub fn is_empty(&self) -> bool {
178        self.tail == 0
179    }
180
181    /// Returns `true` if the buffer is full, otherwise `false`.
182    pub fn is_full(&self) -> bool {
183        self.tail == self.inner.as_ref().len()
184    }
185
186    /// Returns the state of the stream that operated in the last `fill()` call.
187    pub fn stream_state(&self) -> StreamState {
188        self.stream_state
189    }
190
191    /// Returns a mutable reference to the `StreamState` instance.
192    pub fn stream_state_mut(&mut self) -> &mut StreamState {
193        &mut self.stream_state
194    }
195
196    /// Fills the read buffer by reading bytes from the given reader.
197    ///
198    /// The fill process continues until one of the following condition is satisfied:
199    /// - The read buffer became full
200    /// - A read operation returned a `WouldBlock` error
201    /// - The input stream has reached EOS
202    pub fn fill<R: Read>(&mut self, mut reader: R) -> Result<()> {
203        while !self.is_full() {
204            match reader.read(&mut self.inner.as_mut()[self.tail..]) {
205                Err(e) => {
206                    if e.kind() == io::ErrorKind::WouldBlock {
207                        self.stream_state = StreamState::WouldBlock;
208                        break;
209                    } else {
210                        self.stream_state = StreamState::Error;
211                        return Err(track!(Error::from(e)));
212                    }
213                }
214                Ok(0) => {
215                    self.stream_state = StreamState::Eos;
216                    break;
217                }
218                Ok(size) => {
219                    self.stream_state = StreamState::Normal;
220                    self.tail += size;
221                }
222            }
223        }
224        Ok(())
225    }
226
227    /// Returns a reference to the inner bytes of the buffer.
228    pub fn inner_ref(&self) -> &B {
229        &self.inner
230    }
231
232    /// Returns a mutable reference to the inner bytes of the buffer.
233    pub fn inner_mut(&mut self) -> &mut B {
234        &mut self.inner
235    }
236
237    /// Takes ownership of `ReadBuf` and returns the inner bytes of the buffer.
238    pub fn into_inner(self) -> B {
239        self.inner
240    }
241}
242impl<B: AsRef<[u8]> + AsMut<[u8]>> Read for ReadBuf<B> {
243    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
244        let size = cmp::min(buf.len(), self.len());
245        buf[..size].copy_from_slice(&self.inner.as_ref()[self.head..][..size]);
246        self.head += size;
247        if self.head == self.tail {
248            self.head = 0;
249            self.tail = 0;
250        }
251        Ok(size)
252    }
253}
254
255/// Write buffer.
256#[derive(Debug)]
257pub struct WriteBuf<B> {
258    pub(crate) inner: B,
259    pub(crate) head: usize,
260    pub(crate) tail: usize,
261    pub(crate) stream_state: StreamState,
262    #[cfg(feature = "tokio-async")]
263    pub(crate) waker: Option<std::task::Waker>,
264}
265impl<B: AsRef<[u8]> + AsMut<[u8]>> WriteBuf<B> {
266    /// Makes a new `WriteBuf` instance.
267    pub fn new(inner: B) -> Self {
268        WriteBuf {
269            inner,
270            head: 0,
271            tail: 0,
272            stream_state: StreamState::Normal,
273            #[cfg(feature = "tokio-async")]
274            waker: None,
275        }
276    }
277
278    /// Returns the number of encoded bytes in the buffer.
279    pub fn len(&self) -> usize {
280        self.tail - self.head
281    }
282
283    /// Returns the free space of the buffer.
284    ///
285    /// Invariant: `self.len() + self.room() <= self.capacity()`
286    pub fn room(&self) -> usize {
287        self.inner.as_ref().len() - self.tail
288    }
289
290    /// Returns the capacity of the buffer.
291    pub fn capacity(&self) -> usize {
292        self.inner.as_ref().len()
293    }
294
295    /// Returns `true` if the buffer is empty, otherwise `false`.
296    pub fn is_empty(&self) -> bool {
297        self.tail == 0
298    }
299
300    /// Returns `true` if the buffer is full, otherwise `false`.
301    pub fn is_full(&self) -> bool {
302        self.tail == self.inner.as_ref().len()
303    }
304
305    /// Returns the state of the stream that operated in the last `flush()` call.
306    pub fn stream_state(&self) -> StreamState {
307        self.stream_state
308    }
309
310    /// Returns a mutable reference to the `StreamState` instance.
311    pub fn stream_state_mut(&mut self) -> &mut StreamState {
312        &mut self.stream_state
313    }
314
315    /// Writes the encoded bytes contained in this buffer to the given writer.
316    ///
317    /// The written bytes will be removed from the buffer.
318    ///
319    /// The flush process continues until one of the following condition is satisfied:
320    /// - The write buffer became empty
321    /// - A write operation returned a `WouldBlock` error
322    /// - The output stream has reached EOS
323    pub fn flush<W: Write>(&mut self, mut writer: W) -> Result<()> {
324        while !self.is_empty() {
325            match writer.write(&self.inner.as_ref()[self.head..self.tail]) {
326                Err(e) => {
327                    if e.kind() == io::ErrorKind::WouldBlock {
328                        self.stream_state = StreamState::WouldBlock;
329                        break;
330                    } else {
331                        self.stream_state = StreamState::Error;
332                        return Err(track!(Error::from(e)));
333                    }
334                }
335                Ok(0) => {
336                    self.stream_state = StreamState::Eos;
337                    break;
338                }
339                Ok(size) => {
340                    self.stream_state = StreamState::Normal;
341                    self.head += size;
342                    if self.head == self.tail {
343                        self.head = 0;
344                        self.tail = 0;
345                    }
346                }
347            }
348        }
349        #[cfg(feature = "tokio-async")]
350        if !self.is_full() {
351            if let Some(ref waker) = self.waker {
352                waker.wake_by_ref();
353            }
354        }
355        Ok(())
356    }
357
358    /// Returns a reference to the inner bytes of the buffer.
359    pub fn inner_ref(&self) -> &B {
360        &self.inner
361    }
362
363    /// Returns a mutable reference to the inner bytes of the buffer.
364    pub fn inner_mut(&mut self) -> &mut B {
365        &mut self.inner
366    }
367
368    /// Takes ownership of `ReadBuf` and returns the inner bytes of the buffer.
369    pub fn into_inner(self) -> B {
370        self.inner
371    }
372}
373impl<B: AsRef<[u8]> + AsMut<[u8]>> Write for WriteBuf<B> {
374    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
375        let size = cmp::min(buf.len(), self.room());
376        self.inner.as_mut()[self.tail..][..size].copy_from_slice(&buf[..size]);
377        self.tail += size;
378        Ok(size)
379    }
380
381    fn flush(&mut self) -> io::Result<()> {
382        Ok(())
383    }
384}
385
386/// Buffered I/O stream.
387#[cfg_attr(feature = "tokio-async", pin_project)]
388#[derive(Debug)]
389pub struct BufferedIo<T> {
390    #[cfg_attr(feature = "tokio-async", pin)]
391    pub(crate) stream: T,
392    pub(crate) rbuf: ReadBuf<Vec<u8>>,
393    pub(crate) wbuf: WriteBuf<Vec<u8>>,
394}
395impl<T: Read + Write> BufferedIo<T> {
396    /// Executes an I/O operation on the inner stream.
397    ///
398    /// "I/O operation" means "filling the read buffer" and "flushing the write buffer".
399    pub fn execute_io(&mut self) -> Result<()> {
400        track!(self.rbuf.fill(&mut self.stream))?;
401        track!(self.wbuf.flush(&mut self.stream))?;
402        Ok(())
403    }
404}
405
406impl<T> BufferedIo<T> {
407    /// Makes a new `BufferedIo` instance.
408    pub fn new(stream: T, read_buf_size: usize, write_buf_size: usize) -> Self {
409        BufferedIo {
410            stream,
411            rbuf: ReadBuf::new(vec![0; read_buf_size]),
412            wbuf: WriteBuf::new(vec![0; write_buf_size]),
413        }
414    }
415
416    /// Returns `true` if the inner stream reaches EOS, otherwise `false`.
417    pub fn is_eos(&self) -> bool {
418        self.rbuf.stream_state().is_eos() || self.wbuf.stream_state().is_eos()
419    }
420
421    /// Returns `true` if the previous I/O operation on the inner stream would block, otherwise `false`.
422    pub fn would_block(&self) -> bool {
423        self.rbuf.stream_state().would_block()
424            && (self.wbuf.is_empty() || self.wbuf.stream_state().would_block())
425    }
426
427    /// Returns a reference to the read buffer of the instance.
428    pub fn read_buf_ref(&self) -> &ReadBuf<Vec<u8>> {
429        &self.rbuf
430    }
431
432    /// Returns a mutable reference to the read buffer of the instance.
433    pub fn read_buf_mut(&mut self) -> &mut ReadBuf<Vec<u8>> {
434        &mut self.rbuf
435    }
436
437    /// Returns a reference to the write buffer of the instance.
438    pub fn write_buf_ref(&self) -> &WriteBuf<Vec<u8>> {
439        &self.wbuf
440    }
441
442    /// Returns a mutable reference to the write buffer of the instance.
443    pub fn write_buf_mut(&mut self) -> &mut WriteBuf<Vec<u8>> {
444        &mut self.wbuf
445    }
446
447    /// Returns a reference to the inner stream of the instance.
448    pub fn stream_ref(&self) -> &T {
449        &self.stream
450    }
451
452    /// Returns a mutable reference to the inner stream of the instance.
453    pub fn stream_mut(&mut self) -> &mut T {
454        &mut self.stream
455    }
456
457    /// Takes ownership of the instance, and returns the inner stream.
458    pub fn into_stream(self) -> T {
459        self.stream
460    }
461}
462
463#[cfg(test)]
464mod test {
465    use super::*;
466    use crate::bytes::{Utf8Decoder, Utf8Encoder};
467    use crate::EncodeExt;
468    use std::io::{Read, Write};
469
470    #[test]
471    fn decode_from_read_buf_works() {
472        let mut buf = ReadBuf::new(vec![0; 1024]);
473        track_try_unwrap!(buf.fill(b"foo".as_ref()));
474        assert_eq!(buf.len(), 3);
475        assert_eq!(buf.stream_state(), StreamState::Eos);
476
477        let mut decoder = Utf8Decoder::new();
478        track_try_unwrap!(decoder.decode_from_read_buf(&mut buf));
479        assert_eq!(track_try_unwrap!(decoder.finish_decoding()), "foo");
480    }
481
482    #[test]
483    fn read_from_read_buf_works() {
484        let mut rbuf = ReadBuf::new(vec![0; 1024]);
485        track_try_unwrap!(rbuf.fill(b"foo".as_ref()));
486        assert_eq!(rbuf.len(), 3);
487        assert_eq!(rbuf.stream_state(), StreamState::Eos);
488
489        let mut buf = Vec::new();
490        rbuf.read_to_end(&mut buf).unwrap();
491        assert_eq!(buf, b"foo");
492        assert_eq!(rbuf.len(), 0);
493    }
494
495    #[test]
496    fn encode_to_write_buf_works() {
497        let mut encoder = track_try_unwrap!(Utf8Encoder::with_item("foo"));
498
499        let mut buf = WriteBuf::new(vec![0; 1024]);
500        track_try_unwrap!(encoder.encode_to_write_buf(&mut buf));
501        assert_eq!(buf.len(), 3);
502
503        let mut v = Vec::new();
504        track_try_unwrap!(buf.flush(&mut v));
505        assert_eq!(buf.len(), 0);
506        assert_eq!(buf.stream_state(), StreamState::Normal);
507        assert_eq!(v, b"foo");
508    }
509
510    #[test]
511    fn write_to_write_buf_works() {
512        let mut buf = WriteBuf::new(vec![0; 1024]);
513        buf.write_all(b"foo").unwrap();
514        assert_eq!(buf.len(), 3);
515
516        let mut v = Vec::new();
517        track_try_unwrap!(buf.flush(&mut v));
518        assert_eq!(buf.len(), 0);
519        assert_eq!(buf.stream_state(), StreamState::Normal);
520        assert_eq!(v, b"foo");
521    }
522}