Skip to main content

ts_control/
control_dialer.rs

1use core::fmt::{Debug, Formatter};
2use std::time::Instant;
3
4use tokio::{
5    io::{AsyncRead, AsyncWrite},
6    net::TcpStream,
7};
8use tokio_util::future::FutureExt;
9use ts_bitset::BitsetDyn;
10use ts_capabilityversion::CapabilityVersion;
11use ts_http_util::{BytesBody, Http2};
12use url::Url;
13
14use crate::{DialCandidate, DialMode, DialPlan, Error, InternalErrorKind, Operation};
15
16/// Manages state for control dial plan and handles selection of successive dial candidates.
17pub struct ControlDialer {
18    plan: DialPlan,
19    epoch: usize,
20    timestamp: Instant,
21    attempted_candidates: ts_dynbitset::DynBitset,
22}
23
24impl Default for ControlDialer {
25    fn default() -> Self {
26        Self {
27            plan: DialPlan::default(),
28            epoch: 0,
29            timestamp: Instant::now(),
30            attempted_candidates: Default::default(),
31        }
32    }
33}
34
35/// Creates a TCP connection on the basis of a specific [`DialCandidate`].
36///
37/// Produced by [`ControlDialer::next_dialer`].
38pub trait TcpDialer {
39    /// Open a TCP connection using the [`DialCandidate`] assigned to this dialer.
40    ///
41    /// - `host` is used if the [`DialCandidate`] requires DNS lookup.
42    ///   **Ignored** for plain IP [`DialCandidate`]s.
43    /// - `port` is the TCP port to connect to.
44    ///
45    /// Calling this function marks the current candidate as "attempted": the next call to
46    /// [`ControlDialer::next_dialer`] will use the next available candidate.
47    fn dial(
48        self,
49        host: &str,
50        port: u16,
51    ) -> impl Future<Output = tokio::io::Result<TcpStream>> + Send;
52}
53
54enum ControlTcpDialer<'a> {
55    UseDns,
56    Planned {
57        attempted: &'a mut ts_dynbitset::DynBitset,
58        candidate: &'a DialCandidate,
59        index: usize,
60    },
61}
62
63impl Debug for ControlTcpDialer<'_> {
64    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
65        match self {
66            ControlTcpDialer::UseDns => write!(f, "TcpDialer::Dns"),
67            ControlTcpDialer::Planned { candidate, .. } => match &candidate.mode {
68                DialMode::Ip(ip) => f.debug_tuple("TcpDialer::Ip").field(ip).finish(),
69                DialMode::Ace { ip: Some(ip), host } => f
70                    .debug_tuple("TcpDialer::Ace")
71                    .field(ip)
72                    .field(host)
73                    .finish(),
74                DialMode::Ace { host, .. } => f.debug_tuple("TcpDialer::Ace").field(host).finish(),
75            },
76        }
77    }
78}
79
80impl TcpDialer for ControlTcpDialer<'_> {
81    async fn dial(self, host: &str, port: u16) -> tokio::io::Result<TcpStream> {
82        match self {
83            ControlTcpDialer::UseDns => TcpStream::connect(format!("{host}:{port}")).await,
84            ControlTcpDialer::Planned {
85                candidate,
86                attempted: used,
87                index,
88            } => {
89                used.set(index);
90
91                match candidate.mode {
92                    DialMode::Ip(ip) => {
93                        TcpStream::connect((ip, port))
94                            .timeout(candidate.timeout)
95                            .await?
96                    }
97                    DialMode::Ace { .. } => {
98                        unimplemented!()
99                    }
100                }
101            }
102        }
103    }
104}
105
106impl ControlDialer {
107    /// Update the stored dial plan with the new `plan`.
108    ///
109    /// Returns whether the dial plan changed. Resubmission of the same dial plan is
110    /// idempotent.
111    pub fn update_dial_plan(&mut self, plan: &DialPlan) -> bool {
112        if &self.plan == plan {
113            return false;
114        }
115
116        self.plan = plan.clone();
117        self.epoch += 1;
118        self.timestamp = Instant::now();
119
120        true
121    }
122
123    /// Clear the set of attempted dial candidates.
124    ///
125    /// This will cause future connection attempts to retry all available dialers in
126    /// priority order.
127    pub fn clear_attempted(&mut self) {
128        self.attempted_candidates.clear_all();
129    }
130
131    /// Get the next dialer candidate from the dial plan.
132    ///
133    /// If all dialers have already been tried, falls back to system DNS.
134    ///
135    /// NB: the returned [`TcpDialer`] does not mark its corresponding candidate as having
136    /// been attempted until [`TcpDialer::dial`] is called -- it is fine semantically to
137    /// drop the returned dialer without calling `dial`.
138    pub fn next_dialer(&mut self) -> impl TcpDialer + Debug {
139        match &self.plan {
140            DialPlan::UseDns => ControlTcpDialer::UseDns,
141            DialPlan::Plan(candidates) => {
142                let mut selected_candidate: Option<(usize, usize, &DialCandidate)> = None;
143                let now = Instant::now();
144
145                // TODO(npry): ensure candidate sorting, optimistically stop early
146                for (i, candidate) in candidates.iter().enumerate() {
147                    if self.attempted_candidates.test(i) {
148                        continue;
149                    }
150
151                    let start_after = self.timestamp + candidate.start_delay_sec;
152                    if start_after > now {
153                        continue;
154                    }
155
156                    if matches!(candidate.mode, DialMode::Ace { .. }) {
157                        // TODO(npry): ACE unsupported
158                        continue;
159                    }
160
161                    if selected_candidate.is_none_or(|(prio, _idx, elem)| prio < elem.priority) {
162                        selected_candidate = Some((candidate.priority, i, candidate));
163                    }
164                }
165
166                let (i, candidate) = match selected_candidate {
167                    Some((_prio, i, elem)) => (i, elem),
168                    None => {
169                        tracing::warn!(
170                            "no dialer candidates available: falling back to system dns"
171                        );
172                        return ControlTcpDialer::UseDns;
173                    }
174                };
175
176                ControlTcpDialer::Planned {
177                    candidate,
178                    index: i,
179                    attempted: &mut self.attempted_candidates,
180                }
181            }
182        }
183    }
184
185    /// Convenience wrapper for [`next_dialer`][ControlDialer::next_dialer] followed by
186    /// [`complete_connection`].
187    #[tracing::instrument(skip_all, fields(control_url = %url))]
188    pub async fn full_connect_next(
189        &mut self,
190        url: &Url,
191        machine_keys: &ts_keys::MachineKeyPair,
192        allow_http_key_fetch: bool,
193    ) -> Result<Http2<BytesBody>, Error> {
194        let next = self.next_dialer();
195        tracing::trace!(selected_control_dialer = ?next);
196
197        let host = url.host_str().ok_or(Error::InvalidUrl(url.clone()))?;
198        let port = url
199            .port_or_known_default()
200            .ok_or_else(|| Error::InvalidUrl(url.clone()))?;
201
202        let conn = next.dial(host, port).await.map_err(|e| {
203            tracing::error!(error = %e, %url, %host, port, "dialing tcp");
204            Error::Internal(InternalErrorKind::Io, Operation::ConnectToControlServer)
205        })?;
206
207        tracing::debug!(
208            remote_endpoint = ?conn.peer_addr(),
209            "tcp connection to control"
210        );
211
212        let client = complete_connection(url, machine_keys, conn, allow_http_key_fetch).await?;
213
214        Ok(client)
215    }
216}
217
218/// Complete a connection to control over the supplied I/O `stream`.
219///
220/// Establishes an http1 connection over `stream`, wrapping it in a TLS connection if
221/// `url`'s scheme is `https`. Then upgrades the connection over ts2021 and establishes an
222/// inner http2 connection.
223pub async fn complete_connection<Io>(
224    url: &Url,
225    machine_keys: &ts_keys::MachineKeyPair,
226    stream: Io,
227    allow_http_key_fetch: bool,
228) -> Result<Http2<BytesBody>, Error>
229where
230    Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
231{
232    let h1_client = match url.scheme() {
233        "https" => {
234            let conn = ts_tls_util::connect(
235                ts_tls_util::server_name(url).ok_or_else(|| Error::InvalidUrl(url.clone()))?,
236                stream,
237            )
238            .await
239            .map_err(|e| {
240                tracing::error!(error = %e, "establishing tls connection");
241                Error::io_error(e, Operation::ConnectToControlServer)
242            })?;
243            ts_http_util::http1::connect(conn).await?
244        }
245        "http" => ts_http_util::http1::connect(stream).await?,
246        other => {
247            tracing::error!(invalid_scheme = other);
248            return Err(Error::InvalidUrl(url.clone()));
249        }
250    };
251    let control_public_key = crate::tokio::fetch_control_key(url, allow_http_key_fetch).await?;
252
253    let (handshake, init_msg) = ts_control_noise::Handshake::initialize(
254        &crate::tokio::CONTROL_PROTOCOL_VERSION,
255        &machine_keys.private,
256        &control_public_key,
257        CapabilityVersion::CURRENT,
258    );
259
260    let conn = crate::tokio::upgrade_ts2021(url, &init_msg, handshake, h1_client).await?;
261    let conn = crate::tokio::read_challenge_packet(conn).await?;
262
263    let h2_conn = ts_http_util::http2::connect(conn).await?;
264    tracing::debug!("http2 connection to control established");
265
266    Ok(h2_conn)
267}