Skip to main content

ppp_stream/
lib.rs

1use futures_util::{ready, FutureExt};
2use ppp::v2::{Addresses, Header, ParseError};
3use std::future::Future;
4use std::io::{Error as IoError, ErrorKind};
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8use tokio_util::io::poll_read_buf;
9
10pub trait Ext {
11    fn remote_addr_owned(self) -> PPPFuture<Self>
12    where
13        Self: Sized;
14    fn remote_addr(self: Pin<&mut Self>) -> PPPRefFuture<'_, Self>;
15    fn remote_addr_unpin(&mut self) -> PPPRefFuture<'_, Self>
16    where
17        Self: Unpin;
18}
19
20impl<T> Ext for T
21where
22    T: AsyncRead,
23{
24    fn remote_addr_owned(self) -> PPPFuture<Self>
25    where
26        Self: Sized,
27    {
28        PPPFuture {
29            inner: Some(self),
30            buf: vec![],
31        }
32    }
33
34    fn remote_addr(self: Pin<&mut Self>) -> PPPRefFuture<'_, Self> {
35        PPPRefFuture {
36            inner: Some(self),
37            buf: vec![],
38        }
39    }
40
41    fn remote_addr_unpin(&mut self) -> PPPRefFuture<'_, Self>
42    where
43        Self: Unpin,
44    {
45        Pin::new(self).remote_addr()
46    }
47}
48
49pub struct PPPFuture<T> {
50    inner: Option<T>,
51    buf: Vec<u8>,
52}
53
54impl<T: Unpin> Unpin for PPPFuture<T> {}
55
56impl<T> Future for PPPFuture<T>
57where
58    T: AsyncRead + Unpin,
59{
60    type Output = Result<PPPStream<T>, IoError>;
61
62    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
63        let this = self.get_mut();
64
65        let inner = match &mut this.inner {
66            None => panic!("Future polled after completion"),
67            Some(inner) => inner,
68        };
69        let buf = std::mem::take(&mut this.buf);
70
71        let mut fut = PPPRefFuture {
72            inner: Some(Pin::new(inner)),
73            buf,
74        };
75        let res = fut.poll_unpin(cx);
76
77        this.buf = fut.buf;
78
79        let PPPRefStream {
80            start_of_data,
81            addr,
82            data,
83            ..
84        } = ready!(res)?;
85
86        return Poll::Ready(Ok(PPPStream {
87            inner: this.inner.take().unwrap(),
88            start_of_data,
89            data,
90            addr,
91        }));
92    }
93}
94
95impl<'a, T> AsyncWrite for PPPRefStream<'a, T>
96where
97    T: AsyncWrite,
98{
99    fn poll_write(
100        self: Pin<&mut Self>,
101        cx: &mut Context<'_>,
102        buf: &[u8],
103    ) -> Poll<Result<usize, IoError>> {
104        let this = self.get_mut();
105        this.inner.as_mut().poll_write(cx, buf)
106    }
107
108    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
109        let this = self.get_mut();
110        this.inner.as_mut().poll_flush(cx)
111    }
112
113    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
114        let this = self.get_mut();
115        this.inner.as_mut().poll_shutdown(cx)
116    }
117}
118
119pub struct PPPStream<T> {
120    inner: T,
121    data: Vec<u8>,
122    start_of_data: usize,
123    pub addr: Addresses,
124}
125
126impl<T> Unpin for PPPStream<T> {}
127
128impl<T> AsyncRead for PPPStream<T>
129where
130    T: AsyncRead + Unpin,
131{
132    fn poll_read(
133        self: Pin<&mut Self>,
134        cx: &mut Context<'_>,
135        buf: &mut ReadBuf<'_>,
136    ) -> Poll<std::io::Result<()>> {
137        let this = self.get_mut();
138        let data = std::mem::take(&mut this.data);
139
140        let mut stream = PPPRefStream {
141            inner: Pin::new(&mut this.inner),
142            addr: Addresses::Unspecified,
143            data,
144            start_of_data: this.start_of_data,
145        };
146
147        let res = Pin::new(&mut stream).poll_read(cx, buf);
148        this.data = stream.data;
149
150        return res;
151    }
152}
153
154#[derive(Debug)]
155pub struct PPPRefFuture<'a, T: ?Sized> {
156    inner: Option<Pin<&'a mut T>>,
157    buf: Vec<u8>,
158}
159
160impl<'a, T> Unpin for PPPRefFuture<'a, T> {}
161
162impl<'a, T> Future for PPPRefFuture<'a, T>
163where
164    T: AsyncRead,
165{
166    type Output = Result<PPPRefStream<'a, T>, IoError>;
167
168    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
169        let this = self.get_mut();
170        let buf = &mut this.buf;
171        let inner = match &mut this.inner {
172            Some(inner) => inner.as_mut(),
173            None => panic!("future polled after completion"),
174        };
175
176        let added = ready!(poll_read_buf(inner, cx, buf))?;
177        // stream is eof
178        if added == 0 {
179            return Poll::Ready(Err(IoError::new(
180                ErrorKind::Other,
181                ParseError::Incomplete(buf.len()),
182            )));
183        }
184        let res = match Header::try_from(buf.as_ref()) {
185            Err(ParseError::Incomplete(_)) => return this.poll_unpin(cx),
186            Err(e) => return Poll::Ready(Err(IoError::new(ErrorKind::Other, e))),
187            Ok(res) => res,
188        };
189
190        let addr = res.addresses;
191        let start_of_data = res.len();
192
193        let data = std::mem::take(buf);
194        let inner = this.inner.take().unwrap();
195
196        let stream = PPPRefStream {
197            addr,
198            inner,
199            data,
200            start_of_data,
201        };
202
203        return Poll::Ready(Ok(stream));
204    }
205}
206
207#[derive(Debug)]
208pub struct PPPRefStream<'a, T> {
209    inner: Pin<&'a mut T>,
210    data: Vec<u8>,
211    start_of_data: usize,
212    pub addr: Addresses,
213}
214
215impl<'a, T> PPPRefStream<'a, T> {
216    pub fn inner(&mut self) -> Pin<&mut T> {
217        return self.inner.as_mut();
218    }
219}
220
221impl<'a, T> Unpin for PPPRefStream<'a, T> {}
222
223impl<'a, T> AsyncRead for PPPRefStream<'a, T>
224where
225    T: AsyncRead,
226{
227    fn poll_read(
228        self: Pin<&mut Self>,
229        cx: &mut Context<'_>,
230        buf: &mut ReadBuf<'_>,
231    ) -> Poll<std::io::Result<()>> {
232        let this = self.get_mut();
233        let start_of_data = this.start_of_data;
234
235        if this.data.len() > 0 && start_of_data < this.data.len() {
236            if buf.remaining() < this.data.len() - start_of_data {
237                let end_len = start_of_data + buf.remaining();
238                buf.put_slice(&this.data[start_of_data..end_len]);
239                this.start_of_data = end_len;
240            } else {
241                buf.put_slice(&this.data[start_of_data..]);
242                this.data = Vec::new();
243            }
244
245            return Poll::Ready(Ok(()));
246        } else if this.data.len() > 0 {
247            this.data = Vec::new()
248        }
249
250        this.inner.as_mut().poll_read(cx, buf)
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use ppp::v2::{ParseError, PROTOCOL_PREFIX};
257    use std::io::ErrorKind;
258    use tokio::io::AsyncReadExt;
259
260    use super::Ext;
261
262    #[tokio::test]
263    async fn test_small_buffer() {
264        let mut buf = Vec::from(PROTOCOL_PREFIX);
265        buf.extend([
266            0x21, 0x12, 0, 16, 127, 0, 0, 1, 192, 168, 1, 1, 0, 80, 1, 187, 4, 0, 1, 42, 10, 20,
267            30, 40, 50, 60,
268        ]);
269
270        let mut stream = buf.as_slice();
271        let mut addr = (&mut stream).remote_addr_unpin().await.unwrap();
272
273        let res = addr.read_u8().await.unwrap();
274        assert_eq!(10, res);
275
276        let mut res = vec![0; 4];
277        addr.read_exact(&mut res).await.unwrap();
278
279        let expected = vec![20, 30, 40, 50];
280        assert_eq!(expected, res);
281
282        let res = addr.read_u8().await.unwrap();
283        assert_eq!(60, res);
284    }
285
286    #[tokio::test]
287    async fn test() {
288        let mut buf = Vec::from(PROTOCOL_PREFIX);
289
290        let err = (&mut &*buf).remote_addr_unpin().await.unwrap_err();
291        let err = err.into_inner().unwrap().downcast::<ParseError>().unwrap();
292        assert!(matches!(*err, ParseError::Incomplete(12)));
293
294        buf.extend([
295            0x21, 0x12, 0, 16, 127, 0, 0, 1, 192, 168, 1, 1, 0, 80, 1, 187, 4, 0, 1, 42,
296        ]);
297        let mut stream = &*buf;
298        let mut addr = (&mut stream).remote_addr_unpin().await.unwrap();
299        let err = addr.read_u8().await.unwrap_err();
300        assert_eq!(err.kind(), ErrorKind::UnexpectedEof);
301
302        buf.push(10);
303        let mut stream = &*buf;
304        let mut addr = (&mut stream).remote_addr_unpin().await.unwrap();
305        let res = addr.read_u8().await.unwrap();
306        assert_eq!(10, res);
307
308        // test access to inner
309        let err = addr.inner().read_u8().await.unwrap_err();
310        assert_eq!(err.kind(), ErrorKind::UnexpectedEof);
311
312        assert!(!addr.addr.is_empty());
313    }
314}