use core::fmt::{Debug, Formatter};
use std::time::Instant;
use tokio::{
io::{AsyncRead, AsyncWrite},
net::TcpStream,
};
use tokio_util::future::FutureExt;
use ts_bitset::BitsetDyn;
use ts_capabilityversion::CapabilityVersion;
use ts_http_util::{BytesBody, Http2};
use url::Url;
use crate::{DialCandidate, DialMode, DialPlan, tokio::ConnectionError};
pub struct ControlDialer {
plan: DialPlan,
epoch: usize,
timestamp: Instant,
attempted_candidates: ts_dynbitset::DynBitset,
}
impl Default for ControlDialer {
fn default() -> Self {
Self {
plan: DialPlan::default(),
epoch: 0,
timestamp: Instant::now(),
attempted_candidates: Default::default(),
}
}
}
pub trait TcpDialer {
fn dial(
self,
host: &str,
port: u16,
) -> impl Future<Output = tokio::io::Result<TcpStream>> + Send;
}
enum ControlTcpDialer<'a> {
UseDns,
Planned {
attempted: &'a mut ts_dynbitset::DynBitset,
candidate: &'a DialCandidate,
index: usize,
},
}
impl Debug for ControlTcpDialer<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
match self {
ControlTcpDialer::UseDns => write!(f, "TcpDialer::Dns"),
ControlTcpDialer::Planned { candidate, .. } => match &candidate.mode {
DialMode::Ip(ip) => f.debug_tuple("TcpDialer::Ip").field(ip).finish(),
DialMode::Ace { ip: Some(ip), host } => f
.debug_tuple("TcpDialer::Ace")
.field(ip)
.field(host)
.finish(),
DialMode::Ace { host, .. } => f.debug_tuple("TcpDialer::Ace").field(host).finish(),
},
}
}
}
impl TcpDialer for ControlTcpDialer<'_> {
async fn dial(self, host: &str, port: u16) -> tokio::io::Result<TcpStream> {
match self {
ControlTcpDialer::UseDns => TcpStream::connect(format!("{host}:{port}")).await,
ControlTcpDialer::Planned {
candidate,
attempted: used,
index,
} => {
used.set(index);
match candidate.mode {
DialMode::Ip(ip) => {
TcpStream::connect((ip, port))
.timeout(candidate.timeout)
.await?
}
DialMode::Ace { .. } => {
unimplemented!()
}
}
}
}
}
}
impl ControlDialer {
pub fn update_dial_plan(&mut self, plan: &DialPlan) -> bool {
if &self.plan == plan {
return false;
}
self.plan = plan.clone();
self.epoch += 1;
self.timestamp = Instant::now();
true
}
pub fn clear_attempted(&mut self) {
self.attempted_candidates.clear_all();
}
pub fn next_dialer(&mut self) -> impl TcpDialer + Debug {
match &self.plan {
DialPlan::UseDns => ControlTcpDialer::UseDns,
DialPlan::Plan(candidates) => {
let mut selected_candidate: Option<(usize, usize, &DialCandidate)> = None;
let now = Instant::now();
for (i, candidate) in candidates.iter().enumerate() {
if self.attempted_candidates.test(i) {
continue;
}
let start_after = self.timestamp + candidate.start_delay_sec;
if start_after > now {
continue;
}
if matches!(candidate.mode, DialMode::Ace { .. }) {
continue;
}
if selected_candidate.is_none_or(|(prio, _idx, elem)| prio < elem.priority) {
selected_candidate = Some((candidate.priority, i, candidate));
}
}
let (i, candidate) = match selected_candidate {
Some((_prio, i, elem)) => (i, elem),
None => {
tracing::warn!(
"no dialer candidates available: falling back to system dns"
);
return ControlTcpDialer::UseDns;
}
};
ControlTcpDialer::Planned {
candidate,
index: i,
attempted: &mut self.attempted_candidates,
}
}
}
}
#[tracing::instrument(skip_all, fields(control_url = %url))]
pub async fn full_connect_next(
&mut self,
url: &Url,
machine_keys: &ts_keys::MachineKeyPair,
) -> Result<Http2<BytesBody>, ConnectionError> {
let next = self.next_dialer();
tracing::trace!(selected_control_dialer = ?next);
let host = url.host_str().ok_or(ConnectionError::ConnectionFailed)?;
let port = url
.port_or_known_default()
.ok_or(ConnectionError::ConnectionFailed)?;
let conn = next.dial(host, port).await.map_err(|e| {
tracing::error!(error = %e, %url, %host, port, "dialing tcp");
ConnectionError::ConnectionFailed
})?;
tracing::debug!(
remote_endpoint = ?conn.peer_addr(),
"tcp connection to control"
);
let client = complete_connection(url, machine_keys, conn).await?;
Ok(client)
}
}
pub async fn complete_connection<Io>(
url: &Url,
machine_keys: &ts_keys::MachineKeyPair,
stream: Io,
) -> Result<Http2<BytesBody>, ConnectionError>
where
Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
let h1_client = match url.scheme() {
"https" => {
let conn = ts_tls_util::connect(
ts_tls_util::server_name(url).ok_or(ConnectionError::ConnectionFailed)?,
stream,
)
.await
.map_err(|e| {
tracing::error!(error = %e, "establishing tls connection");
ConnectionError::ConnectionFailed
})?;
ts_http_util::http1::connect(conn).await?
}
"http" => ts_http_util::http1::connect(stream).await?,
other => {
tracing::error!(invalid_scheme = other);
return Err(ConnectionError::ConnectionFailed);
}
};
let control_public_key = crate::tokio::fetch_control_key(url).await?;
let (handshake, init_msg) = ts_control_noise::Handshake::initialize(
&crate::tokio::CONTROL_PROTOCOL_VERSION,
&machine_keys.private,
&control_public_key,
CapabilityVersion::CURRENT,
);
let mut conn = crate::tokio::upgrade_ts2021(url, &init_msg, handshake, h1_client).await?;
let _challenge_packet = crate::tokio::read_challenge_packet(&mut conn).await?;
let h2_conn = ts_http_util::http2::connect(conn).await?;
tracing::debug!("http2 connection to control established");
Ok(h2_conn)
}