1use std::{
4 error::Error,
5 fmt::{
6 self,
7 Display,
8 },
9 io,
10 pin::Pin,
11};
12
13use bytes::{
14 buf::UninitSlice,
15 Buf,
16 BufMut,
17 BytesMut,
18};
19
20use futures::{
21 io::{
22 AsyncRead,
23 AsyncWrite,
24 },
25 prelude::*,
26 ready,
27 task::{
28 Context,
29 Poll,
30 },
31};
32
33use super::*;
34
35pub struct Framed<S, C> {
38 stream: S,
39 codec: C,
40 read_buf: BytesMut,
41 write_buf: BytesMut,
42}
43
44impl<S, C: Encode + Decode> Framed<S, C> {
45 pub fn new(stream: S, codec: C) -> Self {
47 Framed {
48 stream,
49 codec,
50 read_buf: BytesMut::default(),
51 write_buf: BytesMut::default(),
52 }
53 }
54}
55
56#[derive(Debug)]
58pub enum ReadFrameError<E> {
59 Io(io::Error),
61 Decode(E),
63}
64
65impl<E: Error + Send + Sync + 'static> From<ReadFrameError<E>> for io::Error {
66 fn from(other: ReadFrameError<E>) -> io::Error {
67 match other {
68 ReadFrameError::Decode(err) => io::Error::new(io::ErrorKind::InvalidData, err),
69 ReadFrameError::Io(err) => err,
70 }
71 }
72}
73
74impl<E: Display> Display for ReadFrameError<E> {
75 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
76 match self {
77 ReadFrameError::Io(e) => write!(f, "error reading from stream: {}", e),
78 ReadFrameError::Decode(e) => write!(f, "error decoding frame: {}", e),
79 }
80 }
81}
82
83impl<E> Error for ReadFrameError<E>
84where
85 E: Error + 'static,
86{
87 fn source(&self) -> Option<&(dyn Error + 'static)> {
88 match self {
89 ReadFrameError::Io(ref e) => Some(e),
90 ReadFrameError::Decode(ref e) => Some(e),
91 }
92 }
93}
94
95impl<S, C> Stream for Framed<S, C>
96where
97 C: Decode,
98 S: AsyncRead,
99{
100 type Item = Result<C::Item, ReadFrameError<C::Error>>;
101
102 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
103 let (mut stream, codec, read_buf) = unsafe {
104 let this = self.get_unchecked_mut();
105 (
106 Pin::new_unchecked(&mut this.stream),
107 &mut this.codec,
108 &mut this.read_buf,
109 )
110 };
111 loop {
112 let empty = read_buf.is_empty();
113 if !empty {
114 let (consumed, decode_res) = codec.decode(read_buf);
115 read_buf.advance(consumed);
116 match decode_res {
117 DecodeResult::Ok(value) => {
118 return Poll::Ready(Some(Ok(value)));
119 }
120 DecodeResult::Err(e) => {
121 return Poll::Ready(Some(Err(ReadFrameError::Decode(e))));
122 }
123 DecodeResult::UnexpectedEnd => {}
124 }
125 }
126
127 read_buf.reserve(1);
129
130 let eof = unsafe {
133 let n = {
134 let b = zero_buf(read_buf.chunk_mut());
135 match ready!(stream.as_mut().poll_read(cx, b)).map_err(ReadFrameError::Io) {
136 Err(e) => return Poll::Ready(Some(Err(e))),
137 Ok(n) => n,
138 }
139 };
140
141 read_buf.advance_mut(n);
142
143 n == 0
144 };
145
146 if eof {
147 if empty {
148 return Poll::Ready(None);
149 } else {
150 return Poll::Ready(Some(Err(ReadFrameError::Io(
151 io::ErrorKind::UnexpectedEof.into(),
152 ))));
153 }
154 }
155 }
156 }
157}
158
159#[derive(Debug)]
161pub enum WriteFrameError<E> {
162 Io(io::Error),
164 Encode(E),
166}
167
168impl<E: Error + Send + Sync + 'static> From<WriteFrameError<E>> for io::Error {
169 fn from(other: WriteFrameError<E>) -> io::Error {
170 match other {
171 WriteFrameError::Encode(err) => io::Error::new(io::ErrorKind::InvalidInput, err),
172 WriteFrameError::Io(err) => err,
173 }
174 }
175}
176
177impl<E: Display> Display for WriteFrameError<E> {
178 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
179 match self {
180 WriteFrameError::Io(e) => write!(f, "error writing to stream: {}", e),
181 WriteFrameError::Encode(e) => write!(f, "error encoding frame: {}", e),
182 }
183 }
184}
185
186impl<E> Error for WriteFrameError<E>
187where
188 E: Error + 'static,
189{
190 fn source(&self) -> Option<&(dyn Error + 'static)> {
191 match self {
192 WriteFrameError::Io(ref e) => Some(e),
193 WriteFrameError::Encode(ref e) => Some(e),
194 }
195 }
196}
197
198impl<S, C> Sink<C::Item> for Framed<S, C>
199where
200 C: Encode,
201 C::Error: std::fmt::Debug,
202 S: AsyncWrite,
203{
204 type Error = WriteFrameError<C::Error>;
205 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
206 let (buffer, mut stream) = unsafe {
207 let this = self.get_unchecked_mut();
208 (&mut this.write_buf, Pin::new_unchecked(&mut this.stream))
209 };
210 loop {
211 if buffer.len() == 0 {
212 return Poll::Ready(Ok(()));
213 }
214 let written = ready!(stream
215 .as_mut()
216 .poll_write(cx, &buffer)
217 .map_err(WriteFrameError::Io))?;
218 if written == 0 {
219 return Poll::Ready(Err(WriteFrameError::Io(io::ErrorKind::WriteZero.into())));
220 }
221 buffer.advance(written);
222 }
223 }
224 fn start_send(self: Pin<&mut Self>, item: C::Item) -> Result<(), Self::Error> {
225 let (buffer, codec) = unsafe {
226 let this = self.get_unchecked_mut();
227 (&mut this.write_buf, &mut this.codec)
228 };
229 codec.reset();
230 loop {
231 let b = zero_buf(buffer.chunk_mut());
232 match codec.encode(&item, b) {
233 EncodeResult::Ok(len) => {
234 unsafe { buffer.advance_mut(len) };
238 return Ok(());
239 }
240 EncodeResult::Err(e) => return Err(WriteFrameError::Encode(e)),
241 EncodeResult::Overflow(0) => buffer.reserve(buffer.remaining_mut() * 2),
242 EncodeResult::Overflow(new_size) => buffer.reserve(new_size),
243 }
244 }
245 }
246 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
247 ready!(self.as_mut().poll_ready(cx))?;
248
249 unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().stream) }
250 .poll_flush(cx)
251 .map_err(WriteFrameError::Io)
252 }
253
254 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
255 ready!(self.as_mut().poll_flush(cx))?;
256
257 unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().stream) }
258 .poll_close(cx)
259 .map_err(WriteFrameError::Io)
260 }
261}
262
263fn zero_buf(b: &mut UninitSlice) -> &mut [u8] {
264 for i in 0..b.len() {
265 b.write_byte(i, 0)
266 }
267 unsafe { std::mem::transmute(b) }
268}
269
270#[cfg(test)]
271mod test {
272 use std::str::Utf8Error;
273
274 use futures::{
275 io::Cursor,
276 prelude::*,
277 };
278
279 use super::*;
280
281 struct LineCodec;
282
283 impl Encode for LineCodec {
284 type Item = String;
285 type Error = ();
286 fn encode(&mut self, item: &String, buf: &mut [u8]) -> EncodeResult<()> {
287 let needed = item.as_bytes().len() + 1;
288 if buf.len() < needed {
289 return EncodeResult::Overflow(needed);
290 }
291 buf[..needed - 1].copy_from_slice(item.as_bytes());
292 buf[needed - 1] = b'\n';
293 Ok(needed).into()
294 }
295 }
296
297 impl Decode for LineCodec {
298 type Item = String;
299 type Error = Utf8Error;
300
301 fn decode(&mut self, buf: &mut [u8]) -> (usize, DecodeResult<String, Utf8Error>) {
302 let newline = match buf.iter().position(|b| *b == b'\n') {
303 Some(idx) => idx,
304 None => return (0, DecodeResult::UnexpectedEnd),
305 };
306 let string_bytes = &buf[..newline];
307 (
308 newline + 1,
309 std::str::from_utf8(string_bytes).map(String::from).into(),
310 )
311 }
312 }
313
314 const SHAKESPEARE: &str = r#"Now is the winter of our discontent
315Made glorious summer by this sun of York.
316Some are born great, some achieve greatness
317And some have greatness thrust upon them.
318Friends, Romans, countrymen - lend me your ears!
319I come not to praise Caesar, but to bury him.
320The evil that men do lives after them
321The good is oft interred with their bones.
322 It is a tale
323Told by an idiot, full of sound and fury
324Signifying nothing.
325Ay me! For aught that I could ever read,
326Could ever hear by tale or history,
327The course of true love never did run smooth.
328I have full cause of weeping, but this heart
329Shall break into a hundred thousand flaws,
330Or ere I'll weep.-O Fool, I shall go mad!
331 Each your doing,
332So singular in each particular,
333Crowns what you are doing in the present deed,
334That all your acts are queens.
335"#;
336
337 #[async_std::test]
338 async fn test_framed_stream() {
339 let reader = Cursor::new(Vec::from(SHAKESPEARE.as_bytes()));
340 let mut framed = Framed::new(reader, LineCodec);
341 let expected = SHAKESPEARE.lines().map(String::from).collect::<Vec<_>>();
342 let mut actual = vec![];
343 while let Some(frame) = framed.next().await.transpose().unwrap() {
344 actual.push(frame);
345 }
346 assert_eq!(actual, expected);
347 }
348
349 #[async_std::test]
350 async fn test_framed_sink() {
351 let frames = SHAKESPEARE.lines().map(String::from).collect::<Vec<_>>();
352 let mut actual = vec![0u8; SHAKESPEARE.as_bytes().len()];
353 {
354 let writer = Cursor::new(&mut actual);
355 let mut framed = Framed::new(writer, LineCodec);
356 for frame in frames {
357 framed.send(frame).await.unwrap();
358 }
359 }
360 assert_eq!(std::str::from_utf8(&actual).unwrap(), SHAKESPEARE);
361 }
362}