Skip to main content

cloudflare_quick_tunnel/
manager.rs

1//! Top-level orchestrator. `QuickTunnelManager::start()` runs:
2//!
3//!   1. POST `/tunnel`               → `api::request_tunnel`
4//!   2. Edge discovery               → `edge::discover`
5//!   3. QUIC dial + register         → `connect_cycle` (helper)
6//!   4. Spawn reactor task           → owns the QUIC connection,
7//!      runs the supervisor, and reconnects on edge drop.
8//!   5. Return handle holding `url` + shutdown channel.
9//!
10//! The reactor task is the long-lived owner: it cycles between
11//! "supervise current connection" and "reconnect with backoff +
12//! `replace_existing=true`" until the operator signals shutdown or
13//! the reconnect attempt count exhausts.
14
15use std::time::Duration;
16
17use quinn::Endpoint;
18use tokio::sync::oneshot;
19use tracing::{debug, info, warn};
20use uuid::Uuid;
21
22use crate::api::{request_tunnel, DEFAULT_SERVICE_URL, DEFAULT_USER_AGENT};
23use crate::edge::{discover, IpVersionFilter};
24use crate::error::TunnelError;
25use crate::quic_dial::{build_endpoint, dial_any};
26use crate::rpc::{register_connection, ConnectionOptions, ControlSession, TunnelAuth};
27use crate::supervisor::{self, SupervisorExit, SupervisorMetrics};
28
29/// Default budget for POST + discovery + handshake + register.
30pub const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(30);
31
32/// Default budget for the `unregisterConnection` RPC on shutdown.
33pub const DEFAULT_GRACE_PERIOD: Duration = Duration::from_secs(30);
34
35/// Hard cap on consecutive reconnect failures before the reactor
36/// gives up. Each failure widens the backoff (1s → 30s, exponential
37/// with a cap), so 10 attempts spans roughly 90s of trying.
38pub const MAX_RECONNECT_ATTEMPTS: u32 = 10;
39
40/// Crate version, baked into `ConnectionOptions.client.version`.
41pub const CLIENT_VERSION: &str = concat!("cloudflare-quick-tunnel/", env!("CARGO_PKG_VERSION"));
42
43#[derive(Debug, Clone, Default)]
44pub struct TunnelMetrics {
45    pub streams_total: u64,
46    pub bytes_in: u64,
47    pub bytes_out: u64,
48    pub reconnects: u64,
49}
50
51pub struct QuickTunnelHandle {
52    pub url: String,
53    pub tunnel_id: Uuid,
54    pub account_tag: String,
55    pub location: String,
56    shutdown_tx: Option<oneshot::Sender<()>>,
57    reactor: Option<tokio::task::JoinHandle<()>>,
58    metrics_view: SupervisorMetrics,
59    reconnects: std::sync::Arc<std::sync::atomic::AtomicU64>,
60}
61
62impl QuickTunnelHandle {
63    pub fn metrics(&self) -> TunnelMetrics {
64        let (s, i, o) = self.metrics_view.snapshot();
65        TunnelMetrics {
66            streams_total: s,
67            bytes_in: i,
68            bytes_out: o,
69            reconnects: self.reconnects.load(std::sync::atomic::Ordering::Relaxed),
70        }
71    }
72
73    /// Signal the reactor to drain + unregister + close. Awaits
74    /// the reactor task to fully finish.
75    pub async fn shutdown_with(mut self, _grace: Duration) -> Result<(), TunnelError> {
76        // `_grace` is honoured inside the reactor — it calls
77        // `ControlSession::shutdown_graceful(DEFAULT_GRACE_PERIOD)`
78        // unconditionally on the way out. We keep the grace param
79        // in the API for forward compatibility; once it matters,
80        // pass it through via a richer shutdown command.
81        if let Some(tx) = self.shutdown_tx.take() {
82            let _ = tx.send(());
83        }
84        if let Some(j) = self.reactor.take() {
85            j.await
86                .map_err(|e| TunnelError::Internal(format!("reactor join: {e}")))?;
87        }
88        Ok(())
89    }
90
91    pub async fn shutdown(self) -> Result<(), TunnelError> {
92        self.shutdown_with(DEFAULT_GRACE_PERIOD).await
93    }
94}
95
96impl Drop for QuickTunnelHandle {
97    fn drop(&mut self) {
98        if let Some(tx) = self.shutdown_tx.take() {
99            let _ = tx.send(());
100        }
101        // We can't await the reactor here — Drop is sync. The
102        // detached task winds down on its own.
103    }
104}
105
106pub struct QuickTunnelManager {
107    pub local_port: u16,
108    pub discovery_timeout: Duration,
109    pub service_url: String,
110    pub user_agent: String,
111}
112
113impl QuickTunnelManager {
114    pub fn new(local_port: u16) -> Self {
115        Self {
116            local_port,
117            discovery_timeout: DEFAULT_HANDSHAKE_TIMEOUT,
118            service_url: DEFAULT_SERVICE_URL.into(),
119            user_agent: DEFAULT_USER_AGENT.into(),
120        }
121    }
122
123    pub fn with_timeout(mut self, d: Duration) -> Self {
124        self.discovery_timeout = d;
125        self
126    }
127
128    pub fn with_service_url(mut self, url: impl Into<String>) -> Self {
129        self.service_url = url.into();
130        self
131    }
132
133    pub fn with_user_agent(mut self, ua: impl Into<String>) -> Self {
134        self.user_agent = ua.into();
135        self
136    }
137
138    pub async fn start(self) -> Result<QuickTunnelHandle, TunnelError> {
139        // Only the first connect cycle is bounded by `discovery_timeout`.
140        // Subsequent reconnects from the reactor have their own backoff.
141        tokio::time::timeout(self.discovery_timeout, self.start_inner())
142            .await
143            .map_err(|_| TunnelError::Internal("start() exceeded discovery_timeout".into()))?
144    }
145
146    async fn start_inner(self) -> Result<QuickTunnelHandle, TunnelError> {
147        // 1. POST /tunnel — single call, the same credentials are
148        //    reused on every reconnect (the edge keys routing off
149        //    `account_tag + tunnel_id`).
150        let tunnel = request_tunnel(&self.service_url, &self.user_agent).await?;
151        info!(hostname = %tunnel.hostname, id = %tunnel.id, "got quick tunnel");
152        let tunnel_id = Uuid::parse_str(&tunnel.id)
153            .map_err(|e| TunnelError::Internal(format!("tunnel.id is not a uuid: {e}")))?;
154        let url = if tunnel.hostname.starts_with("https://") {
155            tunnel.hostname.clone()
156        } else {
157            format!("https://{}", tunnel.hostname)
158        };
159
160        let auth = TunnelAuth {
161            account_tag: tunnel.account_tag.clone(),
162            tunnel_secret: tunnel.secret.clone(),
163        };
164
165        // 2. Build the long-lived QUIC client endpoint. Reused
166        //    across reconnect cycles so the UDP socket stays stable.
167        let endpoint = build_endpoint()?;
168
169        // 3. First connect cycle — `replace_existing=false`.
170        let (conn, control, location) =
171            connect_cycle(&endpoint, &auth, tunnel_id, CLIENT_VERSION, false).await?;
172        info!(%location, "first registration succeeded");
173
174        let metrics = SupervisorMetrics::default();
175        let reconnects = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0));
176
177        let (shutdown_tx, shutdown_rx) = oneshot::channel();
178
179        let reactor = tokio::spawn(reactor_loop(
180            self.local_port,
181            endpoint,
182            auth,
183            tunnel_id,
184            metrics.clone(),
185            reconnects.clone(),
186            conn,
187            control,
188            shutdown_rx,
189        ));
190
191        Ok(QuickTunnelHandle {
192            url,
193            tunnel_id,
194            account_tag: tunnel.account_tag,
195            location,
196            shutdown_tx: Some(shutdown_tx),
197            reactor: Some(reactor),
198            metrics_view: metrics,
199            reconnects,
200        })
201    }
202}
203
204/// Single attempt: dial the next edge, send `RegisterConnection`.
205/// `replace_existing=true` on reconnects so the edge accepts the
206/// new conn for `conn_index=0` (the previous one was dropped).
207async fn connect_cycle(
208    endpoint: &Endpoint,
209    auth: &TunnelAuth,
210    tunnel_id: Uuid,
211    client_version: &str,
212    replace_existing: bool,
213) -> Result<(quinn::Connection, ControlSession, String), TunnelError> {
214    let edges = discover(IpVersionFilter::Auto).await?;
215    let cap = edges.len().min(5);
216    let conn = dial_any(endpoint, &edges[..cap]).await?;
217
218    let mut options = ConnectionOptions::default_for_quick_tunnel(client_version);
219    options.replace_existing = replace_existing;
220
221    let (details, control) = register_connection(&conn, auth, tunnel_id, 0, &options).await?;
222    Ok((conn, control, details.location))
223}
224
225#[allow(clippy::too_many_arguments)]
226async fn reactor_loop(
227    local_port: u16,
228    endpoint: Endpoint,
229    auth: TunnelAuth,
230    tunnel_id: Uuid,
231    metrics: SupervisorMetrics,
232    reconnects: std::sync::Arc<std::sync::atomic::AtomicU64>,
233    mut conn: quinn::Connection,
234    mut control: ControlSession,
235    mut shutdown_rx: oneshot::Receiver<()>,
236) {
237    debug!("reactor loop started");
238    loop {
239        // ── Supervise current connection ─────────────────────────────────────
240        let (sup_tx, sup_rx) = oneshot::channel();
241        let metrics_for_cycle = metrics.clone();
242        let exit = tokio::select! {
243            biased;
244            _ = &mut shutdown_rx => {
245                // Caller wants us out. Forward to the supervisor
246                // so its accept loop sees the shutdown branch and
247                // closes the QUIC connection cleanly.
248                let _ = sup_tx.send(());
249                SupervisorExit::Shutdown
250            }
251            exit = supervisor::run(conn, local_port, metrics_for_cycle, sup_rx) => exit,
252        };
253
254        match exit {
255            SupervisorExit::Shutdown => {
256                // Graceful unregister on the way out. Best-effort
257                // — the edge may have closed already.
258                control.shutdown_graceful(DEFAULT_GRACE_PERIOD).await;
259                debug!("reactor: clean shutdown");
260                return;
261            }
262            SupervisorExit::ConnectionLost => {
263                // Throw away the dead control session — its RPC
264                // stream lives on a connection that's gone.
265                drop(control);
266
267                // ── Reconnect with exponential backoff ────────────────────
268                let mut attempt = 0u32;
269                loop {
270                    attempt += 1;
271                    if attempt > MAX_RECONNECT_ATTEMPTS {
272                        warn!(
273                            "reactor: giving up after {} reconnect attempts",
274                            MAX_RECONNECT_ATTEMPTS
275                        );
276                        return;
277                    }
278                    let delay = backoff(attempt);
279                    warn!(attempt, ?delay, "reactor: scheduling reconnect");
280                    tokio::select! {
281                        biased;
282                        _ = &mut shutdown_rx => {
283                            debug!("reactor: shutdown signal during reconnect backoff");
284                            return;
285                        }
286                        _ = tokio::time::sleep(delay) => {}
287                    }
288                    match connect_cycle(&endpoint, &auth, tunnel_id, CLIENT_VERSION, true).await {
289                        Ok((new_conn, new_control, new_loc)) => {
290                            info!(attempt, location = %new_loc, "reactor: reconnect succeeded");
291                            reconnects.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
292                            conn = new_conn;
293                            control = new_control;
294                            break;
295                        }
296                        Err(e) => {
297                            warn!(attempt, error = %e, "reactor: reconnect failed");
298                            // continue inner loop with bigger backoff
299                        }
300                    }
301                }
302            }
303        }
304    }
305}
306
307/// Exponential backoff with a 30s ceiling: 1s, 2s, 4s, 8s, 16s, 30s, 30s, …
308fn backoff(attempt: u32) -> Duration {
309    let secs = 1u64.checked_shl(attempt.saturating_sub(1)).unwrap_or(30);
310    Duration::from_secs(secs.min(30))
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn backoff_curve() {
319        assert_eq!(backoff(1), Duration::from_secs(1));
320        assert_eq!(backoff(2), Duration::from_secs(2));
321        assert_eq!(backoff(3), Duration::from_secs(4));
322        assert_eq!(backoff(4), Duration::from_secs(8));
323        assert_eq!(backoff(5), Duration::from_secs(16));
324        assert_eq!(backoff(6), Duration::from_secs(30));
325        assert_eq!(backoff(20), Duration::from_secs(30));
326    }
327}