1use futures::future;
2use std::net::SocketAddr;
3use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt};
4use tokio::net::TcpStream;
5use trust_dns_resolver::proto::DnsHandle;
6use trust_dns_resolver::{AsyncResolver, ConnectionProvider};
7
8#[derive(Debug, Eq, PartialEq, Clone, Hash)]
9pub enum TargetAddr {
10 Host(String, u16),
11 Addr(SocketAddr),
12}
13
14impl TargetAddr {
15 pub fn from_host(host: &str, port: u16) -> TargetAddr {
16 TargetAddr::Host(host.to_owned(), port)
17 }
18
19 pub fn from_addr(addr: SocketAddr) -> TargetAddr {
20 TargetAddr::Addr(addr)
21 }
22
23 pub fn port(&self) -> u16 {
24 match *self {
25 TargetAddr::Host(_, port) => port,
26 TargetAddr::Addr(addr) => addr.port(),
27 }
28 }
29
30 pub fn unwrap_addr(&self) -> SocketAddr {
31 match *self {
32 TargetAddr::Host(_, _) => panic!("Invalid destination type"),
33 TargetAddr::Addr(addr) => addr,
34 }
35 }
36
37 pub async fn connect<C: DnsHandle, P: ConnectionProvider<Conn = C>>(
38 &self,
39 resolver: &AsyncResolver<C, P>,
40 ) -> io::Result<TcpStream> {
41 let remote_addr = self.resolve(resolver).await?;
42 let mut err: io::Result<TcpStream> = Err(io::Error::new(
43 io::ErrorKind::AddrNotAvailable,
44 "Resolved addr is empty",
45 ));
46 for addr in remote_addr {
47 match TcpStream::connect(addr).await {
48 Ok(socket) => {
49 return Ok(socket);
50 }
51 Err(e) => {
52 err = Err(e);
53 }
54 }
55 }
56 err
57 }
58
59 pub async fn resolve<C: DnsHandle, P: ConnectionProvider<Conn = C>>(
60 &self,
61 resolver: &AsyncResolver<C, P>,
62 ) -> io::Result<Vec<SocketAddr>> {
63 match self {
64 TargetAddr::Host(host, port) => match resolver.lookup_ip(host.as_str()).await {
65 Ok(result) => Ok(result.iter().map(|x| SocketAddr::new(x, *port)).collect()),
66 Err(_e) => Err(io::Error::new(
67 io::ErrorKind::AddrNotAvailable,
68 "Could't resolve host",
69 )),
70 },
71 TargetAddr::Addr(addr) => Ok(vec![*addr]),
72 }
73 }
74}
75
76pub async fn relay<'a, L, R>(l: &'a mut L, r: &'a mut R) -> io::Result<(u64, u64)>
77 where
78 L: AsyncRead + AsyncWrite + Unpin + ?Sized,
79 R: AsyncRead + AsyncWrite + Unpin + ?Sized,
80{
81 let (mut lr, mut lw) = io::split(l);
82 let (mut rr, mut rw) = io::split(r);
83 return relay_split(&mut lr, &mut lw, &mut rr, &mut rw).await;
84}
85
86pub async fn relay_split<'a, LR, LW, RR, RW>(
87 mut lr: &'a mut LR,
88 mut lw: &'a mut LW,
89 mut rr: &'a mut RR,
90 mut rw: &'a mut RW,
91) -> io::Result<(u64, u64)>
92 where
93 LR: AsyncRead + Unpin + ?Sized,
94 LW: AsyncWrite + Unpin + ?Sized,
95 RR: AsyncRead + Unpin + ?Sized,
96 RW: AsyncWrite + Unpin + ?Sized,
97{
98 let client_to_server = transfer(&mut lr, &mut rw);
99 let server_to_client = transfer(&mut rr, &mut lw);
100 return future::try_join(client_to_server, server_to_client).await;
101}
102
103pub async fn transfer<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
104 where
105 R: AsyncRead + Unpin + ?Sized,
106 W: AsyncWrite + Unpin + ?Sized,
107{
108 let len = io::copy(reader, writer).await?;
109 writer.shutdown().await?;
110 Ok(len)
111}