1use std::time::Duration;
22
23use futures::{AsyncReadExt, AsyncWriteExt};
24use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
25use tokio::net::TcpStream;
26use tracing::{debug, warn};
27
28use std::sync::atomic::{AtomicU64, Ordering};
29use std::sync::Arc;
30
31use crate::error::TunnelError;
32use crate::stream::{
33 self, ConnectRequest, ConnectionType, HTTP_HEADER_KEY, HTTP_HOST_KEY, HTTP_METHOD_KEY,
34 HTTP_STATUS_KEY,
35};
36
37#[derive(Debug, Default, Clone)]
40pub struct StreamCounters {
41 pub bytes_in: Arc<AtomicU64>,
43 pub bytes_out: Arc<AtomicU64>,
45}
46
47pub const LOCAL_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
52
53const MAX_HEADER_BYTES: usize = 32 * 1024;
57
58pub async fn handle_inbound_stream(
62 local_port: u16,
63 send: quinn::SendStream,
64 recv: quinn::RecvStream,
65 counters: StreamCounters,
66) -> Result<(), TunnelError> {
67 let (mut reader, mut writer) = stream::split(send, recv);
68 let req = stream::read_connect_request(&mut reader).await?;
69 debug!(dest = %req.dest, ty = ?req.conn_type, "inbound stream");
70
71 match req.conn_type {
72 ConnectionType::Http | ConnectionType::Websocket => {
73 proxy_http(local_port, req, reader, writer, counters).await
74 }
75 ConnectionType::Tcp => {
76 proxy_tcp(local_port, &req, &mut reader, &mut writer, &counters).await
77 }
78 }
79}
80
81async fn proxy_http<R, W>(
84 local_port: u16,
85 request: ConnectRequest,
86 mut from_edge: R,
87 mut to_edge: W,
88 counters: StreamCounters,
89) -> Result<(), TunnelError>
90where
91 R: futures::io::AsyncRead + Unpin,
92 W: futures::io::AsyncWrite + Unpin,
93{
94 let tcp = match tokio::time::timeout(
96 LOCAL_CONNECT_TIMEOUT,
97 TcpStream::connect(("127.0.0.1", local_port)),
98 )
99 .await
100 {
101 Ok(Ok(s)) => s,
102 Ok(Err(e)) => {
103 warn!(error = %e, local_port, "TCP connect refused");
104 return write_error_response(&mut to_edge, 502, &format!("local connect: {e}")).await;
105 }
106 Err(_) => {
107 warn!(local_port, "TCP connect timed out");
108 return write_error_response(&mut to_edge, 504, "local connect timed out").await;
109 }
110 };
111
112 let (mut tcp_read, mut tcp_write) = tcp.into_split();
113
114 let head = build_request_head(&request)?;
116 tcp_write
117 .write_all(head.as_bytes())
118 .await
119 .map_err(|e| TunnelError::Internal(format!("tcp write head: {e}")))?;
120
121 let in_counter = counters.bytes_in.clone();
127 let body_pump = async {
128 let _ = pump_futures_to_tokio_counted(&mut from_edge, &mut tcp_write, &in_counter).await;
129 let _ = tcp_write.shutdown().await;
130 };
131 let head_read = read_http_response_head(&mut tcp_read);
132 let (_, head) = tokio::join!(body_pump, head_read);
133 let (status, headers, leftover) = head?;
134 debug!(status, header_count = headers.len(), "origin response");
135
136 let mut meta: Vec<(String, String)> = Vec::with_capacity(headers.len() + 1);
139 meta.push((HTTP_STATUS_KEY.into(), status.to_string()));
140 for (name, value) in &headers {
141 meta.push((format!("{HTTP_HEADER_KEY}:{name}"), value.clone()));
142 }
143 let meta_refs: Vec<(&str, &str)> = meta.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
144 stream::write_connect_response(&mut to_edge, "", &meta_refs).await?;
145
146 if !leftover.is_empty() {
150 to_edge
151 .write_all(&leftover)
152 .await
153 .map_err(|e| TunnelError::Internal(format!("write leftover body: {e}")))?;
154 counters
155 .bytes_out
156 .fetch_add(leftover.len() as u64, Ordering::Relaxed);
157 }
158
159 pump_tokio_to_futures_counted(&mut tcp_read, &mut to_edge, &counters.bytes_out)
163 .await
164 .ok();
165
166 to_edge
167 .close()
168 .await
169 .map_err(|e| TunnelError::Internal(format!("close to_edge: {e}")))?;
170 Ok(())
171}
172
173fn build_request_head(req: &ConnectRequest) -> Result<String, TunnelError> {
174 let method = req.meta(HTTP_METHOD_KEY).unwrap_or("GET");
175 let host = req.meta(HTTP_HOST_KEY).unwrap_or("");
176 let path = extract_path(&req.dest);
177
178 let mut head = String::with_capacity(256);
179 head.push_str(method);
180 head.push(' ');
181 head.push_str(&path);
182 head.push_str(" HTTP/1.1\r\n");
183 if !host.is_empty() {
184 head.push_str("Host: ");
185 head.push_str(host);
186 head.push_str("\r\n");
187 }
188
189 let mut saw_connection = false;
190 let mut saw_content_length = false;
191 let mut saw_transfer_encoding = false;
192 for (k, v) in &req.metadata {
193 if let Some(name) = k.strip_prefix(&format!("{HTTP_HEADER_KEY}:")) {
194 if name.eq_ignore_ascii_case("host") {
196 continue;
197 }
198 if name.eq_ignore_ascii_case("connection") {
199 saw_connection = true;
200 }
201 if name.eq_ignore_ascii_case("content-length") {
202 saw_content_length = true;
203 }
204 if name.eq_ignore_ascii_case("transfer-encoding") {
205 saw_transfer_encoding = true;
206 }
207 head.push_str(name);
208 head.push_str(": ");
209 head.push_str(v);
210 head.push_str("\r\n");
211 }
212 }
213 if !saw_connection {
216 head.push_str("Connection: close\r\n");
217 }
218 let _ = (saw_content_length, saw_transfer_encoding);
219
220 head.push_str("\r\n");
221 Ok(head)
222}
223
224fn extract_path(dest: &str) -> String {
228 if let Some(after_scheme) = dest.find("://") {
229 let rest = &dest[after_scheme + 3..];
230 if let Some(slash) = rest.find('/') {
231 return rest[slash..].to_string();
232 }
233 return "/".into();
234 }
235 if dest.starts_with('/') {
236 return dest.to_string();
237 }
238 "/".into()
239}
240
241async fn write_error_response<W>(writer: &mut W, status: u16, msg: &str) -> Result<(), TunnelError>
242where
243 W: futures::io::AsyncWrite + Unpin,
244{
245 let meta = [(HTTP_STATUS_KEY, status.to_string())];
246 let refs: Vec<(&str, &str)> = meta.iter().map(|(k, v)| (*k, v.as_str())).collect();
247 stream::write_connect_response(writer, msg, &refs).await?;
248 Ok(())
249}
250
251async fn read_http_response_head(
252 tcp: &mut (impl tokio::io::AsyncRead + Unpin),
253) -> Result<(u16, Vec<(String, String)>, Vec<u8>), TunnelError> {
254 let mut buf = Vec::with_capacity(4096);
255 let mut tmp = [0u8; 2048];
256 loop {
257 let n = tcp
258 .read(&mut tmp)
259 .await
260 .map_err(|e| TunnelError::Internal(format!("tcp read head: {e}")))?;
261 if n == 0 {
262 return Err(TunnelError::Internal(
263 "local origin closed before sending response head".into(),
264 ));
265 }
266 buf.extend_from_slice(&tmp[..n]);
267 if buf.len() > MAX_HEADER_BYTES {
268 return Err(TunnelError::Internal(format!(
269 "response header exceeds {MAX_HEADER_BYTES} bytes"
270 )));
271 }
272
273 let mut headers = [httparse::EMPTY_HEADER; 64];
274 let mut resp = httparse::Response::new(&mut headers);
275 match resp
276 .parse(&buf)
277 .map_err(|e| TunnelError::Internal(format!("httparse: {e}")))?
278 {
279 httparse::Status::Complete(consumed) => {
280 let status = resp
281 .code
282 .ok_or_else(|| TunnelError::Internal("response had no status code".into()))?;
283 let pairs = resp
284 .headers
285 .iter()
286 .map(|h| {
287 let v = String::from_utf8_lossy(h.value).into_owned();
288 (h.name.to_string(), v)
289 })
290 .collect::<Vec<_>>();
291 let leftover = buf.split_off(consumed);
292 return Ok((status, pairs, leftover));
293 }
294 httparse::Status::Partial => {
295 }
297 }
298 }
299}
300
301async fn proxy_tcp<R, W>(
304 local_port: u16,
305 _request: &ConnectRequest,
306 from_edge: &mut R,
307 to_edge: &mut W,
308 counters: &StreamCounters,
309) -> Result<(), TunnelError>
310where
311 R: futures::io::AsyncRead + Unpin,
312 W: futures::io::AsyncWrite + Unpin,
313{
314 let tcp = TcpStream::connect(("127.0.0.1", local_port))
315 .await
316 .map_err(|e| TunnelError::Internal(format!("tcp connect: {e}")))?;
317 let (mut r, mut w) = tcp.into_split();
318
319 stream::write_connect_response(to_edge, "", &[]).await?;
321
322 let edge_to_local = pump_futures_to_tokio_counted(from_edge, &mut w, &counters.bytes_in);
323 let local_to_edge = pump_tokio_to_futures_counted(&mut r, to_edge, &counters.bytes_out);
324 let _ = futures::future::join(edge_to_local, local_to_edge).await;
325 Ok(())
326}
327
328async fn pump_futures_to_tokio_counted<R, W>(
331 mut src: R,
332 dst: &mut W,
333 counter: &AtomicU64,
334) -> Result<(), TunnelError>
335where
336 R: futures::io::AsyncRead + Unpin,
337 W: tokio::io::AsyncWrite + Unpin,
338{
339 let mut buf = [0u8; 16 * 1024];
340 loop {
341 let n = src
342 .read(&mut buf)
343 .await
344 .map_err(|e| TunnelError::Internal(format!("read: {e}")))?;
345 if n == 0 {
346 break;
347 }
348 dst.write_all(&buf[..n])
349 .await
350 .map_err(|e| TunnelError::Internal(format!("write: {e}")))?;
351 counter.fetch_add(n as u64, Ordering::Relaxed);
352 }
353 Ok(())
354}
355
356async fn pump_tokio_to_futures_counted<R, W>(
357 src: &mut R,
358 dst: &mut W,
359 counter: &AtomicU64,
360) -> Result<(), TunnelError>
361where
362 R: tokio::io::AsyncRead + Unpin,
363 W: futures::io::AsyncWrite + Unpin,
364{
365 let mut buf = [0u8; 16 * 1024];
366 loop {
367 let n = src
368 .read(&mut buf)
369 .await
370 .map_err(|e| TunnelError::Internal(format!("read: {e}")))?;
371 if n == 0 {
372 break;
373 }
374 dst.write_all(&buf[..n])
375 .await
376 .map_err(|e| TunnelError::Internal(format!("write: {e}")))?;
377 counter.fetch_add(n as u64, Ordering::Relaxed);
378 }
379 Ok(())
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385
386 #[test]
387 fn extract_path_strips_scheme() {
388 assert_eq!(
389 extract_path("https://abc.trycloudflare.com/path?q=1"),
390 "/path?q=1"
391 );
392 assert_eq!(extract_path("https://abc.trycloudflare.com"), "/");
393 assert_eq!(extract_path("/relative/x"), "/relative/x");
394 }
395
396 #[test]
397 fn build_head_includes_method_host_path() {
398 let req = ConnectRequest {
399 dest: "https://abc.trycloudflare.com/foo".into(),
400 conn_type: ConnectionType::Http,
401 metadata: vec![
402 (HTTP_METHOD_KEY.into(), "POST".into()),
403 (HTTP_HOST_KEY.into(), "abc.trycloudflare.com".into()),
404 (format!("{HTTP_HEADER_KEY}:User-Agent"), "x/1".into()),
405 (format!("{HTTP_HEADER_KEY}:X-Stuff"), "yo".into()),
406 ],
407 };
408 let head = build_request_head(&req).unwrap();
409 assert!(head.starts_with("POST /foo HTTP/1.1\r\n"));
410 assert!(head.contains("Host: abc.trycloudflare.com\r\n"));
411 assert!(head.contains("User-Agent: x/1\r\n"));
412 assert!(head.contains("X-Stuff: yo\r\n"));
413 assert!(head.ends_with("\r\n\r\n"));
414 }
415}