1use futures_util::{ready, FutureExt};
2use ppp::v2::{Addresses, Header, ParseError};
3use std::future::Future;
4use std::io::{Error as IoError, ErrorKind};
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8use tokio_util::io::poll_read_buf;
9
10pub trait Ext {
11 fn remote_addr_owned(self) -> PPPFuture<Self>
12 where
13 Self: Sized;
14 fn remote_addr(self: Pin<&mut Self>) -> PPPRefFuture<'_, Self>;
15 fn remote_addr_unpin(&mut self) -> PPPRefFuture<'_, Self>
16 where
17 Self: Unpin;
18}
19
20impl<T> Ext for T
21where
22 T: AsyncRead,
23{
24 fn remote_addr_owned(self) -> PPPFuture<Self>
25 where
26 Self: Sized,
27 {
28 PPPFuture {
29 inner: Some(self),
30 buf: vec![],
31 }
32 }
33
34 fn remote_addr(self: Pin<&mut Self>) -> PPPRefFuture<'_, Self> {
35 PPPRefFuture {
36 inner: Some(self),
37 buf: vec![],
38 }
39 }
40
41 fn remote_addr_unpin(&mut self) -> PPPRefFuture<'_, Self>
42 where
43 Self: Unpin,
44 {
45 Pin::new(self).remote_addr()
46 }
47}
48
49pub struct PPPFuture<T> {
50 inner: Option<T>,
51 buf: Vec<u8>,
52}
53
54impl<T: Unpin> Unpin for PPPFuture<T> {}
55
56impl<T> Future for PPPFuture<T>
57where
58 T: AsyncRead + Unpin,
59{
60 type Output = Result<PPPStream<T>, IoError>;
61
62 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
63 let this = self.get_mut();
64
65 let inner = match &mut this.inner {
66 None => panic!("Future polled after completion"),
67 Some(inner) => inner,
68 };
69 let buf = std::mem::take(&mut this.buf);
70
71 let mut fut = PPPRefFuture {
72 inner: Some(Pin::new(inner)),
73 buf,
74 };
75 let res = fut.poll_unpin(cx);
76
77 this.buf = fut.buf;
78
79 let PPPRefStream {
80 start_of_data,
81 addr,
82 data,
83 ..
84 } = ready!(res)?;
85
86 return Poll::Ready(Ok(PPPStream {
87 inner: this.inner.take().unwrap(),
88 start_of_data,
89 data,
90 addr,
91 }));
92 }
93}
94
95impl<'a, T> AsyncWrite for PPPRefStream<'a, T>
96where
97 T: AsyncWrite,
98{
99 fn poll_write(
100 self: Pin<&mut Self>,
101 cx: &mut Context<'_>,
102 buf: &[u8],
103 ) -> Poll<Result<usize, IoError>> {
104 let this = self.get_mut();
105 this.inner.as_mut().poll_write(cx, buf)
106 }
107
108 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
109 let this = self.get_mut();
110 this.inner.as_mut().poll_flush(cx)
111 }
112
113 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
114 let this = self.get_mut();
115 this.inner.as_mut().poll_shutdown(cx)
116 }
117}
118
119pub struct PPPStream<T> {
120 inner: T,
121 data: Vec<u8>,
122 start_of_data: usize,
123 pub addr: Addresses,
124}
125
126impl<T> Unpin for PPPStream<T> {}
127
128impl<T> AsyncRead for PPPStream<T>
129where
130 T: AsyncRead + Unpin,
131{
132 fn poll_read(
133 self: Pin<&mut Self>,
134 cx: &mut Context<'_>,
135 buf: &mut ReadBuf<'_>,
136 ) -> Poll<std::io::Result<()>> {
137 let this = self.get_mut();
138 let data = std::mem::take(&mut this.data);
139
140 let mut stream = PPPRefStream {
141 inner: Pin::new(&mut this.inner),
142 addr: Addresses::Unspecified,
143 data,
144 start_of_data: this.start_of_data,
145 };
146
147 let res = Pin::new(&mut stream).poll_read(cx, buf);
148 this.data = stream.data;
149
150 return res;
151 }
152}
153
154#[derive(Debug)]
155pub struct PPPRefFuture<'a, T: ?Sized> {
156 inner: Option<Pin<&'a mut T>>,
157 buf: Vec<u8>,
158}
159
160impl<'a, T> Unpin for PPPRefFuture<'a, T> {}
161
162impl<'a, T> Future for PPPRefFuture<'a, T>
163where
164 T: AsyncRead,
165{
166 type Output = Result<PPPRefStream<'a, T>, IoError>;
167
168 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
169 let this = self.get_mut();
170 let buf = &mut this.buf;
171 let inner = match &mut this.inner {
172 Some(inner) => inner.as_mut(),
173 None => panic!("future polled after completion"),
174 };
175
176 let added = ready!(poll_read_buf(inner, cx, buf))?;
177 if added == 0 {
179 return Poll::Ready(Err(IoError::new(
180 ErrorKind::Other,
181 ParseError::Incomplete(buf.len()),
182 )));
183 }
184 let res = match Header::try_from(buf.as_ref()) {
185 Err(ParseError::Incomplete(_)) => return this.poll_unpin(cx),
186 Err(e) => return Poll::Ready(Err(IoError::new(ErrorKind::Other, e))),
187 Ok(res) => res,
188 };
189
190 let addr = res.addresses;
191 let start_of_data = res.len();
192
193 let data = std::mem::take(buf);
194 let inner = this.inner.take().unwrap();
195
196 let stream = PPPRefStream {
197 addr,
198 inner,
199 data,
200 start_of_data,
201 };
202
203 return Poll::Ready(Ok(stream));
204 }
205}
206
207#[derive(Debug)]
208pub struct PPPRefStream<'a, T> {
209 inner: Pin<&'a mut T>,
210 data: Vec<u8>,
211 start_of_data: usize,
212 pub addr: Addresses,
213}
214
215impl<'a, T> PPPRefStream<'a, T> {
216 pub fn inner(&mut self) -> Pin<&mut T> {
217 return self.inner.as_mut();
218 }
219}
220
221impl<'a, T> Unpin for PPPRefStream<'a, T> {}
222
223impl<'a, T> AsyncRead for PPPRefStream<'a, T>
224where
225 T: AsyncRead,
226{
227 fn poll_read(
228 self: Pin<&mut Self>,
229 cx: &mut Context<'_>,
230 buf: &mut ReadBuf<'_>,
231 ) -> Poll<std::io::Result<()>> {
232 let this = self.get_mut();
233 let start_of_data = this.start_of_data;
234
235 if this.data.len() > 0 && start_of_data < this.data.len() {
236 if buf.remaining() < this.data.len() - start_of_data {
237 let end_len = start_of_data + buf.remaining();
238 buf.put_slice(&this.data[start_of_data..end_len]);
239 this.start_of_data = end_len;
240 } else {
241 buf.put_slice(&this.data[start_of_data..]);
242 this.data = Vec::new();
243 }
244
245 return Poll::Ready(Ok(()));
246 } else if this.data.len() > 0 {
247 this.data = Vec::new()
248 }
249
250 this.inner.as_mut().poll_read(cx, buf)
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use ppp::v2::{ParseError, PROTOCOL_PREFIX};
257 use std::io::ErrorKind;
258 use tokio::io::AsyncReadExt;
259
260 use super::Ext;
261
262 #[tokio::test]
263 async fn test_small_buffer() {
264 let mut buf = Vec::from(PROTOCOL_PREFIX);
265 buf.extend([
266 0x21, 0x12, 0, 16, 127, 0, 0, 1, 192, 168, 1, 1, 0, 80, 1, 187, 4, 0, 1, 42, 10, 20,
267 30, 40, 50, 60,
268 ]);
269
270 let mut stream = buf.as_slice();
271 let mut addr = (&mut stream).remote_addr_unpin().await.unwrap();
272
273 let res = addr.read_u8().await.unwrap();
274 assert_eq!(10, res);
275
276 let mut res = vec![0; 4];
277 addr.read_exact(&mut res).await.unwrap();
278
279 let expected = vec![20, 30, 40, 50];
280 assert_eq!(expected, res);
281
282 let res = addr.read_u8().await.unwrap();
283 assert_eq!(60, res);
284 }
285
286 #[tokio::test]
287 async fn test() {
288 let mut buf = Vec::from(PROTOCOL_PREFIX);
289
290 let err = (&mut &*buf).remote_addr_unpin().await.unwrap_err();
291 let err = err.into_inner().unwrap().downcast::<ParseError>().unwrap();
292 assert!(matches!(*err, ParseError::Incomplete(12)));
293
294 buf.extend([
295 0x21, 0x12, 0, 16, 127, 0, 0, 1, 192, 168, 1, 1, 0, 80, 1, 187, 4, 0, 1, 42,
296 ]);
297 let mut stream = &*buf;
298 let mut addr = (&mut stream).remote_addr_unpin().await.unwrap();
299 let err = addr.read_u8().await.unwrap_err();
300 assert_eq!(err.kind(), ErrorKind::UnexpectedEof);
301
302 buf.push(10);
303 let mut stream = &*buf;
304 let mut addr = (&mut stream).remote_addr_unpin().await.unwrap();
305 let res = addr.read_u8().await.unwrap();
306 assert_eq!(10, res);
307
308 let err = addr.inner().read_u8().await.unwrap_err();
310 assert_eq!(err.kind(), ErrorKind::UnexpectedEof);
311
312 assert!(!addr.addr.is_empty());
313 }
314}