1mod async_buffer;
2mod frame;
3
4use async_buffer::AsyncBuffer;
5use frame::{Frame, ParseError};
6use futures::{
7 io::{AsyncRead, AsyncWrite},
8 Future,
9};
10use std::{
11 error::Error,
12 fmt::Display,
13 pin::Pin,
14 task::{Context, Poll},
15};
16
17#[derive(Debug)]
18pub enum SinkError {
19 Write(std::io::Error),
20 Read(std::io::Error),
21 LimitExceeded,
22 Parse(ParseError),
23 Closed,
24}
25
26impl Display for SinkError {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 match self {
29 SinkError::Write(e) => write!(f, "Write Error: {}", e),
30 SinkError::Read(e) => write!(f, "Read Error: {}", e),
31 SinkError::LimitExceeded => write!(f, "Limit Exceeded"),
32 SinkError::Parse(e) => write!(f, "Parse Error: {}", e),
33 SinkError::Closed => write!(f, "Stream Error: poll after closed"),
34 }
35 }
36}
37
38impl Error for SinkError {}
39
40pub enum SinkStatus {
41 Open,
42 Closing,
43 Closed,
44}
45
46pub struct MessageSink<S>
47where
48 S: AsyncRead + AsyncWrite + Unpin,
49{
50 stream: S,
51 read_buffer: Vec<u8>,
52 write_buffer: AsyncBuffer,
53 scratch: [u8; 1024],
54 status: SinkStatus,
55 limit: usize,
56}
57
58impl<S> MessageSink<S>
59where
60 S: AsyncRead + AsyncWrite + Unpin,
61{
62 pub fn new(socket: S) -> Self {
63 Self {
64 stream: socket,
65 read_buffer: Default::default(),
66 write_buffer: Default::default(),
67 scratch: [0; 1024],
68 status: SinkStatus::Open,
69 limit: usize::MAX,
70 }
71 }
72 pub fn limit(&mut self, length: usize) {
73 self.limit = length;
74 }
75 pub fn write(&mut self, message: Vec<u8>) -> Result<(), ParseError> {
76 let message: Vec<u8> = Frame::new(message).try_into()?;
77 self.write_buffer.extend(message);
78 Ok(())
79 }
80 pub fn close(&mut self) {
81 self.status = SinkStatus::Closing;
82 }
83}
84
85impl<S> Future for MessageSink<S>
86where
87 S: AsyncRead + AsyncWrite + Unpin,
88{
89 type Output = Result<Vec<u8>, SinkError>;
90 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
91 let sink = self.get_mut();
92 let buffer = sink.write_buffer.as_ref();
93 match sink.status {
94 SinkStatus::Open => {}
95 SinkStatus::Closing => {
96 let stream = Pin::new(&mut sink.stream);
97 match stream.poll_close(cx) {
98 Poll::Pending => return Poll::Pending,
99 Poll::Ready(_) => {
100 sink.status = SinkStatus::Closed;
101 return Poll::Ready(Err(SinkError::Closed));
102 }
103 }
104 }
105 SinkStatus::Closed => {
106 return Poll::Ready(Err(SinkError::Closed));
107 }
108 }
109 let stream = Pin::new(&mut sink.stream);
110 match stream.poll_write(cx, buffer) {
111 Poll::Ready(Ok(length)) => {
112 sink.write_buffer.drain(0..length);
113 }
114 Poll::Ready(Err(e)) => {
115 sink.close();
116 return Poll::Ready(Err(SinkError::Write(e)));
117 }
118 Poll::Pending => {}
119 };
120 sink.write_buffer.set_waker(cx);
121 loop {
122 let stream = Pin::new(&mut sink.stream);
123 match stream.poll_read(cx, &mut sink.scratch) {
124 Poll::Ready(Ok(length)) => {
125 if sink.read_buffer.len() + length > sink.limit {
126 sink.close();
127 return Poll::Ready(Err(SinkError::LimitExceeded));
128 }
129 sink.read_buffer.extend(&sink.scratch[0..length]);
130 }
131 Poll::Ready(Err(e)) => {
132 sink.close();
133 return Poll::Ready(Err(SinkError::Read(e)));
134 }
135 Poll::Pending => {
136 break;
137 }
138 };
139 match Frame::try_from(&mut sink.read_buffer) {
140 Ok(frame) => return Poll::Ready(Ok(frame.into_message())),
141 Err(ParseError::NotReady) => {}
142 Err(e) => {
143 sink.close();
144 return Poll::Ready(Err(SinkError::Parse(e)));
145 }
146 }
147 }
148 match Frame::try_from(&mut sink.read_buffer) {
149 Ok(frame) => return Poll::Ready(Ok(frame.into_message())),
150 Err(ParseError::NotReady) => {}
151 Err(e) => {
152 sink.close();
153 return Poll::Ready(Err(SinkError::Parse(e)));
154 }
155 }
156 Poll::Pending
157 }
158}
159
160#[cfg(test)]
161mod message_sink {
162 use super::*;
163 use futures::{lock::Mutex, FutureExt};
164 use futures_ringbuf::RingBuffer;
165 use rand::RngCore;
166 use std::sync::Arc;
167
168 fn random(len: usize) -> Vec<u8> {
169 let mut bytes = vec![0; len];
170 rand::thread_rng().fill_bytes(&mut bytes);
171 bytes
172 }
173
174 #[tokio::test]
175 async fn parse() {
176 let stream = RingBuffer::new(1024);
177 let mut sink = MessageSink::new(stream);
178 let message = random(128);
179 sink.write(message.clone()).unwrap();
180 let received = sink.await.unwrap();
181 assert_eq!(message, received);
182 }
183
184 #[tokio::test]
185 async fn not_ready() {
186 let stream = RingBuffer::new(1024);
187 let sink = MessageSink::new(stream);
188 if sink.now_or_never().is_some() {
189 panic!("expected sink to not be ready");
190 }
191 }
192
193 #[tokio::test]
194 async fn parse_multiple() {
195 let messages = [random(128), random(128), random(128)];
196 let stream = RingBuffer::new(1024);
197 let mut sink = MessageSink::new(stream);
198 for message in messages.iter() {
199 sink.write(message.clone()).unwrap();
200 }
201 let sink = Arc::new(Mutex::new(sink));
202 for message in messages {
203 let mut guard = sink.lock().await;
204 let received = (&mut *guard).await.unwrap();
205 assert_eq!(message, received);
206 }
207 }
208
209 #[tokio::test]
210 async fn limit() {
211 let stream = RingBuffer::new(1024);
212 let mut sink = MessageSink::new(stream);
213 sink.limit(128);
214 sink.write(random(256)).unwrap();
215 match sink.await {
216 Err(SinkError::LimitExceeded) => {}
217 Err(e) => panic!("unexpected error {}", e),
218 Ok(_) => panic!("unexpected success"),
219 };
220 }
221}