1use bytes::{Bytes, BytesMut};
2use tokio::io::{AsyncRead, AsyncWrite};
3use futures::{Stream, Sink, Async, AsyncSink};
4use std::{io, mem};
5
6fn zeros(n: usize) -> BytesMut {
7 let mut ret = BytesMut::with_capacity(n);
8 unsafe {
9 ret.set_len(n);
10 for i in 0..n {
11 ret[i] = 0;
12 }
13 }
14 ret
15}
16
17pub struct FramedUnbuffered<T> {
21 stream: T,
22 read_state: ReadState,
23 write_state: WriteState,
24}
25
26impl<T> FramedUnbuffered<T> {
27 pub fn new(stream: T) -> FramedUnbuffered<T> {
28 FramedUnbuffered {
29 stream,
30 read_state: ReadState::ReadingSize {
31 bytes_read: 0,
32 size_buffer: [0u8; 4],
33 },
34 write_state: WriteState::WaitingForInput,
35 }
36 }
37
38 pub fn into_inner(self) -> Option<T> {
39 if let ReadState::ReadingSize { bytes_read: 0, .. } = self.read_state {
40 if let WriteState::WaitingForInput = self.write_state {
41 return Some(self.stream);
42 }
43 }
44 None
45 }
46}
47
48enum ReadState {
49 Invalid,
50 ReadingSize {
51 bytes_read: u8,
52 size_buffer: [u8; 4],
53 },
54 ReadingData {
55 bytes_read: u32,
56 data_buffer: BytesMut,
57 },
58}
59
60enum WriteState {
61 Invalid,
62 WaitingForInput,
63 WritingSize {
64 size_buffer: [u8; 4],
65 data_buffer: Bytes,
66 bytes_written: u8,
67 },
68 WritingData {
69 data_buffer: Bytes,
70 bytes_written: u32,
71 }
72}
73
74impl<T> Stream for FramedUnbuffered<T>
75where
76 T: AsyncRead,
77{
78 type Item = BytesMut;
79 type Error = io::Error;
80
81 fn poll(&mut self) -> io::Result<Async<Option<BytesMut>>> {
82 loop {
83 let read_state = mem::replace(&mut self.read_state, ReadState::Invalid);
84 match read_state {
85 ReadState::Invalid => unreachable!(),
86 ReadState::ReadingSize { mut bytes_read, mut size_buffer } => {
87 match self.stream.read(&mut size_buffer[(bytes_read as usize)..]) {
88 Ok(n) => {
89 if n == 0 {
90 if bytes_read == 0 {
91 return Ok(Async::Ready(None));
92 } else {
93 return Err(io::Error::from(io::ErrorKind::BrokenPipe));
94 }
95 }
96 bytes_read += n as u8;
97 if bytes_read == 4 {
98 let len =
99 ((size_buffer[0] as u32) << 24) +
100 ((size_buffer[1] as u32) << 16) +
101 ((size_buffer[2] as u32) << 8) +
102 (size_buffer[3] as u32);
103 self.read_state = ReadState::ReadingData {
104 bytes_read: 0,
105 data_buffer: zeros(len as usize),
106 };
107 } else {
108 self.read_state = ReadState::ReadingSize {
109 bytes_read, size_buffer,
110 };
111 }
112 },
113 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
114 self.read_state = ReadState::ReadingSize {
115 bytes_read, size_buffer,
116 };
117 return Ok(Async::NotReady);
118 }
119 Err(e) => return Err(e),
120 }
121 },
122 ReadState::ReadingData { mut bytes_read, mut data_buffer } => {
123 match self.stream.read(&mut data_buffer[(bytes_read as usize)..]) {
124 Ok(n) => {
125 if n == 0 {
126 return Err(io::Error::from(io::ErrorKind::BrokenPipe));
127 }
128 bytes_read += n as u32;
129 if bytes_read == data_buffer.len() as u32 {
130 self.read_state = ReadState::ReadingSize {
131 bytes_read: 0,
132 size_buffer: [0u8; 4],
133 };
134 return Ok(Async::Ready(Some(data_buffer)));
135 }
136 else {
137 self.read_state = ReadState::ReadingData {
138 bytes_read, data_buffer,
139 }
140 }
141 },
142 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
143 self.read_state = ReadState::ReadingData {
144 bytes_read, data_buffer,
145 };
146 return Ok(Async::NotReady);
147 }
148 Err(e) => return Err(e),
149 }
150 },
151 }
152 }
153 }
154}
155
156impl<T> Sink for FramedUnbuffered<T>
157where
158 T: AsyncWrite,
159{
160 type SinkItem = Bytes;
161 type SinkError = io::Error;
162
163 fn start_send(&mut self, data_buffer: Bytes) -> io::Result<AsyncSink<Bytes>> {
164 let write_state = mem::replace(&mut self.write_state, WriteState::Invalid);
165 match write_state {
166 WriteState::Invalid => unreachable!(),
167 WriteState::WaitingForInput => {
168 let len = data_buffer.len() as u32;
169 let size_buffer = [
170 (len >> 24) as u8,
171 ((len >> 16) & 0xff) as u8,
172 ((len >> 8) & 0xff) as u8,
173 (len & 0xff) as u8,
174 ];
175 self.write_state = WriteState::WritingSize {
176 bytes_written: 0,
177 size_buffer,
178 data_buffer,
179 };
180 return Ok(AsyncSink::Ready);
181 },
182 WriteState::WritingSize { .. } | WriteState::WritingData { .. } => {
183 self.write_state = write_state;
184 return Ok(AsyncSink::NotReady(data_buffer));
185 },
186 }
187 }
188
189 fn poll_complete(&mut self) -> io::Result<Async<()>> {
190 loop {
191 let write_state = mem::replace(&mut self.write_state, WriteState::Invalid);
192 match write_state {
193 WriteState::Invalid => unreachable!(),
194 WriteState::WaitingForInput => {
195 self.write_state = WriteState::WaitingForInput;
196 return Ok(Async::Ready(()));
197 },
198 WriteState::WritingSize { size_buffer, data_buffer, mut bytes_written } => {
199 match self.stream.write(&size_buffer[(bytes_written as usize)..]) {
200 Ok(n) => {
201 if n == 0 {
202 return Err(io::Error::from(io::ErrorKind::BrokenPipe));
203 }
204 bytes_written += n as u8;
205 if bytes_written == 4 {
206 self.write_state = WriteState::WritingData {
207 data_buffer,
208 bytes_written: 0,
209 };
210 } else {
211 self.write_state = WriteState::WritingSize {
212 size_buffer, data_buffer, bytes_written,
213 }
214 }
215 },
216 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
217 self.write_state = WriteState::WritingSize {
218 size_buffer, data_buffer, bytes_written,
219 };
220 return Ok(Async::NotReady);
221 },
222 Err(e) => return Err(e),
223 }
224 },
225 WriteState::WritingData { data_buffer, mut bytes_written } => {
226 match self.stream.write(&data_buffer[(bytes_written as usize)..]) {
227 Ok(n) => {
228 if n == 0 {
229 return Err(io::Error::from(io::ErrorKind::BrokenPipe));
230 }
231 bytes_written += n as u32;
232 if bytes_written == data_buffer.len() as u32 {
233 self.write_state = WriteState::WaitingForInput;
234 } else {
235 self.write_state = WriteState::WritingData {
236 data_buffer, bytes_written,
237 }
238 }
239 },
240 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
241 self.write_state = WriteState::WritingData {
242 data_buffer, bytes_written,
243 };
244 return Ok(Async::NotReady);
245 },
246 Err(e) => return Err(e),
247 }
248 },
249 }
250 }
251 }
252}
253