1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
use futures::future; use std::net::SocketAddr; use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; use trust_dns_resolver::proto::DnsHandle; use trust_dns_resolver::{AsyncResolver, ConnectionProvider}; #[derive(Debug, Eq, PartialEq, Clone, Hash)] pub enum TargetAddr { Host(String, u16), Addr(SocketAddr), } impl TargetAddr { pub fn from_host(host: &str, port: u16) -> TargetAddr { TargetAddr::Host(host.to_owned(), port) } pub fn from_addr(addr: SocketAddr) -> TargetAddr { TargetAddr::Addr(addr) } pub fn port(&self) -> u16 { match *self { TargetAddr::Host(_, port) => port, TargetAddr::Addr(addr) => addr.port(), } } pub fn unwrap_addr(&self) -> SocketAddr { match *self { TargetAddr::Host(_, _) => panic!("Invalid destination type"), TargetAddr::Addr(addr) => addr, } } pub async fn connect<C: DnsHandle, P: ConnectionProvider<Conn = C>>( &self, resolver: &AsyncResolver<C, P>, ) -> io::Result<TcpStream> { let remote_addr = self.resolve(resolver).await?; let mut err: io::Result<TcpStream> = Err(io::Error::new( io::ErrorKind::AddrNotAvailable, "Resolved addr is empty", )); for addr in remote_addr { match TcpStream::connect(addr).await { Ok(socket) => { return Ok(socket); } Err(e) => { err = Err(e); } } } err } pub async fn resolve<C: DnsHandle, P: ConnectionProvider<Conn = C>>( &self, resolver: &AsyncResolver<C, P>, ) -> io::Result<Vec<SocketAddr>> { match self { TargetAddr::Host(host, port) => match resolver.lookup_ip(host.as_str()).await { Ok(result) => Ok(result.iter().map(|x| SocketAddr::new(x, *port)).collect()), Err(_e) => Err(io::Error::new( io::ErrorKind::AddrNotAvailable, "Could't resolve host", )), }, TargetAddr::Addr(addr) => Ok(vec![*addr]), } } } pub async fn relay<'a, L, R>(l: &'a mut L, r: &'a mut R) -> io::Result<(u64, u64)> where L: AsyncRead + AsyncWrite + Unpin + ?Sized, R: AsyncRead + AsyncWrite + Unpin + ?Sized, { let (mut lr, mut lw) = io::split(l); let (mut rr, mut rw) = io::split(r); return relay_split(&mut lr, &mut lw, &mut rr, &mut rw).await; } pub async fn relay_split<'a, LR, LW, RR, RW>( mut lr: &'a mut LR, mut lw: &'a mut LW, mut rr: &'a mut RR, mut rw: &'a mut RW, ) -> io::Result<(u64, u64)> where LR: AsyncRead + Unpin + ?Sized, LW: AsyncWrite + Unpin + ?Sized, RR: AsyncRead + Unpin + ?Sized, RW: AsyncWrite + Unpin + ?Sized, { let client_to_server = transfer(&mut lr, &mut rw); let server_to_client = transfer(&mut rr, &mut lw); return future::try_join(client_to_server, server_to_client).await; } pub async fn transfer<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64> where R: AsyncRead + Unpin + ?Sized, W: AsyncWrite + Unpin + ?Sized, { let len = io::copy(reader, writer).await?; writer.shutdown().await?; Ok(len) }