use std::fmt;
use std::io::Cursor;
use std::net::{IpAddr, SocketAddr};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;
use crate::types::ProxyInfo;
#[derive(Debug, Clone)]
pub struct ProxyConnectInfo {
pub client_addr: SocketAddr,
pub peer_addr: SocketAddr,
pub proxy_info: Option<ProxyInfo>,
}
#[derive(Debug)]
pub struct ProxiedStream {
inner: TcpStream,
leftover: Cursor<Vec<u8>>,
proxy_info: Option<ProxyInfo>,
peer_addr: SocketAddr,
}
impl ProxiedStream {
pub(crate) fn new(
inner: TcpStream,
leftover: Vec<u8>,
proxy_info: Option<ProxyInfo>,
peer_addr: SocketAddr,
) -> Self {
Self {
inner,
leftover: Cursor::new(leftover),
proxy_info,
peer_addr,
}
}
pub fn proxy_info(&self) -> Option<&ProxyInfo> {
self.proxy_info.as_ref()
}
pub fn client_addr(&self) -> SocketAddr {
self.proxy_info
.as_ref()
.and_then(|info| info.source_inet())
.unwrap_or(self.peer_addr)
}
pub fn peer_addr(&self) -> SocketAddr {
self.peer_addr
}
pub fn inner(&self) -> &TcpStream {
&self.inner
}
pub fn connect_info(&self) -> ProxyConnectInfo {
ProxyConnectInfo {
client_addr: self.client_addr(),
peer_addr: self.peer_addr,
proxy_info: self.proxy_info.clone(),
}
}
}
impl ProxyConnectInfo {
pub fn client_ip(&self) -> IpAddr {
self.client_addr.ip()
}
pub fn is_proxied(&self) -> bool {
self.proxy_info.is_some()
}
}
impl fmt::Display for ProxyConnectInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.is_proxied() {
write!(f, "{} via {}", self.client_addr, self.peer_addr)
} else {
write!(f, "{} (direct)", self.client_addr)
}
}
}
impl From<&ProxiedStream> for ProxyConnectInfo {
fn from(stream: &ProxiedStream) -> Self {
stream.connect_info()
}
}
impl From<SocketAddr> for ProxyConnectInfo {
fn from(addr: SocketAddr) -> Self {
Self {
client_addr: addr,
peer_addr: addr,
proxy_info: None,
}
}
}
impl From<(ProxyInfo, SocketAddr)> for ProxyConnectInfo {
fn from((info, peer_addr): (ProxyInfo, SocketAddr)) -> Self {
let client_addr = info.source_inet().unwrap_or(peer_addr);
Self {
client_addr,
peer_addr,
proxy_info: Some(info),
}
}
}
impl AsyncRead for ProxiedStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
let leftover_data = this.leftover.get_ref();
let leftover_pos = this.leftover.position() as usize;
let leftover_remaining = leftover_data.len() - leftover_pos;
if leftover_remaining > 0 {
let to_copy = leftover_remaining.min(buf.remaining());
buf.put_slice(&leftover_data[leftover_pos..leftover_pos + to_copy]);
this.leftover.set_position((leftover_pos + to_copy) as u64);
return Poll::Ready(Ok(()));
}
Pin::new(&mut this.inner).poll_read(cx, buf)
}
}
impl AsyncWrite for ProxiedStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().inner).poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
}
}