1use std::borrow::Cow;
9use std::io;
10use std::net::{IpAddr, SocketAddr};
11use std::sync::Arc;
12use std::time::Duration;
13
14use bytes::Bytes;
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16use tokio::net::TcpStream;
17use tokio::sync::mpsc;
18
19use crate::conn::ProxyConnectState;
20use crate::policy::{EgressEvaluation, HostnameSource, NetworkPolicy, Protocol};
21use crate::secrets::config::{SecretsConfig, ViolationAction};
22use crate::secrets::handler::{
23 SecretsHandler, first_line_is_not_http_request, looks_like_http_request_prefix,
24};
25use crate::shared::SharedState;
26use crate::tls::proxy::{TlsProxyContext, tls_proxy_task};
27use crate::tls::sni;
28use crate::tls::state::TlsState;
29
30const SERVER_READ_BUF_SIZE: usize = 16384;
36
37const CONNECT_RESP_LIMIT: usize = 8192;
39
40const PEEK_BUF_SIZE: usize = 16384;
42
43const PEEK_BUDGET: Duration = Duration::from_secs(5);
46
47#[derive(Debug)]
52struct ConnectRequest {
53 bytes: Vec<u8>,
54 header_end: usize,
55 target: ConnectTarget,
56}
57
58#[derive(Debug, Clone, PartialEq, Eq)]
59struct ConnectTarget {
60 host: String,
61 port: u16,
62 expected_sni: Option<String>,
63}
64
65impl ConnectRequest {
70 fn header_bytes(&self) -> &[u8] {
71 &self.bytes[..self.header_end]
72 }
73
74 fn post_header_bytes(&self) -> &[u8] {
75 &self.bytes[self.header_end..]
76 }
77}
78
79impl ConnectTarget {
80 fn is_intercepted(&self, tls_state: &TlsState) -> bool {
81 tls_state.config.intercepted_ports.contains(&self.port)
82 }
83
84 fn guest_dst(&self, fallback: SocketAddr, shared: &SharedState) -> SocketAddr {
85 if let Ok(ip) = self.host.parse::<IpAddr>() {
86 return SocketAddr::new(ip, self.port);
87 }
88
89 if self.host.eq_ignore_ascii_case(crate::HOST_ALIAS) {
90 match fallback.ip() {
91 IpAddr::V4(_) => {
92 if let Some(ip) = shared.gateway_ipv4() {
93 return SocketAddr::new(IpAddr::V4(ip), self.port);
94 }
95 }
96 IpAddr::V6(_) => {
97 if let Some(ip) = shared.gateway_ipv6() {
98 return SocketAddr::new(IpAddr::V6(ip), self.port);
99 }
100 }
101 }
102 if let Some(ip) = shared.gateway_ipv4() {
103 return SocketAddr::new(IpAddr::V4(ip), self.port);
104 }
105 if let Some(ip) = shared.gateway_ipv6() {
106 return SocketAddr::new(IpAddr::V6(ip), self.port);
107 }
108 }
109
110 SocketAddr::new(fallback.ip(), self.port)
111 }
112}
113
114pub(crate) async fn connect_upstream(
120 dst: SocketAddr,
121 proxy_connect: &ProxyConnectState,
122 shared: &SharedState,
123) -> io::Result<TcpStream> {
124 match TcpStream::connect(dst).await {
125 Ok(s) => {
126 proxy_connect.mark_connected();
127 Ok(s)
128 }
129 Err(e) => {
130 proxy_connect.mark_upstream_connect_failed();
131 shared.proxy_wake.wake();
132 Err(e)
133 }
134 }
135}
136
137#[allow(clippy::too_many_arguments)]
148pub fn spawn_tcp_proxy(
149 handle: &tokio::runtime::Handle,
150 guest_dst: SocketAddr,
151 connect_dst: SocketAddr,
152 from_smoltcp: mpsc::Receiver<Bytes>,
153 to_smoltcp: mpsc::Sender<Bytes>,
154 shared: Arc<SharedState>,
155 network_policy: Arc<NetworkPolicy>,
156 secrets: Arc<SecretsConfig>,
157 tls_state: Option<Arc<TlsState>>,
158 proxy_connect: Arc<ProxyConnectState>,
159) {
160 handle.spawn(async move {
161 if let Err(e) = tcp_proxy_task(
162 guest_dst,
163 connect_dst,
164 from_smoltcp,
165 to_smoltcp,
166 shared,
167 network_policy,
168 secrets,
169 tls_state,
170 proxy_connect,
171 )
172 .await
173 {
174 tracing::debug!(dst = %connect_dst, error = %e, "TCP proxy task ended");
175 }
176 });
177}
178
179#[allow(clippy::too_many_arguments)]
182async fn tcp_proxy_task(
183 guest_dst: SocketAddr,
184 connect_dst: SocketAddr,
185 mut from_smoltcp: mpsc::Receiver<Bytes>,
186 to_smoltcp: mpsc::Sender<Bytes>,
187 shared: Arc<SharedState>,
188 network_policy: Arc<NetworkPolicy>,
189 secrets: Arc<SecretsConfig>,
190 tls_state: Option<Arc<TlsState>>,
191 proxy_connect: Arc<ProxyConnectState>,
192) -> io::Result<()> {
193 let (mut initial_buf, sni) = if network_policy.has_domain_rules() {
199 peek_for_sni(&mut from_smoltcp, PEEK_BUF_SIZE, PEEK_BUDGET).await
200 } else {
201 (Vec::new(), None)
202 };
203
204 if network_policy.has_domain_rules() {
210 let source = match sni.as_deref() {
211 Some(name) => HostnameSource::Sni(name),
212 None => HostnameSource::CacheOnly,
213 };
214 match network_policy.evaluate_egress_with_source(guest_dst, Protocol::Tcp, &shared, source)
215 {
216 EgressEvaluation::Allow => {}
217 EgressEvaluation::Deny => {
218 tracing::debug!(
219 dst = %guest_dst,
220 source = source.label(),
221 "TCP egress denied by domain policy",
222 );
223 proxy_connect.mark_policy_denied();
224 shared.proxy_wake.wake();
225 return Ok(());
226 }
227 EgressEvaluation::DeferUntilHostname => {
228 debug_assert!(false, "DeferUntilHostname leaked into TCP proxy task");
229 proxy_connect.mark_policy_denied();
230 shared.proxy_wake.wake();
231 return Ok(());
232 }
233 }
234 }
235
236 if let Some(tls_state) = tls_state.clone() {
238 if initial_buf.is_empty() {
239 let (peeked, _) = peek_for_sni(&mut from_smoltcp, PEEK_BUF_SIZE, PEEK_BUDGET).await;
240 initial_buf = peeked;
241 }
242 if could_be_connect_request(&initial_buf) {
243 return handle_connect_tunnel(
244 guest_dst,
245 connect_dst,
246 initial_buf,
247 from_smoltcp,
248 to_smoltcp,
249 shared,
250 network_policy,
251 tls_state,
252 proxy_connect,
253 None,
254 )
255 .await;
256 }
257 }
258
259 let stream = connect_upstream(connect_dst, &proxy_connect, &shared).await?;
264 let (mut server_rx, mut server_tx) = stream.into_split();
265
266 let want_headers = secrets.has_plain_http_candidates() || secrets.has_host_scoped_secrets();
272 let (initial_buf, is_tls) = if !secrets.secrets.is_empty() {
273 classify_first_flight(
274 initial_buf,
275 &mut from_smoltcp,
276 &mut server_rx,
277 &to_smoltcp,
278 &shared,
279 want_headers,
280 PEEK_BUF_SIZE,
281 PEEK_BUDGET,
282 )
283 .await?
284 } else {
285 (initial_buf, false)
286 };
287
288 if let Some(tls_state) = tls_state.clone()
289 && could_be_connect_request(&initial_buf)
290 {
291 let proxy_stream = server_rx
296 .reunite(server_tx)
297 .map_err(|_| io::Error::other("failed to reunite proxy stream halves"))?;
298 return handle_connect_tunnel(
299 guest_dst,
300 connect_dst,
301 initial_buf,
302 from_smoltcp,
303 to_smoltcp,
304 shared,
305 network_policy,
306 tls_state,
307 proxy_connect,
308 Some(proxy_stream),
309 )
310 .await;
311 }
312
313 let mut late_connect_state = tls_state;
314 let mut secrets_handler: Option<SecretsHandler> = if !secrets.secrets.is_empty() && !is_tls {
315 Some(match extract_http_host(&initial_buf) {
316 Some(host) => SecretsHandler::new_plain_http(&secrets, &host, guest_dst.ip(), &shared),
317 None => SecretsHandler::new_plain_http_invalid_host(&secrets),
318 })
319 } else {
320 None
321 };
322
323 if !initial_buf.is_empty() {
325 let out: Cow<[u8]> = match secrets_handler.as_mut() {
326 Some(h) => match h.substitute(&initial_buf) {
327 Ok(cow) => cow,
330 Err(action) => {
331 tracing::warn!(dst = %connect_dst, violation = ?action, "secret violation in first flight");
332 if matches!(action, ViolationAction::BlockAndTerminate) {
333 shared.trigger_termination();
334 }
335 return Ok(());
336 }
337 },
338 None => Cow::Borrowed(&initial_buf),
339 };
340 if !out.is_empty() {
341 if let Err(e) = server_tx.write_all(&out).await {
342 tracing::debug!(dst = %connect_dst, error = %e, "replay of buffered first flight failed");
343 return Ok(());
344 }
345 if let Err(e) = server_tx.flush().await {
346 tracing::debug!(dst = %connect_dst, error = %e, "flush after first flight failed");
347 return Ok(());
348 }
349 }
350 }
351
352 let mut server_buf = vec![0u8; SERVER_READ_BUF_SIZE];
353
354 loop {
359 tokio::select! {
360 data = from_smoltcp.recv() => {
362 match data {
363 Some(bytes) => {
364 if let Some(tls_state) = late_connect_state.take()
365 && could_be_connect_request(&bytes)
366 {
367 let proxy_stream = server_rx
372 .reunite(server_tx)
373 .map_err(|_| io::Error::other("failed to reunite proxy stream halves"))?;
374 return handle_connect_tunnel(
375 guest_dst,
376 connect_dst,
377 bytes.to_vec(),
378 from_smoltcp,
379 to_smoltcp,
380 shared,
381 network_policy,
382 tls_state,
383 proxy_connect,
384 Some(proxy_stream),
385 )
386 .await;
387 }
388 let out: Cow<[u8]> = match secrets_handler.as_mut() {
391 Some(h) => match h.substitute(&bytes) {
392 Ok(cow) => cow,
393 Err(action) => {
394 tracing::warn!(dst = %connect_dst, violation = ?action, "secret violation");
395 if matches!(action, ViolationAction::BlockAndTerminate) {
396 shared.trigger_termination();
397 }
398 break;
399 }
400 },
401 None => Cow::Borrowed(&bytes),
402 };
403 if !out.is_empty() {
404 if let Err(e) = server_tx.write_all(&out).await {
405 tracing::debug!(dst = %connect_dst, error = %e, "write to server failed");
406 break;
407 }
408 if let Err(e) = server_tx.flush().await {
409 tracing::debug!(dst = %connect_dst, error = %e, "flush to server failed");
410 break;
411 }
412 }
413 }
414 None => break,
416 }
417 }
418
419 result = server_rx.read(&mut server_buf) => {
421 match result {
422 Ok(0) => break, Ok(n) => {
424 late_connect_state = None;
427 let data = Bytes::copy_from_slice(&server_buf[..n]);
428 if to_smoltcp.send(data).await.is_err() {
429 break;
431 }
432 shared.proxy_wake.wake();
435 }
436 Err(e) => {
437 tracing::debug!(dst = %connect_dst, error = %e, "read from server failed");
438 break;
439 }
440 }
441 }
442 }
443 }
444
445 Ok(())
446}
447
448#[allow(clippy::too_many_arguments)]
454async fn handle_connect_tunnel(
455 guest_dst: SocketAddr,
456 proxy_dst: SocketAddr,
457 initial_buf: Vec<u8>,
458 mut from_smoltcp: mpsc::Receiver<Bytes>,
459 to_smoltcp: mpsc::Sender<Bytes>,
460 shared: Arc<SharedState>,
461 network_policy: Arc<NetworkPolicy>,
462 tls_state: Arc<TlsState>,
463 proxy_connect: Arc<ProxyConnectState>,
464 preconnected_proxy: Option<TcpStream>,
465) -> io::Result<()> {
466 let connect_req =
467 parse_connect_request(buffer_connect_request(initial_buf, &mut from_smoltcp).await?)?;
468
469 let connect_headers = match sanitize_connect_headers(
470 connect_req.header_bytes(),
471 &tls_state.secrets,
472 ) {
473 Ok(headers) => headers,
474 Err(action) => {
475 tracing::warn!(dst = %proxy_dst, violation = ?action, "secret violation in CONNECT headers");
476 if matches!(action, ViolationAction::BlockAndTerminate) {
477 shared.trigger_termination();
478 }
479 return Ok(());
480 }
481 };
482
483 let mut proxy_stream = match preconnected_proxy {
485 Some(stream) => stream,
486 None => match TcpStream::connect(proxy_dst).await {
487 Ok(s) => s,
488 Err(e) => {
489 proxy_connect.mark_upstream_connect_failed();
490 shared.proxy_wake.wake();
491 return Err(e);
492 }
493 },
494 };
495
496 if !connect_req.target.is_intercepted(&tls_state) {
497 proxy_stream.write_all(&connect_headers).await?;
498 proxy_stream.flush().await?;
499 let (proxy_resp, header_end) = read_connect_response_headers(&mut proxy_stream).await?;
500 if to_smoltcp
501 .send(Bytes::copy_from_slice(&proxy_resp[..header_end]))
502 .await
503 .is_err()
504 {
505 return Ok(());
506 }
507 if !proxy_resp[header_end..].is_empty()
508 && to_smoltcp
509 .send(Bytes::copy_from_slice(&proxy_resp[header_end..]))
510 .await
511 .is_err()
512 {
513 return Ok(());
514 }
515 shared.proxy_wake.wake();
516 if !connect_response_is_success(&proxy_resp[..header_end]) {
517 proxy_connect.mark_connected();
518 return Ok(());
519 }
520 if !connect_req.post_header_bytes().is_empty() {
521 proxy_stream
522 .write_all(connect_req.post_header_bytes())
523 .await?;
524 }
525 proxy_stream.flush().await?;
526 proxy_connect.mark_connected();
527 return relay_connected_stream(proxy_stream, from_smoltcp, to_smoltcp, shared).await;
528 }
529
530 proxy_stream.write_all(&connect_headers).await?;
531 proxy_stream.flush().await?;
532
533 let (proxy_resp, header_end) = read_connect_response_headers(&mut proxy_stream).await?;
534 if !connect_response_is_success(&proxy_resp[..header_end]) {
535 return Err(io::Error::new(
536 io::ErrorKind::ConnectionRefused,
537 "proxy rejected CONNECT",
538 ));
539 }
540 if !proxy_resp[header_end..].is_empty() {
541 return Err(io::Error::new(
542 io::ErrorKind::InvalidData,
543 "proxy sent unexpected bytes after CONNECT response headers",
544 ));
545 }
546 proxy_connect.mark_connected();
547
548 if to_smoltcp
549 .send(Bytes::copy_from_slice(&proxy_resp[..header_end]))
550 .await
551 .is_err()
552 {
553 return Ok(());
554 }
555 shared.proxy_wake.wake();
556
557 let tls_seed = connect_req.post_header_bytes().to_vec();
558 let tls_guest_dst = connect_req.target.guest_dst(guest_dst, &shared);
559 let expected_sni = connect_req.target.expected_sni.clone();
560
561 tls_proxy_task(
562 TlsProxyContext {
563 guest_dst: tls_guest_dst,
564 connect_dst: proxy_dst,
565 shared,
566 tls_state,
567 network_policy,
568 proxy_connect,
569 upstream_stream: Some(proxy_stream),
570 via_connect: expected_sni.is_some(),
571 expected_sni,
572 },
573 from_smoltcp,
574 to_smoltcp,
575 tls_seed,
576 )
577 .await
578}
579
580async fn relay_connected_stream(
582 stream: TcpStream,
583 mut from_smoltcp: mpsc::Receiver<Bytes>,
584 to_smoltcp: mpsc::Sender<Bytes>,
585 shared: Arc<SharedState>,
586) -> io::Result<()> {
587 let (mut server_rx, mut server_tx) = stream.into_split();
588 let mut server_buf = vec![0u8; SERVER_READ_BUF_SIZE];
589
590 loop {
591 tokio::select! {
592 data = from_smoltcp.recv() => {
593 match data {
594 Some(bytes) => {
595 server_tx.write_all(&bytes).await?;
596 server_tx.flush().await?;
597 }
598 None => break,
599 }
600 }
601 result = server_rx.read(&mut server_buf) => {
602 match result {
603 Ok(0) => break,
604 Ok(n) => {
605 if to_smoltcp
606 .send(Bytes::copy_from_slice(&server_buf[..n]))
607 .await
608 .is_err()
609 {
610 break;
611 }
612 shared.proxy_wake.wake();
613 }
614 Err(e) => return Err(e),
615 }
616 }
617 }
618 }
619
620 Ok(())
621}
622
623async fn buffer_connect_request(
624 mut buf: Vec<u8>,
625 from_smoltcp: &mut mpsc::Receiver<Bytes>,
626) -> io::Result<Vec<u8>> {
627 let timeout_fut = tokio::time::sleep(PEEK_BUDGET);
628 tokio::pin!(timeout_fut);
629
630 loop {
631 if !could_be_connect_request(&buf) {
632 return Err(io::Error::new(
633 io::ErrorKind::InvalidData,
634 "malformed CONNECT request prefix",
635 ));
636 }
637 if headers_end(&buf).is_some() {
638 return Ok(buf);
639 }
640 if buf.len() >= PEEK_BUF_SIZE {
641 return Err(io::Error::new(
642 io::ErrorKind::InvalidData,
643 "CONNECT request headers too large",
644 ));
645 }
646
647 tokio::select! {
648 biased;
649 _ = &mut timeout_fut => {
650 return Err(io::Error::new(
651 io::ErrorKind::TimedOut,
652 "timed out waiting for complete CONNECT request headers",
653 ));
654 }
655 data = from_smoltcp.recv() => match data {
656 Some(bytes) => {
657 buf.extend_from_slice(&bytes);
658 }
659 None => {
660 return Err(io::Error::new(
661 io::ErrorKind::UnexpectedEof,
662 "channel closed before complete CONNECT request headers",
663 ));
664 }
665 }
666 }
667 }
668}
669
670async fn read_connect_response_headers(stream: &mut TcpStream) -> io::Result<(Vec<u8>, usize)> {
671 tokio::time::timeout(PEEK_BUDGET, async {
672 let mut proxy_resp = Vec::with_capacity(256);
673 let mut buf = [0u8; 4096];
674 loop {
675 let n = stream.read(&mut buf).await?;
676 if n == 0 {
677 return Err(io::Error::new(
678 io::ErrorKind::UnexpectedEof,
679 "proxy closed before sending CONNECT response",
680 ));
681 }
682 proxy_resp.extend_from_slice(&buf[..n]);
683 if let Some(end) = headers_end(&proxy_resp) {
684 return Ok((proxy_resp, end));
685 }
686 if proxy_resp.len() > CONNECT_RESP_LIMIT {
687 return Err(io::Error::new(
688 io::ErrorKind::InvalidData,
689 "proxy CONNECT response too large",
690 ));
691 }
692 }
693 })
694 .await
695 .map_err(|_| {
696 io::Error::new(
697 io::ErrorKind::TimedOut,
698 "timed out waiting for proxy CONNECT response",
699 )
700 })?
701}
702
703fn sanitize_connect_headers<'a>(
704 header_bytes: &'a [u8],
705 secrets: &SecretsConfig,
706) -> Result<Cow<'a, [u8]>, ViolationAction> {
707 if secrets.secrets.is_empty() {
708 return Ok(Cow::Borrowed(header_bytes));
709 }
710
711 let mut handler = SecretsHandler::new_plain_http_untrusted_metadata(secrets);
712 handler.substitute(header_bytes)
713}
714
715fn headers_end(buf: &[u8]) -> Option<usize> {
717 buf.windows(4).position(|w| w == b"\r\n\r\n").map(|p| p + 4)
718}
719
720fn could_be_connect_request(buf: &[u8]) -> bool {
721 const PREFIX: &[u8] = b"CONNECT ";
722 if buf.is_empty() {
723 return false;
724 }
725 let n = buf.len().min(PREFIX.len());
726 buf[..n].eq_ignore_ascii_case(&PREFIX[..n])
727}
728
729fn parse_connect_request(bytes: Vec<u8>) -> io::Result<ConnectRequest> {
730 let header_end = headers_end(&bytes).ok_or_else(|| {
731 io::Error::new(
732 io::ErrorKind::InvalidData,
733 "incomplete CONNECT request headers",
734 )
735 })?;
736 let target = {
737 let request_line = bytes[..header_end]
738 .split(|&b| b == b'\n')
739 .next()
740 .unwrap_or(&[]);
741 let request_line = std::str::from_utf8(request_line)
742 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "CONNECT line is not UTF-8"))?
743 .trim_end_matches('\r');
744 let mut parts = request_line.split_ascii_whitespace();
745 let method = parts.next().unwrap_or_default();
746 let authority = parts.next().unwrap_or_default();
747 let version = parts.next().unwrap_or_default();
748 if !method.eq_ignore_ascii_case("CONNECT")
749 || authority.is_empty()
750 || !is_http_version(version)
751 || parts.next().is_some()
752 {
753 return Err(io::Error::new(
754 io::ErrorKind::InvalidData,
755 "malformed CONNECT request line",
756 ));
757 }
758 parse_connect_target(authority)?
759 };
760
761 Ok(ConnectRequest {
762 bytes,
763 header_end,
764 target,
765 })
766}
767
768fn parse_connect_target(authority: &str) -> io::Result<ConnectTarget> {
769 let authority = authority.trim();
770 let (host, port) = if let Some(rest) = authority.strip_prefix('[') {
771 let (host, rest) = rest.split_once(']').ok_or_else(|| {
772 io::Error::new(
773 io::ErrorKind::InvalidData,
774 "malformed CONNECT IPv6 authority",
775 )
776 })?;
777 let port = rest.strip_prefix(':').ok_or_else(|| {
778 io::Error::new(io::ErrorKind::InvalidData, "CONNECT authority missing port")
779 })?;
780 (host, port)
781 } else {
782 let (host, port) = authority.rsplit_once(':').ok_or_else(|| {
783 io::Error::new(io::ErrorKind::InvalidData, "CONNECT authority missing port")
784 })?;
785 if host.contains(':') {
786 return Err(io::Error::new(
787 io::ErrorKind::InvalidData,
788 "CONNECT IPv6 authority must be bracketed",
789 ));
790 }
791 (host, port)
792 };
793 let host = host.trim().trim_end_matches('.');
794 if host.is_empty() {
795 return Err(io::Error::new(
796 io::ErrorKind::InvalidData,
797 "CONNECT authority missing host",
798 ));
799 }
800 let port = port
801 .parse::<u16>()
802 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid CONNECT port"))?;
803 let expected_sni = host
804 .parse::<IpAddr>()
805 .is_err()
806 .then(|| host.to_ascii_lowercase());
807
808 Ok(ConnectTarget {
809 host: host.to_ascii_lowercase(),
810 port,
811 expected_sni,
812 })
813}
814
815fn is_http_version(version: &str) -> bool {
816 let Some(version) = version.strip_prefix("HTTP/") else {
817 return false;
818 };
819 let Some((major, minor)) = version.split_once('.') else {
820 return false;
821 };
822 !major.is_empty()
823 && !minor.is_empty()
824 && major.bytes().all(|b| b.is_ascii_digit())
825 && minor.bytes().all(|b| b.is_ascii_digit())
826}
827
828fn connect_response_is_success(headers: &[u8]) -> bool {
829 let Some(status_line) = headers.split(|&b| b == b'\n').next() else {
830 return false;
831 };
832 let Ok(status_line) = std::str::from_utf8(status_line) else {
833 return false;
834 };
835 let mut parts = status_line.trim_end_matches('\r').split_ascii_whitespace();
836 let version = parts.next().unwrap_or_default();
837 let status = parts.next().unwrap_or_default();
838 is_http_version(version)
839 && status.len() == 3
840 && status
841 .parse::<u16>()
842 .is_ok_and(|code| (200..300).contains(&code))
843}
844
845fn extract_http_host(buf: &[u8]) -> Option<String> {
855 if buf.first() == Some(&0x16) {
856 return None;
857 }
858 let mut headers = vec![httparse::EMPTY_HEADER; (buf.len() / 4).max(16)];
864 let mut req = httparse::Request::new(&mut headers);
865 req.parse(buf).ok()?;
866 req.headers
867 .iter()
868 .find(|h| h.name.eq_ignore_ascii_case("host"))
869 .and_then(|h| std::str::from_utf8(h.value).ok())
870 .map(|v| {
871 let host = v.trim();
872 host.rsplit_once(':')
874 .map(|(h, _)| h)
875 .unwrap_or(host)
876 .to_ascii_lowercase()
877 })
878 .filter(|h| !h.is_empty())
879}
880
881#[allow(clippy::too_many_arguments)]
898async fn classify_first_flight(
899 mut buf: Vec<u8>,
900 from_smoltcp: &mut mpsc::Receiver<Bytes>,
901 server_rx: &mut tokio::net::tcp::OwnedReadHalf,
902 to_smoltcp: &mpsc::Sender<Bytes>,
903 shared: &SharedState,
904 want_headers: bool,
905 max: usize,
906 budget: Duration,
907) -> io::Result<(Vec<u8>, bool)> {
908 let mut server_buf = vec![0u8; SERVER_READ_BUF_SIZE];
909 let timeout_fut = tokio::time::sleep(budget);
910 tokio::pin!(timeout_fut);
911
912 loop {
913 if !buf.is_empty() {
919 let is_tls = buf.first() == Some(&0x16);
920 let not_http = !is_tls
921 && (!looks_like_http_request_prefix(&buf) || first_line_is_not_http_request(&buf));
922 let done = !want_headers
923 || is_tls
924 || not_http
925 || buf.len() >= max
926 || buf.windows(4).any(|w| w == b"\r\n\r\n");
927 if done {
928 return Ok((buf, is_tls));
929 }
930 }
931
932 tokio::select! {
933 biased;
934 _ = &mut timeout_fut => {
935 let is_tls = buf.first() == Some(&0x16);
936 return Ok((buf, is_tls));
937 }
938 guest = from_smoltcp.recv() => match guest {
941 Some(bytes) => buf.extend_from_slice(&bytes),
942 None => {
943 let is_tls = buf.first() == Some(&0x16);
944 return Ok((buf, is_tls));
945 }
946 },
947 server = server_rx.read(&mut server_buf) => match server {
950 Ok(0) => {
951 let is_tls = buf.first() == Some(&0x16);
952 return Ok((buf, is_tls));
953 }
954 Ok(n) => {
955 let data = Bytes::copy_from_slice(&server_buf[..n]);
956 if to_smoltcp.send(data).await.is_err() {
957 let is_tls = buf.first() == Some(&0x16);
958 return Ok((buf, is_tls));
959 }
960 shared.proxy_wake.wake();
961 }
962 Err(e) => return Err(e),
963 },
964 }
965 }
966}
967
968async fn peek_for_sni(
978 rx: &mut mpsc::Receiver<Bytes>,
979 max: usize,
980 budget: Duration,
981) -> (Vec<u8>, Option<String>) {
982 let mut buf = Vec::with_capacity(PEEK_BUF_SIZE.min(8192));
983 let timeout_fut = tokio::time::sleep(budget);
984 tokio::pin!(timeout_fut);
985
986 let raw_sni = loop {
987 tokio::select! {
988 biased;
989 _ = &mut timeout_fut => break None,
990 data = rx.recv() => {
991 match data {
992 Some(bytes) => {
993 buf.extend_from_slice(&bytes);
994 if buf.first() != Some(&0x16) {
999 break None;
1000 }
1001 if let Some(name) = sni::extract_sni(&buf) {
1002 break Some(name);
1003 }
1004 if buf.len() >= max {
1005 break None;
1006 }
1007 }
1008 None => break None,
1009 }
1010 }
1011 }
1012 };
1013
1014 let canonical = raw_sni.map(|s| s.trim_end_matches('.').to_ascii_lowercase());
1015 (buf, canonical)
1016}
1017
1018#[cfg(test)]
1023mod tests {
1024 use super::*;
1025
1026 fn synthetic_client_hello(sni: &str) -> Vec<u8> {
1030 let host_bytes = sni.as_bytes();
1033 let host_len = host_bytes.len() as u16;
1034 let server_name_list_len = 3 + host_len; let extension_data_len = 2 + server_name_list_len; let extensions_total = 4 + extension_data_len; let mut body = Vec::new();
1039 body.extend_from_slice(&[0x03, 0x03]);
1041 body.extend_from_slice(&[0u8; 32]);
1043 body.push(0);
1045 body.extend_from_slice(&[0x00, 0x02, 0x00, 0x2f]);
1047 body.extend_from_slice(&[0x01, 0x00]);
1049 body.extend_from_slice(&extensions_total.to_be_bytes());
1051 body.extend_from_slice(&[0x00, 0x00]);
1053 body.extend_from_slice(&extension_data_len.to_be_bytes());
1054 body.extend_from_slice(&server_name_list_len.to_be_bytes());
1055 body.push(0x00); body.extend_from_slice(&host_len.to_be_bytes());
1057 body.extend_from_slice(host_bytes);
1058
1059 let handshake_len = body.len() as u32;
1060 let mut hs = Vec::new();
1061 hs.push(0x01); hs.extend_from_slice(&handshake_len.to_be_bytes()[1..]); hs.extend_from_slice(&body);
1064
1065 let record_len = hs.len() as u16;
1066 let mut record = Vec::new();
1067 record.extend_from_slice(&[0x16, 0x03, 0x01]); record.extend_from_slice(&record_len.to_be_bytes());
1069 record.extend_from_slice(&hs);
1070
1071 record
1072 }
1073
1074 #[test]
1075 fn could_be_connect_request_matches_split_prefixes_only() {
1076 assert!(could_be_connect_request(b"C"));
1077 assert!(could_be_connect_request(b"connect "));
1078 assert!(could_be_connect_request(b"CONNECT example.com:443"));
1079 assert!(!could_be_connect_request(b"CLIENT"));
1080 assert!(!could_be_connect_request(b"GET / HTTP/1.1\r\n"));
1081 }
1082
1083 #[tokio::test]
1084 async fn buffer_connect_request_reads_split_headers() {
1085 let (tx, mut rx) = mpsc::channel(4);
1086 tx.send(Bytes::from_static(b"NECT example.com:443 HTTP/1.1\r\n"))
1087 .await
1088 .unwrap();
1089 tx.send(Bytes::from_static(b"Host: example.com\r\n\r\n"))
1090 .await
1091 .unwrap();
1092 drop(tx);
1093
1094 let buffered = buffer_connect_request(b"CON".to_vec(), &mut rx)
1095 .await
1096 .unwrap();
1097 let parsed = parse_connect_request(buffered).unwrap();
1098
1099 assert_eq!(parsed.target.host, "example.com");
1100 assert_eq!(parsed.target.port, 443);
1101 assert_eq!(parsed.target.expected_sni.as_deref(), Some("example.com"));
1102 assert!(parsed.post_header_bytes().is_empty());
1103 }
1104
1105 #[test]
1106 fn parse_connect_request_preserves_post_header_tls_seed() {
1107 let mut request = b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com\r\n\r\n".to_vec();
1108 request.extend_from_slice(b"\x16\x03\x01client-hello");
1109
1110 let parsed = parse_connect_request(request).unwrap();
1111
1112 assert_eq!(
1113 parsed.header_bytes(),
1114 b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com\r\n\r\n"
1115 );
1116 assert_eq!(parsed.post_header_bytes(), b"\x16\x03\x01client-hello");
1117 }
1118
1119 #[test]
1120 fn parse_connect_target_requires_authority_port() {
1121 assert!(parse_connect_target("example.com").is_err());
1122 assert!(parse_connect_target("2001:db8::1:443").is_err());
1123
1124 let target = parse_connect_target("[2001:db8::1]:8443").unwrap();
1125 assert_eq!(target.host, "2001:db8::1");
1126 assert_eq!(target.port, 8443);
1127 assert_eq!(target.expected_sni, None);
1128 }
1129
1130 #[test]
1131 fn connect_response_success_requires_exact_2xx_status_code() {
1132 assert!(connect_response_is_success(
1133 b"HTTP/1.1 200 Connection Established\r\n\r\n"
1134 ));
1135 assert!(connect_response_is_success(
1136 b"HTTP/1.1 204 Connection Established\r\n\r\n"
1137 ));
1138 assert!(!connect_response_is_success(b"HTTP/1.1 2000 Weird\r\n\r\n"));
1139 assert!(!connect_response_is_success(b"HTTP/1.1 199 Nope\r\n\r\n"));
1140 assert!(!connect_response_is_success(b"NOTHTTP 200 OK\r\n\r\n"));
1141 }
1142
1143 #[tokio::test]
1144 async fn peek_for_sni_extracts_and_canonicalizes() {
1145 let (tx, mut rx) = mpsc::channel(4);
1146 let hello = synthetic_client_hello("Example.COM");
1147 tx.send(Bytes::from(hello.clone())).await.unwrap();
1148 drop(tx); let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1151 assert_eq!(sni.as_deref(), Some("example.com"));
1152 assert_eq!(buf, hello);
1153 }
1154
1155 #[tokio::test]
1156 async fn peek_for_sni_returns_none_on_channel_close_without_data() {
1157 let (tx, mut rx) = mpsc::channel::<Bytes>(1);
1158 drop(tx);
1159 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1160 assert!(buf.is_empty());
1161 assert_eq!(sni, None);
1162 }
1163
1164 #[tokio::test]
1165 async fn peek_for_sni_returns_none_on_non_tls_data() {
1166 let (tx, mut rx) = mpsc::channel(4);
1167 tx.send(Bytes::from_static(
1169 b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n",
1170 ))
1171 .await
1172 .unwrap();
1173 drop(tx);
1174 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1175 assert!(
1176 !buf.is_empty(),
1177 "buffered bytes must be returned for replay"
1178 );
1179 assert_eq!(sni, None);
1180 }
1181
1182 #[tokio::test]
1183 async fn peek_for_sni_falls_back_on_timeout() {
1184 let (tx, mut rx) = mpsc::channel::<Bytes>(1);
1185 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, Duration::from_millis(50)).await;
1187 drop(tx);
1188 assert!(buf.is_empty());
1189 assert_eq!(sni, None);
1190 }
1191
1192 #[tokio::test]
1193 async fn peek_for_sni_caps_at_max_bytes() {
1194 let (tx, mut rx) = mpsc::channel(4);
1195 let mut first = vec![0u8; 8192];
1199 first[0] = 0x16;
1200 tx.send(Bytes::from(first)).await.unwrap();
1201 tx.send(Bytes::from(vec![0u8; 8192])).await.unwrap();
1202 tx.send(Bytes::from(vec![0u8; 8192])).await.unwrap();
1203 drop(tx);
1204
1205 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1206 assert_eq!(sni, None, "no SNI in non-TLS data");
1207 assert!(
1208 buf.len() >= PEEK_BUF_SIZE,
1209 "buffer must hit the cap before bail-out: got {}",
1210 buf.len()
1211 );
1212 }
1213
1214 #[tokio::test]
1215 async fn peek_for_sni_bails_immediately_on_non_tls_first_byte() {
1216 let (tx, mut rx) = mpsc::channel(4);
1217 tx.send(Bytes::from_static(b"GET / HTTP/1.1\r\nHost: x\r\n\r\n"))
1219 .await
1220 .unwrap();
1221 drop(tx);
1222
1223 let started = std::time::Instant::now();
1226 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1227 let elapsed = started.elapsed();
1228 assert_eq!(sni, None);
1229 assert!(buf.starts_with(b"GET"));
1230 assert!(
1231 elapsed < Duration::from_millis(500),
1232 "non-TLS bail must be fast: took {elapsed:?}"
1233 );
1234 }
1235
1236 use std::net::IpAddr;
1241 use std::time::Duration as StdDuration;
1242
1243 use crate::policy::{Action, Destination, NetworkPolicy, PortRange, Rule};
1244 use crate::shared::{ResolvedHostnameFamily, SharedState};
1245
1246 const SHARED_FASTLY_IP: &str = "151.101.0.223";
1247
1248 fn shared_with(host: &str, ip: &str) -> SharedState {
1249 let shared = SharedState::new(4);
1250 shared.cache_resolved_hostname(
1251 host,
1252 ResolvedHostnameFamily::Ipv4,
1253 [ip.parse::<IpAddr>().unwrap()],
1254 StdDuration::from_secs(60),
1255 );
1256 shared
1257 }
1258
1259 fn allow_https(domain: &str) -> Rule {
1260 Rule {
1261 direction: crate::policy::Direction::Egress,
1262 destination: Destination::Domain(domain.parse().unwrap()),
1263 protocols: vec![Protocol::Tcp],
1264 ports: vec![PortRange::single(443)],
1265 action: Action::Allow,
1266 }
1267 }
1268
1269 #[tokio::test]
1272 async fn integration_sni_overrides_cache_for_over_allow() {
1273 let shared = shared_with("pypi.org", SHARED_FASTLY_IP);
1274 let policy = NetworkPolicy {
1275 default_egress: Action::Deny,
1276 default_ingress: Action::Allow,
1277 rules: vec![allow_https("pypi.org")],
1278 };
1279 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
1280
1281 let (tx, mut rx) = mpsc::channel(4);
1282 tx.send(Bytes::from(synthetic_client_hello("evil.com")))
1283 .await
1284 .unwrap();
1285 drop(tx);
1286
1287 let (initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1288 assert_eq!(sni.as_deref(), Some("evil.com"));
1289 assert!(!initial_buf.is_empty());
1290
1291 let source = sni
1292 .as_deref()
1293 .map(HostnameSource::Sni)
1294 .unwrap_or(HostnameSource::CacheOnly);
1295 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
1296 assert_eq!(
1297 eval,
1298 EgressEvaluation::Deny,
1299 "SNI=evil.com must not piggy-back on the cached pypi.org match",
1300 );
1301 }
1302
1303 #[tokio::test]
1306 async fn integration_sni_overrides_cache_for_over_block() {
1307 let shared = shared_with("ads.example.com", SHARED_FASTLY_IP);
1308 let policy = NetworkPolicy {
1309 default_egress: Action::Allow,
1310 default_ingress: Action::Allow,
1311 rules: vec![Rule::deny_egress(Destination::Domain(
1312 "ads.example.com".parse().unwrap(),
1313 ))],
1314 };
1315 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
1316
1317 let (tx, mut rx) = mpsc::channel(4);
1318 tx.send(Bytes::from(synthetic_client_hello("api.example.com")))
1319 .await
1320 .unwrap();
1321 drop(tx);
1322
1323 let (_initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1324 assert_eq!(sni.as_deref(), Some("api.example.com"));
1325
1326 let source = sni
1327 .as_deref()
1328 .map(HostnameSource::Sni)
1329 .unwrap_or(HostnameSource::CacheOnly);
1330 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
1331 assert_eq!(
1332 eval,
1333 EgressEvaluation::Allow,
1334 "SNI=api.example.com must not be caught by the deny on ads.example.com",
1335 );
1336 }
1337
1338 #[tokio::test]
1341 async fn integration_non_tls_falls_back_to_cache() {
1342 let shared = shared_with("pypi.org", SHARED_FASTLY_IP);
1343 let policy = NetworkPolicy {
1344 default_egress: Action::Deny,
1345 default_ingress: Action::Allow,
1346 rules: vec![allow_https("pypi.org")],
1347 };
1348 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
1349
1350 let (tx, mut rx) = mpsc::channel(4);
1351 tx.send(Bytes::from_static(
1353 b"GET / HTTP/1.1\r\nHost: pypi.org\r\n\r\n",
1354 ))
1355 .await
1356 .unwrap();
1357 drop(tx);
1358
1359 let (initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1360 assert_eq!(sni, None, "non-TLS data → no SNI");
1361 assert!(
1362 !initial_buf.is_empty(),
1363 "buffered bytes must survive for replay"
1364 );
1365
1366 let source = sni
1367 .as_deref()
1368 .map(HostnameSource::Sni)
1369 .unwrap_or(HostnameSource::CacheOnly);
1370 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
1371 assert_eq!(
1372 eval,
1373 EgressEvaluation::Allow,
1374 "cache-only fallback must still allow the cached hostname's IP",
1375 );
1376 }
1377
1378 #[tokio::test]
1381 async fn integration_sni_matches_domain_suffix_with_cache_binding() {
1382 let shared = shared_with("files.pythonhosted.org", SHARED_FASTLY_IP);
1383 let policy = NetworkPolicy {
1384 default_egress: Action::Deny,
1385 default_ingress: Action::Allow,
1386 rules: vec![Rule {
1387 direction: crate::policy::Direction::Egress,
1388 destination: Destination::DomainSuffix(".pythonhosted.org".parse().unwrap()),
1389 protocols: vec![Protocol::Tcp],
1390 ports: vec![PortRange::single(443)],
1391 action: Action::Allow,
1392 }],
1393 };
1394 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
1395
1396 let (tx, mut rx) = mpsc::channel(4);
1397 tx.send(Bytes::from(synthetic_client_hello(
1398 "files.pythonhosted.org",
1399 )))
1400 .await
1401 .unwrap();
1402 drop(tx);
1403
1404 let (_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1405 let source = sni
1406 .as_deref()
1407 .map(HostnameSource::Sni)
1408 .unwrap_or(HostnameSource::CacheOnly);
1409 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
1410 assert_eq!(eval, EgressEvaluation::Allow);
1411 }
1412
1413 #[tokio::test]
1418 async fn integration_sni_denies_domain_suffix_without_cache_binding() {
1419 let shared = SharedState::new(4); let policy = NetworkPolicy {
1421 default_egress: Action::Deny,
1422 default_ingress: Action::Allow,
1423 rules: vec![Rule {
1424 direction: crate::policy::Direction::Egress,
1425 destination: Destination::DomainSuffix(".pythonhosted.org".parse().unwrap()),
1426 protocols: vec![Protocol::Tcp],
1427 ports: vec![PortRange::single(443)],
1428 action: Action::Allow,
1429 }],
1430 };
1431 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
1432
1433 let (tx, mut rx) = mpsc::channel(4);
1434 tx.send(Bytes::from(synthetic_client_hello(
1435 "files.pythonhosted.org",
1436 )))
1437 .await
1438 .unwrap();
1439 drop(tx);
1440
1441 let (_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1442 let source = sni
1443 .as_deref()
1444 .map(HostnameSource::Sni)
1445 .unwrap_or(HostnameSource::CacheOnly);
1446 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
1447 assert_eq!(eval, EgressEvaluation::Deny);
1448 }
1449
1450 #[test]
1453 fn extract_http_host_basic() {
1454 let buf = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n";
1455 assert_eq!(extract_http_host(buf), Some("example.com".into()));
1456 }
1457
1458 #[test]
1459 fn extract_http_host_strips_port() {
1460 let buf = b"POST /api HTTP/1.1\r\nHost: api.company.com:8080\r\n\r\n";
1461 assert_eq!(extract_http_host(buf), Some("api.company.com".into()));
1462 }
1463
1464 #[test]
1465 fn extract_http_host_case_insensitive_lowercased() {
1466 let buf = b"GET / HTTP/1.1\r\nhost: Example.COM\r\n\r\n";
1467 assert_eq!(extract_http_host(buf), Some("example.com".into()));
1468 }
1469
1470 #[test]
1471 fn extract_http_host_no_host_header() {
1472 let buf = b"GET / HTTP/1.1\r\nX-Other: foo\r\n\r\n";
1473 assert_eq!(extract_http_host(buf), None);
1474 }
1475
1476 #[test]
1477 fn extract_http_host_incomplete_headers() {
1478 let buf = b"GET / HTTP/1.1\r\nHost: x";
1479 assert_eq!(extract_http_host(buf), None);
1480 }
1481
1482 #[test]
1483 fn extract_http_host_tls_first_byte() {
1484 let buf = [0x16u8, 0x03, 0x01, 0x00, 0x01];
1485 assert_eq!(extract_http_host(&buf), None);
1486 }
1487
1488 #[test]
1489 fn extract_http_host_with_many_headers() {
1490 let mut req = Vec::from(&b"GET / HTTP/1.1\r\n"[..]);
1493 for i in 0..100 {
1494 req.extend_from_slice(format!("X-Pad-{i}: v\r\n").as_bytes());
1495 }
1496 req.extend_from_slice(b"Host: example.com\r\n\r\n");
1497 assert_eq!(extract_http_host(&req), Some("example.com".into()));
1498 }
1499
1500 use std::sync::Arc;
1503 use tokio::io::AsyncReadExt;
1504 use tokio::net::TcpListener;
1505 use tokio::task::JoinHandle;
1506
1507 use crate::secrets::config::{HostPattern, SecretEntry, SecretInjection, SecretsConfig};
1508
1509 fn make_plain_http_secret(placeholder: &str, value: &str, require_tls: bool) -> SecretsConfig {
1510 SecretsConfig {
1511 secrets: vec![SecretEntry {
1512 env_var: "API_KEY".into(),
1513 value: value.into(),
1514 placeholder: placeholder.into(),
1515 allowed_hosts: vec![HostPattern::Any],
1516 injection: SecretInjection {
1517 headers: true,
1518 basic_auth: false,
1519 query_params: false,
1520 body: false,
1521 },
1522 on_violation: None,
1523 require_tls_identity: require_tls,
1524 }],
1525 ..Default::default()
1526 }
1527 }
1528
1529 fn make_host_bound_secret(placeholder: &str, value: &str, host: &str) -> SecretsConfig {
1530 SecretsConfig {
1531 secrets: vec![SecretEntry {
1532 env_var: "API_KEY".into(),
1533 value: value.into(),
1534 placeholder: placeholder.into(),
1535 allowed_hosts: vec![HostPattern::Exact(host.into())],
1536 injection: SecretInjection::default(),
1537 on_violation: None,
1538 require_tls_identity: true,
1539 }],
1540 ..Default::default()
1541 }
1542 }
1543
1544 #[test]
1545 fn sanitize_connect_headers_blocks_placeholder_metadata_header_by_default() {
1546 let secrets = make_host_bound_secret("$MSB_KEY", "real-secret-value", "example.com");
1547 let headers = b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\nProxy-Authorization: Bearer $MSB_KEY\r\nUser-Agent: curl\r\n\r\n";
1548
1549 assert_eq!(
1550 sanitize_connect_headers(headers, &secrets),
1551 Err(ViolationAction::BlockAndLog)
1552 );
1553 }
1554
1555 #[test]
1556 fn sanitize_connect_headers_respects_block_and_terminate() {
1557 let mut secrets = make_host_bound_secret("$MSB_KEY", "real-secret-value", "example.com");
1558 secrets.on_violation = ViolationAction::BlockAndTerminate;
1559 let headers = b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\nProxy-Authorization: Bearer $MSB_KEY\r\n\r\n";
1560
1561 assert_eq!(
1562 sanitize_connect_headers(headers, &secrets),
1563 Err(ViolationAction::BlockAndTerminate)
1564 );
1565 }
1566
1567 #[test]
1568 fn sanitize_connect_headers_respects_explicit_passthrough() {
1569 let mut secrets = make_host_bound_secret("$MSB_KEY", "real-secret-value", "example.com");
1570 secrets.on_violation = ViolationAction::Passthrough(vec![HostPattern::Any]);
1571 let headers = b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\nProxy-Authorization: Bearer $MSB_KEY\r\n\r\n";
1572
1573 let sanitized = sanitize_connect_headers(headers, &secrets).unwrap();
1574
1575 assert_eq!(sanitized.as_ref(), headers);
1576 assert!(
1577 !String::from_utf8_lossy(sanitized.as_ref()).contains("real-secret-value"),
1578 "passthrough must never substitute real secrets into CONNECT metadata"
1579 );
1580 }
1581
1582 #[test]
1583 fn sanitize_connect_headers_keeps_safe_metadata_headers() {
1584 let secrets = make_host_bound_secret("$MSB_KEY", "real-secret-value", "example.com");
1585 let headers =
1586 b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\nUser-Agent: curl\r\n\r\n";
1587
1588 let sanitized = sanitize_connect_headers(headers, &secrets).unwrap();
1589
1590 assert_eq!(sanitized.as_ref(), headers);
1591 }
1592
1593 #[test]
1594 fn sanitize_connect_headers_blocks_placeholder_in_request_line() {
1595 let secrets = make_host_bound_secret("$MSB_KEY", "real-secret-value", "example.com");
1596 let headers = b"CONNECT $MSB_KEY:443 HTTP/1.1\r\nHost: example.com:443\r\n\r\n";
1597
1598 assert_eq!(
1599 sanitize_connect_headers(headers, &secrets),
1600 Err(ViolationAction::BlockAndLog)
1601 );
1602 }
1603
1604 async fn spawn_sink() -> (SocketAddr, JoinHandle<Vec<u8>>) {
1605 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1606 let addr = listener.local_addr().unwrap();
1607 let handle = tokio::spawn(async move {
1608 let (mut stream, _) = listener.accept().await.unwrap();
1609 let mut received = Vec::new();
1610 let mut buf = vec![0u8; 4096];
1611 loop {
1612 match stream.read(&mut buf).await {
1613 Ok(0) | Err(_) => break,
1614 Ok(n) => received.extend_from_slice(&buf[..n]),
1615 }
1616 }
1617 received
1618 });
1619 (addr, handle)
1620 }
1621
1622 async fn relay_through_proxy(
1623 request: Vec<u8>,
1624 secrets: SecretsConfig,
1625 handle: JoinHandle<Vec<u8>>,
1626 server_addr: SocketAddr,
1627 ) -> Vec<u8> {
1628 let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
1629 let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
1630 let shared = SharedState::new(4);
1631 let policy = Arc::new(NetworkPolicy::default());
1632 let secrets = Arc::new(secrets);
1633 let proxy_connect = Arc::new(ProxyConnectState::new());
1634
1635 from_tx.send(Bytes::from(request)).await.unwrap();
1636 drop(from_tx);
1637
1638 tcp_proxy_task(
1639 server_addr,
1640 server_addr,
1641 from_rx,
1642 to_tx,
1643 Arc::new(shared),
1644 policy,
1645 secrets,
1646 None,
1647 proxy_connect,
1648 )
1649 .await
1650 .unwrap();
1651
1652 handle.await.unwrap()
1653 }
1654
1655 #[tokio::test]
1656 async fn plain_http_substitutes_placeholder_when_host_arrives_in_second_segment() {
1657 let (addr, sink) = spawn_sink().await;
1660 let secrets = make_plain_http_secret("$MSB_KEY", "real-secret-value", false);
1661
1662 let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
1663 let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
1664 let proxy_connect = Arc::new(ProxyConnectState::new());
1665
1666 from_tx
1667 .send(Bytes::from_static(b"GET /api HTTP/1.1\r\n"))
1668 .await
1669 .unwrap();
1670 from_tx
1671 .send(Bytes::from_static(
1672 b"Host: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n",
1673 ))
1674 .await
1675 .unwrap();
1676 drop(from_tx);
1677
1678 tcp_proxy_task(
1679 addr,
1680 addr,
1681 from_rx,
1682 to_tx,
1683 Arc::new(SharedState::new(4)),
1684 Arc::new(NetworkPolicy::default()),
1685 Arc::new(secrets),
1686 None,
1687 proxy_connect,
1688 )
1689 .await
1690 .unwrap();
1691
1692 let wire = String::from_utf8(sink.await.unwrap()).unwrap();
1693 assert!(wire.contains("real-secret-value"), "got: {wire:?}");
1694 assert!(!wire.contains("$MSB_KEY"), "got: {wire:?}");
1695 }
1696
1697 #[tokio::test]
1698 async fn plain_http_forwards_placeholder_to_allowed_host_with_split_headers() {
1699 let (addr, sink) = spawn_sink().await;
1704
1705 let shared = SharedState::new(4);
1706 shared.cache_resolved_hostname(
1707 "example.com",
1708 ResolvedHostnameFamily::Ipv4,
1709 ["127.0.0.1".parse::<IpAddr>().unwrap()],
1710 StdDuration::from_secs(60),
1711 );
1712
1713 let secrets = SecretsConfig {
1714 secrets: vec![SecretEntry {
1715 env_var: "API_KEY".into(),
1716 value: "real-secret-value".into(),
1717 placeholder: "$MSB_KEY".into(),
1718 allowed_hosts: vec![HostPattern::Exact("example.com".into())],
1719 injection: SecretInjection {
1720 headers: true,
1721 basic_auth: false,
1722 query_params: false,
1723 body: false,
1724 },
1725 on_violation: None,
1726 require_tls_identity: true,
1727 }],
1728 ..Default::default()
1729 };
1730
1731 let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
1732 let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
1733 let proxy_connect = Arc::new(ProxyConnectState::new());
1734
1735 from_tx
1736 .send(Bytes::from_static(b"GET /api HTTP/1.1\r\n"))
1737 .await
1738 .unwrap();
1739 from_tx
1740 .send(Bytes::from_static(
1741 b"Host: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n",
1742 ))
1743 .await
1744 .unwrap();
1745 drop(from_tx);
1746
1747 tcp_proxy_task(
1748 addr,
1749 addr,
1750 from_rx,
1751 to_tx,
1752 Arc::new(shared),
1753 Arc::new(NetworkPolicy::default()),
1754 Arc::new(secrets),
1755 None,
1756 proxy_connect,
1757 )
1758 .await
1759 .unwrap();
1760
1761 let wire = String::from_utf8(sink.await.unwrap()).unwrap();
1762 assert!(
1763 wire.contains("Host: example.com"),
1764 "request must reach the allowed host, got: {wire:?}"
1765 );
1766 assert!(
1767 wire.contains("$MSB_KEY"),
1768 "placeholder must be forwarded unchanged for a require_tls_identity secret, got: {wire:?}"
1769 );
1770 assert!(
1771 !wire.contains("real-secret-value"),
1772 "secret must never be substituted over plain HTTP, got: {wire:?}"
1773 );
1774 }
1775
1776 #[tokio::test]
1777 async fn plain_http_substitutes_placeholder_in_first_flight() {
1778 let (addr, sink) = spawn_sink().await;
1779
1780 let request =
1781 b"GET /api HTTP/1.1\r\nHost: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n"
1782 .to_vec();
1783 let secrets = make_plain_http_secret("$MSB_KEY", "real-secret-value", false);
1784
1785 let wire =
1786 String::from_utf8(relay_through_proxy(request, secrets, sink, addr).await).unwrap();
1787 assert!(
1788 wire.contains("real-secret-value"),
1789 "real value must reach server, got: {wire:?}"
1790 );
1791 assert!(
1792 !wire.contains("$MSB_KEY"),
1793 "placeholder must not reach server, got: {wire:?}"
1794 );
1795 }
1796
1797 #[tokio::test]
1798 async fn plain_http_no_substitution_when_require_tls_identity_true() {
1799 let (addr, sink) = spawn_sink().await;
1800
1801 let request =
1802 b"GET /api HTTP/1.1\r\nHost: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n"
1803 .to_vec();
1804 let secrets = make_plain_http_secret("$MSB_KEY", "real-secret-value", true);
1805
1806 let wire =
1807 String::from_utf8_lossy(&relay_through_proxy(request, secrets, sink, addr).await)
1808 .into_owned();
1809 assert!(
1810 wire.contains("$MSB_KEY"),
1811 "placeholder must be forwarded unchanged when require_tls_identity=true, got: {wire:?}"
1812 );
1813 assert!(
1814 !wire.contains("real-secret-value"),
1815 "real value must not leak when require_tls_identity=true, got: {wire:?}"
1816 );
1817 }
1818
1819 #[tokio::test]
1820 async fn plain_http_large_body_forwarded_verbatim_in_relay_loop() {
1821 let (addr, sink) = spawn_sink().await;
1825 let secrets = make_plain_http_secret("$MSB_KEY", "real-value", false);
1826
1827 let body = "x".repeat(32_000);
1828 let header = format!(
1829 "POST /upload HTTP/1.1\r\nHost: example.com\r\nAuthorization: Bearer $MSB_KEY\r\nContent-Length: {}\r\n\r\n",
1830 body.len()
1831 );
1832
1833 let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
1834 let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
1835 let proxy_connect = Arc::new(ProxyConnectState::new());
1836
1837 from_tx
1838 .send(Bytes::from(header.into_bytes()))
1839 .await
1840 .unwrap();
1841 from_tx
1842 .send(Bytes::from(body.clone().into_bytes()))
1843 .await
1844 .unwrap();
1845 drop(from_tx);
1846
1847 tcp_proxy_task(
1848 addr,
1849 addr,
1850 from_rx,
1851 to_tx,
1852 Arc::new(SharedState::new(4)),
1853 Arc::new(NetworkPolicy::default()),
1854 Arc::new(secrets),
1855 None,
1856 proxy_connect,
1857 )
1858 .await
1859 .unwrap();
1860
1861 let wire = String::from_utf8_lossy(&sink.await.unwrap()).into_owned();
1862 assert!(wire.contains(&body), "got {} bytes", wire.len());
1863 assert!(!wire.contains("$MSB_KEY"), "got: {wire:?}");
1864 }
1865}