future_utils/
framed_unbuffered.rs

1use bytes::{Bytes, BytesMut};
2use tokio::io::{AsyncRead, AsyncWrite};
3use futures::{Stream, Sink, Async, AsyncSink};
4use std::{io, mem};
5
6fn zeros(n: usize) -> BytesMut {
7    let mut ret = BytesMut::with_capacity(n);
8    unsafe {
9        ret.set_len(n);
10        for i in 0..n {
11            ret[i] = 0;
12        }
13    }
14    ret
15}
16
17/// An alternative to tokio_io's `Framed` which doesn't internally buffer data.
18/// This gives it much lower performance but means that you can use `.into_inner()` without losing
19/// data.
20pub struct FramedUnbuffered<T> {
21    stream: T,
22    read_state: ReadState,
23    write_state: WriteState,
24}
25
26impl<T> FramedUnbuffered<T> {
27    pub fn new(stream: T) -> FramedUnbuffered<T> {
28        FramedUnbuffered {
29            stream,
30            read_state: ReadState::ReadingSize {
31                bytes_read: 0,
32                size_buffer: [0u8; 4],
33            },
34            write_state: WriteState::WaitingForInput,
35        }
36    }
37
38    pub fn into_inner(self) -> Option<T> {
39        if let ReadState::ReadingSize { bytes_read: 0, .. } = self.read_state {
40            if let WriteState::WaitingForInput = self.write_state {
41                return Some(self.stream);
42            }
43        }
44        None
45    }
46}
47
48enum ReadState {
49    Invalid,
50    ReadingSize {
51        bytes_read: u8,
52        size_buffer: [u8; 4],
53    },
54    ReadingData {
55        bytes_read: u32,
56        data_buffer: BytesMut,
57    },
58}
59
60enum WriteState {
61    Invalid,
62    WaitingForInput,
63    WritingSize {
64        size_buffer: [u8; 4],
65        data_buffer: Bytes,
66        bytes_written: u8,
67    },
68    WritingData {
69        data_buffer: Bytes,
70        bytes_written: u32,
71    }
72}
73
74impl<T> Stream for FramedUnbuffered<T>
75where
76    T: AsyncRead,
77{
78    type Item = BytesMut;
79    type Error = io::Error;
80
81    fn poll(&mut self) -> io::Result<Async<Option<BytesMut>>> {
82        loop {
83            let read_state = mem::replace(&mut self.read_state, ReadState::Invalid);
84            match read_state {
85                ReadState::Invalid => unreachable!(),
86                ReadState::ReadingSize { mut bytes_read, mut size_buffer } => {
87                    match self.stream.read(&mut size_buffer[(bytes_read as usize)..]) {
88                        Ok(n) => {
89                            if n == 0 {
90                                if bytes_read == 0 {
91                                    return Ok(Async::Ready(None));
92                                } else {
93                                    return Err(io::Error::from(io::ErrorKind::BrokenPipe));
94                                }
95                            }
96                            bytes_read += n as u8;
97                            if bytes_read == 4 {
98                                let len =
99                                    ((size_buffer[0] as u32) << 24) +
100                                    ((size_buffer[1] as u32) << 16) +
101                                    ((size_buffer[2] as u32) << 8) +
102                                    (size_buffer[3] as u32);
103                                self.read_state = ReadState::ReadingData {
104                                    bytes_read: 0,
105                                    data_buffer: zeros(len as usize),
106                                };
107                            } else {
108                                self.read_state = ReadState::ReadingSize {
109                                    bytes_read, size_buffer,
110                                };
111                            }
112                        },
113                        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
114                            self.read_state = ReadState::ReadingSize {
115                                bytes_read, size_buffer,
116                            };
117                            return Ok(Async::NotReady);
118                        }
119                        Err(e) => return Err(e),
120                    }
121                },
122                ReadState::ReadingData { mut bytes_read, mut data_buffer } => {
123                    match self.stream.read(&mut data_buffer[(bytes_read as usize)..]) {
124                        Ok(n) => {
125                            if n == 0 {
126                                return Err(io::Error::from(io::ErrorKind::BrokenPipe));
127                            }
128                            bytes_read += n as u32;
129                            if bytes_read == data_buffer.len() as u32 {
130                                self.read_state = ReadState::ReadingSize {
131                                    bytes_read: 0,
132                                    size_buffer: [0u8; 4],
133                                };
134                                return Ok(Async::Ready(Some(data_buffer)));
135                            }
136                            else {
137                                self.read_state = ReadState::ReadingData {
138                                    bytes_read, data_buffer,
139                                }
140                            }
141                        },
142                        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
143                            self.read_state = ReadState::ReadingData {
144                                bytes_read, data_buffer,
145                            };
146                            return Ok(Async::NotReady);
147                        }
148                        Err(e) => return Err(e),
149                    }
150                },
151            }
152        }
153    }
154}
155
156impl<T> Sink for FramedUnbuffered<T>
157where
158    T: AsyncWrite,
159{
160    type SinkItem = Bytes;
161    type SinkError = io::Error;
162
163    fn start_send(&mut self, data_buffer: Bytes) -> io::Result<AsyncSink<Bytes>> {
164        let write_state = mem::replace(&mut self.write_state, WriteState::Invalid);
165        match write_state {
166            WriteState::Invalid => unreachable!(),
167            WriteState::WaitingForInput => {
168                let len = data_buffer.len() as u32;
169                let size_buffer = [
170                    (len >> 24) as u8,
171                    ((len >> 16) & 0xff) as u8,
172                    ((len >> 8) & 0xff) as u8,
173                    (len & 0xff) as u8,
174                ];
175                self.write_state = WriteState::WritingSize {
176                    bytes_written: 0,
177                    size_buffer,
178                    data_buffer,
179                };
180                return Ok(AsyncSink::Ready);
181            },
182            WriteState::WritingSize { .. } | WriteState::WritingData { .. } => {
183                self.write_state = write_state;
184                return Ok(AsyncSink::NotReady(data_buffer));
185            },
186        }
187    }
188
189    fn poll_complete(&mut self) -> io::Result<Async<()>> {
190        loop {
191            let write_state = mem::replace(&mut self.write_state, WriteState::Invalid);
192            match write_state {
193                WriteState::Invalid => unreachable!(),
194                WriteState::WaitingForInput => {
195                    self.write_state = WriteState::WaitingForInput;
196                    return Ok(Async::Ready(()));
197                },
198                WriteState::WritingSize { size_buffer, data_buffer, mut bytes_written } => {
199                    match self.stream.write(&size_buffer[(bytes_written as usize)..]) {
200                        Ok(n) => {
201                            if n == 0 {
202                                return Err(io::Error::from(io::ErrorKind::BrokenPipe));
203                            }
204                            bytes_written += n as u8;
205                            if bytes_written == 4 {
206                                self.write_state = WriteState::WritingData {
207                                    data_buffer,
208                                    bytes_written: 0,
209                                };
210                            } else {
211                                self.write_state = WriteState::WritingSize {
212                                    size_buffer, data_buffer, bytes_written,
213                                }
214                            }
215                        },
216                        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
217                            self.write_state = WriteState::WritingSize {
218                                size_buffer, data_buffer, bytes_written,
219                            };
220                            return Ok(Async::NotReady);
221                        },
222                        Err(e) => return Err(e),
223                    }
224                },
225                WriteState::WritingData { data_buffer, mut bytes_written } => {
226                    match self.stream.write(&data_buffer[(bytes_written as usize)..]) {
227                        Ok(n) => {
228                            if n == 0 {
229                                return Err(io::Error::from(io::ErrorKind::BrokenPipe));
230                            }
231                            bytes_written += n as u32;
232                            if bytes_written == data_buffer.len() as u32 {
233                                self.write_state = WriteState::WaitingForInput;
234                            } else {
235                                self.write_state = WriteState::WritingData {
236                                    data_buffer, bytes_written,
237                                }
238                            }
239                        },
240                        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
241                            self.write_state = WriteState::WritingData {
242                                data_buffer, bytes_written,
243                            };
244                            return Ok(Async::NotReady);
245                        },
246                        Err(e) => return Err(e),
247                    }
248                },
249            }
250        }
251    }
252}
253