message_sink/
lib.rs

1mod async_buffer;
2mod frame;
3
4use async_buffer::AsyncBuffer;
5use frame::{Frame, ParseError};
6use futures::{
7    io::{AsyncRead, AsyncWrite},
8    Future,
9};
10use std::{
11    error::Error,
12    fmt::Display,
13    pin::Pin,
14    task::{Context, Poll},
15};
16
17#[derive(Debug)]
18pub enum SinkError {
19    Write(std::io::Error),
20    Read(std::io::Error),
21    LimitExceeded,
22    Parse(ParseError),
23    Closed,
24}
25
26impl Display for SinkError {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            SinkError::Write(e) => write!(f, "Write Error: {}", e),
30            SinkError::Read(e) => write!(f, "Read Error: {}", e),
31            SinkError::LimitExceeded => write!(f, "Limit Exceeded"),
32            SinkError::Parse(e) => write!(f, "Parse Error: {}", e),
33            SinkError::Closed => write!(f, "Stream Error: poll after closed"),
34        }
35    }
36}
37
38impl Error for SinkError {}
39
40pub enum SinkStatus {
41    Open,
42    Closing,
43    Closed,
44}
45
46pub struct MessageSink<S>
47where
48    S: AsyncRead + AsyncWrite + Unpin,
49{
50    stream: S,
51    read_buffer: Vec<u8>,
52    write_buffer: AsyncBuffer,
53    scratch: [u8; 1024],
54    status: SinkStatus,
55    limit: usize,
56}
57
58impl<S> MessageSink<S>
59where
60    S: AsyncRead + AsyncWrite + Unpin,
61{
62    pub fn new(socket: S) -> Self {
63        Self {
64            stream: socket,
65            read_buffer: Default::default(),
66            write_buffer: Default::default(),
67            scratch: [0; 1024],
68            status: SinkStatus::Open,
69            limit: usize::MAX,
70        }
71    }
72    pub fn limit(&mut self, length: usize) {
73        self.limit = length;
74    }
75    pub fn write(&mut self, message: Vec<u8>) -> Result<(), ParseError> {
76        let message: Vec<u8> = Frame::new(message).try_into()?;
77        self.write_buffer.extend(message);
78        Ok(())
79    }
80    pub fn close(&mut self) {
81        self.status = SinkStatus::Closing;
82    }
83}
84
85impl<S> Future for MessageSink<S>
86where
87    S: AsyncRead + AsyncWrite + Unpin,
88{
89    type Output = Result<Vec<u8>, SinkError>;
90    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
91        let sink = self.get_mut();
92        let buffer = sink.write_buffer.as_ref();
93        match sink.status {
94            SinkStatus::Open => {}
95            SinkStatus::Closing => {
96                let stream = Pin::new(&mut sink.stream);
97                match stream.poll_close(cx) {
98                    Poll::Pending => return Poll::Pending,
99                    Poll::Ready(_) => {
100                        sink.status = SinkStatus::Closed;
101                        return Poll::Ready(Err(SinkError::Closed));
102                    }
103                }
104            }
105            SinkStatus::Closed => {
106                return Poll::Ready(Err(SinkError::Closed));
107            }
108        }
109        let stream = Pin::new(&mut sink.stream);
110        match stream.poll_write(cx, buffer) {
111            Poll::Ready(Ok(length)) => {
112                sink.write_buffer.drain(0..length);
113            }
114            Poll::Ready(Err(e)) => {
115                sink.close();
116                return Poll::Ready(Err(SinkError::Write(e)));
117            }
118            Poll::Pending => {}
119        };
120        sink.write_buffer.set_waker(cx);
121        loop {
122            let stream = Pin::new(&mut sink.stream);
123            match stream.poll_read(cx, &mut sink.scratch) {
124                Poll::Ready(Ok(length)) => {
125                    if sink.read_buffer.len() + length > sink.limit {
126                        sink.close();
127                        return Poll::Ready(Err(SinkError::LimitExceeded));
128                    }
129                    sink.read_buffer.extend(&sink.scratch[0..length]);
130                }
131                Poll::Ready(Err(e)) => {
132                    sink.close();
133                    return Poll::Ready(Err(SinkError::Read(e)));
134                }
135                Poll::Pending => {
136                    break;
137                }
138            };
139            match Frame::try_from(&mut sink.read_buffer) {
140                Ok(frame) => return Poll::Ready(Ok(frame.into_message())),
141                Err(ParseError::NotReady) => {}
142                Err(e) => {
143                    sink.close();
144                    return Poll::Ready(Err(SinkError::Parse(e)));
145                }
146            }
147        }
148        match Frame::try_from(&mut sink.read_buffer) {
149            Ok(frame) => return Poll::Ready(Ok(frame.into_message())),
150            Err(ParseError::NotReady) => {}
151            Err(e) => {
152                sink.close();
153                return Poll::Ready(Err(SinkError::Parse(e)));
154            }
155        }
156        Poll::Pending
157    }
158}
159
160#[cfg(test)]
161mod message_sink {
162    use super::*;
163    use futures::{lock::Mutex, FutureExt};
164    use futures_ringbuf::RingBuffer;
165    use rand::RngCore;
166    use std::sync::Arc;
167
168    fn random(len: usize) -> Vec<u8> {
169        let mut bytes = vec![0; len];
170        rand::thread_rng().fill_bytes(&mut bytes);
171        bytes
172    }
173
174    #[tokio::test]
175    async fn parse() {
176        let stream = RingBuffer::new(1024);
177        let mut sink = MessageSink::new(stream);
178        let message = random(128);
179        sink.write(message.clone()).unwrap();
180        let received = sink.await.unwrap();
181        assert_eq!(message, received);
182    }
183
184    #[tokio::test]
185    async fn not_ready() {
186        let stream = RingBuffer::new(1024);
187        let sink = MessageSink::new(stream);
188        if sink.now_or_never().is_some() {
189            panic!("expected sink to not be ready");
190        }
191    }
192
193    #[tokio::test]
194    async fn parse_multiple() {
195        let messages = [random(128), random(128), random(128)];
196        let stream = RingBuffer::new(1024);
197        let mut sink = MessageSink::new(stream);
198        for message in messages.iter() {
199            sink.write(message.clone()).unwrap();
200        }
201        let sink = Arc::new(Mutex::new(sink));
202        for message in messages {
203            let mut guard = sink.lock().await;
204            let received = (&mut *guard).await.unwrap();
205            assert_eq!(message, received);
206        }
207    }
208
209    #[tokio::test]
210    async fn limit() {
211        let stream = RingBuffer::new(1024);
212        let mut sink = MessageSink::new(stream);
213        sink.limit(128);
214        sink.write(random(256)).unwrap();
215        match sink.await {
216            Err(SinkError::LimitExceeded) => {}
217            Err(e) => panic!("unexpected error {}", e),
218            Ok(_) => panic!("unexpected success"),
219        };
220    }
221}