1use std::sync::Arc;
16use std::time::Duration;
17
18use quinn::Endpoint;
19use tokio::sync::oneshot;
20use tracing::{debug, info, warn};
21use uuid::Uuid;
22
23use crate::api::{request_tunnel, DEFAULT_SERVICE_URL, DEFAULT_USER_AGENT};
24use crate::edge::{discover, IpVersionFilter};
25use crate::error::TunnelError;
26use crate::pool::Pool;
27use crate::quic_dial::{build_endpoint, dial_any};
28use crate::rpc::{register_connection, ConnectionOptions, ControlSession, TunnelAuth};
29use crate::supervisor::{self, SupervisorExit, SupervisorMetrics};
30
31pub const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(30);
33
34pub const DEFAULT_GRACE_PERIOD: Duration = Duration::from_secs(30);
36
37pub const MAX_RECONNECT_ATTEMPTS: u32 = 10;
41
42pub const DEFAULT_HA_CONNECTIONS: u8 = 2;
48
49pub const MAX_HA_CONNECTIONS: u8 = 4;
52
53pub const CLIENT_VERSION: &str = concat!("cloudflare-quick-tunnel/", env!("CARGO_PKG_VERSION"));
55
56#[derive(Debug, Clone, Default)]
57pub struct TunnelMetrics {
58 pub streams_total: u64,
59 pub bytes_in: u64,
60 pub bytes_out: u64,
61 pub reconnects: u64,
62}
63
64pub struct QuickTunnelHandle {
65 pub url: String,
66 pub tunnel_id: Uuid,
67 pub account_tag: String,
68 pub location: String,
72 shutdown: Arc<tokio::sync::Notify>,
76 reactors: Vec<tokio::task::JoinHandle<()>>,
77 metrics_view: SupervisorMetrics,
78 reconnects: Arc<std::sync::atomic::AtomicU64>,
79}
80
81impl QuickTunnelHandle {
82 pub fn metrics(&self) -> TunnelMetrics {
83 let (s, i, o) = self.metrics_view.snapshot();
84 TunnelMetrics {
85 streams_total: s,
86 bytes_in: i,
87 bytes_out: o,
88 reconnects: self.reconnects.load(std::sync::atomic::Ordering::Relaxed),
89 }
90 }
91
92 pub async fn shutdown_with(mut self, _grace: Duration) -> Result<(), TunnelError> {
95 self.shutdown.notify_waiters();
101 for j in self.reactors.drain(..) {
102 j.await
103 .map_err(|e| TunnelError::Internal(format!("reactor join: {e}")))?;
104 }
105 Ok(())
106 }
107
108 pub async fn shutdown(self) -> Result<(), TunnelError> {
109 self.shutdown_with(DEFAULT_GRACE_PERIOD).await
110 }
111}
112
113impl Drop for QuickTunnelHandle {
114 fn drop(&mut self) {
115 self.shutdown.notify_waiters();
118 }
119}
120
121pub struct QuickTunnelManager {
122 pub local_port: u16,
123 pub discovery_timeout: Duration,
124 pub service_url: String,
125 pub user_agent: String,
126 pub ha_connections: u8,
127}
128
129impl QuickTunnelManager {
130 pub fn new(local_port: u16) -> Self {
131 Self {
132 local_port,
133 discovery_timeout: DEFAULT_HANDSHAKE_TIMEOUT,
134 service_url: DEFAULT_SERVICE_URL.into(),
135 user_agent: DEFAULT_USER_AGENT.into(),
136 ha_connections: DEFAULT_HA_CONNECTIONS,
137 }
138 }
139
140 pub fn with_timeout(mut self, d: Duration) -> Self {
141 self.discovery_timeout = d;
142 self
143 }
144
145 pub fn with_service_url(mut self, url: impl Into<String>) -> Self {
146 self.service_url = url.into();
147 self
148 }
149
150 pub fn with_user_agent(mut self, ua: impl Into<String>) -> Self {
151 self.user_agent = ua.into();
152 self
153 }
154
155 pub fn with_ha_connections(mut self, n: u8) -> Self {
160 self.ha_connections = n.clamp(1, MAX_HA_CONNECTIONS);
161 self
162 }
163
164 pub async fn start(self) -> Result<QuickTunnelHandle, TunnelError> {
165 tokio::time::timeout(self.discovery_timeout, self.start_inner())
168 .await
169 .map_err(|_| TunnelError::Internal("start() exceeded discovery_timeout".into()))?
170 }
171
172 async fn start_inner(self) -> Result<QuickTunnelHandle, TunnelError> {
173 let tunnel = request_tunnel(&self.service_url, &self.user_agent).await?;
177 info!(hostname = %tunnel.hostname, id = %tunnel.id, ha = self.ha_connections, "got quick tunnel");
178 let tunnel_id = Uuid::parse_str(&tunnel.id)
179 .map_err(|e| TunnelError::Internal(format!("tunnel.id is not a uuid: {e}")))?;
180 let url = if tunnel.hostname.starts_with("https://") {
181 tunnel.hostname.clone()
182 } else {
183 format!("https://{}", tunnel.hostname)
184 };
185
186 let auth = TunnelAuth {
187 account_tag: tunnel.account_tag.clone(),
188 tunnel_secret: tunnel.secret.clone(),
189 };
190
191 let endpoint = build_endpoint()?;
194
195 let (conn0, control0, location0) =
199 connect_cycle(&endpoint, &auth, tunnel_id, CLIENT_VERSION, 0, false).await?;
200 info!(%location0, conn_index = 0, "first registration succeeded");
201
202 let metrics = SupervisorMetrics::default();
203 let reconnects = Arc::new(std::sync::atomic::AtomicU64::new(0));
204 let shutdown = Arc::new(tokio::sync::Notify::new());
205 let pool = Arc::new(Pool::new(self.local_port));
209
210 let mut reactors = Vec::with_capacity(self.ha_connections as usize);
213 reactors.push(tokio::spawn(reactor_loop(
214 self.local_port,
215 endpoint.clone(),
216 auth.clone(),
217 tunnel_id,
218 0,
219 metrics.clone(),
220 reconnects.clone(),
221 pool.clone(),
222 conn0,
223 control0,
224 shutdown.clone(),
225 )));
226
227 for idx in 1..self.ha_connections {
232 let endpoint = endpoint.clone();
233 let auth = auth.clone();
234 let metrics = metrics.clone();
235 let reconnects = reconnects.clone();
236 let shutdown = shutdown.clone();
237 let pool = pool.clone();
238 let local_port = self.local_port;
239 reactors.push(tokio::spawn(async move {
240 match connect_cycle(&endpoint, &auth, tunnel_id, CLIENT_VERSION, idx, false).await {
241 Ok((conn, control, location)) => {
242 info!(%location, conn_index = idx, "HA registration succeeded");
243 reactor_loop(
244 local_port, endpoint, auth, tunnel_id, idx, metrics, reconnects, pool,
245 conn, control, shutdown,
246 )
247 .await;
248 }
249 Err(e) => {
250 warn!(error = %e, conn_index = idx, "HA registration failed; will retry");
251 reactor_loop_after_failure(
252 local_port, endpoint, auth, tunnel_id, idx, metrics, reconnects, pool,
253 shutdown,
254 )
255 .await;
256 }
257 }
258 }));
259 }
260
261 Ok(QuickTunnelHandle {
262 url,
263 tunnel_id,
264 account_tag: tunnel.account_tag,
265 location: location0,
266 shutdown,
267 reactors,
268 metrics_view: metrics,
269 reconnects,
270 })
271 }
272}
273
274async fn connect_cycle(
280 endpoint: &Endpoint,
281 auth: &TunnelAuth,
282 tunnel_id: Uuid,
283 client_version: &str,
284 conn_index: u8,
285 replace_existing: bool,
286) -> Result<(quinn::Connection, ControlSession, String), TunnelError> {
287 let edges = discover(IpVersionFilter::Auto).await?;
288 let cap = edges.len().min(5);
289 let conn = dial_any(endpoint, &edges[..cap]).await?;
290
291 let mut options = ConnectionOptions::default_for_quick_tunnel(client_version);
292 options.replace_existing = replace_existing;
293
294 let (details, control) =
295 register_connection(&conn, auth, tunnel_id, conn_index, &options).await?;
296 Ok((conn, control, details.location))
297}
298
299#[allow(clippy::too_many_arguments)]
300async fn reactor_loop(
301 local_port: u16,
302 endpoint: Endpoint,
303 auth: TunnelAuth,
304 tunnel_id: Uuid,
305 conn_index: u8,
306 metrics: SupervisorMetrics,
307 reconnects: Arc<std::sync::atomic::AtomicU64>,
308 pool: Arc<Pool>,
309 mut conn: quinn::Connection,
310 mut control: ControlSession,
311 shutdown: Arc<tokio::sync::Notify>,
312) {
313 debug!(conn_index, "reactor loop started");
314 loop {
315 let (sup_tx, sup_rx) = oneshot::channel();
317 let metrics_for_cycle = metrics.clone();
318 let shutdown_wait = shutdown.notified();
319 tokio::pin!(shutdown_wait);
320 let exit = tokio::select! {
321 biased;
322 _ = &mut shutdown_wait => {
323 let _ = sup_tx.send(());
327 SupervisorExit::Shutdown
328 }
329 exit = supervisor::run(conn, local_port, metrics_for_cycle, pool.clone(), sup_rx) => exit,
330 };
331
332 match exit {
333 SupervisorExit::Shutdown => {
334 control.shutdown_graceful(DEFAULT_GRACE_PERIOD).await;
335 debug!(conn_index, "reactor: clean shutdown");
336 return;
337 }
338 SupervisorExit::ConnectionLost => {
339 drop(control);
340
341 let mut attempt = 0u32;
342 loop {
343 attempt += 1;
344 if attempt > MAX_RECONNECT_ATTEMPTS {
345 warn!(
346 conn_index,
347 "reactor: giving up after {} reconnect attempts",
348 MAX_RECONNECT_ATTEMPTS
349 );
350 return;
351 }
352 let delay = backoff(attempt);
353 warn!(conn_index, attempt, ?delay, "reactor: scheduling reconnect");
354 let shutdown_wait = shutdown.notified();
355 tokio::pin!(shutdown_wait);
356 tokio::select! {
357 biased;
358 _ = shutdown_wait => {
359 debug!(conn_index, "reactor: shutdown during reconnect backoff");
360 return;
361 }
362 _ = tokio::time::sleep(delay) => {}
363 }
364 match connect_cycle(
365 &endpoint,
366 &auth,
367 tunnel_id,
368 CLIENT_VERSION,
369 conn_index,
370 true,
371 )
372 .await
373 {
374 Ok((new_conn, new_control, new_loc)) => {
375 info!(conn_index, attempt, location = %new_loc, "reactor: reconnect succeeded");
376 reconnects.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
377 conn = new_conn;
378 control = new_control;
379 break;
380 }
381 Err(e) => {
382 warn!(attempt, error = %e, "reactor: reconnect failed");
383 }
385 }
386 }
387 }
388 }
389 }
390}
391
392#[allow(clippy::too_many_arguments)]
396async fn reactor_loop_after_failure(
397 local_port: u16,
398 endpoint: Endpoint,
399 auth: TunnelAuth,
400 tunnel_id: Uuid,
401 conn_index: u8,
402 metrics: SupervisorMetrics,
403 reconnects: Arc<std::sync::atomic::AtomicU64>,
404 pool: Arc<Pool>,
405 shutdown: Arc<tokio::sync::Notify>,
406) {
407 let mut attempt = 0u32;
408 loop {
409 attempt += 1;
410 if attempt > MAX_RECONNECT_ATTEMPTS {
411 warn!(
412 conn_index,
413 "HA reactor: giving up after {} initial-register attempts", MAX_RECONNECT_ATTEMPTS
414 );
415 return;
416 }
417 let delay = backoff(attempt);
418 warn!(
419 conn_index,
420 attempt,
421 ?delay,
422 "HA reactor: scheduling initial register retry"
423 );
424 let shutdown_wait = shutdown.notified();
425 tokio::pin!(shutdown_wait);
426 tokio::select! {
427 biased;
428 _ = shutdown_wait => return,
429 _ = tokio::time::sleep(delay) => {}
430 }
431 let result = connect_cycle(
433 &endpoint,
434 &auth,
435 tunnel_id,
436 CLIENT_VERSION,
437 conn_index,
438 false,
439 )
440 .await;
441 match result {
442 Ok((conn, control, location)) => {
443 info!(conn_index, %location, "HA leg eventually registered after {attempt} retries");
444 let shutdown = shutdown.clone();
447 reactor_loop(
448 local_port, endpoint, auth, tunnel_id, conn_index, metrics, reconnects, pool,
449 conn, control, shutdown,
450 )
451 .await;
452 return;
453 }
454 Err(e) => warn!(conn_index, attempt, error = %e, "HA register retry failed"),
455 }
456 }
457}
458
459fn backoff(attempt: u32) -> Duration {
461 let secs = 1u64.checked_shl(attempt.saturating_sub(1)).unwrap_or(30);
462 Duration::from_secs(secs.min(30))
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468
469 #[test]
470 fn backoff_curve() {
471 assert_eq!(backoff(1), Duration::from_secs(1));
472 assert_eq!(backoff(2), Duration::from_secs(2));
473 assert_eq!(backoff(3), Duration::from_secs(4));
474 assert_eq!(backoff(4), Duration::from_secs(8));
475 assert_eq!(backoff(5), Duration::from_secs(16));
476 assert_eq!(backoff(6), Duration::from_secs(30));
477 assert_eq!(backoff(20), Duration::from_secs(30));
478 }
479}