1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
use futures_util::{ready, FutureExt};
use ppp::v2::{Addresses, Header, ParseError};
use std::future::Future;
use std::io::{Error as IoError, ErrorKind};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, ReadBuf};
use tokio_util::io::poll_read_buf;
pub trait Ext {
fn remote_addr(self: Pin<&mut Self>) -> ProxyProtoFuture<'_, Self>;
fn remote_addr_unpin(&mut self) -> ProxyProtoFuture<'_, Self> where Self: Unpin;
}
impl<T> Ext for T
where
T: AsyncRead + Sized,
{
fn remote_addr(self: Pin<&mut Self>) -> ProxyProtoFuture<'_, Self> {
ProxyProtoFuture {
inner: Some(self),
buf: vec![],
}
}
fn remote_addr_unpin(&mut self) -> ProxyProtoFuture<'_, Self> where Self: Unpin {
Pin::new(self).remote_addr()
}
}
#[derive(Debug)]
pub struct ProxyProtoFuture<'a, T: ?Sized> {
inner: Option<Pin<&'a mut T>>,
buf: Vec<u8>,
}
impl<'a, T> Unpin for ProxyProtoFuture<'a, T> {}
impl<'a, T> Future for ProxyProtoFuture<'a, T>
where
T: AsyncRead,
{
type Output = Result<ProxyProtoStream<'a, T>, IoError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
let buf = &mut this.buf;
let inner = match &mut this.inner {
Some(inner) => inner.as_mut(),
None => panic!("future polled after completion")
};
let added = ready!(poll_read_buf(inner, cx, buf))?;
if added == 0 {
return Poll::Ready(Err(IoError::new(
ErrorKind::Other,
ParseError::Incomplete(buf.len()),
)));
}
let res = match Header::try_from(buf.as_ref()) {
Err(ParseError::Incomplete(_)) => return this.poll_unpin(cx),
Err(e) => return Poll::Ready(Err(IoError::new(ErrorKind::Other, e))),
Ok(res) => res,
};
let addr = res.addresses;
let start_of_data= res.len();
let buf = std::mem::take(buf);
let inner = this.inner.take().unwrap();
let stream = ProxyProtoStream {
addr,
inner,
buf,
start_of_data,
};
return Poll::Ready(Ok(stream))
}
}
#[derive(Debug)]
pub struct ProxyProtoStream<'a, T> {
inner: Pin<&'a mut T>,
buf: Vec<u8>,
start_of_data: usize,
pub addr: Addresses,
}
impl <'a, T> ProxyProtoStream<'a, T> {
pub fn inner(&mut self) -> Pin<&mut T> {
return self.inner.as_mut()
}
}
impl <'a, T> Unpin for ProxyProtoStream<'a, T> {}
impl <'a, T> AsyncRead for ProxyProtoStream<'a, T> where T: AsyncRead {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
let this = self.get_mut();
let start_of_data = this.start_of_data;
if this.buf.len() > 0 {
buf.put_slice(&this.buf[start_of_data..]);
this.buf = Vec::new();
}
return this.inner.as_mut().poll_read(cx, buf)
}
}
#[cfg(test)]
mod tests {
use std::io::ErrorKind;
use ppp::v2::{ParseError, PROTOCOL_PREFIX};
use tokio::io::AsyncReadExt;
use super::Ext;
#[tokio::test]
async fn test() {
let mut buf = Vec::from(PROTOCOL_PREFIX);
let err = (&mut &*buf).remote_addr_unpin().await.unwrap_err();
let err = err.into_inner().unwrap().downcast::<ParseError>().unwrap();
assert!(matches!(*err, ParseError::Incomplete(12)));
buf.extend([
0x21, 0x12, 0, 16, 127, 0, 0, 1, 192, 168, 1, 1, 0, 80, 1, 187, 4, 0, 1, 42
]);
let mut stream = &*buf;
let mut addr = (&mut stream).remote_addr_unpin().await.unwrap();
let err = addr.read_u8().await.unwrap_err();
assert_eq!(err.kind(), ErrorKind::UnexpectedEof);
buf.push(10);
let mut stream = &*buf;
let mut addr = (&mut stream).remote_addr_unpin().await.unwrap();
let res = addr.read_u8().await.unwrap();
assert_eq!(10, res);
let err = addr.inner().read_u8().await.unwrap_err();
assert_eq!(err.kind(), ErrorKind::UnexpectedEof);
assert!(!addr.addr.is_empty());
}
}