use std::{
fmt::{Debug, Formatter},
io,
sync::Arc,
};
use log::debug;
use tokio::net::TcpStream;
use crate::{
multiaddr::Multiaddr,
socks::{self, Socks5Client},
transports::{Transport, dns::SystemDnsResolver, predicate::Predicate, tcp::TcpTransport},
types::TransportProtocol,
utils::network::supports_ipv6,
};
const LOG_TARGET: &str = "comms::transports::socks";
#[derive(Clone)]
pub struct SocksConfig {
pub proxy_address: Multiaddr,
pub authentication: socks::Authentication,
pub proxy_bypass_predicate: Arc<dyn Predicate<Multiaddr> + Send + Sync>,
}
impl Debug for SocksConfig {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SocksConfig")
.field("proxy_address", &self.proxy_address)
.field("authentication", &self.authentication)
.field("proxy_bypass_predicate", &"...")
.finish()
}
}
#[derive(Clone)]
pub struct SocksTransport {
socks_config: SocksConfig,
tcp_transport: TcpTransport,
supported_protocols: Vec<TransportProtocol>,
}
impl SocksTransport {
pub fn new(socks_config: SocksConfig) -> Self {
let mut supported_protocols = vec![TransportProtocol::Ipv4, TransportProtocol::Onion];
if supports_ipv6() {
supported_protocols.push(TransportProtocol::Ipv6);
}
Self {
socks_config,
tcp_transport: Self::create_socks_tcp_transport(),
supported_protocols,
}
}
pub fn create_socks_tcp_transport() -> TcpTransport {
let mut tcp_transport = TcpTransport::new();
tcp_transport.set_dns_resolver(SystemDnsResolver);
tcp_transport
}
async fn socks_connect(
tcp: TcpTransport,
socks_config: &SocksConfig,
dest_addr: &Multiaddr,
) -> io::Result<TcpStream> {
let socks_conn = tcp.dial(&socks_config.proxy_address).await?;
let mut client = Socks5Client::new(socks_conn);
client
.with_authentication(socks_config.authentication.clone())
.map_err(io::Error::other)?;
client
.connect(dest_addr)
.await
.map(|(socket, _)| socket)
.map_err(io::Error::other)
}
}
#[crate::async_trait]
impl Transport for SocksTransport {
type Error = <TcpTransport as Transport>::Error;
type Listener = <TcpTransport as Transport>::Listener;
type Output = <TcpTransport as Transport>::Output;
async fn listen(&self, addr: &Multiaddr) -> Result<(Self::Listener, Multiaddr), Self::Error> {
self.tcp_transport.listen(addr).await
}
async fn dial(&self, addr: &Multiaddr) -> Result<Self::Output, Self::Error> {
if self.socks_config.proxy_bypass_predicate.check(addr) {
debug!(target: LOG_TARGET, "SOCKS proxy bypassed for '{addr}'. Using TCP.");
return self.tcp_transport.dial(addr).await;
}
let socket = Self::socks_connect(self.tcp_transport.clone(), &self.socks_config, addr).await?;
Ok(socket)
}
fn supported_protocols(&self) -> Vec<TransportProtocol> {
self.supported_protocols.clone()
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{socks::Authentication, transports::predicate::FalsePredicate};
#[test]
fn new() {
let proxy_address = "/ip4/127.0.0.1/tcp/1234".parse::<Multiaddr>().unwrap();
let transport = SocksTransport::new(SocksConfig {
proxy_address: proxy_address.clone(),
authentication: Default::default(),
proxy_bypass_predicate: Arc::new(FalsePredicate::new()),
});
assert_eq!(transport.socks_config.proxy_address, proxy_address);
assert_eq!(transport.socks_config.authentication, Authentication::None);
}
}