1use core::fmt::{Debug, Formatter};
2use std::time::Instant;
3
4use tokio::{
5 io::{AsyncRead, AsyncWrite},
6 net::TcpStream,
7};
8use tokio_util::future::FutureExt;
9use ts_bitset::BitsetDyn;
10use ts_capabilityversion::CapabilityVersion;
11use ts_http_util::{BytesBody, Http2};
12use url::Url;
13
14use crate::{DialCandidate, DialMode, DialPlan, Error, InternalErrorKind, Operation};
15
16pub struct ControlDialer {
18 plan: DialPlan,
19 epoch: usize,
20 timestamp: Instant,
21 attempted_candidates: ts_dynbitset::DynBitset,
22}
23
24impl Default for ControlDialer {
25 fn default() -> Self {
26 Self {
27 plan: DialPlan::default(),
28 epoch: 0,
29 timestamp: Instant::now(),
30 attempted_candidates: Default::default(),
31 }
32 }
33}
34
35pub trait TcpDialer {
39 fn dial(
48 self,
49 host: &str,
50 port: u16,
51 ) -> impl Future<Output = tokio::io::Result<TcpStream>> + Send;
52}
53
54enum ControlTcpDialer<'a> {
55 UseDns,
56 Planned {
57 attempted: &'a mut ts_dynbitset::DynBitset,
58 candidate: &'a DialCandidate,
59 index: usize,
60 },
61}
62
63impl Debug for ControlTcpDialer<'_> {
64 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
65 match self {
66 ControlTcpDialer::UseDns => write!(f, "TcpDialer::Dns"),
67 ControlTcpDialer::Planned { candidate, .. } => match &candidate.mode {
68 DialMode::Ip(ip) => f.debug_tuple("TcpDialer::Ip").field(ip).finish(),
69 DialMode::Ace { ip: Some(ip), host } => f
70 .debug_tuple("TcpDialer::Ace")
71 .field(ip)
72 .field(host)
73 .finish(),
74 DialMode::Ace { host, .. } => f.debug_tuple("TcpDialer::Ace").field(host).finish(),
75 },
76 }
77 }
78}
79
80impl TcpDialer for ControlTcpDialer<'_> {
81 async fn dial(self, host: &str, port: u16) -> tokio::io::Result<TcpStream> {
82 match self {
83 ControlTcpDialer::UseDns => TcpStream::connect(format!("{host}:{port}")).await,
84 ControlTcpDialer::Planned {
85 candidate,
86 attempted: used,
87 index,
88 } => {
89 used.set(index);
90
91 match candidate.mode {
92 DialMode::Ip(ip) => {
93 TcpStream::connect((ip, port))
94 .timeout(candidate.timeout)
95 .await?
96 }
97 DialMode::Ace { .. } => {
98 unimplemented!()
99 }
100 }
101 }
102 }
103 }
104}
105
106impl ControlDialer {
107 pub fn update_dial_plan(&mut self, plan: &DialPlan) -> bool {
112 if &self.plan == plan {
113 return false;
114 }
115
116 self.plan = plan.clone();
117 self.epoch += 1;
118 self.timestamp = Instant::now();
119
120 true
121 }
122
123 pub fn clear_attempted(&mut self) {
128 self.attempted_candidates.clear_all();
129 }
130
131 pub fn next_dialer(&mut self) -> impl TcpDialer + Debug {
139 match &self.plan {
140 DialPlan::UseDns => ControlTcpDialer::UseDns,
141 DialPlan::Plan(candidates) => {
142 let mut selected_candidate: Option<(usize, usize, &DialCandidate)> = None;
143 let now = Instant::now();
144
145 for (i, candidate) in candidates.iter().enumerate() {
147 if self.attempted_candidates.test(i) {
148 continue;
149 }
150
151 let start_after = self.timestamp + candidate.start_delay_sec;
152 if start_after > now {
153 continue;
154 }
155
156 if matches!(candidate.mode, DialMode::Ace { .. }) {
157 continue;
159 }
160
161 if selected_candidate.is_none_or(|(prio, _idx, elem)| prio < elem.priority) {
162 selected_candidate = Some((candidate.priority, i, candidate));
163 }
164 }
165
166 let (i, candidate) = match selected_candidate {
167 Some((_prio, i, elem)) => (i, elem),
168 None => {
169 tracing::warn!(
170 "no dialer candidates available: falling back to system dns"
171 );
172 return ControlTcpDialer::UseDns;
173 }
174 };
175
176 ControlTcpDialer::Planned {
177 candidate,
178 index: i,
179 attempted: &mut self.attempted_candidates,
180 }
181 }
182 }
183 }
184
185 #[tracing::instrument(skip_all, fields(control_url = %url))]
188 pub async fn full_connect_next(
189 &mut self,
190 url: &Url,
191 machine_keys: &ts_keys::MachineKeyPair,
192 allow_http_key_fetch: bool,
193 ) -> Result<Http2<BytesBody>, Error> {
194 let next = self.next_dialer();
195 tracing::trace!(selected_control_dialer = ?next);
196
197 let host = url.host_str().ok_or(Error::InvalidUrl(url.clone()))?;
198 let port = url
199 .port_or_known_default()
200 .ok_or_else(|| Error::InvalidUrl(url.clone()))?;
201
202 let conn = next.dial(host, port).await.map_err(|e| {
203 tracing::error!(error = %e, %url, %host, port, "dialing tcp");
204 Error::Internal(InternalErrorKind::Io, Operation::ConnectToControlServer)
205 })?;
206
207 tracing::debug!(
208 remote_endpoint = ?conn.peer_addr(),
209 "tcp connection to control"
210 );
211
212 let client = complete_connection(url, machine_keys, conn, allow_http_key_fetch).await?;
213
214 Ok(client)
215 }
216}
217
218pub async fn complete_connection<Io>(
224 url: &Url,
225 machine_keys: &ts_keys::MachineKeyPair,
226 stream: Io,
227 allow_http_key_fetch: bool,
228) -> Result<Http2<BytesBody>, Error>
229where
230 Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
231{
232 let h1_client = match url.scheme() {
233 "https" => {
234 let conn = ts_tls_util::connect(
235 ts_tls_util::server_name(url).ok_or_else(|| Error::InvalidUrl(url.clone()))?,
236 stream,
237 )
238 .await
239 .map_err(|e| {
240 tracing::error!(error = %e, "establishing tls connection");
241 Error::io_error(e, Operation::ConnectToControlServer)
242 })?;
243 ts_http_util::http1::connect(conn).await?
244 }
245 "http" => ts_http_util::http1::connect(stream).await?,
246 other => {
247 tracing::error!(invalid_scheme = other);
248 return Err(Error::InvalidUrl(url.clone()));
249 }
250 };
251 let control_public_key = crate::tokio::fetch_control_key(url, allow_http_key_fetch).await?;
252
253 let (handshake, init_msg) = ts_control_noise::Handshake::initialize(
254 &crate::tokio::CONTROL_PROTOCOL_VERSION,
255 &machine_keys.private,
256 &control_public_key,
257 CapabilityVersion::CURRENT,
258 );
259
260 let conn = crate::tokio::upgrade_ts2021(url, &init_msg, handshake, h1_client).await?;
261 let conn = crate::tokio::read_challenge_packet(conn).await?;
262
263 let h2_conn = ts_http_util::http2::connect(conn).await?;
264 tracing::debug!("http2 connection to control established");
265
266 Ok(h2_conn)
267}