use std::io;
use multiaddr::Multiaddr;
use tokio::net::TcpStream;
use super::Transport;
use crate::{
transports::{SocksConfig, SocksTransport, TcpTransport, dns::TorDnsResolver, predicate::is_onion_address},
types::TransportProtocol,
utils::network::supports_ipv6,
};
#[derive(Clone, Default)]
pub struct TcpWithTorTransport {
socks_transport: Option<SocksTransport>,
tcp_transport: TcpTransport,
supported_protocols: Vec<TransportProtocol>,
}
impl TcpWithTorTransport {
pub fn set_tor_socks_proxy(&mut self, socks_config: SocksConfig) -> &mut Self {
self.socks_transport = Some(SocksTransport::new(socks_config.clone()));
self.tcp_transport.set_dns_resolver(TorDnsResolver::new(socks_config));
if !self.supported_protocols.contains(&TransportProtocol::Onion) {
self.supported_protocols.push(TransportProtocol::Onion);
}
self
}
pub fn with_tor_socks_proxy(socks_config: SocksConfig) -> Self {
let mut transport = Self::default();
transport.set_tor_socks_proxy(socks_config);
transport
}
pub fn new() -> Self {
let mut supported_protocols = vec![TransportProtocol::Ipv4];
if supports_ipv6() {
supported_protocols.push(TransportProtocol::Ipv6);
}
Self {
supported_protocols,
..Default::default()
}
}
pub fn tcp_transport_mut(&mut self) -> &mut TcpTransport {
&mut self.tcp_transport
}
}
#[crate::async_trait]
impl Transport for TcpWithTorTransport {
type Error = io::Error;
type Listener = <TcpTransport as Transport>::Listener;
type Output = TcpStream;
async fn listen(&self, addr: &Multiaddr) -> Result<(Self::Listener, Multiaddr), Self::Error> {
self.tcp_transport.listen(addr).await
}
async fn dial(&self, addr: &Multiaddr) -> Result<Self::Output, Self::Error> {
if addr.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("Invalid address '{addr}'"),
));
}
if is_onion_address(addr) {
match self.socks_transport {
Some(ref transport) => {
let socket = transport.dial(addr).await?;
Ok(socket)
},
None => Err(io::Error::other(
"Tor SOCKS proxy is not set for TCP transport. Cannot dial peer with onion addresses.".to_owned(),
)),
}
} else {
let socket = self.tcp_transport.dial(addr).await?;
Ok(socket)
}
}
fn supported_protocols(&self) -> Vec<TransportProtocol> {
self.supported_protocols.clone()
}
}