1use std::time::Duration;
16
17use quinn::Endpoint;
18use tokio::sync::oneshot;
19use tracing::{debug, info, warn};
20use uuid::Uuid;
21
22use crate::api::{request_tunnel, DEFAULT_SERVICE_URL, DEFAULT_USER_AGENT};
23use crate::edge::{discover, IpVersionFilter};
24use crate::error::TunnelError;
25use crate::quic_dial::{build_endpoint, dial_any};
26use crate::rpc::{register_connection, ConnectionOptions, ControlSession, TunnelAuth};
27use crate::supervisor::{self, SupervisorExit, SupervisorMetrics};
28
29pub const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(30);
31
32pub const DEFAULT_GRACE_PERIOD: Duration = Duration::from_secs(30);
34
35pub const MAX_RECONNECT_ATTEMPTS: u32 = 10;
39
40pub const CLIENT_VERSION: &str = concat!("cloudflare-quick-tunnel/", env!("CARGO_PKG_VERSION"));
42
43#[derive(Debug, Clone, Default)]
44pub struct TunnelMetrics {
45 pub streams_total: u64,
46 pub bytes_in: u64,
47 pub bytes_out: u64,
48 pub reconnects: u64,
49}
50
51pub struct QuickTunnelHandle {
52 pub url: String,
53 pub tunnel_id: Uuid,
54 pub account_tag: String,
55 pub location: String,
56 shutdown_tx: Option<oneshot::Sender<()>>,
57 reactor: Option<tokio::task::JoinHandle<()>>,
58 metrics_view: SupervisorMetrics,
59 reconnects: std::sync::Arc<std::sync::atomic::AtomicU64>,
60}
61
62impl QuickTunnelHandle {
63 pub fn metrics(&self) -> TunnelMetrics {
64 let (s, i, o) = self.metrics_view.snapshot();
65 TunnelMetrics {
66 streams_total: s,
67 bytes_in: i,
68 bytes_out: o,
69 reconnects: self.reconnects.load(std::sync::atomic::Ordering::Relaxed),
70 }
71 }
72
73 pub async fn shutdown_with(mut self, _grace: Duration) -> Result<(), TunnelError> {
76 if let Some(tx) = self.shutdown_tx.take() {
82 let _ = tx.send(());
83 }
84 if let Some(j) = self.reactor.take() {
85 j.await
86 .map_err(|e| TunnelError::Internal(format!("reactor join: {e}")))?;
87 }
88 Ok(())
89 }
90
91 pub async fn shutdown(self) -> Result<(), TunnelError> {
92 self.shutdown_with(DEFAULT_GRACE_PERIOD).await
93 }
94}
95
96impl Drop for QuickTunnelHandle {
97 fn drop(&mut self) {
98 if let Some(tx) = self.shutdown_tx.take() {
99 let _ = tx.send(());
100 }
101 }
104}
105
106pub struct QuickTunnelManager {
107 pub local_port: u16,
108 pub discovery_timeout: Duration,
109 pub service_url: String,
110 pub user_agent: String,
111}
112
113impl QuickTunnelManager {
114 pub fn new(local_port: u16) -> Self {
115 Self {
116 local_port,
117 discovery_timeout: DEFAULT_HANDSHAKE_TIMEOUT,
118 service_url: DEFAULT_SERVICE_URL.into(),
119 user_agent: DEFAULT_USER_AGENT.into(),
120 }
121 }
122
123 pub fn with_timeout(mut self, d: Duration) -> Self {
124 self.discovery_timeout = d;
125 self
126 }
127
128 pub fn with_service_url(mut self, url: impl Into<String>) -> Self {
129 self.service_url = url.into();
130 self
131 }
132
133 pub fn with_user_agent(mut self, ua: impl Into<String>) -> Self {
134 self.user_agent = ua.into();
135 self
136 }
137
138 pub async fn start(self) -> Result<QuickTunnelHandle, TunnelError> {
139 tokio::time::timeout(self.discovery_timeout, self.start_inner())
142 .await
143 .map_err(|_| TunnelError::Internal("start() exceeded discovery_timeout".into()))?
144 }
145
146 async fn start_inner(self) -> Result<QuickTunnelHandle, TunnelError> {
147 let tunnel = request_tunnel(&self.service_url, &self.user_agent).await?;
151 info!(hostname = %tunnel.hostname, id = %tunnel.id, "got quick tunnel");
152 let tunnel_id = Uuid::parse_str(&tunnel.id)
153 .map_err(|e| TunnelError::Internal(format!("tunnel.id is not a uuid: {e}")))?;
154 let url = if tunnel.hostname.starts_with("https://") {
155 tunnel.hostname.clone()
156 } else {
157 format!("https://{}", tunnel.hostname)
158 };
159
160 let auth = TunnelAuth {
161 account_tag: tunnel.account_tag.clone(),
162 tunnel_secret: tunnel.secret.clone(),
163 };
164
165 let endpoint = build_endpoint()?;
168
169 let (conn, control, location) =
171 connect_cycle(&endpoint, &auth, tunnel_id, CLIENT_VERSION, false).await?;
172 info!(%location, "first registration succeeded");
173
174 let metrics = SupervisorMetrics::default();
175 let reconnects = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0));
176
177 let (shutdown_tx, shutdown_rx) = oneshot::channel();
178
179 let reactor = tokio::spawn(reactor_loop(
180 self.local_port,
181 endpoint,
182 auth,
183 tunnel_id,
184 metrics.clone(),
185 reconnects.clone(),
186 conn,
187 control,
188 shutdown_rx,
189 ));
190
191 Ok(QuickTunnelHandle {
192 url,
193 tunnel_id,
194 account_tag: tunnel.account_tag,
195 location,
196 shutdown_tx: Some(shutdown_tx),
197 reactor: Some(reactor),
198 metrics_view: metrics,
199 reconnects,
200 })
201 }
202}
203
204async fn connect_cycle(
208 endpoint: &Endpoint,
209 auth: &TunnelAuth,
210 tunnel_id: Uuid,
211 client_version: &str,
212 replace_existing: bool,
213) -> Result<(quinn::Connection, ControlSession, String), TunnelError> {
214 let edges = discover(IpVersionFilter::Auto).await?;
215 let cap = edges.len().min(5);
216 let conn = dial_any(endpoint, &edges[..cap]).await?;
217
218 let mut options = ConnectionOptions::default_for_quick_tunnel(client_version);
219 options.replace_existing = replace_existing;
220
221 let (details, control) = register_connection(&conn, auth, tunnel_id, 0, &options).await?;
222 Ok((conn, control, details.location))
223}
224
225#[allow(clippy::too_many_arguments)]
226async fn reactor_loop(
227 local_port: u16,
228 endpoint: Endpoint,
229 auth: TunnelAuth,
230 tunnel_id: Uuid,
231 metrics: SupervisorMetrics,
232 reconnects: std::sync::Arc<std::sync::atomic::AtomicU64>,
233 mut conn: quinn::Connection,
234 mut control: ControlSession,
235 mut shutdown_rx: oneshot::Receiver<()>,
236) {
237 debug!("reactor loop started");
238 loop {
239 let (sup_tx, sup_rx) = oneshot::channel();
241 let metrics_for_cycle = metrics.clone();
242 let exit = tokio::select! {
243 biased;
244 _ = &mut shutdown_rx => {
245 let _ = sup_tx.send(());
249 SupervisorExit::Shutdown
250 }
251 exit = supervisor::run(conn, local_port, metrics_for_cycle, sup_rx) => exit,
252 };
253
254 match exit {
255 SupervisorExit::Shutdown => {
256 control.shutdown_graceful(DEFAULT_GRACE_PERIOD).await;
259 debug!("reactor: clean shutdown");
260 return;
261 }
262 SupervisorExit::ConnectionLost => {
263 drop(control);
266
267 let mut attempt = 0u32;
269 loop {
270 attempt += 1;
271 if attempt > MAX_RECONNECT_ATTEMPTS {
272 warn!(
273 "reactor: giving up after {} reconnect attempts",
274 MAX_RECONNECT_ATTEMPTS
275 );
276 return;
277 }
278 let delay = backoff(attempt);
279 warn!(attempt, ?delay, "reactor: scheduling reconnect");
280 tokio::select! {
281 biased;
282 _ = &mut shutdown_rx => {
283 debug!("reactor: shutdown signal during reconnect backoff");
284 return;
285 }
286 _ = tokio::time::sleep(delay) => {}
287 }
288 match connect_cycle(&endpoint, &auth, tunnel_id, CLIENT_VERSION, true).await {
289 Ok((new_conn, new_control, new_loc)) => {
290 info!(attempt, location = %new_loc, "reactor: reconnect succeeded");
291 reconnects.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
292 conn = new_conn;
293 control = new_control;
294 break;
295 }
296 Err(e) => {
297 warn!(attempt, error = %e, "reactor: reconnect failed");
298 }
300 }
301 }
302 }
303 }
304 }
305}
306
307fn backoff(attempt: u32) -> Duration {
309 let secs = 1u64.checked_shl(attempt.saturating_sub(1)).unwrap_or(30);
310 Duration::from_secs(secs.min(30))
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn backoff_curve() {
319 assert_eq!(backoff(1), Duration::from_secs(1));
320 assert_eq!(backoff(2), Duration::from_secs(2));
321 assert_eq!(backoff(3), Duration::from_secs(4));
322 assert_eq!(backoff(4), Duration::from_secs(8));
323 assert_eq!(backoff(5), Duration::from_secs(16));
324 assert_eq!(backoff(6), Duration::from_secs(30));
325 assert_eq!(backoff(20), Duration::from_secs(30));
326 }
327}