1use std::sync::atomic::{AtomicU64, Ordering};
20use std::sync::Arc;
21use std::time::Duration;
22
23use futures::{AsyncReadExt, AsyncWriteExt};
24use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
25use tokio::net::TcpStream;
26use tracing::{debug, warn};
27
28use crate::error::TunnelError;
29use crate::pool::Pool;
30use crate::stream::{
31 self, ConnectRequest, ConnectionType, HTTP_HEADER_KEY, HTTP_HOST_KEY, HTTP_METHOD_KEY,
32 HTTP_STATUS_KEY,
33};
34
35#[derive(Debug, Default, Clone)]
37pub struct StreamCounters {
38 pub bytes_in: Arc<AtomicU64>,
39 pub bytes_out: Arc<AtomicU64>,
40}
41
42pub const LOCAL_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
46
47const MAX_HEADER_BYTES: usize = 32 * 1024;
49
50pub async fn handle_inbound_stream(
54 local_port: u16,
55 send: quinn::SendStream,
56 recv: quinn::RecvStream,
57 counters: StreamCounters,
58 pool: Arc<Pool>,
59) -> Result<(), TunnelError> {
60 let (mut reader, mut writer) = stream::split(send, recv);
61 let req = stream::read_connect_request(&mut reader).await?;
62 debug!(dest = %req.dest, ty = ?req.conn_type, "inbound stream");
63
64 match req.conn_type {
65 ConnectionType::Http | ConnectionType::Websocket => {
66 proxy_http(local_port, req, reader, writer, counters, pool).await
67 }
68 ConnectionType::Tcp => {
69 proxy_tcp(local_port, &req, &mut reader, &mut writer, &counters).await
70 }
71 }
72}
73
74#[derive(Debug, Clone, Copy)]
77struct RequestShape {
78 content_length: Option<u64>,
82 is_chunked: bool,
84 is_upgrade: bool,
86 wants_close: bool,
88}
89
90impl RequestShape {
91 fn poolable(&self) -> bool {
92 !self.is_chunked && !self.is_upgrade && !self.wants_close
93 }
94}
95
96fn analyse_request(req: &ConnectRequest) -> RequestShape {
97 let mut shape = RequestShape {
98 content_length: None,
99 is_chunked: false,
100 is_upgrade: false,
101 wants_close: false,
102 };
103 for (k, v) in &req.metadata {
104 let Some(name) = k.strip_prefix(&format!("{HTTP_HEADER_KEY}:")) else {
105 continue;
106 };
107 let lname = name.to_ascii_lowercase();
108 let lval = v.to_ascii_lowercase();
109 match lname.as_str() {
110 "content-length" => {
111 shape.content_length = v.parse().ok();
112 }
113 "transfer-encoding" if lval.contains("chunked") => {
114 shape.is_chunked = true;
115 }
116 "upgrade" => {
117 shape.is_upgrade = true;
118 }
119 "connection" => {
120 if lval.contains("upgrade") {
121 shape.is_upgrade = true;
122 }
123 if lval.contains("close") {
124 shape.wants_close = true;
125 }
126 }
127 _ => {}
128 }
129 }
130 shape
131}
132
133#[derive(Debug, Clone)]
134struct ResponseShape {
135 content_length: Option<u64>,
136 is_chunked: bool,
137 is_upgrade: bool, wants_close: bool,
139}
140
141impl ResponseShape {
142 fn poolable(&self) -> bool {
143 self.content_length.is_some() && !self.is_chunked && !self.is_upgrade && !self.wants_close
144 }
145}
146
147fn analyse_response(status: u16, headers: &[(String, String)]) -> ResponseShape {
148 let mut shape = ResponseShape {
149 content_length: None,
150 is_chunked: false,
151 is_upgrade: status == 101,
152 wants_close: false,
153 };
154 for (name, value) in headers {
155 let lname = name.to_ascii_lowercase();
156 let lval = value.to_ascii_lowercase();
157 match lname.as_str() {
158 "content-length" => shape.content_length = value.parse().ok(),
159 "transfer-encoding" if lval.contains("chunked") => {
160 shape.is_chunked = true;
161 }
162 "connection" => {
163 if lval.contains("close") {
164 shape.wants_close = true;
165 }
166 if lval.contains("upgrade") {
167 shape.is_upgrade = true;
168 }
169 }
170 "upgrade" => shape.is_upgrade = true,
171 _ => {}
172 }
173 }
174 shape
175}
176
177async fn proxy_http<R, W>(
180 local_port: u16,
181 request: ConnectRequest,
182 from_edge: R,
183 mut to_edge: W,
184 counters: StreamCounters,
185 pool: Arc<Pool>,
186) -> Result<(), TunnelError>
187where
188 R: futures::io::AsyncRead + Unpin,
189 W: futures::io::AsyncWrite + Unpin,
190{
191 let req_shape = analyse_request(&request);
192
193 let tcp = match tokio::time::timeout(LOCAL_CONNECT_TIMEOUT, pool.acquire()).await {
195 Ok(Ok(s)) => s,
196 Ok(Err(e)) => {
197 warn!(error = %e, local_port, "TCP connect refused");
198 return write_error_response(&mut to_edge, 502, &format!("local connect: {e}")).await;
199 }
200 Err(_) => {
201 warn!(local_port, "TCP connect timed out");
202 return write_error_response(&mut to_edge, 504, "local connect timed out").await;
203 }
204 };
205
206 let (tcp_read, mut tcp_write) = tcp.into_split();
207
208 let head = build_request_head(&request, req_shape.poolable());
210 tcp_write
211 .write_all(head.as_bytes())
212 .await
213 .map_err(|e| TunnelError::Internal(format!("tcp write head: {e}")))?;
214
215 if req_shape.poolable() {
216 run_pooled(
218 req_shape, from_edge, to_edge, tcp_read, tcp_write, counters, &pool, local_port,
219 )
220 .await
221 } else {
222 run_bidi(from_edge, to_edge, tcp_read, tcp_write, counters).await
224 }
225}
226
227#[allow(clippy::too_many_arguments)]
228async fn run_pooled<R, W>(
229 req_shape: RequestShape,
230 mut from_edge: R,
231 mut to_edge: W,
232 mut tcp_read: tokio::net::tcp::OwnedReadHalf,
233 mut tcp_write: tokio::net::tcp::OwnedWriteHalf,
234 counters: StreamCounters,
235 pool: &Pool,
236 local_port: u16,
237) -> Result<(), TunnelError>
238where
239 R: futures::io::AsyncRead + Unpin,
240 W: futures::io::AsyncWrite + Unpin,
241{
242 let in_counter = counters.bytes_in.clone();
243 let out_counter = counters.bytes_out.clone();
244
245 if let Some(n) = req_shape.content_length {
248 if n > 0 {
249 pump_n_futures_to_tokio(&mut from_edge, &mut tcp_write, n, &in_counter).await?;
250 }
251 }
252 let (status, headers, leftover) = read_http_response_head(&mut tcp_read).await?;
257 debug!(status, header_count = headers.len(), "origin response");
258 let resp_shape = analyse_response(status, &headers);
259
260 let mut meta: Vec<(String, String)> = Vec::with_capacity(headers.len() + 1);
262 meta.push((HTTP_STATUS_KEY.into(), status.to_string()));
263 for (name, value) in &headers {
264 meta.push((format!("{HTTP_HEADER_KEY}:{name}"), value.clone()));
265 }
266 let meta_refs: Vec<(&str, &str)> = meta.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
267 stream::write_connect_response(&mut to_edge, "", &meta_refs).await?;
268
269 if !leftover.is_empty() {
271 to_edge
272 .write_all(&leftover)
273 .await
274 .map_err(|e| TunnelError::Internal(format!("write leftover body: {e}")))?;
275 out_counter.fetch_add(leftover.len() as u64, Ordering::Relaxed);
276 }
277
278 if let Some(total) = resp_shape.content_length.filter(|_| resp_shape.poolable()) {
279 let remaining = total.saturating_sub(leftover.len() as u64);
282 if remaining > 0 {
283 pump_n_tokio_to_futures(&mut tcp_read, &mut to_edge, remaining, &out_counter).await?;
284 }
285 to_edge
286 .close()
287 .await
288 .map_err(|e| TunnelError::Internal(format!("close to_edge: {e}")))?;
289
290 match tcp_read.reunite(tcp_write) {
294 Ok(socket) => pool.release(socket).await,
295 Err(e) => {
296 warn!(error = %e, "tcp halves did not reunite; dropping socket");
297 }
298 }
299 let _ = local_port; Ok(())
301 } else {
302 pump_tokio_to_futures_counted(&mut tcp_read, &mut to_edge, &out_counter)
306 .await
307 .ok();
308 to_edge
309 .close()
310 .await
311 .map_err(|e| TunnelError::Internal(format!("close to_edge: {e}")))?;
312 Ok(())
313 }
314}
315
316async fn run_bidi<R, W>(
317 mut from_edge: R,
318 mut to_edge: W,
319 mut tcp_read: tokio::net::tcp::OwnedReadHalf,
320 mut tcp_write: tokio::net::tcp::OwnedWriteHalf,
321 counters: StreamCounters,
322) -> Result<(), TunnelError>
323where
324 R: futures::io::AsyncRead + Unpin,
325 W: futures::io::AsyncWrite + Unpin,
326{
327 let in_counter = counters.bytes_in.clone();
332 let out_counter = counters.bytes_out.clone();
333 let edge_to_local = async {
334 let _ = pump_futures_to_tokio_counted(&mut from_edge, &mut tcp_write, &in_counter).await;
335 let _ = tcp_write.shutdown().await;
336 Ok::<(), TunnelError>(())
337 };
338 let local_to_edge = async {
339 let (status, headers, leftover) = read_http_response_head(&mut tcp_read).await?;
340 debug!(
341 status,
342 header_count = headers.len(),
343 "origin response (bidi)"
344 );
345 let mut meta: Vec<(String, String)> = Vec::with_capacity(headers.len() + 1);
346 meta.push((HTTP_STATUS_KEY.into(), status.to_string()));
347 for (name, value) in &headers {
348 meta.push((format!("{HTTP_HEADER_KEY}:{name}"), value.clone()));
349 }
350 let meta_refs: Vec<(&str, &str)> =
351 meta.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
352 stream::write_connect_response(&mut to_edge, "", &meta_refs).await?;
353 if !leftover.is_empty() {
354 to_edge
355 .write_all(&leftover)
356 .await
357 .map_err(|e| TunnelError::Internal(format!("write leftover body: {e}")))?;
358 out_counter.fetch_add(leftover.len() as u64, Ordering::Relaxed);
359 }
360 pump_tokio_to_futures_counted(&mut tcp_read, &mut to_edge, &out_counter).await
361 };
362
363 let (_, response_result) = tokio::join!(edge_to_local, local_to_edge);
364 response_result?;
365 to_edge
366 .close()
367 .await
368 .map_err(|e| TunnelError::Internal(format!("close to_edge: {e}")))?;
369 Ok(())
370}
371
372fn build_request_head(req: &ConnectRequest, keep_alive: bool) -> String {
375 let method = req.meta(HTTP_METHOD_KEY).unwrap_or("GET");
376 let host = req.meta(HTTP_HOST_KEY).unwrap_or("");
377 let path = extract_path(&req.dest);
378
379 let mut head = String::with_capacity(256);
380 head.push_str(method);
381 head.push(' ');
382 head.push_str(&path);
383 head.push_str(" HTTP/1.1\r\n");
384 if !host.is_empty() {
385 head.push_str("Host: ");
386 head.push_str(host);
387 head.push_str("\r\n");
388 }
389
390 let mut saw_connection = false;
391 for (k, v) in &req.metadata {
392 if let Some(name) = k.strip_prefix(&format!("{HTTP_HEADER_KEY}:")) {
393 if name.eq_ignore_ascii_case("host") {
394 continue;
395 }
396 if name.eq_ignore_ascii_case("connection") {
397 saw_connection = true;
398 }
399 head.push_str(name);
400 head.push_str(": ");
401 head.push_str(v);
402 head.push_str("\r\n");
403 }
404 }
405 if !saw_connection {
407 if keep_alive {
408 head.push_str("Connection: keep-alive\r\n");
409 } else {
410 head.push_str("Connection: close\r\n");
411 }
412 }
413 head.push_str("\r\n");
414 head
415}
416
417fn extract_path(dest: &str) -> String {
418 if let Some(after_scheme) = dest.find("://") {
419 let rest = &dest[after_scheme + 3..];
420 if let Some(slash) = rest.find('/') {
421 return rest[slash..].to_string();
422 }
423 return "/".into();
424 }
425 if dest.starts_with('/') {
426 return dest.to_string();
427 }
428 "/".into()
429}
430
431async fn write_error_response<W>(writer: &mut W, status: u16, msg: &str) -> Result<(), TunnelError>
432where
433 W: futures::io::AsyncWrite + Unpin,
434{
435 let meta = [(HTTP_STATUS_KEY, status.to_string())];
436 let refs: Vec<(&str, &str)> = meta.iter().map(|(k, v)| (*k, v.as_str())).collect();
437 stream::write_connect_response(writer, msg, &refs).await?;
438 Ok(())
439}
440
441async fn read_http_response_head(
442 tcp: &mut (impl tokio::io::AsyncRead + Unpin),
443) -> Result<(u16, Vec<(String, String)>, Vec<u8>), TunnelError> {
444 let mut buf = Vec::with_capacity(4096);
445 let mut tmp = [0u8; 2048];
446 loop {
447 let n = tcp
448 .read(&mut tmp)
449 .await
450 .map_err(|e| TunnelError::Internal(format!("tcp read head: {e}")))?;
451 if n == 0 {
452 return Err(TunnelError::Internal(
453 "local origin closed before sending response head".into(),
454 ));
455 }
456 buf.extend_from_slice(&tmp[..n]);
457 if buf.len() > MAX_HEADER_BYTES {
458 return Err(TunnelError::Internal(format!(
459 "response header exceeds {MAX_HEADER_BYTES} bytes"
460 )));
461 }
462 let mut headers = [httparse::EMPTY_HEADER; 64];
463 let mut resp = httparse::Response::new(&mut headers);
464 match resp
465 .parse(&buf)
466 .map_err(|e| TunnelError::Internal(format!("httparse: {e}")))?
467 {
468 httparse::Status::Complete(consumed) => {
469 let status = resp
470 .code
471 .ok_or_else(|| TunnelError::Internal("response had no status code".into()))?;
472 let pairs = resp
473 .headers
474 .iter()
475 .map(|h| {
476 let v = String::from_utf8_lossy(h.value).into_owned();
477 (h.name.to_string(), v)
478 })
479 .collect::<Vec<_>>();
480 let leftover = buf.split_off(consumed);
481 return Ok((status, pairs, leftover));
482 }
483 httparse::Status::Partial => {}
484 }
485 }
486}
487
488async fn proxy_tcp<R, W>(
491 local_port: u16,
492 _request: &ConnectRequest,
493 from_edge: &mut R,
494 to_edge: &mut W,
495 counters: &StreamCounters,
496) -> Result<(), TunnelError>
497where
498 R: futures::io::AsyncRead + Unpin,
499 W: futures::io::AsyncWrite + Unpin,
500{
501 let tcp = TcpStream::connect(("127.0.0.1", local_port))
502 .await
503 .map_err(|e| TunnelError::Internal(format!("tcp connect: {e}")))?;
504 let (mut r, mut w) = tcp.into_split();
505 stream::write_connect_response(to_edge, "", &[]).await?;
506 let edge_to_local = pump_futures_to_tokio_counted(from_edge, &mut w, &counters.bytes_in);
507 let local_to_edge = pump_tokio_to_futures_counted(&mut r, to_edge, &counters.bytes_out);
508 let _ = futures::future::join(edge_to_local, local_to_edge).await;
509 Ok(())
510}
511
512async fn pump_futures_to_tokio_counted<R, W>(
515 mut src: R,
516 dst: &mut W,
517 counter: &AtomicU64,
518) -> Result<(), TunnelError>
519where
520 R: futures::io::AsyncRead + Unpin,
521 W: tokio::io::AsyncWrite + Unpin,
522{
523 let mut buf = [0u8; 16 * 1024];
524 loop {
525 let n = src
526 .read(&mut buf)
527 .await
528 .map_err(|e| TunnelError::Internal(format!("read: {e}")))?;
529 if n == 0 {
530 break;
531 }
532 dst.write_all(&buf[..n])
533 .await
534 .map_err(|e| TunnelError::Internal(format!("write: {e}")))?;
535 counter.fetch_add(n as u64, Ordering::Relaxed);
536 }
537 Ok(())
538}
539
540async fn pump_tokio_to_futures_counted<R, W>(
541 src: &mut R,
542 dst: &mut W,
543 counter: &AtomicU64,
544) -> Result<(), TunnelError>
545where
546 R: tokio::io::AsyncRead + Unpin,
547 W: futures::io::AsyncWrite + Unpin,
548{
549 let mut buf = [0u8; 16 * 1024];
550 loop {
551 let n = src
552 .read(&mut buf)
553 .await
554 .map_err(|e| TunnelError::Internal(format!("read: {e}")))?;
555 if n == 0 {
556 break;
557 }
558 dst.write_all(&buf[..n])
559 .await
560 .map_err(|e| TunnelError::Internal(format!("write: {e}")))?;
561 counter.fetch_add(n as u64, Ordering::Relaxed);
562 }
563 Ok(())
564}
565
566async fn pump_n_futures_to_tokio<R, W>(
569 src: &mut R,
570 dst: &mut W,
571 mut n: u64,
572 counter: &AtomicU64,
573) -> Result<(), TunnelError>
574where
575 R: futures::io::AsyncRead + Unpin,
576 W: tokio::io::AsyncWrite + Unpin,
577{
578 let mut buf = [0u8; 16 * 1024];
579 while n > 0 {
580 let want = std::cmp::min(buf.len() as u64, n) as usize;
581 let read = src
582 .read(&mut buf[..want])
583 .await
584 .map_err(|e| TunnelError::Internal(format!("read: {e}")))?;
585 if read == 0 {
586 return Err(TunnelError::Internal(format!(
587 "source EOF with {n} bytes still expected"
588 )));
589 }
590 dst.write_all(&buf[..read])
591 .await
592 .map_err(|e| TunnelError::Internal(format!("write: {e}")))?;
593 counter.fetch_add(read as u64, Ordering::Relaxed);
594 n -= read as u64;
595 }
596 Ok(())
597}
598
599async fn pump_n_tokio_to_futures<R, W>(
602 src: &mut R,
603 dst: &mut W,
604 mut n: u64,
605 counter: &AtomicU64,
606) -> Result<(), TunnelError>
607where
608 R: tokio::io::AsyncRead + Unpin,
609 W: futures::io::AsyncWrite + Unpin,
610{
611 let mut buf = [0u8; 16 * 1024];
612 while n > 0 {
613 let want = std::cmp::min(buf.len() as u64, n) as usize;
614 let read = src
615 .read(&mut buf[..want])
616 .await
617 .map_err(|e| TunnelError::Internal(format!("read: {e}")))?;
618 if read == 0 {
619 return Err(TunnelError::Internal(format!(
620 "tcp EOF with {n} bytes still expected"
621 )));
622 }
623 dst.write_all(&buf[..read])
624 .await
625 .map_err(|e| TunnelError::Internal(format!("write: {e}")))?;
626 counter.fetch_add(read as u64, Ordering::Relaxed);
627 n -= read as u64;
628 }
629 Ok(())
630}
631
632#[cfg(test)]
633mod tests {
634 use super::*;
635
636 #[test]
637 fn extract_path_strips_scheme() {
638 assert_eq!(
639 extract_path("https://abc.trycloudflare.com/path?q=1"),
640 "/path?q=1"
641 );
642 assert_eq!(extract_path("https://abc.trycloudflare.com"), "/");
643 assert_eq!(extract_path("/relative/x"), "/relative/x");
644 }
645
646 #[test]
647 fn build_head_includes_method_host_path() {
648 let req = ConnectRequest {
649 dest: "https://abc.trycloudflare.com/foo".into(),
650 conn_type: ConnectionType::Http,
651 metadata: vec![
652 (HTTP_METHOD_KEY.into(), "POST".into()),
653 (HTTP_HOST_KEY.into(), "abc.trycloudflare.com".into()),
654 (format!("{HTTP_HEADER_KEY}:User-Agent"), "x/1".into()),
655 (format!("{HTTP_HEADER_KEY}:X-Stuff"), "yo".into()),
656 ],
657 };
658 let head = build_request_head(&req, true);
659 assert!(head.starts_with("POST /foo HTTP/1.1\r\n"));
660 assert!(head.contains("Host: abc.trycloudflare.com\r\n"));
661 assert!(head.contains("User-Agent: x/1\r\n"));
662 assert!(head.contains("X-Stuff: yo\r\n"));
663 assert!(head.contains("Connection: keep-alive\r\n"));
664 assert!(head.ends_with("\r\n\r\n"));
665 }
666
667 #[test]
668 fn poolable_request_default() {
669 let req = ConnectRequest {
670 dest: "https://x/".into(),
671 conn_type: ConnectionType::Http,
672 metadata: vec![
673 (HTTP_METHOD_KEY.into(), "GET".into()),
674 (HTTP_HOST_KEY.into(), "x".into()),
675 ],
676 };
677 let s = analyse_request(&req);
678 assert!(s.poolable());
679 assert_eq!(s.content_length, None);
680 }
681
682 #[test]
683 fn websocket_request_not_poolable() {
684 let req = ConnectRequest {
685 dest: "https://x/ws".into(),
686 conn_type: ConnectionType::Websocket,
687 metadata: vec![
688 (HTTP_METHOD_KEY.into(), "GET".into()),
689 (HTTP_HOST_KEY.into(), "x".into()),
690 (format!("{HTTP_HEADER_KEY}:Upgrade"), "websocket".into()),
691 (format!("{HTTP_HEADER_KEY}:Connection"), "Upgrade".into()),
692 ],
693 };
694 let s = analyse_request(&req);
695 assert!(s.is_upgrade);
696 assert!(!s.poolable());
697 }
698
699 #[test]
700 fn chunked_request_not_poolable() {
701 let req = ConnectRequest {
702 dest: "https://x/upload".into(),
703 conn_type: ConnectionType::Http,
704 metadata: vec![
705 (HTTP_METHOD_KEY.into(), "POST".into()),
706 (
707 format!("{HTTP_HEADER_KEY}:Transfer-Encoding"),
708 "chunked".into(),
709 ),
710 ],
711 };
712 let s = analyse_request(&req);
713 assert!(s.is_chunked);
714 assert!(!s.poolable());
715 }
716
717 #[test]
718 fn response_with_content_length_is_poolable() {
719 let hs = vec![("Content-Length".into(), "42".into())];
720 let s = analyse_response(200, &hs);
721 assert!(s.poolable());
722 assert_eq!(s.content_length, Some(42));
723 }
724
725 #[test]
726 fn response_101_never_poolable() {
727 let hs = vec![("Upgrade".into(), "websocket".into())];
728 let s = analyse_response(101, &hs);
729 assert!(s.is_upgrade);
730 assert!(!s.poolable());
731 }
732}