1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
use futures::future;
use std::net::SocketAddr;
use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use trust_dns_resolver::proto::DnsHandle;
use trust_dns_resolver::{AsyncResolver, ConnectionProvider};

#[derive(Debug, Eq, PartialEq, Clone, Hash)]
pub enum TargetAddr {
    Host(String, u16),
    Addr(SocketAddr),
}

impl TargetAddr {
    pub fn from_host(host: &str, port: u16) -> TargetAddr {
        TargetAddr::Host(host.to_owned(), port)
    }

    pub fn from_addr(addr: SocketAddr) -> TargetAddr {
        TargetAddr::Addr(addr)
    }

    pub fn port(&self) -> u16 {
        match *self {
            TargetAddr::Host(_, port) => port,
            TargetAddr::Addr(addr) => addr.port(),
        }
    }

    pub fn unwrap_addr(&self) -> SocketAddr {
        match *self {
            TargetAddr::Host(_, _) => panic!("Invalid destination type"),
            TargetAddr::Addr(addr) => addr,
        }
    }

    pub async fn connect<C: DnsHandle, P: ConnectionProvider<Conn = C>>(
        &self,
        resolver: &AsyncResolver<C, P>,
    ) -> io::Result<TcpStream> {
        let remote_addr = self.resolve(resolver).await?;
        let mut err: io::Result<TcpStream> = Err(io::Error::new(
            io::ErrorKind::AddrNotAvailable,
            "Resolved addr is empty",
        ));
        for addr in remote_addr {
            match TcpStream::connect(addr).await {
                Ok(socket) => {
                    return Ok(socket);
                }
                Err(e) => {
                    err = Err(e);
                }
            }
        }
        err
    }

    pub async fn resolve<C: DnsHandle, P: ConnectionProvider<Conn = C>>(
        &self,
        resolver: &AsyncResolver<C, P>,
    ) -> io::Result<Vec<SocketAddr>> {
        match self {
            TargetAddr::Host(host, port) => match resolver.lookup_ip(host.as_str()).await {
                Ok(result) => Ok(result.iter().map(|x| SocketAddr::new(x, *port)).collect()),
                Err(_e) => Err(io::Error::new(
                    io::ErrorKind::AddrNotAvailable,
                    "Could't resolve host",
                )),
            },
            TargetAddr::Addr(addr) => Ok(vec![*addr]),
        }
    }
}

pub async fn relay<'a, L, R>(l: &'a mut L, r: &'a mut R) -> io::Result<(u64, u64)>
    where
        L: AsyncRead + AsyncWrite + Unpin + ?Sized,
        R: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
    let (mut lr, mut lw) = io::split(l);
    let (mut rr, mut rw) = io::split(r);
    return relay_split(&mut lr, &mut lw, &mut rr, &mut rw).await;
}

pub async fn relay_split<'a, LR, LW, RR, RW>(
    mut lr: &'a mut LR,
    mut lw: &'a mut LW,
    mut rr: &'a mut RR,
    mut rw: &'a mut RW,
) -> io::Result<(u64, u64)>
    where
        LR: AsyncRead + Unpin + ?Sized,
        LW: AsyncWrite + Unpin + ?Sized,
        RR: AsyncRead + Unpin + ?Sized,
        RW: AsyncWrite + Unpin + ?Sized,
{
    let client_to_server = transfer(&mut lr, &mut rw);
    let server_to_client = transfer(&mut rr, &mut lw);
    return future::try_join(client_to_server, server_to_client).await;
}

pub async fn transfer<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
    where
        R: AsyncRead + Unpin + ?Sized,
        W: AsyncWrite + Unpin + ?Sized,
{
    let len = io::copy(reader, writer).await?;
    writer.shutdown().await?;
    Ok(len)
}