1use std::sync::atomic::{AtomicU64, Ordering};
20use std::sync::Arc;
21use std::time::Duration;
22
23use futures::{AsyncReadExt, AsyncWriteExt};
24use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
25use tokio::net::TcpStream;
26use tracing::{debug, warn};
27
28use crate::error::TunnelError;
29use crate::pool::Pool;
30use crate::stream::{
31 self, ConnectRequest, ConnectionType, HTTP_HEADER_KEY, HTTP_HOST_KEY, HTTP_METHOD_KEY,
32 HTTP_STATUS_KEY,
33};
34
35#[derive(Debug, Default, Clone)]
37pub struct StreamCounters {
38 pub bytes_in: Arc<AtomicU64>,
39 pub bytes_out: Arc<AtomicU64>,
40}
41
42pub const LOCAL_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
46
47const MAX_HEADER_BYTES: usize = 32 * 1024;
49
50pub async fn handle_inbound_stream(
54 local_port: u16,
55 send: quinn::SendStream,
56 recv: quinn::RecvStream,
57 counters: StreamCounters,
58 pool: Arc<Pool>,
59) -> Result<(), TunnelError> {
60 let (mut reader, mut writer) = stream::split(send, recv);
61 let req = stream::read_connect_request(&mut reader).await?;
62 debug!(dest = %req.dest, ty = ?req.conn_type, "inbound stream");
63
64 match req.conn_type {
65 ConnectionType::Http | ConnectionType::Websocket => {
66 proxy_http(local_port, req, reader, writer, counters, pool).await
67 }
68 ConnectionType::Tcp => {
69 proxy_tcp(local_port, &req, &mut reader, &mut writer, &counters).await
70 }
71 }
72}
73
74#[derive(Debug, Clone, Copy)]
77struct RequestShape {
78 content_length: Option<u64>,
82 is_chunked: bool,
84 is_upgrade: bool,
86 wants_close: bool,
88}
89
90impl RequestShape {
91 fn poolable(&self) -> bool {
92 !self.is_chunked && !self.is_upgrade && !self.wants_close
93 }
94}
95
96fn analyse_request(req: &ConnectRequest) -> RequestShape {
97 let mut shape = RequestShape {
98 content_length: None,
99 is_chunked: false,
100 is_upgrade: false,
101 wants_close: false,
102 };
103 for (k, v) in &req.metadata {
104 let Some(name) = k.strip_prefix(&format!("{HTTP_HEADER_KEY}:")) else {
105 continue;
106 };
107 let lname = name.to_ascii_lowercase();
108 let lval = v.to_ascii_lowercase();
109 match lname.as_str() {
110 "content-length" => {
111 shape.content_length = v.parse().ok();
112 }
113 "transfer-encoding" => {
114 if lval.contains("chunked") {
115 shape.is_chunked = true;
116 }
117 }
118 "upgrade" => {
119 shape.is_upgrade = true;
120 }
121 "connection" => {
122 if lval.contains("upgrade") {
123 shape.is_upgrade = true;
124 }
125 if lval.contains("close") {
126 shape.wants_close = true;
127 }
128 }
129 _ => {}
130 }
131 }
132 shape
133}
134
135#[derive(Debug, Clone)]
136struct ResponseShape {
137 content_length: Option<u64>,
138 is_chunked: bool,
139 is_upgrade: bool, wants_close: bool,
141}
142
143impl ResponseShape {
144 fn poolable(&self) -> bool {
145 self.content_length.is_some() && !self.is_chunked && !self.is_upgrade && !self.wants_close
146 }
147}
148
149fn analyse_response(status: u16, headers: &[(String, String)]) -> ResponseShape {
150 let mut shape = ResponseShape {
151 content_length: None,
152 is_chunked: false,
153 is_upgrade: status == 101,
154 wants_close: false,
155 };
156 for (name, value) in headers {
157 let lname = name.to_ascii_lowercase();
158 let lval = value.to_ascii_lowercase();
159 match lname.as_str() {
160 "content-length" => shape.content_length = value.parse().ok(),
161 "transfer-encoding" => {
162 if lval.contains("chunked") {
163 shape.is_chunked = true;
164 }
165 }
166 "connection" => {
167 if lval.contains("close") {
168 shape.wants_close = true;
169 }
170 if lval.contains("upgrade") {
171 shape.is_upgrade = true;
172 }
173 }
174 "upgrade" => shape.is_upgrade = true,
175 _ => {}
176 }
177 }
178 shape
179}
180
181async fn proxy_http<R, W>(
184 local_port: u16,
185 request: ConnectRequest,
186 from_edge: R,
187 mut to_edge: W,
188 counters: StreamCounters,
189 pool: Arc<Pool>,
190) -> Result<(), TunnelError>
191where
192 R: futures::io::AsyncRead + Unpin,
193 W: futures::io::AsyncWrite + Unpin,
194{
195 let req_shape = analyse_request(&request);
196
197 let tcp = match tokio::time::timeout(LOCAL_CONNECT_TIMEOUT, pool.acquire()).await {
199 Ok(Ok(s)) => s,
200 Ok(Err(e)) => {
201 warn!(error = %e, local_port, "TCP connect refused");
202 return write_error_response(&mut to_edge, 502, &format!("local connect: {e}")).await;
203 }
204 Err(_) => {
205 warn!(local_port, "TCP connect timed out");
206 return write_error_response(&mut to_edge, 504, "local connect timed out").await;
207 }
208 };
209
210 let (tcp_read, mut tcp_write) = tcp.into_split();
211
212 let head = build_request_head(&request, req_shape.poolable());
214 tcp_write
215 .write_all(head.as_bytes())
216 .await
217 .map_err(|e| TunnelError::Internal(format!("tcp write head: {e}")))?;
218
219 if req_shape.poolable() {
220 run_pooled(
222 req_shape, from_edge, to_edge, tcp_read, tcp_write, counters, &pool, local_port,
223 )
224 .await
225 } else {
226 run_bidi(from_edge, to_edge, tcp_read, tcp_write, counters).await
228 }
229}
230
231#[allow(clippy::too_many_arguments)]
232async fn run_pooled<R, W>(
233 req_shape: RequestShape,
234 mut from_edge: R,
235 mut to_edge: W,
236 mut tcp_read: tokio::net::tcp::OwnedReadHalf,
237 mut tcp_write: tokio::net::tcp::OwnedWriteHalf,
238 counters: StreamCounters,
239 pool: &Pool,
240 local_port: u16,
241) -> Result<(), TunnelError>
242where
243 R: futures::io::AsyncRead + Unpin,
244 W: futures::io::AsyncWrite + Unpin,
245{
246 let in_counter = counters.bytes_in.clone();
247 let out_counter = counters.bytes_out.clone();
248
249 if let Some(n) = req_shape.content_length {
252 if n > 0 {
253 pump_n_futures_to_tokio(&mut from_edge, &mut tcp_write, n, &in_counter).await?;
254 }
255 }
256 let (status, headers, leftover) = read_http_response_head(&mut tcp_read).await?;
261 debug!(status, header_count = headers.len(), "origin response");
262 let resp_shape = analyse_response(status, &headers);
263
264 let mut meta: Vec<(String, String)> = Vec::with_capacity(headers.len() + 1);
266 meta.push((HTTP_STATUS_KEY.into(), status.to_string()));
267 for (name, value) in &headers {
268 meta.push((format!("{HTTP_HEADER_KEY}:{name}"), value.clone()));
269 }
270 let meta_refs: Vec<(&str, &str)> = meta.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
271 stream::write_connect_response(&mut to_edge, "", &meta_refs).await?;
272
273 if !leftover.is_empty() {
275 to_edge
276 .write_all(&leftover)
277 .await
278 .map_err(|e| TunnelError::Internal(format!("write leftover body: {e}")))?;
279 out_counter.fetch_add(leftover.len() as u64, Ordering::Relaxed);
280 }
281
282 if let Some(total) = resp_shape.content_length.filter(|_| resp_shape.poolable()) {
283 let remaining = total.saturating_sub(leftover.len() as u64);
286 if remaining > 0 {
287 pump_n_tokio_to_futures(&mut tcp_read, &mut to_edge, remaining, &out_counter).await?;
288 }
289 to_edge
290 .close()
291 .await
292 .map_err(|e| TunnelError::Internal(format!("close to_edge: {e}")))?;
293
294 match tcp_read.reunite(tcp_write) {
298 Ok(socket) => pool.release(socket).await,
299 Err(e) => {
300 warn!(error = %e, "tcp halves did not reunite; dropping socket");
301 }
302 }
303 let _ = local_port; Ok(())
305 } else {
306 pump_tokio_to_futures_counted(&mut tcp_read, &mut to_edge, &out_counter)
310 .await
311 .ok();
312 to_edge
313 .close()
314 .await
315 .map_err(|e| TunnelError::Internal(format!("close to_edge: {e}")))?;
316 Ok(())
317 }
318}
319
320async fn run_bidi<R, W>(
321 mut from_edge: R,
322 mut to_edge: W,
323 mut tcp_read: tokio::net::tcp::OwnedReadHalf,
324 mut tcp_write: tokio::net::tcp::OwnedWriteHalf,
325 counters: StreamCounters,
326) -> Result<(), TunnelError>
327where
328 R: futures::io::AsyncRead + Unpin,
329 W: futures::io::AsyncWrite + Unpin,
330{
331 let in_counter = counters.bytes_in.clone();
336 let out_counter = counters.bytes_out.clone();
337 let edge_to_local = async {
338 let _ = pump_futures_to_tokio_counted(&mut from_edge, &mut tcp_write, &in_counter).await;
339 let _ = tcp_write.shutdown().await;
340 Ok::<(), TunnelError>(())
341 };
342 let local_to_edge = async {
343 let (status, headers, leftover) = read_http_response_head(&mut tcp_read).await?;
344 debug!(
345 status,
346 header_count = headers.len(),
347 "origin response (bidi)"
348 );
349 let mut meta: Vec<(String, String)> = Vec::with_capacity(headers.len() + 1);
350 meta.push((HTTP_STATUS_KEY.into(), status.to_string()));
351 for (name, value) in &headers {
352 meta.push((format!("{HTTP_HEADER_KEY}:{name}"), value.clone()));
353 }
354 let meta_refs: Vec<(&str, &str)> =
355 meta.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
356 stream::write_connect_response(&mut to_edge, "", &meta_refs).await?;
357 if !leftover.is_empty() {
358 to_edge
359 .write_all(&leftover)
360 .await
361 .map_err(|e| TunnelError::Internal(format!("write leftover body: {e}")))?;
362 out_counter.fetch_add(leftover.len() as u64, Ordering::Relaxed);
363 }
364 pump_tokio_to_futures_counted(&mut tcp_read, &mut to_edge, &out_counter).await
365 };
366
367 let (_, response_result) = tokio::join!(edge_to_local, local_to_edge);
368 response_result?;
369 to_edge
370 .close()
371 .await
372 .map_err(|e| TunnelError::Internal(format!("close to_edge: {e}")))?;
373 Ok(())
374}
375
376fn build_request_head(req: &ConnectRequest, keep_alive: bool) -> String {
379 let method = req.meta(HTTP_METHOD_KEY).unwrap_or("GET");
380 let host = req.meta(HTTP_HOST_KEY).unwrap_or("");
381 let path = extract_path(&req.dest);
382
383 let mut head = String::with_capacity(256);
384 head.push_str(method);
385 head.push(' ');
386 head.push_str(&path);
387 head.push_str(" HTTP/1.1\r\n");
388 if !host.is_empty() {
389 head.push_str("Host: ");
390 head.push_str(host);
391 head.push_str("\r\n");
392 }
393
394 let mut saw_connection = false;
395 for (k, v) in &req.metadata {
396 if let Some(name) = k.strip_prefix(&format!("{HTTP_HEADER_KEY}:")) {
397 if name.eq_ignore_ascii_case("host") {
398 continue;
399 }
400 if name.eq_ignore_ascii_case("connection") {
401 saw_connection = true;
402 }
403 head.push_str(name);
404 head.push_str(": ");
405 head.push_str(v);
406 head.push_str("\r\n");
407 }
408 }
409 if !saw_connection {
411 if keep_alive {
412 head.push_str("Connection: keep-alive\r\n");
413 } else {
414 head.push_str("Connection: close\r\n");
415 }
416 }
417 head.push_str("\r\n");
418 head
419}
420
421fn extract_path(dest: &str) -> String {
422 if let Some(after_scheme) = dest.find("://") {
423 let rest = &dest[after_scheme + 3..];
424 if let Some(slash) = rest.find('/') {
425 return rest[slash..].to_string();
426 }
427 return "/".into();
428 }
429 if dest.starts_with('/') {
430 return dest.to_string();
431 }
432 "/".into()
433}
434
435async fn write_error_response<W>(writer: &mut W, status: u16, msg: &str) -> Result<(), TunnelError>
436where
437 W: futures::io::AsyncWrite + Unpin,
438{
439 let meta = [(HTTP_STATUS_KEY, status.to_string())];
440 let refs: Vec<(&str, &str)> = meta.iter().map(|(k, v)| (*k, v.as_str())).collect();
441 stream::write_connect_response(writer, msg, &refs).await?;
442 Ok(())
443}
444
445async fn read_http_response_head(
446 tcp: &mut (impl tokio::io::AsyncRead + Unpin),
447) -> Result<(u16, Vec<(String, String)>, Vec<u8>), TunnelError> {
448 let mut buf = Vec::with_capacity(4096);
449 let mut tmp = [0u8; 2048];
450 loop {
451 let n = tcp
452 .read(&mut tmp)
453 .await
454 .map_err(|e| TunnelError::Internal(format!("tcp read head: {e}")))?;
455 if n == 0 {
456 return Err(TunnelError::Internal(
457 "local origin closed before sending response head".into(),
458 ));
459 }
460 buf.extend_from_slice(&tmp[..n]);
461 if buf.len() > MAX_HEADER_BYTES {
462 return Err(TunnelError::Internal(format!(
463 "response header exceeds {MAX_HEADER_BYTES} bytes"
464 )));
465 }
466 let mut headers = [httparse::EMPTY_HEADER; 64];
467 let mut resp = httparse::Response::new(&mut headers);
468 match resp
469 .parse(&buf)
470 .map_err(|e| TunnelError::Internal(format!("httparse: {e}")))?
471 {
472 httparse::Status::Complete(consumed) => {
473 let status = resp
474 .code
475 .ok_or_else(|| TunnelError::Internal("response had no status code".into()))?;
476 let pairs = resp
477 .headers
478 .iter()
479 .map(|h| {
480 let v = String::from_utf8_lossy(h.value).into_owned();
481 (h.name.to_string(), v)
482 })
483 .collect::<Vec<_>>();
484 let leftover = buf.split_off(consumed);
485 return Ok((status, pairs, leftover));
486 }
487 httparse::Status::Partial => {}
488 }
489 }
490}
491
492async fn proxy_tcp<R, W>(
495 local_port: u16,
496 _request: &ConnectRequest,
497 from_edge: &mut R,
498 to_edge: &mut W,
499 counters: &StreamCounters,
500) -> Result<(), TunnelError>
501where
502 R: futures::io::AsyncRead + Unpin,
503 W: futures::io::AsyncWrite + Unpin,
504{
505 let tcp = TcpStream::connect(("127.0.0.1", local_port))
506 .await
507 .map_err(|e| TunnelError::Internal(format!("tcp connect: {e}")))?;
508 let (mut r, mut w) = tcp.into_split();
509 stream::write_connect_response(to_edge, "", &[]).await?;
510 let edge_to_local = pump_futures_to_tokio_counted(from_edge, &mut w, &counters.bytes_in);
511 let local_to_edge = pump_tokio_to_futures_counted(&mut r, to_edge, &counters.bytes_out);
512 let _ = futures::future::join(edge_to_local, local_to_edge).await;
513 Ok(())
514}
515
516async fn pump_futures_to_tokio_counted<R, W>(
519 mut src: R,
520 dst: &mut W,
521 counter: &AtomicU64,
522) -> Result<(), TunnelError>
523where
524 R: futures::io::AsyncRead + Unpin,
525 W: tokio::io::AsyncWrite + Unpin,
526{
527 let mut buf = [0u8; 16 * 1024];
528 loop {
529 let n = src
530 .read(&mut buf)
531 .await
532 .map_err(|e| TunnelError::Internal(format!("read: {e}")))?;
533 if n == 0 {
534 break;
535 }
536 dst.write_all(&buf[..n])
537 .await
538 .map_err(|e| TunnelError::Internal(format!("write: {e}")))?;
539 counter.fetch_add(n as u64, Ordering::Relaxed);
540 }
541 Ok(())
542}
543
544async fn pump_tokio_to_futures_counted<R, W>(
545 src: &mut R,
546 dst: &mut W,
547 counter: &AtomicU64,
548) -> Result<(), TunnelError>
549where
550 R: tokio::io::AsyncRead + Unpin,
551 W: futures::io::AsyncWrite + Unpin,
552{
553 let mut buf = [0u8; 16 * 1024];
554 loop {
555 let n = src
556 .read(&mut buf)
557 .await
558 .map_err(|e| TunnelError::Internal(format!("read: {e}")))?;
559 if n == 0 {
560 break;
561 }
562 dst.write_all(&buf[..n])
563 .await
564 .map_err(|e| TunnelError::Internal(format!("write: {e}")))?;
565 counter.fetch_add(n as u64, Ordering::Relaxed);
566 }
567 Ok(())
568}
569
570async fn pump_n_futures_to_tokio<R, W>(
573 src: &mut R,
574 dst: &mut W,
575 mut n: u64,
576 counter: &AtomicU64,
577) -> Result<(), TunnelError>
578where
579 R: futures::io::AsyncRead + Unpin,
580 W: tokio::io::AsyncWrite + Unpin,
581{
582 let mut buf = [0u8; 16 * 1024];
583 while n > 0 {
584 let want = std::cmp::min(buf.len() as u64, n) as usize;
585 let read = src
586 .read(&mut buf[..want])
587 .await
588 .map_err(|e| TunnelError::Internal(format!("read: {e}")))?;
589 if read == 0 {
590 return Err(TunnelError::Internal(format!(
591 "source EOF with {n} bytes still expected"
592 )));
593 }
594 dst.write_all(&buf[..read])
595 .await
596 .map_err(|e| TunnelError::Internal(format!("write: {e}")))?;
597 counter.fetch_add(read as u64, Ordering::Relaxed);
598 n -= read as u64;
599 }
600 Ok(())
601}
602
603async fn pump_n_tokio_to_futures<R, W>(
606 src: &mut R,
607 dst: &mut W,
608 mut n: u64,
609 counter: &AtomicU64,
610) -> Result<(), TunnelError>
611where
612 R: tokio::io::AsyncRead + Unpin,
613 W: futures::io::AsyncWrite + Unpin,
614{
615 let mut buf = [0u8; 16 * 1024];
616 while n > 0 {
617 let want = std::cmp::min(buf.len() as u64, n) as usize;
618 let read = src
619 .read(&mut buf[..want])
620 .await
621 .map_err(|e| TunnelError::Internal(format!("read: {e}")))?;
622 if read == 0 {
623 return Err(TunnelError::Internal(format!(
624 "tcp EOF with {n} bytes still expected"
625 )));
626 }
627 dst.write_all(&buf[..read])
628 .await
629 .map_err(|e| TunnelError::Internal(format!("write: {e}")))?;
630 counter.fetch_add(read as u64, Ordering::Relaxed);
631 n -= read as u64;
632 }
633 Ok(())
634}
635
636#[cfg(test)]
637mod tests {
638 use super::*;
639
640 #[test]
641 fn extract_path_strips_scheme() {
642 assert_eq!(
643 extract_path("https://abc.trycloudflare.com/path?q=1"),
644 "/path?q=1"
645 );
646 assert_eq!(extract_path("https://abc.trycloudflare.com"), "/");
647 assert_eq!(extract_path("/relative/x"), "/relative/x");
648 }
649
650 #[test]
651 fn build_head_includes_method_host_path() {
652 let req = ConnectRequest {
653 dest: "https://abc.trycloudflare.com/foo".into(),
654 conn_type: ConnectionType::Http,
655 metadata: vec![
656 (HTTP_METHOD_KEY.into(), "POST".into()),
657 (HTTP_HOST_KEY.into(), "abc.trycloudflare.com".into()),
658 (format!("{HTTP_HEADER_KEY}:User-Agent"), "x/1".into()),
659 (format!("{HTTP_HEADER_KEY}:X-Stuff"), "yo".into()),
660 ],
661 };
662 let head = build_request_head(&req, true);
663 assert!(head.starts_with("POST /foo HTTP/1.1\r\n"));
664 assert!(head.contains("Host: abc.trycloudflare.com\r\n"));
665 assert!(head.contains("User-Agent: x/1\r\n"));
666 assert!(head.contains("X-Stuff: yo\r\n"));
667 assert!(head.contains("Connection: keep-alive\r\n"));
668 assert!(head.ends_with("\r\n\r\n"));
669 }
670
671 #[test]
672 fn poolable_request_default() {
673 let req = ConnectRequest {
674 dest: "https://x/".into(),
675 conn_type: ConnectionType::Http,
676 metadata: vec![
677 (HTTP_METHOD_KEY.into(), "GET".into()),
678 (HTTP_HOST_KEY.into(), "x".into()),
679 ],
680 };
681 let s = analyse_request(&req);
682 assert!(s.poolable());
683 assert_eq!(s.content_length, None);
684 }
685
686 #[test]
687 fn websocket_request_not_poolable() {
688 let req = ConnectRequest {
689 dest: "https://x/ws".into(),
690 conn_type: ConnectionType::Websocket,
691 metadata: vec![
692 (HTTP_METHOD_KEY.into(), "GET".into()),
693 (HTTP_HOST_KEY.into(), "x".into()),
694 (format!("{HTTP_HEADER_KEY}:Upgrade"), "websocket".into()),
695 (format!("{HTTP_HEADER_KEY}:Connection"), "Upgrade".into()),
696 ],
697 };
698 let s = analyse_request(&req);
699 assert!(s.is_upgrade);
700 assert!(!s.poolable());
701 }
702
703 #[test]
704 fn chunked_request_not_poolable() {
705 let req = ConnectRequest {
706 dest: "https://x/upload".into(),
707 conn_type: ConnectionType::Http,
708 metadata: vec![
709 (HTTP_METHOD_KEY.into(), "POST".into()),
710 (
711 format!("{HTTP_HEADER_KEY}:Transfer-Encoding"),
712 "chunked".into(),
713 ),
714 ],
715 };
716 let s = analyse_request(&req);
717 assert!(s.is_chunked);
718 assert!(!s.poolable());
719 }
720
721 #[test]
722 fn response_with_content_length_is_poolable() {
723 let hs = vec![("Content-Length".into(), "42".into())];
724 let s = analyse_response(200, &hs);
725 assert!(s.poolable());
726 assert_eq!(s.content_length, Some(42));
727 }
728
729 #[test]
730 fn response_101_never_poolable() {
731 let hs = vec![("Upgrade".into(), "websocket".into())];
732 let s = analyse_response(101, &hs);
733 assert!(s.is_upgrade);
734 assert!(!s.poolable());
735 }
736}