use crate::tunnel::{TunnelCtx, TunnelTarget};
use async_trait::async_trait;
use log::{debug, error, info};
use rand::prelude::thread_rng;
use rand::Rng;
use serde::export::PhantomData;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Instant;
use tokio::io;
use tokio::io::{AsyncRead, AsyncWrite, Error, ErrorKind};
use tokio::net::TcpStream;
use tokio::sync::RwLock;
use tokio::time::timeout;
use tokio::time::Duration;
#[async_trait]
pub trait TargetConnector {
type Target: TunnelTarget + Send + Sync + Sized;
type Stream: AsyncRead + AsyncWrite + Send + Sized + 'static;
async fn connect(&mut self, target: &Self::Target) -> io::Result<Self::Stream>;
}
#[async_trait]
pub trait DnsResolver {
async fn resolve(&mut self, target: &str) -> io::Result<SocketAddr>;
}
#[derive(Clone, Builder)]
pub struct SimpleTcpConnector<D, R: DnsResolver> {
connect_timeout: Duration,
tunnel_ctx: TunnelCtx,
dns_resolver: R,
#[builder(setter(skip))]
_phantom_target: PhantomData<D>,
}
type CachedSocketAddrs = (Vec<SocketAddr>, u128);
#[derive(Clone)]
pub struct SimpleCachingDnsResolver {
cache: Arc<RwLock<HashMap<String, CachedSocketAddrs>>>,
ttl: Duration,
start_time: Instant,
}
#[async_trait]
impl<D, R> TargetConnector for SimpleTcpConnector<D, R>
where
D: TunnelTarget<Addr = String> + Send + Sync + Sized,
R: DnsResolver + Send + Sync + 'static,
{
type Target = D;
type Stream = TcpStream;
async fn connect(&mut self, target: &Self::Target) -> io::Result<Self::Stream> {
let target_addr = &target.target_addr();
let addr = self.dns_resolver.resolve(target_addr).await?;
if let Ok(tcp_stream) = timeout(self.connect_timeout, TcpStream::connect(addr)).await {
let stream = tcp_stream?;
stream.nodelay()?;
Ok(stream)
} else {
error!(
"Timeout connecting to {}, {}, CTX={}",
addr, target_addr, self.tunnel_ctx
);
Err(Error::from(ErrorKind::TimedOut))
}
}
}
#[async_trait]
impl DnsResolver for SimpleCachingDnsResolver {
async fn resolve(&mut self, target: &str) -> io::Result<SocketAddr> {
match self.try_find(target).await {
Some(a) => Ok(a),
_ => Ok(self.resolve_and_cache(target).await?),
}
}
}
impl<D, R> SimpleTcpConnector<D, R>
where
R: DnsResolver,
{
pub fn new(dns_resolver: R, connect_timeout: Duration, tunnel_ctx: TunnelCtx) -> Self {
Self {
dns_resolver,
connect_timeout,
tunnel_ctx,
_phantom_target: PhantomData,
}
}
}
impl SimpleCachingDnsResolver {
pub fn new(ttl: Duration) -> Self {
Self {
cache: Arc::new(RwLock::new(HashMap::new())),
ttl,
start_time: Instant::now(),
}
}
fn pick(&self, addrs: &[SocketAddr]) -> SocketAddr {
addrs[thread_rng().gen::<usize>() % addrs.len()]
}
async fn try_find(&mut self, target: &str) -> Option<SocketAddr> {
let map = self.cache.read().await;
let addr = match map.get(target) {
None => None,
Some((cached, expiration)) => {
let expiration_gitter = *expiration + thread_rng().gen_range(0..5_000);
if Instant::now().duration_since(self.start_time).as_millis() < expiration_gitter {
Some(self.pick(cached))
} else {
None
}
}
};
addr
}
async fn resolve_and_cache(&mut self, target: &str) -> io::Result<SocketAddr> {
let resolved = SimpleCachingDnsResolver::resolve(target).await?;
let mut map = self.cache.write().await;
map.insert(
target.to_string(),
(
resolved.clone(),
Instant::now().duration_since(self.start_time).as_millis() + self.ttl.as_millis(),
),
);
Ok(self.pick(&resolved))
}
async fn resolve(target: &str) -> io::Result<Vec<SocketAddr>> {
debug!("Resolving DNS {}", target,);
let resolved: Vec<SocketAddr> = tokio::net::lookup_host(target).await?.collect();
info!("Resolved DNS {} to {:?}", target, resolved);
if resolved.is_empty() {
error!("Cannot resolve DNS {}", target,);
return Err(Error::from(ErrorKind::AddrNotAvailable));
}
Ok(resolved)
}
}