Skip to main content

cloudflare_quick_tunnel/
proxy.rs

1//! Per-request HTTP/1.1 proxy: bridge an inbound capnp-framed
2//! stream from the edge to the local TCP listener the caller
3//! wants to expose at `https://<sub>.trycloudflare.com`.
4//!
5//! Wire flow on the edge → tunnel side:
6//!
7//! 1. Edge writes the data-stream preamble + `ConnectRequest`
8//!    (parsed by `stream::read_connect_request`).
9//! 2. We connect a TCP socket to `127.0.0.1:<local_port>` and
10//!    write a re-synthesised HTTP/1.1 request line + `Host` +
11//!    every `HttpHeader:<Name>` the edge forwarded + an empty
12//!    line. Body bytes (if any) follow from the QUIC stream.
13//! 3. We parse the HTTP/1.1 response status + headers off the
14//!    TCP socket, write them back as a `ConnectResponse` (status
15//!    in `HttpStatus`, each header in `HttpHeader:<Name>`), and
16//!    finally byte-pump the response body from TCP to QUIC.
17//!
18//! Mirrors `cloudflared/connection/quic_connection.go`
19//! (`buildHTTPRequest` + `httpResponseAdapter`).
20
21use std::time::Duration;
22
23use futures::{AsyncReadExt, AsyncWriteExt};
24use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
25use tokio::net::TcpStream;
26use tracing::{debug, warn};
27
28use std::sync::atomic::{AtomicU64, Ordering};
29use std::sync::Arc;
30
31use crate::error::TunnelError;
32use crate::stream::{
33    self, ConnectRequest, ConnectionType, HTTP_HEADER_KEY, HTTP_HOST_KEY, HTTP_METHOD_KEY,
34    HTTP_STATUS_KEY,
35};
36
37/// Byte counters the supervisor accumulates across all streams.
38/// Cheap atomics so the proxy task can update without coordination.
39#[derive(Debug, Default, Clone)]
40pub struct StreamCounters {
41    /// Bytes received from the edge (request bodies + ws frames).
42    pub bytes_in: Arc<AtomicU64>,
43    /// Bytes sent to the edge (response bodies + ws frames).
44    pub bytes_out: Arc<AtomicU64>,
45}
46
47/// How long we wait for the local TCP listener to accept the first
48/// byte. Quick tunnels are usually pointed at a process that just
49/// finished booting; 5s is generous without making real failures
50/// drag out.
51pub const LOCAL_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
52
53/// Hard cap on the response header section. Anything more is
54/// almost certainly a misconfigured origin and we want to fail
55/// before pumping its body.
56const MAX_HEADER_BYTES: usize = 32 * 1024;
57
58/// Drive one inbound request stream to completion. Reads the
59/// `ConnectRequest` off `recv`, runs the matching proxy strategy,
60/// writes the `ConnectResponse` back, and pumps the body bytes.
61pub async fn handle_inbound_stream(
62    local_port: u16,
63    send: quinn::SendStream,
64    recv: quinn::RecvStream,
65    counters: StreamCounters,
66) -> Result<(), TunnelError> {
67    let (mut reader, mut writer) = stream::split(send, recv);
68    let req = stream::read_connect_request(&mut reader).await?;
69    debug!(dest = %req.dest, ty = ?req.conn_type, "inbound stream");
70
71    match req.conn_type {
72        ConnectionType::Http | ConnectionType::Websocket => {
73            proxy_http(local_port, req, reader, writer, counters).await
74        }
75        ConnectionType::Tcp => {
76            proxy_tcp(local_port, &req, &mut reader, &mut writer, &counters).await
77        }
78    }
79}
80
81// ── HTTP / Websocket ─────────────────────────────────────────────────────────
82
83async fn proxy_http<R, W>(
84    local_port: u16,
85    request: ConnectRequest,
86    mut from_edge: R,
87    mut to_edge: W,
88    counters: StreamCounters,
89) -> Result<(), TunnelError>
90where
91    R: futures::io::AsyncRead + Unpin,
92    W: futures::io::AsyncWrite + Unpin,
93{
94    // 1. Connect to local.
95    let tcp = match tokio::time::timeout(
96        LOCAL_CONNECT_TIMEOUT,
97        TcpStream::connect(("127.0.0.1", local_port)),
98    )
99    .await
100    {
101        Ok(Ok(s)) => s,
102        Ok(Err(e)) => {
103            warn!(error = %e, local_port, "TCP connect refused");
104            return write_error_response(&mut to_edge, 502, &format!("local connect: {e}")).await;
105        }
106        Err(_) => {
107            warn!(local_port, "TCP connect timed out");
108            return write_error_response(&mut to_edge, 504, "local connect timed out").await;
109        }
110    };
111
112    let (mut tcp_read, mut tcp_write) = tcp.into_split();
113
114    // 2. Synthesise + send HTTP/1.1 request head.
115    let head = build_request_head(&request)?;
116    tcp_write
117        .write_all(head.as_bytes())
118        .await
119        .map_err(|e| TunnelError::Internal(format!("tcp write head: {e}")))?;
120
121    // 3. Pump request body (QUIC → TCP) concurrently with response
122    //    head read (TCP → us). We can't sequence them because
123    //    HTTP/1.1 in practice often interleaves — and even a tiny
124    //    HEAD/GET wants the request shut to flush. `join!` drives
125    //    both halves; the body pump terminates first on EOF.
126    let in_counter = counters.bytes_in.clone();
127    let body_pump = async {
128        let _ = pump_futures_to_tokio_counted(&mut from_edge, &mut tcp_write, &in_counter).await;
129        let _ = tcp_write.shutdown().await;
130    };
131    let head_read = read_http_response_head(&mut tcp_read);
132    let (_, head) = tokio::join!(body_pump, head_read);
133    let (status, headers, leftover) = head?;
134    debug!(status, header_count = headers.len(), "origin response");
135
136    // 5. Write the ConnectResponse back to the edge with status +
137    //    headers as metadata, then pump leftover + remaining body.
138    let mut meta: Vec<(String, String)> = Vec::with_capacity(headers.len() + 1);
139    meta.push((HTTP_STATUS_KEY.into(), status.to_string()));
140    for (name, value) in &headers {
141        meta.push((format!("{HTTP_HEADER_KEY}:{name}"), value.clone()));
142    }
143    let meta_refs: Vec<(&str, &str)> = meta.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
144    stream::write_connect_response(&mut to_edge, "", &meta_refs).await?;
145
146    // Leftover bytes from header parse (anything captured past the
147    // \r\n\r\n boundary) must be flushed first so chunked bodies
148    // don't lose their initial frames.
149    if !leftover.is_empty() {
150        to_edge
151            .write_all(&leftover)
152            .await
153            .map_err(|e| TunnelError::Internal(format!("write leftover body: {e}")))?;
154        counters
155            .bytes_out
156            .fetch_add(leftover.len() as u64, Ordering::Relaxed);
157    }
158
159    // 6. Pump response body TCP → QUIC. Continues until the origin
160    //    closes its half — typical HTTP/1.1 with Content-Length or
161    //    chunked encoding will trigger that.
162    pump_tokio_to_futures_counted(&mut tcp_read, &mut to_edge, &counters.bytes_out)
163        .await
164        .ok();
165
166    to_edge
167        .close()
168        .await
169        .map_err(|e| TunnelError::Internal(format!("close to_edge: {e}")))?;
170    Ok(())
171}
172
173fn build_request_head(req: &ConnectRequest) -> Result<String, TunnelError> {
174    let method = req.meta(HTTP_METHOD_KEY).unwrap_or("GET");
175    let host = req.meta(HTTP_HOST_KEY).unwrap_or("");
176    let path = extract_path(&req.dest);
177
178    let mut head = String::with_capacity(256);
179    head.push_str(method);
180    head.push(' ');
181    head.push_str(&path);
182    head.push_str(" HTTP/1.1\r\n");
183    if !host.is_empty() {
184        head.push_str("Host: ");
185        head.push_str(host);
186        head.push_str("\r\n");
187    }
188
189    let mut saw_connection = false;
190    let mut saw_content_length = false;
191    let mut saw_transfer_encoding = false;
192    for (k, v) in &req.metadata {
193        if let Some(name) = k.strip_prefix(&format!("{HTTP_HEADER_KEY}:")) {
194            // Skip Host — already emitted above.
195            if name.eq_ignore_ascii_case("host") {
196                continue;
197            }
198            if name.eq_ignore_ascii_case("connection") {
199                saw_connection = true;
200            }
201            if name.eq_ignore_ascii_case("content-length") {
202                saw_content_length = true;
203            }
204            if name.eq_ignore_ascii_case("transfer-encoding") {
205                saw_transfer_encoding = true;
206            }
207            head.push_str(name);
208            head.push_str(": ");
209            head.push_str(v);
210            head.push_str("\r\n");
211        }
212    }
213    // If the edge didn't tell us about a body, advertise close-on-EOF
214    // so a hung-open keep-alive doesn't stall the response.
215    if !saw_connection {
216        head.push_str("Connection: close\r\n");
217    }
218    let _ = (saw_content_length, saw_transfer_encoding);
219
220    head.push_str("\r\n");
221    Ok(head)
222}
223
224/// `dest` looks like `https://abc-123.trycloudflare.com/path?q=1`.
225/// We want the `/path?q=1` portion. Pure-byte scan, no `url`
226/// crate dep.
227fn extract_path(dest: &str) -> String {
228    if let Some(after_scheme) = dest.find("://") {
229        let rest = &dest[after_scheme + 3..];
230        if let Some(slash) = rest.find('/') {
231            return rest[slash..].to_string();
232        }
233        return "/".into();
234    }
235    if dest.starts_with('/') {
236        return dest.to_string();
237    }
238    "/".into()
239}
240
241async fn write_error_response<W>(writer: &mut W, status: u16, msg: &str) -> Result<(), TunnelError>
242where
243    W: futures::io::AsyncWrite + Unpin,
244{
245    let meta = [(HTTP_STATUS_KEY, status.to_string())];
246    let refs: Vec<(&str, &str)> = meta.iter().map(|(k, v)| (*k, v.as_str())).collect();
247    stream::write_connect_response(writer, msg, &refs).await?;
248    Ok(())
249}
250
251async fn read_http_response_head(
252    tcp: &mut (impl tokio::io::AsyncRead + Unpin),
253) -> Result<(u16, Vec<(String, String)>, Vec<u8>), TunnelError> {
254    let mut buf = Vec::with_capacity(4096);
255    let mut tmp = [0u8; 2048];
256    loop {
257        let n = tcp
258            .read(&mut tmp)
259            .await
260            .map_err(|e| TunnelError::Internal(format!("tcp read head: {e}")))?;
261        if n == 0 {
262            return Err(TunnelError::Internal(
263                "local origin closed before sending response head".into(),
264            ));
265        }
266        buf.extend_from_slice(&tmp[..n]);
267        if buf.len() > MAX_HEADER_BYTES {
268            return Err(TunnelError::Internal(format!(
269                "response header exceeds {MAX_HEADER_BYTES} bytes"
270            )));
271        }
272
273        let mut headers = [httparse::EMPTY_HEADER; 64];
274        let mut resp = httparse::Response::new(&mut headers);
275        match resp
276            .parse(&buf)
277            .map_err(|e| TunnelError::Internal(format!("httparse: {e}")))?
278        {
279            httparse::Status::Complete(consumed) => {
280                let status = resp
281                    .code
282                    .ok_or_else(|| TunnelError::Internal("response had no status code".into()))?;
283                let pairs = resp
284                    .headers
285                    .iter()
286                    .map(|h| {
287                        let v = String::from_utf8_lossy(h.value).into_owned();
288                        (h.name.to_string(), v)
289                    })
290                    .collect::<Vec<_>>();
291                let leftover = buf.split_off(consumed);
292                return Ok((status, pairs, leftover));
293            }
294            httparse::Status::Partial => {
295                // need more bytes
296            }
297        }
298    }
299}
300
301// ── Plain TCP ────────────────────────────────────────────────────────────────
302
303async fn proxy_tcp<R, W>(
304    local_port: u16,
305    _request: &ConnectRequest,
306    from_edge: &mut R,
307    to_edge: &mut W,
308    counters: &StreamCounters,
309) -> Result<(), TunnelError>
310where
311    R: futures::io::AsyncRead + Unpin,
312    W: futures::io::AsyncWrite + Unpin,
313{
314    let tcp = TcpStream::connect(("127.0.0.1", local_port))
315        .await
316        .map_err(|e| TunnelError::Internal(format!("tcp connect: {e}")))?;
317    let (mut r, mut w) = tcp.into_split();
318
319    // ACK: send empty ConnectResponse before bytes flow.
320    stream::write_connect_response(to_edge, "", &[]).await?;
321
322    let edge_to_local = pump_futures_to_tokio_counted(from_edge, &mut w, &counters.bytes_in);
323    let local_to_edge = pump_tokio_to_futures_counted(&mut r, to_edge, &counters.bytes_out);
324    let _ = futures::future::join(edge_to_local, local_to_edge).await;
325    Ok(())
326}
327
328// ── Cross-IO byte pumps ──────────────────────────────────────────────────────
329
330async fn pump_futures_to_tokio_counted<R, W>(
331    mut src: R,
332    dst: &mut W,
333    counter: &AtomicU64,
334) -> Result<(), TunnelError>
335where
336    R: futures::io::AsyncRead + Unpin,
337    W: tokio::io::AsyncWrite + Unpin,
338{
339    let mut buf = [0u8; 16 * 1024];
340    loop {
341        let n = src
342            .read(&mut buf)
343            .await
344            .map_err(|e| TunnelError::Internal(format!("read: {e}")))?;
345        if n == 0 {
346            break;
347        }
348        dst.write_all(&buf[..n])
349            .await
350            .map_err(|e| TunnelError::Internal(format!("write: {e}")))?;
351        counter.fetch_add(n as u64, Ordering::Relaxed);
352    }
353    Ok(())
354}
355
356async fn pump_tokio_to_futures_counted<R, W>(
357    src: &mut R,
358    dst: &mut W,
359    counter: &AtomicU64,
360) -> Result<(), TunnelError>
361where
362    R: tokio::io::AsyncRead + Unpin,
363    W: futures::io::AsyncWrite + Unpin,
364{
365    let mut buf = [0u8; 16 * 1024];
366    loop {
367        let n = src
368            .read(&mut buf)
369            .await
370            .map_err(|e| TunnelError::Internal(format!("read: {e}")))?;
371        if n == 0 {
372            break;
373        }
374        dst.write_all(&buf[..n])
375            .await
376            .map_err(|e| TunnelError::Internal(format!("write: {e}")))?;
377        counter.fetch_add(n as u64, Ordering::Relaxed);
378    }
379    Ok(())
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    #[test]
387    fn extract_path_strips_scheme() {
388        assert_eq!(
389            extract_path("https://abc.trycloudflare.com/path?q=1"),
390            "/path?q=1"
391        );
392        assert_eq!(extract_path("https://abc.trycloudflare.com"), "/");
393        assert_eq!(extract_path("/relative/x"), "/relative/x");
394    }
395
396    #[test]
397    fn build_head_includes_method_host_path() {
398        let req = ConnectRequest {
399            dest: "https://abc.trycloudflare.com/foo".into(),
400            conn_type: ConnectionType::Http,
401            metadata: vec![
402                (HTTP_METHOD_KEY.into(), "POST".into()),
403                (HTTP_HOST_KEY.into(), "abc.trycloudflare.com".into()),
404                (format!("{HTTP_HEADER_KEY}:User-Agent"), "x/1".into()),
405                (format!("{HTTP_HEADER_KEY}:X-Stuff"), "yo".into()),
406            ],
407        };
408        let head = build_request_head(&req).unwrap();
409        assert!(head.starts_with("POST /foo HTTP/1.1\r\n"));
410        assert!(head.contains("Host: abc.trycloudflare.com\r\n"));
411        assert!(head.contains("User-Agent: x/1\r\n"));
412        assert!(head.contains("X-Stuff: yo\r\n"));
413        assert!(head.ends_with("\r\n\r\n"));
414    }
415}