Skip to main content

cloudflare_quick_tunnel/
api.rs

1//! POST `/tunnel` client for `api.trycloudflare.com`.
2//!
3//! Returns the credentials the edge expects on the subsequent
4//! `RegisterConnection` RPC: a UUID-shaped `id`, the public
5//! `hostname` (`<sub>.trycloudflare.com`), the `account_tag` to
6//! quote on RPC, and 32 random bytes of `secret` that double as the
7//! `TunnelSecret` in the auth blob.
8//!
9//! Mirrors `cmd/cloudflared/tunnel/quick_tunnel.go` upstream.
10
11use std::time::Duration;
12
13use serde::Deserialize;
14use tokio::time::sleep;
15use tracing::{debug, warn};
16
17use crate::error::{QuickTunnelApiError, TunnelError};
18
19/// Public-facing JSON envelope returned by `POST /tunnel`.
20#[derive(Debug, Deserialize)]
21pub struct QuickTunnelResponse {
22    pub success: bool,
23    #[serde(default)]
24    pub result: Option<QuickTunnel>,
25    #[serde(default)]
26    pub errors: Vec<QuickTunnelApiError>,
27}
28
29/// The bits the QUIC + capnp-RPC dance needs.
30///
31/// `secret` is delivered as a base64 string in the JSON body; the
32/// `serde_bytes_b64` helper decodes it back to raw bytes so callers
33/// can stuff them straight into the capnp `TunnelAuth.tunnelSecret`
34/// field. Mirror of cloudflared's `QuickTunnel` Go struct.
35#[derive(Debug, Deserialize)]
36pub struct QuickTunnel {
37    pub id: String,
38    pub name: String,
39    pub hostname: String,
40    pub account_tag: String,
41    #[serde(with = "serde_bytes_b64")]
42    pub secret: Vec<u8>,
43}
44
45/// Default endpoint (the public trycloudflare API).
46pub const DEFAULT_SERVICE_URL: &str = "https://api.trycloudflare.com";
47
48/// User-Agent we send. Mimic a recent `cloudflared` so the edge
49/// doesn't trip a novelty filter. Bump in lockstep with the
50/// schema commit pinned in `THIRD_PARTY_NOTICES.md`.
51pub const DEFAULT_USER_AGENT: &str = "cloudflared/2024.12.0";
52
53/// HTTP-level deadline for the POST.
54pub const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(15);
55
56/// How many times to retry on transient 5xx / network errors.
57pub const MAX_RETRIES: u32 = 3;
58
59/// Fetch a fresh quick-tunnel handshake. Retries 5xx + network
60/// errors with exponential backoff (1s → 2s → 4s); never retries
61/// 4xx. Business errors inside a 200 response surface as
62/// [`TunnelError::ApiBusiness`].
63pub async fn request_tunnel(
64    service_url: &str,
65    user_agent: &str,
66) -> Result<QuickTunnel, TunnelError> {
67    let url = format!("{}/tunnel", service_url.trim_end_matches('/'));
68    let client = reqwest::Client::builder()
69        .user_agent(user_agent)
70        .timeout(DEFAULT_HTTP_TIMEOUT)
71        .build()
72        .map_err(TunnelError::Api)?;
73
74    let mut backoff = Duration::from_secs(1);
75    let mut last_err: Option<TunnelError> = None;
76
77    for attempt in 0..=MAX_RETRIES {
78        debug!(attempt, %url, "POST /tunnel");
79        match try_once(&client, &url).await {
80            Ok(tunnel) => return Ok(tunnel),
81            Err(err) => {
82                if !err.is_transient() || attempt == MAX_RETRIES {
83                    return Err(err);
84                }
85                warn!(
86                    attempt,
87                    error = %err,
88                    backoff_ms = backoff.as_millis() as u64,
89                    "POST /tunnel transient failure; retrying"
90                );
91                last_err = Some(err);
92                sleep(backoff).await;
93                backoff = backoff.saturating_mul(2);
94            }
95        }
96    }
97    Err(last_err.unwrap_or_else(|| {
98        TunnelError::Internal("request_tunnel: retry loop fell through without an error".into())
99    }))
100}
101
102async fn try_once(client: &reqwest::Client, url: &str) -> Result<QuickTunnel, TunnelError> {
103    let resp = client
104        .post(url)
105        .header("Content-Type", "application/json")
106        .send()
107        .await?;
108
109    let status = resp.status();
110    let body = resp.bytes().await?;
111
112    // 5xx with non-JSON body must surface as a transient error so
113    // the retry loop kicks in. JSON 5xx envelopes are rare but
114    // possible — they go through the parse path below.
115    if status.is_server_error() && !looks_like_json(&body) {
116        let snippet_len = 200usize.min(body.len());
117        let body_snippet = String::from_utf8_lossy(&body[..snippet_len]).into_owned();
118        return Err(TunnelError::ApiNonJson {
119            status: status.as_u16(),
120            body_snippet,
121        });
122    }
123
124    // The edge sometimes hands back HTML when rate-limiting; surface
125    // a snippet so the operator can read the actual reason instead
126    // of staring at a bare "expected value at line 1 column 1".
127    if !looks_like_json(&body) {
128        let snippet_len = 200usize.min(body.len());
129        let body_snippet = String::from_utf8_lossy(&body[..snippet_len]).into_owned();
130        return Err(TunnelError::ApiNonJson {
131            status: status.as_u16(),
132            body_snippet,
133        });
134    }
135
136    let envelope: QuickTunnelResponse = serde_json::from_slice(&body)
137        .map_err(|e| TunnelError::Internal(format!("malformed JSON from /tunnel: {e}")))?;
138
139    if !envelope.success {
140        return Err(TunnelError::ApiBusiness(envelope.errors));
141    }
142
143    envelope.result.ok_or_else(|| {
144        TunnelError::Internal("POST /tunnel returned success=true but no `result` body".into())
145    })
146}
147
148fn looks_like_json(body: &[u8]) -> bool {
149    body.iter()
150        .find(|b| !b.is_ascii_whitespace())
151        .is_some_and(|b| *b == b'{' || *b == b'[')
152}
153
154mod serde_bytes_b64 {
155    use base64::engine::general_purpose::STANDARD;
156    use base64::Engine;
157    use serde::{Deserialize, Deserializer};
158
159    pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
160        let s: String = Deserialize::deserialize(d)?;
161        STANDARD.decode(s).map_err(serde::de::Error::custom)
162    }
163}
164
165impl TunnelError {
166    /// Errors a retry could plausibly recover (network / 5xx).
167    pub(crate) fn is_transient(&self) -> bool {
168        match self {
169            TunnelError::Api(e) => {
170                e.is_timeout()
171                    || e.is_connect()
172                    || e.is_request()
173                    || e.status().is_some_and(|s| s.is_server_error())
174            }
175            TunnelError::ApiNonJson { status, .. } => (500..600).contains(status),
176            _ => false,
177        }
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use wiremock::matchers::{header, method, path};
185    use wiremock::{Mock, MockServer, ResponseTemplate};
186
187    fn sample_ok_body() -> serde_json::Value {
188        serde_json::json!({
189            "success": true,
190            "result": {
191                "id": "8f6d3c2a-1111-4d2e-9b9b-aaaaaaaaaaaa",
192                "name": "quick-tunnel-abc",
193                "hostname": "abc-123.trycloudflare.com",
194                "account_tag": "deadbeefcafef00d",
195                "secret": "AQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0eHyA="
196            },
197            "errors": []
198        })
199    }
200
201    #[tokio::test]
202    async fn happy_path_parses_credentials() {
203        let server = MockServer::start().await;
204        Mock::given(method("POST"))
205            .and(path("/tunnel"))
206            .and(header("Content-Type", "application/json"))
207            .respond_with(ResponseTemplate::new(200).set_body_json(sample_ok_body()))
208            .expect(1)
209            .mount(&server)
210            .await;
211
212        let t = request_tunnel(&server.uri(), DEFAULT_USER_AGENT)
213            .await
214            .expect("happy path");
215        assert_eq!(t.hostname, "abc-123.trycloudflare.com");
216        assert_eq!(t.account_tag, "deadbeefcafef00d");
217        assert_eq!(t.secret.len(), 32);
218        assert_eq!(t.secret[0..4], [1, 2, 3, 4]);
219    }
220
221    #[tokio::test]
222    async fn business_error_does_not_retry() {
223        let server = MockServer::start().await;
224        let body = serde_json::json!({
225            "success": false,
226            "errors": [{ "code": 1003, "message": "tunnel quota exceeded" }]
227        });
228        Mock::given(method("POST"))
229            .and(path("/tunnel"))
230            .respond_with(ResponseTemplate::new(200).set_body_json(body))
231            .expect(1)
232            .mount(&server)
233            .await;
234
235        let err = request_tunnel(&server.uri(), DEFAULT_USER_AGENT)
236            .await
237            .expect_err("should fail");
238        match err {
239            TunnelError::ApiBusiness(errs) => {
240                assert_eq!(errs.len(), 1);
241                assert_eq!(errs[0].code, 1003);
242            }
243            other => panic!("unexpected error: {other:?}"),
244        }
245    }
246
247    #[tokio::test]
248    async fn html_body_surfaces_snippet() {
249        let server = MockServer::start().await;
250        Mock::given(method("POST"))
251            .and(path("/tunnel"))
252            .respond_with(
253                ResponseTemplate::new(429)
254                    .set_body_string("<html><body>rate limited</body></html>"),
255            )
256            .expect(1)
257            .mount(&server)
258            .await;
259
260        let err = request_tunnel(&server.uri(), DEFAULT_USER_AGENT)
261            .await
262            .expect_err("should fail");
263        match err {
264            TunnelError::ApiNonJson {
265                status,
266                body_snippet,
267            } => {
268                assert_eq!(status, 429);
269                assert!(body_snippet.contains("rate limited"));
270            }
271            other => panic!("unexpected error: {other:?}"),
272        }
273    }
274
275    #[tokio::test]
276    async fn five_xx_retries_then_succeeds() {
277        let server = MockServer::start().await;
278
279        // First call → 503 (non-JSON, transient), second → 200 OK.
280        Mock::given(method("POST"))
281            .and(path("/tunnel"))
282            .respond_with(ResponseTemplate::new(503).set_body_string("service unavailable"))
283            .up_to_n_times(1)
284            .mount(&server)
285            .await;
286        Mock::given(method("POST"))
287            .and(path("/tunnel"))
288            .respond_with(ResponseTemplate::new(200).set_body_json(sample_ok_body()))
289            .expect(1)
290            .mount(&server)
291            .await;
292
293        // We can't override the 1s initial backoff without exposing
294        // it, so the test does sit for ~1s. That's acceptable for
295        // wiremock-class tests and keeps the public API surface
296        // minimal.
297        let t = request_tunnel(&server.uri(), DEFAULT_USER_AGENT)
298            .await
299            .expect("retry should succeed");
300        assert_eq!(t.hostname, "abc-123.trycloudflare.com");
301    }
302
303    #[tokio::test]
304    async fn four_xx_does_not_retry() {
305        let server = MockServer::start().await;
306        Mock::given(method("POST"))
307            .and(path("/tunnel"))
308            .respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
309                "success": false,
310                "errors": [{ "code": 400, "message": "bad request" }]
311            })))
312            .expect(1) // critical: only 1 hit, no retry
313            .mount(&server)
314            .await;
315
316        let err = request_tunnel(&server.uri(), DEFAULT_USER_AGENT)
317            .await
318            .expect_err("should fail");
319        // It's a business error (success=false), not a transport one.
320        assert!(matches!(err, TunnelError::ApiBusiness(_)));
321    }
322
323    #[test]
324    fn looks_like_json_handles_leading_whitespace() {
325        assert!(looks_like_json(b"  \n  {"));
326        assert!(looks_like_json(b"["));
327        assert!(!looks_like_json(b"<html>"));
328        assert!(!looks_like_json(b""));
329    }
330}