proxy_relay/
lib.rs

1use futures::future;
2use std::net::SocketAddr;
3use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt};
4use tokio::net::TcpStream;
5use trust_dns_resolver::proto::DnsHandle;
6use trust_dns_resolver::{AsyncResolver, ConnectionProvider};
7
8#[derive(Debug, Eq, PartialEq, Clone, Hash)]
9pub enum TargetAddr {
10    Host(String, u16),
11    Addr(SocketAddr),
12}
13
14impl TargetAddr {
15    pub fn from_host(host: &str, port: u16) -> TargetAddr {
16        TargetAddr::Host(host.to_owned(), port)
17    }
18
19    pub fn from_addr(addr: SocketAddr) -> TargetAddr {
20        TargetAddr::Addr(addr)
21    }
22
23    pub fn port(&self) -> u16 {
24        match *self {
25            TargetAddr::Host(_, port) => port,
26            TargetAddr::Addr(addr) => addr.port(),
27        }
28    }
29
30    pub fn unwrap_addr(&self) -> SocketAddr {
31        match *self {
32            TargetAddr::Host(_, _) => panic!("Invalid destination type"),
33            TargetAddr::Addr(addr) => addr,
34        }
35    }
36
37    pub async fn connect<C: DnsHandle, P: ConnectionProvider<Conn = C>>(
38        &self,
39        resolver: &AsyncResolver<C, P>,
40    ) -> io::Result<TcpStream> {
41        let remote_addr = self.resolve(resolver).await?;
42        let mut err: io::Result<TcpStream> = Err(io::Error::new(
43            io::ErrorKind::AddrNotAvailable,
44            "Resolved addr is empty",
45        ));
46        for addr in remote_addr {
47            match TcpStream::connect(addr).await {
48                Ok(socket) => {
49                    return Ok(socket);
50                }
51                Err(e) => {
52                    err = Err(e);
53                }
54            }
55        }
56        err
57    }
58
59    pub async fn resolve<C: DnsHandle, P: ConnectionProvider<Conn = C>>(
60        &self,
61        resolver: &AsyncResolver<C, P>,
62    ) -> io::Result<Vec<SocketAddr>> {
63        match self {
64            TargetAddr::Host(host, port) => match resolver.lookup_ip(host.as_str()).await {
65                Ok(result) => Ok(result.iter().map(|x| SocketAddr::new(x, *port)).collect()),
66                Err(_e) => Err(io::Error::new(
67                    io::ErrorKind::AddrNotAvailable,
68                    "Could't resolve host",
69                )),
70            },
71            TargetAddr::Addr(addr) => Ok(vec![*addr]),
72        }
73    }
74}
75
76pub async fn relay<'a, L, R>(l: &'a mut L, r: &'a mut R) -> io::Result<(u64, u64)>
77    where
78        L: AsyncRead + AsyncWrite + Unpin + ?Sized,
79        R: AsyncRead + AsyncWrite + Unpin + ?Sized,
80{
81    let (mut lr, mut lw) = io::split(l);
82    let (mut rr, mut rw) = io::split(r);
83    return relay_split(&mut lr, &mut lw, &mut rr, &mut rw).await;
84}
85
86pub async fn relay_split<'a, LR, LW, RR, RW>(
87    mut lr: &'a mut LR,
88    mut lw: &'a mut LW,
89    mut rr: &'a mut RR,
90    mut rw: &'a mut RW,
91) -> io::Result<(u64, u64)>
92    where
93        LR: AsyncRead + Unpin + ?Sized,
94        LW: AsyncWrite + Unpin + ?Sized,
95        RR: AsyncRead + Unpin + ?Sized,
96        RW: AsyncWrite + Unpin + ?Sized,
97{
98    let client_to_server = transfer(&mut lr, &mut rw);
99    let server_to_client = transfer(&mut rr, &mut lw);
100    return future::try_join(client_to_server, server_to_client).await;
101}
102
103pub async fn transfer<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
104    where
105        R: AsyncRead + Unpin + ?Sized,
106        W: AsyncWrite + Unpin + ?Sized,
107{
108    let len = io::copy(reader, writer).await?;
109    writer.shutdown().await?;
110    Ok(len)
111}