use std::{net::SocketAddr, time::Duration};
use async_trait::async_trait;
use futures::{FutureExt, StreamExt, TryFutureExt, stream::FuturesUnordered};
use safelog::sensitive as sv;
use tor_error::bad_api_usage;
use tor_linkspec::{ChannelMethod, HasChanMethod, OwnedChanTarget};
use tor_proto::peer::PeerAddr;
use tor_rtcompat::{NetStreamProvider, Runtime};
use tracing::{instrument, trace};
use crate::{Error, err::ConnectError};
#[derive(Clone, Debug)]
pub(crate) struct DefaultTransport<R: Runtime> {
runtime: R,
outbound_proxy: Option<crate::config::ProxyProtocol>,
}
impl<R: Runtime> DefaultTransport<R> {
pub(crate) fn new(runtime: R, outbound_proxy: Option<crate::config::ProxyProtocol>) -> Self {
Self {
runtime,
outbound_proxy,
}
}
}
#[async_trait]
impl<R: Runtime> crate::transport::TransportImplHelper for DefaultTransport<R> {
type Stream = <R as NetStreamProvider>::Stream;
#[instrument(skip_all, level = "trace")]
async fn connect(&self, target: &OwnedChanTarget) -> crate::Result<(PeerAddr, Self::Stream)> {
let direct_addrs: Vec<_> = match target.chan_method() {
ChannelMethod::Direct(addrs) => addrs,
#[allow(unreachable_patterns)]
_ => {
return Err(Error::UnusableTarget(bad_api_usage!(
"Used default transport implementation for an unsupported transport."
)));
}
};
trace!("Launching direct connection for {}", target);
let (stream, addr) =
connect_to_one(&self.runtime, &direct_addrs, &self.outbound_proxy).await?;
Ok((addr.into(), stream))
}
}
static CONNECTION_DELAY: Duration = Duration::from_millis(150);
#[instrument(skip_all, level = "trace")]
async fn connect_to_one<R: Runtime>(
rt: &R,
addrs: &[SocketAddr],
outbound_proxy: &Option<crate::config::ProxyProtocol>,
) -> crate::Result<(<R as NetStreamProvider>::Stream, SocketAddr)> {
if addrs.is_empty() {
return Err(Error::UnusableTarget(bad_api_usage!(
"No addresses for chosen relay"
)));
}
let mut connections = addrs
.iter()
.enumerate()
.map(|(i, a)| {
let delay = rt.sleep(CONNECTION_DELAY * i as u32);
let proxy = outbound_proxy.clone();
delay.then(move |_| {
tracing::debug!("Connecting to {}", a);
let a = *a;
async move {
let stream = if let Some(ref protocol) = proxy {
let target = tor_linkspec::PtTargetAddr::IpPort(a);
match protocol {
crate::config::ProxyProtocol::Socks {
version,
auth,
addr,
} => {
let proto = super::proxied::Protocol::Socks(*version, auth.clone());
super::proxied::connect_via_proxy(rt, addr, &proto, &target).await?
}
crate::config::ProxyProtocol::HttpConnect { addr, credentials } => {
let auth = credentials.as_ref().map(|cred| {
(
safelog::Sensitive::new(cred.username.clone()),
safelog::Sensitive::new(
cred.password.clone().unwrap_or_default(),
),
)
});
let proto = super::proxied::Protocol::HttpConnect { auth };
super::proxied::connect_via_proxy(rt, addr, &proto, &target).await?
}
}
} else {
rt.connect(&a).await?
};
Ok((stream, a))
}
.map_err(move |e: ConnectError| (e, a))
})
})
.collect::<FuturesUnordered<_>>();
let mut ret = None;
let mut errors: Vec<(ConnectError, SocketAddr)> = vec![];
while let Some(result) = connections.next().await {
match result {
Ok(s) => {
ret = Some(s);
break;
}
Err((e, a)) => {
errors.push((e, a));
}
}
}
drop(connections);
ret.ok_or_else(|| Error::Connect {
addresses: errors
.into_iter()
.map(|(e, a)| (sv(a.to_string()), e))
.collect(),
})
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_time_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use std::str::FromStr;
use tor_rtcompat::{SleepProviderExt, test_with_one_runtime};
use tor_rtmock::net::MockNetwork;
use super::*;
#[test]
fn test_connect_one() {
let client_addr = "192.0.1.16".parse().unwrap();
let addr1 = SocketAddr::from_str("192.0.2.17:443").unwrap();
let addr2 = SocketAddr::from_str("192.0.3.18:443").unwrap();
let addr3 = SocketAddr::from_str("192.0.4.19:443").unwrap();
let addr4 = SocketAddr::from_str("192.0.9.9:443").unwrap();
test_with_one_runtime!(|rt| async move {
let network = MockNetwork::new();
let client_rt = network
.builder()
.add_address(client_addr)
.runtime(rt.clone());
let server_rt = network
.builder()
.add_address(addr1.ip())
.add_address(addr4.ip())
.runtime(rt.clone());
let listen_options = Default::default();
let _listener = server_rt
.mock_net()
.listen(&addr1, &listen_options)
.await
.unwrap();
let _listener2 = server_rt
.mock_net()
.listen(&addr4, &listen_options)
.await
.unwrap();
network.add_blackhole(addr3).unwrap();
let failure = connect_to_one(&client_rt, &[], &None).await;
assert!(failure.is_err());
for addresses in [
&[addr1][..],
&[addr1, addr2][..],
&[addr2, addr1][..],
&[addr1, addr3][..],
&[addr3, addr1][..],
&[addr1, addr2, addr3][..],
&[addr3, addr2, addr1][..],
] {
let (_conn, addr) = connect_to_one(&client_rt, addresses, &None).await.unwrap();
assert_eq!(addr, addr1);
}
for addresses in [
&[addr2][..],
&[addr2, addr3][..],
&[addr3, addr2][..],
&[addr3][..],
] {
let expect_timeout = addresses.contains(&addr3);
let failure = rt
.timeout(
Duration::from_millis(300),
connect_to_one(&client_rt, addresses, &None),
)
.await;
if expect_timeout {
assert!(failure.is_err());
} else {
assert!(failure.unwrap().is_err());
}
}
let (_conn, addr) = connect_to_one(&client_rt, &[addr1, addr4], &None)
.await
.unwrap();
assert_eq!(addr, addr1);
let (_conn, addr) = connect_to_one(&client_rt, &[addr4, addr1], &None)
.await
.unwrap();
assert_eq!(addr, addr4);
});
}
}