1use std::borrow::Cow;
9use std::io;
10use std::net::{IpAddr, SocketAddr};
11use std::sync::Arc;
12use std::time::Duration;
13
14use bytes::Bytes;
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16use tokio::net::TcpStream;
17use tokio::sync::mpsc;
18
19use crate::conn::ProxyConnectState;
20use crate::policy::{EgressEvaluation, HostnameSource, NetworkPolicy, Protocol};
21use crate::secrets::config::{SecretsConfig, ViolationAction};
22use crate::secrets::handler::{
23 SecretsHandler, first_line_is_not_http_request, looks_like_http_request_prefix,
24};
25use crate::shared::SharedState;
26use crate::tls::proxy::{TlsProxyContext, tls_proxy_task};
27use crate::tls::sni;
28use crate::tls::state::TlsState;
29
30const SERVER_READ_BUF_SIZE: usize = 16384;
36
37const CONNECT_RESP_LIMIT: usize = 8192;
39
40const PEEK_BUF_SIZE: usize = 16384;
42
43const PEEK_BUDGET: Duration = Duration::from_secs(5);
46
47#[derive(Debug)]
52struct ConnectRequest {
53 bytes: Vec<u8>,
54 header_end: usize,
55 target: ConnectTarget,
56}
57
58#[derive(Debug, Clone, PartialEq, Eq)]
59struct ConnectTarget {
60 host: String,
61 port: u16,
62 expected_sni: Option<String>,
63}
64
65impl ConnectRequest {
70 fn header_bytes(&self) -> &[u8] {
71 &self.bytes[..self.header_end]
72 }
73
74 fn post_header_bytes(&self) -> &[u8] {
75 &self.bytes[self.header_end..]
76 }
77}
78
79impl ConnectTarget {
80 fn is_intercepted(&self, tls_state: &TlsState) -> bool {
81 tls_state.config.intercepted_ports.contains(&self.port)
82 }
83
84 fn guest_dst(&self, fallback: SocketAddr, shared: &SharedState) -> SocketAddr {
85 if let Ok(ip) = self.host.parse::<IpAddr>() {
86 return SocketAddr::new(ip, self.port);
87 }
88
89 if self.host.eq_ignore_ascii_case(crate::HOST_ALIAS) {
90 match fallback.ip() {
91 IpAddr::V4(_) => {
92 if let Some(ip) = shared.gateway_ipv4() {
93 return SocketAddr::new(IpAddr::V4(ip), self.port);
94 }
95 }
96 IpAddr::V6(_) => {
97 if let Some(ip) = shared.gateway_ipv6() {
98 return SocketAddr::new(IpAddr::V6(ip), self.port);
99 }
100 }
101 }
102 if let Some(ip) = shared.gateway_ipv4() {
103 return SocketAddr::new(IpAddr::V4(ip), self.port);
104 }
105 if let Some(ip) = shared.gateway_ipv6() {
106 return SocketAddr::new(IpAddr::V6(ip), self.port);
107 }
108 }
109
110 SocketAddr::new(fallback.ip(), self.port)
111 }
112}
113
114pub(crate) async fn connect_upstream(
120 dst: SocketAddr,
121 proxy_connect: &ProxyConnectState,
122 shared: &SharedState,
123) -> io::Result<TcpStream> {
124 match TcpStream::connect(dst).await {
125 Ok(s) => {
126 proxy_connect.mark_connected();
127 Ok(s)
128 }
129 Err(e) => {
130 proxy_connect.mark_upstream_connect_failed();
131 shared.proxy_wake.wake();
132 Err(e)
133 }
134 }
135}
136
137#[allow(clippy::too_many_arguments)]
148pub fn spawn_tcp_proxy(
149 handle: &tokio::runtime::Handle,
150 guest_dst: SocketAddr,
151 connect_dst: SocketAddr,
152 from_smoltcp: mpsc::Receiver<Bytes>,
153 to_smoltcp: mpsc::Sender<Bytes>,
154 shared: Arc<SharedState>,
155 network_policy: Arc<NetworkPolicy>,
156 secrets: Arc<SecretsConfig>,
157 tls_state: Option<Arc<TlsState>>,
158 proxy_connect: Arc<ProxyConnectState>,
159) {
160 handle.spawn(async move {
161 if let Err(e) = tcp_proxy_task(
162 guest_dst,
163 connect_dst,
164 from_smoltcp,
165 to_smoltcp,
166 shared,
167 network_policy,
168 secrets,
169 tls_state,
170 proxy_connect,
171 )
172 .await
173 {
174 tracing::debug!(dst = %connect_dst, error = %e, "TCP proxy task ended");
175 }
176 });
177}
178
179#[allow(clippy::too_many_arguments)]
182async fn tcp_proxy_task(
183 guest_dst: SocketAddr,
184 connect_dst: SocketAddr,
185 mut from_smoltcp: mpsc::Receiver<Bytes>,
186 to_smoltcp: mpsc::Sender<Bytes>,
187 shared: Arc<SharedState>,
188 network_policy: Arc<NetworkPolicy>,
189 secrets: Arc<SecretsConfig>,
190 tls_state: Option<Arc<TlsState>>,
191 proxy_connect: Arc<ProxyConnectState>,
192) -> io::Result<()> {
193 let (mut initial_buf, sni) = if network_policy.has_domain_rules() {
199 peek_for_sni(&mut from_smoltcp, PEEK_BUF_SIZE, PEEK_BUDGET).await
200 } else {
201 (Vec::new(), None)
202 };
203
204 if network_policy.has_domain_rules() {
210 let source = match sni.as_deref() {
211 Some(name) => HostnameSource::Sni(name),
212 None => HostnameSource::CacheOnly,
213 };
214 match network_policy.evaluate_egress_with_source(guest_dst, Protocol::Tcp, &shared, source)
215 {
216 EgressEvaluation::Allow => {}
217 EgressEvaluation::Deny => {
218 tracing::debug!(
219 dst = %guest_dst,
220 source = source.label(),
221 "TCP egress denied by domain policy",
222 );
223 proxy_connect.mark_policy_denied();
224 shared.proxy_wake.wake();
225 return Ok(());
226 }
227 EgressEvaluation::DeferUntilHostname => {
228 debug_assert!(false, "DeferUntilHostname leaked into TCP proxy task");
229 proxy_connect.mark_policy_denied();
230 shared.proxy_wake.wake();
231 return Ok(());
232 }
233 }
234 }
235
236 if let Some(tls_state) = tls_state.clone() {
238 if initial_buf.is_empty() {
239 let (peeked, _) = peek_for_sni(&mut from_smoltcp, PEEK_BUF_SIZE, PEEK_BUDGET).await;
240 initial_buf = peeked;
241 }
242 if could_be_connect_request(&initial_buf) {
243 return handle_connect_tunnel(
244 guest_dst,
245 connect_dst,
246 initial_buf,
247 from_smoltcp,
248 to_smoltcp,
249 shared,
250 network_policy,
251 tls_state,
252 proxy_connect,
253 None,
254 )
255 .await;
256 }
257 }
258
259 let stream = connect_upstream(connect_dst, &proxy_connect, &shared).await?;
264 let (mut server_rx, mut server_tx) = stream.into_split();
265
266 let want_headers = secrets.has_plain_http_candidates() || secrets.has_host_scoped_secrets();
272 let (initial_buf, is_tls) = if !secrets.secrets.is_empty() {
273 classify_first_flight(
274 initial_buf,
275 &mut from_smoltcp,
276 &mut server_rx,
277 &to_smoltcp,
278 &shared,
279 want_headers,
280 PEEK_BUF_SIZE,
281 PEEK_BUDGET,
282 )
283 .await?
284 } else {
285 (initial_buf, false)
286 };
287
288 if let Some(tls_state) = tls_state.clone()
289 && could_be_connect_request(&initial_buf)
290 {
291 let proxy_stream = server_rx
296 .reunite(server_tx)
297 .map_err(|_| io::Error::other("failed to reunite proxy stream halves"))?;
298 return handle_connect_tunnel(
299 guest_dst,
300 connect_dst,
301 initial_buf,
302 from_smoltcp,
303 to_smoltcp,
304 shared,
305 network_policy,
306 tls_state,
307 proxy_connect,
308 Some(proxy_stream),
309 )
310 .await;
311 }
312
313 let mut late_connect_state = tls_state;
314 let mut secrets_handler: Option<SecretsHandler> = if !secrets.secrets.is_empty() && !is_tls {
315 Some(match extract_http_host(&initial_buf) {
316 Some(host) => SecretsHandler::new_plain_http(&secrets, &host, guest_dst.ip(), &shared),
317 None => SecretsHandler::new_plain_http_invalid_host(&secrets),
318 })
319 } else {
320 None
321 };
322
323 if !initial_buf.is_empty() {
325 let out: Cow<[u8]> = match secrets_handler.as_mut() {
326 Some(h) => match h.substitute(&initial_buf) {
327 Ok(cow) => cow,
330 Err(action) => {
331 tracing::warn!(dst = %connect_dst, violation = ?action, "secret violation in first flight");
332 if matches!(action, ViolationAction::BlockAndTerminate) {
333 shared.trigger_termination();
334 }
335 return Ok(());
336 }
337 },
338 None => Cow::Borrowed(&initial_buf),
339 };
340 if !out.is_empty() {
341 if let Err(e) = server_tx.write_all(&out).await {
342 tracing::debug!(dst = %connect_dst, error = %e, "replay of buffered first flight failed");
343 return Ok(());
344 }
345 if let Err(e) = server_tx.flush().await {
346 tracing::debug!(dst = %connect_dst, error = %e, "flush after first flight failed");
347 return Ok(());
348 }
349 }
350 }
351
352 let mut server_buf = vec![0u8; SERVER_READ_BUF_SIZE];
353
354 loop {
359 tokio::select! {
360 data = from_smoltcp.recv() => {
362 match data {
363 Some(bytes) => {
364 if let Some(tls_state) = late_connect_state.take()
365 && could_be_connect_request(&bytes)
366 {
367 let proxy_stream = server_rx
372 .reunite(server_tx)
373 .map_err(|_| io::Error::other("failed to reunite proxy stream halves"))?;
374 return handle_connect_tunnel(
375 guest_dst,
376 connect_dst,
377 bytes.to_vec(),
378 from_smoltcp,
379 to_smoltcp,
380 shared,
381 network_policy,
382 tls_state,
383 proxy_connect,
384 Some(proxy_stream),
385 )
386 .await;
387 }
388 let out: Cow<[u8]> = match secrets_handler.as_mut() {
391 Some(h) => match h.substitute(&bytes) {
392 Ok(cow) => cow,
393 Err(action) => {
394 tracing::warn!(dst = %connect_dst, violation = ?action, "secret violation");
395 if matches!(action, ViolationAction::BlockAndTerminate) {
396 shared.trigger_termination();
397 }
398 break;
399 }
400 },
401 None => Cow::Borrowed(&bytes),
402 };
403 if !out.is_empty() {
404 if let Err(e) = server_tx.write_all(&out).await {
405 tracing::debug!(dst = %connect_dst, error = %e, "write to server failed");
406 break;
407 }
408 if let Err(e) = server_tx.flush().await {
409 tracing::debug!(dst = %connect_dst, error = %e, "flush to server failed");
410 break;
411 }
412 }
413 }
414 None => break,
416 }
417 }
418
419 result = server_rx.read(&mut server_buf) => {
421 match result {
422 Ok(0) => break, Ok(n) => {
424 late_connect_state = None;
427 let data = Bytes::copy_from_slice(&server_buf[..n]);
428 if to_smoltcp.send(data).await.is_err() {
429 break;
431 }
432 shared.proxy_wake.wake();
435 }
436 Err(e) => {
437 tracing::debug!(dst = %connect_dst, error = %e, "read from server failed");
438 break;
439 }
440 }
441 }
442 }
443 }
444
445 Ok(())
446}
447
448#[allow(clippy::too_many_arguments)]
454async fn handle_connect_tunnel(
455 guest_dst: SocketAddr,
456 proxy_dst: SocketAddr,
457 initial_buf: Vec<u8>,
458 mut from_smoltcp: mpsc::Receiver<Bytes>,
459 to_smoltcp: mpsc::Sender<Bytes>,
460 shared: Arc<SharedState>,
461 network_policy: Arc<NetworkPolicy>,
462 tls_state: Arc<TlsState>,
463 proxy_connect: Arc<ProxyConnectState>,
464 preconnected_proxy: Option<TcpStream>,
465) -> io::Result<()> {
466 let connect_req =
467 parse_connect_request(buffer_connect_request(initial_buf, &mut from_smoltcp).await?)?;
468
469 let connect_headers = match sanitize_connect_headers(
470 connect_req.header_bytes(),
471 &tls_state.secrets,
472 ) {
473 Ok(headers) => headers,
474 Err(action) => {
475 tracing::warn!(dst = %proxy_dst, violation = ?action, "secret violation in CONNECT headers");
476 if matches!(action, ViolationAction::BlockAndTerminate) {
477 shared.trigger_termination();
478 }
479 return Ok(());
480 }
481 };
482
483 let mut proxy_stream = match preconnected_proxy {
485 Some(stream) => stream,
486 None => match TcpStream::connect(proxy_dst).await {
487 Ok(s) => s,
488 Err(e) => {
489 proxy_connect.mark_upstream_connect_failed();
490 shared.proxy_wake.wake();
491 return Err(e);
492 }
493 },
494 };
495
496 if !connect_req.target.is_intercepted(&tls_state) {
497 proxy_stream.write_all(&connect_headers).await?;
498 proxy_stream.flush().await?;
499 let (proxy_resp, header_end) = read_connect_response_headers(&mut proxy_stream).await?;
500 if to_smoltcp
501 .send(Bytes::copy_from_slice(&proxy_resp[..header_end]))
502 .await
503 .is_err()
504 {
505 return Ok(());
506 }
507 if !proxy_resp[header_end..].is_empty()
508 && to_smoltcp
509 .send(Bytes::copy_from_slice(&proxy_resp[header_end..]))
510 .await
511 .is_err()
512 {
513 return Ok(());
514 }
515 shared.proxy_wake.wake();
516 if !connect_response_is_success(&proxy_resp[..header_end]) {
517 proxy_connect.mark_connected();
518 return Ok(());
519 }
520 if !connect_req.post_header_bytes().is_empty() {
521 proxy_stream
522 .write_all(connect_req.post_header_bytes())
523 .await?;
524 }
525 proxy_stream.flush().await?;
526 proxy_connect.mark_connected();
527 return relay_connected_stream(proxy_stream, from_smoltcp, to_smoltcp, shared).await;
528 }
529
530 proxy_stream.write_all(&connect_headers).await?;
531 proxy_stream.flush().await?;
532
533 let (proxy_resp, header_end) = read_connect_response_headers(&mut proxy_stream).await?;
534 if !connect_response_is_success(&proxy_resp[..header_end]) {
535 return Err(io::Error::new(
536 io::ErrorKind::ConnectionRefused,
537 "proxy rejected CONNECT",
538 ));
539 }
540 if !proxy_resp[header_end..].is_empty() {
541 return Err(io::Error::new(
542 io::ErrorKind::InvalidData,
543 "proxy sent unexpected bytes after CONNECT response headers",
544 ));
545 }
546 proxy_connect.mark_connected();
547
548 if to_smoltcp
549 .send(Bytes::copy_from_slice(&proxy_resp[..header_end]))
550 .await
551 .is_err()
552 {
553 return Ok(());
554 }
555 shared.proxy_wake.wake();
556
557 let tls_seed = connect_req.post_header_bytes().to_vec();
558 let tls_guest_dst = connect_req.target.guest_dst(guest_dst, &shared);
559 let expected_sni = connect_req.target.expected_sni.clone();
560
561 tls_proxy_task(
562 TlsProxyContext {
563 guest_dst: tls_guest_dst,
564 connect_dst: proxy_dst,
565 shared,
566 tls_state,
567 network_policy,
568 proxy_connect,
569 upstream_stream: Some(proxy_stream),
570 expected_sni,
571 },
572 from_smoltcp,
573 to_smoltcp,
574 tls_seed,
575 )
576 .await
577}
578
579async fn relay_connected_stream(
581 stream: TcpStream,
582 mut from_smoltcp: mpsc::Receiver<Bytes>,
583 to_smoltcp: mpsc::Sender<Bytes>,
584 shared: Arc<SharedState>,
585) -> io::Result<()> {
586 let (mut server_rx, mut server_tx) = stream.into_split();
587 let mut server_buf = vec![0u8; SERVER_READ_BUF_SIZE];
588
589 loop {
590 tokio::select! {
591 data = from_smoltcp.recv() => {
592 match data {
593 Some(bytes) => {
594 server_tx.write_all(&bytes).await?;
595 server_tx.flush().await?;
596 }
597 None => break,
598 }
599 }
600 result = server_rx.read(&mut server_buf) => {
601 match result {
602 Ok(0) => break,
603 Ok(n) => {
604 if to_smoltcp
605 .send(Bytes::copy_from_slice(&server_buf[..n]))
606 .await
607 .is_err()
608 {
609 break;
610 }
611 shared.proxy_wake.wake();
612 }
613 Err(e) => return Err(e),
614 }
615 }
616 }
617 }
618
619 Ok(())
620}
621
622async fn buffer_connect_request(
623 mut buf: Vec<u8>,
624 from_smoltcp: &mut mpsc::Receiver<Bytes>,
625) -> io::Result<Vec<u8>> {
626 let timeout_fut = tokio::time::sleep(PEEK_BUDGET);
627 tokio::pin!(timeout_fut);
628
629 loop {
630 if !could_be_connect_request(&buf) {
631 return Err(io::Error::new(
632 io::ErrorKind::InvalidData,
633 "malformed CONNECT request prefix",
634 ));
635 }
636 if headers_end(&buf).is_some() {
637 return Ok(buf);
638 }
639 if buf.len() >= PEEK_BUF_SIZE {
640 return Err(io::Error::new(
641 io::ErrorKind::InvalidData,
642 "CONNECT request headers too large",
643 ));
644 }
645
646 tokio::select! {
647 biased;
648 _ = &mut timeout_fut => {
649 return Err(io::Error::new(
650 io::ErrorKind::TimedOut,
651 "timed out waiting for complete CONNECT request headers",
652 ));
653 }
654 data = from_smoltcp.recv() => match data {
655 Some(bytes) => {
656 buf.extend_from_slice(&bytes);
657 }
658 None => {
659 return Err(io::Error::new(
660 io::ErrorKind::UnexpectedEof,
661 "channel closed before complete CONNECT request headers",
662 ));
663 }
664 }
665 }
666 }
667}
668
669async fn read_connect_response_headers(stream: &mut TcpStream) -> io::Result<(Vec<u8>, usize)> {
670 tokio::time::timeout(PEEK_BUDGET, async {
671 let mut proxy_resp = Vec::with_capacity(256);
672 let mut buf = [0u8; 4096];
673 loop {
674 let n = stream.read(&mut buf).await?;
675 if n == 0 {
676 return Err(io::Error::new(
677 io::ErrorKind::UnexpectedEof,
678 "proxy closed before sending CONNECT response",
679 ));
680 }
681 proxy_resp.extend_from_slice(&buf[..n]);
682 if let Some(end) = headers_end(&proxy_resp) {
683 return Ok((proxy_resp, end));
684 }
685 if proxy_resp.len() > CONNECT_RESP_LIMIT {
686 return Err(io::Error::new(
687 io::ErrorKind::InvalidData,
688 "proxy CONNECT response too large",
689 ));
690 }
691 }
692 })
693 .await
694 .map_err(|_| {
695 io::Error::new(
696 io::ErrorKind::TimedOut,
697 "timed out waiting for proxy CONNECT response",
698 )
699 })?
700}
701
702fn sanitize_connect_headers<'a>(
703 header_bytes: &'a [u8],
704 secrets: &SecretsConfig,
705) -> Result<Cow<'a, [u8]>, ViolationAction> {
706 if secrets.secrets.is_empty() {
707 return Ok(Cow::Borrowed(header_bytes));
708 }
709
710 let mut handler = SecretsHandler::new_plain_http_untrusted_metadata(secrets);
711 handler.substitute(header_bytes)
712}
713
714fn headers_end(buf: &[u8]) -> Option<usize> {
716 buf.windows(4).position(|w| w == b"\r\n\r\n").map(|p| p + 4)
717}
718
719fn could_be_connect_request(buf: &[u8]) -> bool {
720 const PREFIX: &[u8] = b"CONNECT ";
721 if buf.is_empty() {
722 return false;
723 }
724 let n = buf.len().min(PREFIX.len());
725 buf[..n].eq_ignore_ascii_case(&PREFIX[..n])
726}
727
728fn parse_connect_request(bytes: Vec<u8>) -> io::Result<ConnectRequest> {
729 let header_end = headers_end(&bytes).ok_or_else(|| {
730 io::Error::new(
731 io::ErrorKind::InvalidData,
732 "incomplete CONNECT request headers",
733 )
734 })?;
735 let target = {
736 let request_line = bytes[..header_end]
737 .split(|&b| b == b'\n')
738 .next()
739 .unwrap_or(&[]);
740 let request_line = std::str::from_utf8(request_line)
741 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "CONNECT line is not UTF-8"))?
742 .trim_end_matches('\r');
743 let mut parts = request_line.split_ascii_whitespace();
744 let method = parts.next().unwrap_or_default();
745 let authority = parts.next().unwrap_or_default();
746 let version = parts.next().unwrap_or_default();
747 if !method.eq_ignore_ascii_case("CONNECT")
748 || authority.is_empty()
749 || !is_http_version(version)
750 || parts.next().is_some()
751 {
752 return Err(io::Error::new(
753 io::ErrorKind::InvalidData,
754 "malformed CONNECT request line",
755 ));
756 }
757 parse_connect_target(authority)?
758 };
759
760 Ok(ConnectRequest {
761 bytes,
762 header_end,
763 target,
764 })
765}
766
767fn parse_connect_target(authority: &str) -> io::Result<ConnectTarget> {
768 let authority = authority.trim();
769 let (host, port) = if let Some(rest) = authority.strip_prefix('[') {
770 let (host, rest) = rest.split_once(']').ok_or_else(|| {
771 io::Error::new(
772 io::ErrorKind::InvalidData,
773 "malformed CONNECT IPv6 authority",
774 )
775 })?;
776 let port = rest.strip_prefix(':').ok_or_else(|| {
777 io::Error::new(io::ErrorKind::InvalidData, "CONNECT authority missing port")
778 })?;
779 (host, port)
780 } else {
781 let (host, port) = authority.rsplit_once(':').ok_or_else(|| {
782 io::Error::new(io::ErrorKind::InvalidData, "CONNECT authority missing port")
783 })?;
784 if host.contains(':') {
785 return Err(io::Error::new(
786 io::ErrorKind::InvalidData,
787 "CONNECT IPv6 authority must be bracketed",
788 ));
789 }
790 (host, port)
791 };
792 let host = host.trim().trim_end_matches('.');
793 if host.is_empty() {
794 return Err(io::Error::new(
795 io::ErrorKind::InvalidData,
796 "CONNECT authority missing host",
797 ));
798 }
799 let port = port
800 .parse::<u16>()
801 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid CONNECT port"))?;
802 let expected_sni = host
803 .parse::<IpAddr>()
804 .is_err()
805 .then(|| host.to_ascii_lowercase());
806
807 Ok(ConnectTarget {
808 host: host.to_ascii_lowercase(),
809 port,
810 expected_sni,
811 })
812}
813
814fn is_http_version(version: &str) -> bool {
815 let Some(version) = version.strip_prefix("HTTP/") else {
816 return false;
817 };
818 let Some((major, minor)) = version.split_once('.') else {
819 return false;
820 };
821 !major.is_empty()
822 && !minor.is_empty()
823 && major.bytes().all(|b| b.is_ascii_digit())
824 && minor.bytes().all(|b| b.is_ascii_digit())
825}
826
827fn connect_response_is_success(headers: &[u8]) -> bool {
828 let Some(status_line) = headers.split(|&b| b == b'\n').next() else {
829 return false;
830 };
831 let Ok(status_line) = std::str::from_utf8(status_line) else {
832 return false;
833 };
834 let mut parts = status_line.trim_end_matches('\r').split_ascii_whitespace();
835 let version = parts.next().unwrap_or_default();
836 let status = parts.next().unwrap_or_default();
837 is_http_version(version)
838 && status.len() == 3
839 && status
840 .parse::<u16>()
841 .is_ok_and(|code| (200..300).contains(&code))
842}
843
844fn extract_http_host(buf: &[u8]) -> Option<String> {
854 if buf.first() == Some(&0x16) {
855 return None;
856 }
857 let mut headers = vec![httparse::EMPTY_HEADER; (buf.len() / 4).max(16)];
863 let mut req = httparse::Request::new(&mut headers);
864 req.parse(buf).ok()?;
865 req.headers
866 .iter()
867 .find(|h| h.name.eq_ignore_ascii_case("host"))
868 .and_then(|h| std::str::from_utf8(h.value).ok())
869 .map(|v| {
870 let host = v.trim();
871 host.rsplit_once(':')
873 .map(|(h, _)| h)
874 .unwrap_or(host)
875 .to_ascii_lowercase()
876 })
877 .filter(|h| !h.is_empty())
878}
879
880#[allow(clippy::too_many_arguments)]
897async fn classify_first_flight(
898 mut buf: Vec<u8>,
899 from_smoltcp: &mut mpsc::Receiver<Bytes>,
900 server_rx: &mut tokio::net::tcp::OwnedReadHalf,
901 to_smoltcp: &mpsc::Sender<Bytes>,
902 shared: &SharedState,
903 want_headers: bool,
904 max: usize,
905 budget: Duration,
906) -> io::Result<(Vec<u8>, bool)> {
907 let mut server_buf = vec![0u8; SERVER_READ_BUF_SIZE];
908 let timeout_fut = tokio::time::sleep(budget);
909 tokio::pin!(timeout_fut);
910
911 loop {
912 if !buf.is_empty() {
918 let is_tls = buf.first() == Some(&0x16);
919 let not_http = !is_tls
920 && (!looks_like_http_request_prefix(&buf) || first_line_is_not_http_request(&buf));
921 let done = !want_headers
922 || is_tls
923 || not_http
924 || buf.len() >= max
925 || buf.windows(4).any(|w| w == b"\r\n\r\n");
926 if done {
927 return Ok((buf, is_tls));
928 }
929 }
930
931 tokio::select! {
932 biased;
933 _ = &mut timeout_fut => {
934 let is_tls = buf.first() == Some(&0x16);
935 return Ok((buf, is_tls));
936 }
937 guest = from_smoltcp.recv() => match guest {
940 Some(bytes) => buf.extend_from_slice(&bytes),
941 None => {
942 let is_tls = buf.first() == Some(&0x16);
943 return Ok((buf, is_tls));
944 }
945 },
946 server = server_rx.read(&mut server_buf) => match server {
949 Ok(0) => {
950 let is_tls = buf.first() == Some(&0x16);
951 return Ok((buf, is_tls));
952 }
953 Ok(n) => {
954 let data = Bytes::copy_from_slice(&server_buf[..n]);
955 if to_smoltcp.send(data).await.is_err() {
956 let is_tls = buf.first() == Some(&0x16);
957 return Ok((buf, is_tls));
958 }
959 shared.proxy_wake.wake();
960 }
961 Err(e) => return Err(e),
962 },
963 }
964 }
965}
966
967async fn peek_for_sni(
977 rx: &mut mpsc::Receiver<Bytes>,
978 max: usize,
979 budget: Duration,
980) -> (Vec<u8>, Option<String>) {
981 let mut buf = Vec::with_capacity(PEEK_BUF_SIZE.min(8192));
982 let timeout_fut = tokio::time::sleep(budget);
983 tokio::pin!(timeout_fut);
984
985 let raw_sni = loop {
986 tokio::select! {
987 biased;
988 _ = &mut timeout_fut => break None,
989 data = rx.recv() => {
990 match data {
991 Some(bytes) => {
992 buf.extend_from_slice(&bytes);
993 if buf.first() != Some(&0x16) {
998 break None;
999 }
1000 if let Some(name) = sni::extract_sni(&buf) {
1001 break Some(name);
1002 }
1003 if buf.len() >= max {
1004 break None;
1005 }
1006 }
1007 None => break None,
1008 }
1009 }
1010 }
1011 };
1012
1013 let canonical = raw_sni.map(|s| s.trim_end_matches('.').to_ascii_lowercase());
1014 (buf, canonical)
1015}
1016
1017#[cfg(test)]
1022mod tests {
1023 use super::*;
1024
1025 fn synthetic_client_hello(sni: &str) -> Vec<u8> {
1029 let host_bytes = sni.as_bytes();
1032 let host_len = host_bytes.len() as u16;
1033 let server_name_list_len = 3 + host_len; let extension_data_len = 2 + server_name_list_len; let extensions_total = 4 + extension_data_len; let mut body = Vec::new();
1038 body.extend_from_slice(&[0x03, 0x03]);
1040 body.extend_from_slice(&[0u8; 32]);
1042 body.push(0);
1044 body.extend_from_slice(&[0x00, 0x02, 0x00, 0x2f]);
1046 body.extend_from_slice(&[0x01, 0x00]);
1048 body.extend_from_slice(&extensions_total.to_be_bytes());
1050 body.extend_from_slice(&[0x00, 0x00]);
1052 body.extend_from_slice(&extension_data_len.to_be_bytes());
1053 body.extend_from_slice(&server_name_list_len.to_be_bytes());
1054 body.push(0x00); body.extend_from_slice(&host_len.to_be_bytes());
1056 body.extend_from_slice(host_bytes);
1057
1058 let handshake_len = body.len() as u32;
1059 let mut hs = Vec::new();
1060 hs.push(0x01); hs.extend_from_slice(&handshake_len.to_be_bytes()[1..]); hs.extend_from_slice(&body);
1063
1064 let record_len = hs.len() as u16;
1065 let mut record = Vec::new();
1066 record.extend_from_slice(&[0x16, 0x03, 0x01]); record.extend_from_slice(&record_len.to_be_bytes());
1068 record.extend_from_slice(&hs);
1069
1070 record
1071 }
1072
1073 #[test]
1074 fn could_be_connect_request_matches_split_prefixes_only() {
1075 assert!(could_be_connect_request(b"C"));
1076 assert!(could_be_connect_request(b"connect "));
1077 assert!(could_be_connect_request(b"CONNECT example.com:443"));
1078 assert!(!could_be_connect_request(b"CLIENT"));
1079 assert!(!could_be_connect_request(b"GET / HTTP/1.1\r\n"));
1080 }
1081
1082 #[tokio::test]
1083 async fn buffer_connect_request_reads_split_headers() {
1084 let (tx, mut rx) = mpsc::channel(4);
1085 tx.send(Bytes::from_static(b"NECT example.com:443 HTTP/1.1\r\n"))
1086 .await
1087 .unwrap();
1088 tx.send(Bytes::from_static(b"Host: example.com\r\n\r\n"))
1089 .await
1090 .unwrap();
1091 drop(tx);
1092
1093 let buffered = buffer_connect_request(b"CON".to_vec(), &mut rx)
1094 .await
1095 .unwrap();
1096 let parsed = parse_connect_request(buffered).unwrap();
1097
1098 assert_eq!(parsed.target.host, "example.com");
1099 assert_eq!(parsed.target.port, 443);
1100 assert_eq!(parsed.target.expected_sni.as_deref(), Some("example.com"));
1101 assert!(parsed.post_header_bytes().is_empty());
1102 }
1103
1104 #[test]
1105 fn parse_connect_request_preserves_post_header_tls_seed() {
1106 let mut request = b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com\r\n\r\n".to_vec();
1107 request.extend_from_slice(b"\x16\x03\x01client-hello");
1108
1109 let parsed = parse_connect_request(request).unwrap();
1110
1111 assert_eq!(
1112 parsed.header_bytes(),
1113 b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com\r\n\r\n"
1114 );
1115 assert_eq!(parsed.post_header_bytes(), b"\x16\x03\x01client-hello");
1116 }
1117
1118 #[test]
1119 fn parse_connect_target_requires_authority_port() {
1120 assert!(parse_connect_target("example.com").is_err());
1121 assert!(parse_connect_target("2001:db8::1:443").is_err());
1122
1123 let target = parse_connect_target("[2001:db8::1]:8443").unwrap();
1124 assert_eq!(target.host, "2001:db8::1");
1125 assert_eq!(target.port, 8443);
1126 assert_eq!(target.expected_sni, None);
1127 }
1128
1129 #[test]
1130 fn connect_response_success_requires_exact_2xx_status_code() {
1131 assert!(connect_response_is_success(
1132 b"HTTP/1.1 200 Connection Established\r\n\r\n"
1133 ));
1134 assert!(connect_response_is_success(
1135 b"HTTP/1.1 204 Connection Established\r\n\r\n"
1136 ));
1137 assert!(!connect_response_is_success(b"HTTP/1.1 2000 Weird\r\n\r\n"));
1138 assert!(!connect_response_is_success(b"HTTP/1.1 199 Nope\r\n\r\n"));
1139 assert!(!connect_response_is_success(b"NOTHTTP 200 OK\r\n\r\n"));
1140 }
1141
1142 #[tokio::test]
1143 async fn peek_for_sni_extracts_and_canonicalizes() {
1144 let (tx, mut rx) = mpsc::channel(4);
1145 let hello = synthetic_client_hello("Example.COM");
1146 tx.send(Bytes::from(hello.clone())).await.unwrap();
1147 drop(tx); let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1150 assert_eq!(sni.as_deref(), Some("example.com"));
1151 assert_eq!(buf, hello);
1152 }
1153
1154 #[tokio::test]
1155 async fn peek_for_sni_returns_none_on_channel_close_without_data() {
1156 let (tx, mut rx) = mpsc::channel::<Bytes>(1);
1157 drop(tx);
1158 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1159 assert!(buf.is_empty());
1160 assert_eq!(sni, None);
1161 }
1162
1163 #[tokio::test]
1164 async fn peek_for_sni_returns_none_on_non_tls_data() {
1165 let (tx, mut rx) = mpsc::channel(4);
1166 tx.send(Bytes::from_static(
1168 b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n",
1169 ))
1170 .await
1171 .unwrap();
1172 drop(tx);
1173 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1174 assert!(
1175 !buf.is_empty(),
1176 "buffered bytes must be returned for replay"
1177 );
1178 assert_eq!(sni, None);
1179 }
1180
1181 #[tokio::test]
1182 async fn peek_for_sni_falls_back_on_timeout() {
1183 let (tx, mut rx) = mpsc::channel::<Bytes>(1);
1184 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, Duration::from_millis(50)).await;
1186 drop(tx);
1187 assert!(buf.is_empty());
1188 assert_eq!(sni, None);
1189 }
1190
1191 #[tokio::test]
1192 async fn peek_for_sni_caps_at_max_bytes() {
1193 let (tx, mut rx) = mpsc::channel(4);
1194 let mut first = vec![0u8; 8192];
1198 first[0] = 0x16;
1199 tx.send(Bytes::from(first)).await.unwrap();
1200 tx.send(Bytes::from(vec![0u8; 8192])).await.unwrap();
1201 tx.send(Bytes::from(vec![0u8; 8192])).await.unwrap();
1202 drop(tx);
1203
1204 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1205 assert_eq!(sni, None, "no SNI in non-TLS data");
1206 assert!(
1207 buf.len() >= PEEK_BUF_SIZE,
1208 "buffer must hit the cap before bail-out: got {}",
1209 buf.len()
1210 );
1211 }
1212
1213 #[tokio::test]
1214 async fn peek_for_sni_bails_immediately_on_non_tls_first_byte() {
1215 let (tx, mut rx) = mpsc::channel(4);
1216 tx.send(Bytes::from_static(b"GET / HTTP/1.1\r\nHost: x\r\n\r\n"))
1218 .await
1219 .unwrap();
1220 drop(tx);
1221
1222 let started = std::time::Instant::now();
1225 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1226 let elapsed = started.elapsed();
1227 assert_eq!(sni, None);
1228 assert!(buf.starts_with(b"GET"));
1229 assert!(
1230 elapsed < Duration::from_millis(500),
1231 "non-TLS bail must be fast: took {elapsed:?}"
1232 );
1233 }
1234
1235 use std::net::IpAddr;
1240 use std::time::Duration as StdDuration;
1241
1242 use crate::policy::{Action, Destination, NetworkPolicy, PortRange, Rule};
1243 use crate::shared::{ResolvedHostnameFamily, SharedState};
1244
1245 const SHARED_FASTLY_IP: &str = "151.101.0.223";
1246
1247 fn shared_with(host: &str, ip: &str) -> SharedState {
1248 let shared = SharedState::new(4);
1249 shared.cache_resolved_hostname(
1250 host,
1251 ResolvedHostnameFamily::Ipv4,
1252 [ip.parse::<IpAddr>().unwrap()],
1253 StdDuration::from_secs(60),
1254 );
1255 shared
1256 }
1257
1258 fn allow_https(domain: &str) -> Rule {
1259 Rule {
1260 direction: crate::policy::Direction::Egress,
1261 destination: Destination::Domain(domain.parse().unwrap()),
1262 protocols: vec![Protocol::Tcp],
1263 ports: vec![PortRange::single(443)],
1264 action: Action::Allow,
1265 }
1266 }
1267
1268 #[tokio::test]
1271 async fn integration_sni_overrides_cache_for_over_allow() {
1272 let shared = shared_with("pypi.org", SHARED_FASTLY_IP);
1273 let policy = NetworkPolicy {
1274 default_egress: Action::Deny,
1275 default_ingress: Action::Allow,
1276 rules: vec![allow_https("pypi.org")],
1277 };
1278 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
1279
1280 let (tx, mut rx) = mpsc::channel(4);
1281 tx.send(Bytes::from(synthetic_client_hello("evil.com")))
1282 .await
1283 .unwrap();
1284 drop(tx);
1285
1286 let (initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1287 assert_eq!(sni.as_deref(), Some("evil.com"));
1288 assert!(!initial_buf.is_empty());
1289
1290 let source = sni
1291 .as_deref()
1292 .map(HostnameSource::Sni)
1293 .unwrap_or(HostnameSource::CacheOnly);
1294 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
1295 assert_eq!(
1296 eval,
1297 EgressEvaluation::Deny,
1298 "SNI=evil.com must not piggy-back on the cached pypi.org match",
1299 );
1300 }
1301
1302 #[tokio::test]
1305 async fn integration_sni_overrides_cache_for_over_block() {
1306 let shared = shared_with("ads.example.com", SHARED_FASTLY_IP);
1307 let policy = NetworkPolicy {
1308 default_egress: Action::Allow,
1309 default_ingress: Action::Allow,
1310 rules: vec![Rule::deny_egress(Destination::Domain(
1311 "ads.example.com".parse().unwrap(),
1312 ))],
1313 };
1314 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
1315
1316 let (tx, mut rx) = mpsc::channel(4);
1317 tx.send(Bytes::from(synthetic_client_hello("api.example.com")))
1318 .await
1319 .unwrap();
1320 drop(tx);
1321
1322 let (_initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1323 assert_eq!(sni.as_deref(), Some("api.example.com"));
1324
1325 let source = sni
1326 .as_deref()
1327 .map(HostnameSource::Sni)
1328 .unwrap_or(HostnameSource::CacheOnly);
1329 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
1330 assert_eq!(
1331 eval,
1332 EgressEvaluation::Allow,
1333 "SNI=api.example.com must not be caught by the deny on ads.example.com",
1334 );
1335 }
1336
1337 #[tokio::test]
1340 async fn integration_non_tls_falls_back_to_cache() {
1341 let shared = shared_with("pypi.org", SHARED_FASTLY_IP);
1342 let policy = NetworkPolicy {
1343 default_egress: Action::Deny,
1344 default_ingress: Action::Allow,
1345 rules: vec![allow_https("pypi.org")],
1346 };
1347 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
1348
1349 let (tx, mut rx) = mpsc::channel(4);
1350 tx.send(Bytes::from_static(
1352 b"GET / HTTP/1.1\r\nHost: pypi.org\r\n\r\n",
1353 ))
1354 .await
1355 .unwrap();
1356 drop(tx);
1357
1358 let (initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1359 assert_eq!(sni, None, "non-TLS data → no SNI");
1360 assert!(
1361 !initial_buf.is_empty(),
1362 "buffered bytes must survive for replay"
1363 );
1364
1365 let source = sni
1366 .as_deref()
1367 .map(HostnameSource::Sni)
1368 .unwrap_or(HostnameSource::CacheOnly);
1369 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
1370 assert_eq!(
1371 eval,
1372 EgressEvaluation::Allow,
1373 "cache-only fallback must still allow the cached hostname's IP",
1374 );
1375 }
1376
1377 #[tokio::test]
1380 async fn integration_sni_matches_domain_suffix_with_cache_binding() {
1381 let shared = shared_with("files.pythonhosted.org", SHARED_FASTLY_IP);
1382 let policy = NetworkPolicy {
1383 default_egress: Action::Deny,
1384 default_ingress: Action::Allow,
1385 rules: vec![Rule {
1386 direction: crate::policy::Direction::Egress,
1387 destination: Destination::DomainSuffix(".pythonhosted.org".parse().unwrap()),
1388 protocols: vec![Protocol::Tcp],
1389 ports: vec![PortRange::single(443)],
1390 action: Action::Allow,
1391 }],
1392 };
1393 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
1394
1395 let (tx, mut rx) = mpsc::channel(4);
1396 tx.send(Bytes::from(synthetic_client_hello(
1397 "files.pythonhosted.org",
1398 )))
1399 .await
1400 .unwrap();
1401 drop(tx);
1402
1403 let (_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1404 let source = sni
1405 .as_deref()
1406 .map(HostnameSource::Sni)
1407 .unwrap_or(HostnameSource::CacheOnly);
1408 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
1409 assert_eq!(eval, EgressEvaluation::Allow);
1410 }
1411
1412 #[tokio::test]
1417 async fn integration_sni_denies_domain_suffix_without_cache_binding() {
1418 let shared = SharedState::new(4); let policy = NetworkPolicy {
1420 default_egress: Action::Deny,
1421 default_ingress: Action::Allow,
1422 rules: vec![Rule {
1423 direction: crate::policy::Direction::Egress,
1424 destination: Destination::DomainSuffix(".pythonhosted.org".parse().unwrap()),
1425 protocols: vec![Protocol::Tcp],
1426 ports: vec![PortRange::single(443)],
1427 action: Action::Allow,
1428 }],
1429 };
1430 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
1431
1432 let (tx, mut rx) = mpsc::channel(4);
1433 tx.send(Bytes::from(synthetic_client_hello(
1434 "files.pythonhosted.org",
1435 )))
1436 .await
1437 .unwrap();
1438 drop(tx);
1439
1440 let (_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1441 let source = sni
1442 .as_deref()
1443 .map(HostnameSource::Sni)
1444 .unwrap_or(HostnameSource::CacheOnly);
1445 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
1446 assert_eq!(eval, EgressEvaluation::Deny);
1447 }
1448
1449 #[test]
1452 fn extract_http_host_basic() {
1453 let buf = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n";
1454 assert_eq!(extract_http_host(buf), Some("example.com".into()));
1455 }
1456
1457 #[test]
1458 fn extract_http_host_strips_port() {
1459 let buf = b"POST /api HTTP/1.1\r\nHost: api.company.com:8080\r\n\r\n";
1460 assert_eq!(extract_http_host(buf), Some("api.company.com".into()));
1461 }
1462
1463 #[test]
1464 fn extract_http_host_case_insensitive_lowercased() {
1465 let buf = b"GET / HTTP/1.1\r\nhost: Example.COM\r\n\r\n";
1466 assert_eq!(extract_http_host(buf), Some("example.com".into()));
1467 }
1468
1469 #[test]
1470 fn extract_http_host_no_host_header() {
1471 let buf = b"GET / HTTP/1.1\r\nX-Other: foo\r\n\r\n";
1472 assert_eq!(extract_http_host(buf), None);
1473 }
1474
1475 #[test]
1476 fn extract_http_host_incomplete_headers() {
1477 let buf = b"GET / HTTP/1.1\r\nHost: x";
1478 assert_eq!(extract_http_host(buf), None);
1479 }
1480
1481 #[test]
1482 fn extract_http_host_tls_first_byte() {
1483 let buf = [0x16u8, 0x03, 0x01, 0x00, 0x01];
1484 assert_eq!(extract_http_host(&buf), None);
1485 }
1486
1487 #[test]
1488 fn extract_http_host_with_many_headers() {
1489 let mut req = Vec::from(&b"GET / HTTP/1.1\r\n"[..]);
1492 for i in 0..100 {
1493 req.extend_from_slice(format!("X-Pad-{i}: v\r\n").as_bytes());
1494 }
1495 req.extend_from_slice(b"Host: example.com\r\n\r\n");
1496 assert_eq!(extract_http_host(&req), Some("example.com".into()));
1497 }
1498
1499 use std::sync::Arc;
1502 use tokio::io::AsyncReadExt;
1503 use tokio::net::TcpListener;
1504 use tokio::task::JoinHandle;
1505
1506 use crate::secrets::config::{HostPattern, SecretEntry, SecretInjection, SecretsConfig};
1507
1508 fn make_plain_http_secret(placeholder: &str, value: &str, require_tls: bool) -> SecretsConfig {
1509 SecretsConfig {
1510 secrets: vec![SecretEntry {
1511 env_var: "API_KEY".into(),
1512 value: value.into(),
1513 placeholder: placeholder.into(),
1514 allowed_hosts: vec![HostPattern::Any],
1515 injection: SecretInjection {
1516 headers: true,
1517 basic_auth: false,
1518 query_params: false,
1519 body: false,
1520 },
1521 on_violation: None,
1522 require_tls_identity: require_tls,
1523 }],
1524 ..Default::default()
1525 }
1526 }
1527
1528 fn make_host_bound_secret(placeholder: &str, value: &str, host: &str) -> SecretsConfig {
1529 SecretsConfig {
1530 secrets: vec![SecretEntry {
1531 env_var: "API_KEY".into(),
1532 value: value.into(),
1533 placeholder: placeholder.into(),
1534 allowed_hosts: vec![HostPattern::Exact(host.into())],
1535 injection: SecretInjection::default(),
1536 on_violation: None,
1537 require_tls_identity: true,
1538 }],
1539 ..Default::default()
1540 }
1541 }
1542
1543 #[test]
1544 fn sanitize_connect_headers_blocks_placeholder_metadata_header_by_default() {
1545 let secrets = make_host_bound_secret("$MSB_KEY", "real-secret-value", "example.com");
1546 let headers = b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\nProxy-Authorization: Bearer $MSB_KEY\r\nUser-Agent: curl\r\n\r\n";
1547
1548 assert_eq!(
1549 sanitize_connect_headers(headers, &secrets),
1550 Err(ViolationAction::BlockAndLog)
1551 );
1552 }
1553
1554 #[test]
1555 fn sanitize_connect_headers_respects_block_and_terminate() {
1556 let mut secrets = make_host_bound_secret("$MSB_KEY", "real-secret-value", "example.com");
1557 secrets.on_violation = ViolationAction::BlockAndTerminate;
1558 let headers = b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\nProxy-Authorization: Bearer $MSB_KEY\r\n\r\n";
1559
1560 assert_eq!(
1561 sanitize_connect_headers(headers, &secrets),
1562 Err(ViolationAction::BlockAndTerminate)
1563 );
1564 }
1565
1566 #[test]
1567 fn sanitize_connect_headers_respects_explicit_passthrough() {
1568 let mut secrets = make_host_bound_secret("$MSB_KEY", "real-secret-value", "example.com");
1569 secrets.on_violation = ViolationAction::Passthrough(vec![HostPattern::Any]);
1570 let headers = b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\nProxy-Authorization: Bearer $MSB_KEY\r\n\r\n";
1571
1572 let sanitized = sanitize_connect_headers(headers, &secrets).unwrap();
1573
1574 assert_eq!(sanitized.as_ref(), headers);
1575 assert!(
1576 !String::from_utf8_lossy(sanitized.as_ref()).contains("real-secret-value"),
1577 "passthrough must never substitute real secrets into CONNECT metadata"
1578 );
1579 }
1580
1581 #[test]
1582 fn sanitize_connect_headers_keeps_safe_metadata_headers() {
1583 let secrets = make_host_bound_secret("$MSB_KEY", "real-secret-value", "example.com");
1584 let headers =
1585 b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\nUser-Agent: curl\r\n\r\n";
1586
1587 let sanitized = sanitize_connect_headers(headers, &secrets).unwrap();
1588
1589 assert_eq!(sanitized.as_ref(), headers);
1590 }
1591
1592 #[test]
1593 fn sanitize_connect_headers_blocks_placeholder_in_request_line() {
1594 let secrets = make_host_bound_secret("$MSB_KEY", "real-secret-value", "example.com");
1595 let headers = b"CONNECT $MSB_KEY:443 HTTP/1.1\r\nHost: example.com:443\r\n\r\n";
1596
1597 assert_eq!(
1598 sanitize_connect_headers(headers, &secrets),
1599 Err(ViolationAction::BlockAndLog)
1600 );
1601 }
1602
1603 async fn spawn_sink() -> (SocketAddr, JoinHandle<Vec<u8>>) {
1604 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1605 let addr = listener.local_addr().unwrap();
1606 let handle = tokio::spawn(async move {
1607 let (mut stream, _) = listener.accept().await.unwrap();
1608 let mut received = Vec::new();
1609 let mut buf = vec![0u8; 4096];
1610 loop {
1611 match stream.read(&mut buf).await {
1612 Ok(0) | Err(_) => break,
1613 Ok(n) => received.extend_from_slice(&buf[..n]),
1614 }
1615 }
1616 received
1617 });
1618 (addr, handle)
1619 }
1620
1621 async fn relay_through_proxy(
1622 request: Vec<u8>,
1623 secrets: SecretsConfig,
1624 handle: JoinHandle<Vec<u8>>,
1625 server_addr: SocketAddr,
1626 ) -> Vec<u8> {
1627 let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
1628 let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
1629 let shared = SharedState::new(4);
1630 let policy = Arc::new(NetworkPolicy::default());
1631 let secrets = Arc::new(secrets);
1632 let proxy_connect = Arc::new(ProxyConnectState::new());
1633
1634 from_tx.send(Bytes::from(request)).await.unwrap();
1635 drop(from_tx);
1636
1637 tcp_proxy_task(
1638 server_addr,
1639 server_addr,
1640 from_rx,
1641 to_tx,
1642 Arc::new(shared),
1643 policy,
1644 secrets,
1645 None,
1646 proxy_connect,
1647 )
1648 .await
1649 .unwrap();
1650
1651 handle.await.unwrap()
1652 }
1653
1654 #[tokio::test]
1655 async fn plain_http_substitutes_placeholder_when_host_arrives_in_second_segment() {
1656 let (addr, sink) = spawn_sink().await;
1659 let secrets = make_plain_http_secret("$MSB_KEY", "real-secret-value", false);
1660
1661 let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
1662 let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
1663 let proxy_connect = Arc::new(ProxyConnectState::new());
1664
1665 from_tx
1666 .send(Bytes::from_static(b"GET /api HTTP/1.1\r\n"))
1667 .await
1668 .unwrap();
1669 from_tx
1670 .send(Bytes::from_static(
1671 b"Host: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n",
1672 ))
1673 .await
1674 .unwrap();
1675 drop(from_tx);
1676
1677 tcp_proxy_task(
1678 addr,
1679 addr,
1680 from_rx,
1681 to_tx,
1682 Arc::new(SharedState::new(4)),
1683 Arc::new(NetworkPolicy::default()),
1684 Arc::new(secrets),
1685 None,
1686 proxy_connect,
1687 )
1688 .await
1689 .unwrap();
1690
1691 let wire = String::from_utf8(sink.await.unwrap()).unwrap();
1692 assert!(wire.contains("real-secret-value"), "got: {wire:?}");
1693 assert!(!wire.contains("$MSB_KEY"), "got: {wire:?}");
1694 }
1695
1696 #[tokio::test]
1697 async fn plain_http_forwards_placeholder_to_allowed_host_with_split_headers() {
1698 let (addr, sink) = spawn_sink().await;
1703
1704 let shared = SharedState::new(4);
1705 shared.cache_resolved_hostname(
1706 "example.com",
1707 ResolvedHostnameFamily::Ipv4,
1708 ["127.0.0.1".parse::<IpAddr>().unwrap()],
1709 StdDuration::from_secs(60),
1710 );
1711
1712 let secrets = SecretsConfig {
1713 secrets: vec![SecretEntry {
1714 env_var: "API_KEY".into(),
1715 value: "real-secret-value".into(),
1716 placeholder: "$MSB_KEY".into(),
1717 allowed_hosts: vec![HostPattern::Exact("example.com".into())],
1718 injection: SecretInjection {
1719 headers: true,
1720 basic_auth: false,
1721 query_params: false,
1722 body: false,
1723 },
1724 on_violation: None,
1725 require_tls_identity: true,
1726 }],
1727 ..Default::default()
1728 };
1729
1730 let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
1731 let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
1732 let proxy_connect = Arc::new(ProxyConnectState::new());
1733
1734 from_tx
1735 .send(Bytes::from_static(b"GET /api HTTP/1.1\r\n"))
1736 .await
1737 .unwrap();
1738 from_tx
1739 .send(Bytes::from_static(
1740 b"Host: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n",
1741 ))
1742 .await
1743 .unwrap();
1744 drop(from_tx);
1745
1746 tcp_proxy_task(
1747 addr,
1748 addr,
1749 from_rx,
1750 to_tx,
1751 Arc::new(shared),
1752 Arc::new(NetworkPolicy::default()),
1753 Arc::new(secrets),
1754 None,
1755 proxy_connect,
1756 )
1757 .await
1758 .unwrap();
1759
1760 let wire = String::from_utf8(sink.await.unwrap()).unwrap();
1761 assert!(
1762 wire.contains("Host: example.com"),
1763 "request must reach the allowed host, got: {wire:?}"
1764 );
1765 assert!(
1766 wire.contains("$MSB_KEY"),
1767 "placeholder must be forwarded unchanged for a require_tls_identity secret, got: {wire:?}"
1768 );
1769 assert!(
1770 !wire.contains("real-secret-value"),
1771 "secret must never be substituted over plain HTTP, got: {wire:?}"
1772 );
1773 }
1774
1775 #[tokio::test]
1776 async fn plain_http_substitutes_placeholder_in_first_flight() {
1777 let (addr, sink) = spawn_sink().await;
1778
1779 let request =
1780 b"GET /api HTTP/1.1\r\nHost: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n"
1781 .to_vec();
1782 let secrets = make_plain_http_secret("$MSB_KEY", "real-secret-value", false);
1783
1784 let wire =
1785 String::from_utf8(relay_through_proxy(request, secrets, sink, addr).await).unwrap();
1786 assert!(
1787 wire.contains("real-secret-value"),
1788 "real value must reach server, got: {wire:?}"
1789 );
1790 assert!(
1791 !wire.contains("$MSB_KEY"),
1792 "placeholder must not reach server, got: {wire:?}"
1793 );
1794 }
1795
1796 #[tokio::test]
1797 async fn plain_http_no_substitution_when_require_tls_identity_true() {
1798 let (addr, sink) = spawn_sink().await;
1799
1800 let request =
1801 b"GET /api HTTP/1.1\r\nHost: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n"
1802 .to_vec();
1803 let secrets = make_plain_http_secret("$MSB_KEY", "real-secret-value", true);
1804
1805 let wire =
1806 String::from_utf8_lossy(&relay_through_proxy(request, secrets, sink, addr).await)
1807 .into_owned();
1808 assert!(
1809 wire.contains("$MSB_KEY"),
1810 "placeholder must be forwarded unchanged when require_tls_identity=true, got: {wire:?}"
1811 );
1812 assert!(
1813 !wire.contains("real-secret-value"),
1814 "real value must not leak when require_tls_identity=true, got: {wire:?}"
1815 );
1816 }
1817
1818 #[tokio::test]
1819 async fn plain_http_large_body_forwarded_verbatim_in_relay_loop() {
1820 let (addr, sink) = spawn_sink().await;
1824 let secrets = make_plain_http_secret("$MSB_KEY", "real-value", false);
1825
1826 let body = "x".repeat(32_000);
1827 let header = format!(
1828 "POST /upload HTTP/1.1\r\nHost: example.com\r\nAuthorization: Bearer $MSB_KEY\r\nContent-Length: {}\r\n\r\n",
1829 body.len()
1830 );
1831
1832 let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
1833 let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
1834 let proxy_connect = Arc::new(ProxyConnectState::new());
1835
1836 from_tx
1837 .send(Bytes::from(header.into_bytes()))
1838 .await
1839 .unwrap();
1840 from_tx
1841 .send(Bytes::from(body.clone().into_bytes()))
1842 .await
1843 .unwrap();
1844 drop(from_tx);
1845
1846 tcp_proxy_task(
1847 addr,
1848 addr,
1849 from_rx,
1850 to_tx,
1851 Arc::new(SharedState::new(4)),
1852 Arc::new(NetworkPolicy::default()),
1853 Arc::new(secrets),
1854 None,
1855 proxy_connect,
1856 )
1857 .await
1858 .unwrap();
1859
1860 let wire = String::from_utf8_lossy(&sink.await.unwrap()).into_owned();
1861 assert!(wire.contains(&body), "got {} bytes", wire.len());
1862 assert!(!wire.contains("$MSB_KEY"), "got: {wire:?}");
1863 }
1864}