1use std::borrow::Cow;
9use std::io;
10use std::net::SocketAddr;
11use std::sync::Arc;
12use std::time::Duration;
13
14use bytes::Bytes;
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16use tokio::net::TcpStream;
17use tokio::sync::mpsc;
18
19use crate::conn::ProxyConnectState;
20use crate::policy::{EgressEvaluation, HostnameSource, NetworkPolicy, Protocol};
21use crate::secrets::config::{SecretsConfig, ViolationAction};
22use crate::secrets::handler::{
23 SecretsHandler, first_line_is_not_http_request, looks_like_http_request_prefix,
24};
25use crate::shared::SharedState;
26use crate::tls::sni;
27
28const SERVER_READ_BUF_SIZE: usize = 16384;
34
35const PEEK_BUF_SIZE: usize = 16384;
37
38const PEEK_BUDGET: Duration = Duration::from_secs(5);
41
42#[allow(clippy::too_many_arguments)]
57pub fn spawn_tcp_proxy(
58 handle: &tokio::runtime::Handle,
59 guest_dst: SocketAddr,
60 connect_dst: SocketAddr,
61 from_smoltcp: mpsc::Receiver<Bytes>,
62 to_smoltcp: mpsc::Sender<Bytes>,
63 shared: Arc<SharedState>,
64 network_policy: Arc<NetworkPolicy>,
65 secrets: Arc<SecretsConfig>,
66 proxy_connect: Arc<ProxyConnectState>,
67) {
68 handle.spawn(async move {
69 if let Err(e) = tcp_proxy_task(
70 guest_dst,
71 connect_dst,
72 from_smoltcp,
73 to_smoltcp,
74 shared,
75 network_policy,
76 secrets,
77 proxy_connect,
78 )
79 .await
80 {
81 tracing::debug!(dst = %connect_dst, error = %e, "TCP proxy task ended");
82 }
83 });
84}
85
86#[allow(clippy::too_many_arguments)]
89async fn tcp_proxy_task(
90 guest_dst: SocketAddr,
91 connect_dst: SocketAddr,
92 mut from_smoltcp: mpsc::Receiver<Bytes>,
93 to_smoltcp: mpsc::Sender<Bytes>,
94 shared: Arc<SharedState>,
95 network_policy: Arc<NetworkPolicy>,
96 secrets: Arc<SecretsConfig>,
97 proxy_connect: Arc<ProxyConnectState>,
98) -> io::Result<()> {
99 let (initial_buf, sni) = if network_policy.has_domain_rules() {
105 peek_for_sni(&mut from_smoltcp, PEEK_BUF_SIZE, PEEK_BUDGET).await
106 } else {
107 (Vec::new(), None)
108 };
109
110 if network_policy.has_domain_rules() {
116 let source = match sni.as_deref() {
117 Some(name) => HostnameSource::Sni(name),
118 None => HostnameSource::CacheOnly,
119 };
120 match network_policy.evaluate_egress_with_source(guest_dst, Protocol::Tcp, &shared, source)
121 {
122 EgressEvaluation::Allow => {}
123 EgressEvaluation::Deny => {
124 tracing::debug!(
125 dst = %guest_dst,
126 source = source.label(),
127 "TCP egress denied by domain policy",
128 );
129 proxy_connect.mark_policy_denied();
130 shared.proxy_wake.wake();
131 return Ok(());
132 }
133 EgressEvaluation::DeferUntilHostname => {
134 debug_assert!(false, "DeferUntilHostname leaked into TCP proxy task");
135 proxy_connect.mark_policy_denied();
136 shared.proxy_wake.wake();
137 return Ok(());
138 }
139 }
140 }
141
142 let stream = match TcpStream::connect(connect_dst).await {
147 Ok(stream) => {
148 proxy_connect.mark_connected();
149 stream
150 }
151 Err(e) => {
152 proxy_connect.mark_upstream_connect_failed();
153 shared.proxy_wake.wake();
154 return Err(e);
155 }
156 };
157 let (mut server_rx, mut server_tx) = stream.into_split();
158
159 let want_headers = secrets.has_plain_http_candidates() || secrets.has_host_scoped_secrets();
165 let (initial_buf, is_tls) = if !secrets.secrets.is_empty() {
166 classify_first_flight(
167 initial_buf,
168 &mut from_smoltcp,
169 &mut server_rx,
170 &to_smoltcp,
171 &shared,
172 want_headers,
173 PEEK_BUF_SIZE,
174 PEEK_BUDGET,
175 )
176 .await?
177 } else {
178 (initial_buf, false)
179 };
180
181 let mut secrets_handler: Option<SecretsHandler> = if !secrets.secrets.is_empty() && !is_tls {
182 Some(match extract_http_host(&initial_buf) {
183 Some(host) => SecretsHandler::new_plain_http(&secrets, &host, guest_dst.ip(), &shared),
184 None => SecretsHandler::new_plain_http_invalid_host(&secrets),
185 })
186 } else {
187 None
188 };
189
190 if !initial_buf.is_empty() {
192 let out: Cow<[u8]> = match secrets_handler.as_mut() {
193 Some(h) => match h.substitute(&initial_buf) {
194 Ok(cow) => cow,
197 Err(action) => {
198 tracing::warn!(dst = %connect_dst, violation = ?action, "secret violation in first flight");
199 if matches!(action, ViolationAction::BlockAndTerminate) {
200 shared.trigger_termination();
201 }
202 return Ok(());
203 }
204 },
205 None => Cow::Borrowed(&initial_buf),
206 };
207 if !out.is_empty() {
208 if let Err(e) = server_tx.write_all(&out).await {
209 tracing::debug!(dst = %connect_dst, error = %e, "replay of buffered first flight failed");
210 return Ok(());
211 }
212 if let Err(e) = server_tx.flush().await {
213 tracing::debug!(dst = %connect_dst, error = %e, "flush after first flight failed");
214 return Ok(());
215 }
216 }
217 }
218
219 let mut server_buf = vec![0u8; SERVER_READ_BUF_SIZE];
220
221 loop {
226 tokio::select! {
227 data = from_smoltcp.recv() => {
229 match data {
230 Some(bytes) => {
231 let out: Cow<[u8]> = match secrets_handler.as_mut() {
234 Some(h) => match h.substitute(&bytes) {
235 Ok(cow) => cow,
236 Err(action) => {
237 tracing::warn!(dst = %connect_dst, violation = ?action, "secret violation");
238 if matches!(action, ViolationAction::BlockAndTerminate) {
239 shared.trigger_termination();
240 }
241 break;
242 }
243 },
244 None => Cow::Borrowed(&bytes),
245 };
246 if !out.is_empty() {
247 if let Err(e) = server_tx.write_all(&out).await {
248 tracing::debug!(dst = %connect_dst, error = %e, "write to server failed");
249 break;
250 }
251 if let Err(e) = server_tx.flush().await {
252 tracing::debug!(dst = %connect_dst, error = %e, "flush to server failed");
253 break;
254 }
255 }
256 }
257 None => break,
259 }
260 }
261
262 result = server_rx.read(&mut server_buf) => {
264 match result {
265 Ok(0) => break, Ok(n) => {
267 let data = Bytes::copy_from_slice(&server_buf[..n]);
268 if to_smoltcp.send(data).await.is_err() {
269 break;
271 }
272 shared.proxy_wake.wake();
275 }
276 Err(e) => {
277 tracing::debug!(dst = %connect_dst, error = %e, "read from server failed");
278 break;
279 }
280 }
281 }
282 }
283 }
284
285 Ok(())
286}
287
288fn extract_http_host(buf: &[u8]) -> Option<String> {
298 if buf.first() == Some(&0x16) {
299 return None;
300 }
301 let mut headers = vec![httparse::EMPTY_HEADER; (buf.len() / 4).max(16)];
307 let mut req = httparse::Request::new(&mut headers);
308 req.parse(buf).ok()?;
309 req.headers
310 .iter()
311 .find(|h| h.name.eq_ignore_ascii_case("host"))
312 .and_then(|h| std::str::from_utf8(h.value).ok())
313 .map(|v| {
314 let host = v.trim();
315 host.rsplit_once(':')
317 .map(|(h, _)| h)
318 .unwrap_or(host)
319 .to_ascii_lowercase()
320 })
321 .filter(|h| !h.is_empty())
322}
323
324#[allow(clippy::too_many_arguments)]
341async fn classify_first_flight(
342 mut buf: Vec<u8>,
343 from_smoltcp: &mut mpsc::Receiver<Bytes>,
344 server_rx: &mut tokio::net::tcp::OwnedReadHalf,
345 to_smoltcp: &mpsc::Sender<Bytes>,
346 shared: &SharedState,
347 want_headers: bool,
348 max: usize,
349 budget: Duration,
350) -> io::Result<(Vec<u8>, bool)> {
351 let mut server_buf = vec![0u8; SERVER_READ_BUF_SIZE];
352 let timeout_fut = tokio::time::sleep(budget);
353 tokio::pin!(timeout_fut);
354
355 loop {
356 if !buf.is_empty() {
362 let is_tls = buf.first() == Some(&0x16);
363 let not_http = !is_tls
364 && (!looks_like_http_request_prefix(&buf) || first_line_is_not_http_request(&buf));
365 let done = !want_headers
366 || is_tls
367 || not_http
368 || buf.len() >= max
369 || buf.windows(4).any(|w| w == b"\r\n\r\n");
370 if done {
371 return Ok((buf, is_tls));
372 }
373 }
374
375 tokio::select! {
376 biased;
377 _ = &mut timeout_fut => {
378 let is_tls = buf.first() == Some(&0x16);
379 return Ok((buf, is_tls));
380 }
381 guest = from_smoltcp.recv() => match guest {
384 Some(bytes) => buf.extend_from_slice(&bytes),
385 None => {
386 let is_tls = buf.first() == Some(&0x16);
387 return Ok((buf, is_tls));
388 }
389 },
390 server = server_rx.read(&mut server_buf) => match server {
393 Ok(0) => {
394 let is_tls = buf.first() == Some(&0x16);
395 return Ok((buf, is_tls));
396 }
397 Ok(n) => {
398 let data = Bytes::copy_from_slice(&server_buf[..n]);
399 if to_smoltcp.send(data).await.is_err() {
400 let is_tls = buf.first() == Some(&0x16);
401 return Ok((buf, is_tls));
402 }
403 shared.proxy_wake.wake();
404 }
405 Err(e) => return Err(e),
406 },
407 }
408 }
409}
410
411async fn peek_for_sni(
421 rx: &mut mpsc::Receiver<Bytes>,
422 max: usize,
423 budget: Duration,
424) -> (Vec<u8>, Option<String>) {
425 let mut buf = Vec::with_capacity(PEEK_BUF_SIZE.min(8192));
426 let timeout_fut = tokio::time::sleep(budget);
427 tokio::pin!(timeout_fut);
428
429 let raw_sni = loop {
430 tokio::select! {
431 biased;
432 _ = &mut timeout_fut => break None,
433 data = rx.recv() => {
434 match data {
435 Some(bytes) => {
436 buf.extend_from_slice(&bytes);
437 if buf.first() != Some(&0x16) {
442 break None;
443 }
444 if let Some(name) = sni::extract_sni(&buf) {
445 break Some(name);
446 }
447 if buf.len() >= max {
448 break None;
449 }
450 }
451 None => break None,
452 }
453 }
454 }
455 };
456
457 let canonical = raw_sni.map(|s| s.trim_end_matches('.').to_ascii_lowercase());
458 (buf, canonical)
459}
460
461#[cfg(test)]
466mod tests {
467 use super::*;
468
469 fn synthetic_client_hello(sni: &str) -> Vec<u8> {
473 let host_bytes = sni.as_bytes();
476 let host_len = host_bytes.len() as u16;
477 let server_name_list_len = 3 + host_len; let extension_data_len = 2 + server_name_list_len; let extensions_total = 4 + extension_data_len; let mut body = Vec::new();
482 body.extend_from_slice(&[0x03, 0x03]);
484 body.extend_from_slice(&[0u8; 32]);
486 body.push(0);
488 body.extend_from_slice(&[0x00, 0x02, 0x00, 0x2f]);
490 body.extend_from_slice(&[0x01, 0x00]);
492 body.extend_from_slice(&extensions_total.to_be_bytes());
494 body.extend_from_slice(&[0x00, 0x00]);
496 body.extend_from_slice(&extension_data_len.to_be_bytes());
497 body.extend_from_slice(&server_name_list_len.to_be_bytes());
498 body.push(0x00); body.extend_from_slice(&host_len.to_be_bytes());
500 body.extend_from_slice(host_bytes);
501
502 let handshake_len = body.len() as u32;
503 let mut hs = Vec::new();
504 hs.push(0x01); hs.extend_from_slice(&handshake_len.to_be_bytes()[1..]); hs.extend_from_slice(&body);
507
508 let record_len = hs.len() as u16;
509 let mut record = Vec::new();
510 record.extend_from_slice(&[0x16, 0x03, 0x01]); record.extend_from_slice(&record_len.to_be_bytes());
512 record.extend_from_slice(&hs);
513
514 record
515 }
516
517 #[tokio::test]
518 async fn peek_for_sni_extracts_and_canonicalizes() {
519 let (tx, mut rx) = mpsc::channel(4);
520 let hello = synthetic_client_hello("Example.COM");
521 tx.send(Bytes::from(hello.clone())).await.unwrap();
522 drop(tx); let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
525 assert_eq!(sni.as_deref(), Some("example.com"));
526 assert_eq!(buf, hello);
527 }
528
529 #[tokio::test]
530 async fn peek_for_sni_returns_none_on_channel_close_without_data() {
531 let (tx, mut rx) = mpsc::channel::<Bytes>(1);
532 drop(tx);
533 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
534 assert!(buf.is_empty());
535 assert_eq!(sni, None);
536 }
537
538 #[tokio::test]
539 async fn peek_for_sni_returns_none_on_non_tls_data() {
540 let (tx, mut rx) = mpsc::channel(4);
541 tx.send(Bytes::from_static(
543 b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n",
544 ))
545 .await
546 .unwrap();
547 drop(tx);
548 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
549 assert!(
550 !buf.is_empty(),
551 "buffered bytes must be returned for replay"
552 );
553 assert_eq!(sni, None);
554 }
555
556 #[tokio::test]
557 async fn peek_for_sni_falls_back_on_timeout() {
558 let (tx, mut rx) = mpsc::channel::<Bytes>(1);
559 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, Duration::from_millis(50)).await;
561 drop(tx);
562 assert!(buf.is_empty());
563 assert_eq!(sni, None);
564 }
565
566 #[tokio::test]
567 async fn peek_for_sni_caps_at_max_bytes() {
568 let (tx, mut rx) = mpsc::channel(4);
569 let mut first = vec![0u8; 8192];
573 first[0] = 0x16;
574 tx.send(Bytes::from(first)).await.unwrap();
575 tx.send(Bytes::from(vec![0u8; 8192])).await.unwrap();
576 tx.send(Bytes::from(vec![0u8; 8192])).await.unwrap();
577 drop(tx);
578
579 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
580 assert_eq!(sni, None, "no SNI in non-TLS data");
581 assert!(
582 buf.len() >= PEEK_BUF_SIZE,
583 "buffer must hit the cap before bail-out: got {}",
584 buf.len()
585 );
586 }
587
588 #[tokio::test]
589 async fn peek_for_sni_bails_immediately_on_non_tls_first_byte() {
590 let (tx, mut rx) = mpsc::channel(4);
591 tx.send(Bytes::from_static(b"GET / HTTP/1.1\r\nHost: x\r\n\r\n"))
593 .await
594 .unwrap();
595 drop(tx);
596
597 let started = std::time::Instant::now();
600 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
601 let elapsed = started.elapsed();
602 assert_eq!(sni, None);
603 assert!(buf.starts_with(b"GET"));
604 assert!(
605 elapsed < Duration::from_millis(500),
606 "non-TLS bail must be fast: took {elapsed:?}"
607 );
608 }
609
610 use std::net::IpAddr;
615 use std::time::Duration as StdDuration;
616
617 use crate::policy::{Action, Destination, NetworkPolicy, PortRange, Rule};
618 use crate::shared::{ResolvedHostnameFamily, SharedState};
619
620 const SHARED_FASTLY_IP: &str = "151.101.0.223";
621
622 fn shared_with(host: &str, ip: &str) -> SharedState {
623 let shared = SharedState::new(4);
624 shared.cache_resolved_hostname(
625 host,
626 ResolvedHostnameFamily::Ipv4,
627 [ip.parse::<IpAddr>().unwrap()],
628 StdDuration::from_secs(60),
629 );
630 shared
631 }
632
633 fn allow_https(domain: &str) -> Rule {
634 Rule {
635 direction: crate::policy::Direction::Egress,
636 destination: Destination::Domain(domain.parse().unwrap()),
637 protocols: vec![Protocol::Tcp],
638 ports: vec![PortRange::single(443)],
639 action: Action::Allow,
640 }
641 }
642
643 #[tokio::test]
646 async fn integration_sni_overrides_cache_for_over_allow() {
647 let shared = shared_with("pypi.org", SHARED_FASTLY_IP);
648 let policy = NetworkPolicy {
649 default_egress: Action::Deny,
650 default_ingress: Action::Allow,
651 rules: vec![allow_https("pypi.org")],
652 };
653 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
654
655 let (tx, mut rx) = mpsc::channel(4);
656 tx.send(Bytes::from(synthetic_client_hello("evil.com")))
657 .await
658 .unwrap();
659 drop(tx);
660
661 let (initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
662 assert_eq!(sni.as_deref(), Some("evil.com"));
663 assert!(!initial_buf.is_empty());
664
665 let source = sni
666 .as_deref()
667 .map(HostnameSource::Sni)
668 .unwrap_or(HostnameSource::CacheOnly);
669 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
670 assert_eq!(
671 eval,
672 EgressEvaluation::Deny,
673 "SNI=evil.com must not piggy-back on the cached pypi.org match",
674 );
675 }
676
677 #[tokio::test]
680 async fn integration_sni_overrides_cache_for_over_block() {
681 let shared = shared_with("ads.example.com", SHARED_FASTLY_IP);
682 let policy = NetworkPolicy {
683 default_egress: Action::Allow,
684 default_ingress: Action::Allow,
685 rules: vec![Rule::deny_egress(Destination::Domain(
686 "ads.example.com".parse().unwrap(),
687 ))],
688 };
689 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
690
691 let (tx, mut rx) = mpsc::channel(4);
692 tx.send(Bytes::from(synthetic_client_hello("api.example.com")))
693 .await
694 .unwrap();
695 drop(tx);
696
697 let (_initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
698 assert_eq!(sni.as_deref(), Some("api.example.com"));
699
700 let source = sni
701 .as_deref()
702 .map(HostnameSource::Sni)
703 .unwrap_or(HostnameSource::CacheOnly);
704 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
705 assert_eq!(
706 eval,
707 EgressEvaluation::Allow,
708 "SNI=api.example.com must not be caught by the deny on ads.example.com",
709 );
710 }
711
712 #[tokio::test]
715 async fn integration_non_tls_falls_back_to_cache() {
716 let shared = shared_with("pypi.org", SHARED_FASTLY_IP);
717 let policy = NetworkPolicy {
718 default_egress: Action::Deny,
719 default_ingress: Action::Allow,
720 rules: vec![allow_https("pypi.org")],
721 };
722 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
723
724 let (tx, mut rx) = mpsc::channel(4);
725 tx.send(Bytes::from_static(
727 b"GET / HTTP/1.1\r\nHost: pypi.org\r\n\r\n",
728 ))
729 .await
730 .unwrap();
731 drop(tx);
732
733 let (initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
734 assert_eq!(sni, None, "non-TLS data → no SNI");
735 assert!(
736 !initial_buf.is_empty(),
737 "buffered bytes must survive for replay"
738 );
739
740 let source = sni
741 .as_deref()
742 .map(HostnameSource::Sni)
743 .unwrap_or(HostnameSource::CacheOnly);
744 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
745 assert_eq!(
746 eval,
747 EgressEvaluation::Allow,
748 "cache-only fallback must still allow the cached hostname's IP",
749 );
750 }
751
752 #[tokio::test]
755 async fn integration_sni_matches_domain_suffix_with_cache_binding() {
756 let shared = shared_with("files.pythonhosted.org", SHARED_FASTLY_IP);
757 let policy = NetworkPolicy {
758 default_egress: Action::Deny,
759 default_ingress: Action::Allow,
760 rules: vec![Rule {
761 direction: crate::policy::Direction::Egress,
762 destination: Destination::DomainSuffix(".pythonhosted.org".parse().unwrap()),
763 protocols: vec![Protocol::Tcp],
764 ports: vec![PortRange::single(443)],
765 action: Action::Allow,
766 }],
767 };
768 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
769
770 let (tx, mut rx) = mpsc::channel(4);
771 tx.send(Bytes::from(synthetic_client_hello(
772 "files.pythonhosted.org",
773 )))
774 .await
775 .unwrap();
776 drop(tx);
777
778 let (_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
779 let source = sni
780 .as_deref()
781 .map(HostnameSource::Sni)
782 .unwrap_or(HostnameSource::CacheOnly);
783 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
784 assert_eq!(eval, EgressEvaluation::Allow);
785 }
786
787 #[tokio::test]
792 async fn integration_sni_denies_domain_suffix_without_cache_binding() {
793 let shared = SharedState::new(4); let policy = NetworkPolicy {
795 default_egress: Action::Deny,
796 default_ingress: Action::Allow,
797 rules: vec![Rule {
798 direction: crate::policy::Direction::Egress,
799 destination: Destination::DomainSuffix(".pythonhosted.org".parse().unwrap()),
800 protocols: vec![Protocol::Tcp],
801 ports: vec![PortRange::single(443)],
802 action: Action::Allow,
803 }],
804 };
805 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
806
807 let (tx, mut rx) = mpsc::channel(4);
808 tx.send(Bytes::from(synthetic_client_hello(
809 "files.pythonhosted.org",
810 )))
811 .await
812 .unwrap();
813 drop(tx);
814
815 let (_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
816 let source = sni
817 .as_deref()
818 .map(HostnameSource::Sni)
819 .unwrap_or(HostnameSource::CacheOnly);
820 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
821 assert_eq!(eval, EgressEvaluation::Deny);
822 }
823
824 #[test]
827 fn extract_http_host_basic() {
828 let buf = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n";
829 assert_eq!(extract_http_host(buf), Some("example.com".into()));
830 }
831
832 #[test]
833 fn extract_http_host_strips_port() {
834 let buf = b"POST /api HTTP/1.1\r\nHost: api.company.com:8080\r\n\r\n";
835 assert_eq!(extract_http_host(buf), Some("api.company.com".into()));
836 }
837
838 #[test]
839 fn extract_http_host_case_insensitive_lowercased() {
840 let buf = b"GET / HTTP/1.1\r\nhost: Example.COM\r\n\r\n";
841 assert_eq!(extract_http_host(buf), Some("example.com".into()));
842 }
843
844 #[test]
845 fn extract_http_host_no_host_header() {
846 let buf = b"GET / HTTP/1.1\r\nX-Other: foo\r\n\r\n";
847 assert_eq!(extract_http_host(buf), None);
848 }
849
850 #[test]
851 fn extract_http_host_incomplete_headers() {
852 let buf = b"GET / HTTP/1.1\r\nHost: x";
853 assert_eq!(extract_http_host(buf), None);
854 }
855
856 #[test]
857 fn extract_http_host_tls_first_byte() {
858 let buf = [0x16u8, 0x03, 0x01, 0x00, 0x01];
859 assert_eq!(extract_http_host(&buf), None);
860 }
861
862 #[test]
863 fn extract_http_host_with_many_headers() {
864 let mut req = Vec::from(&b"GET / HTTP/1.1\r\n"[..]);
867 for i in 0..100 {
868 req.extend_from_slice(format!("X-Pad-{i}: v\r\n").as_bytes());
869 }
870 req.extend_from_slice(b"Host: example.com\r\n\r\n");
871 assert_eq!(extract_http_host(&req), Some("example.com".into()));
872 }
873
874 use std::sync::Arc;
877 use tokio::io::AsyncReadExt;
878 use tokio::net::TcpListener;
879 use tokio::task::JoinHandle;
880
881 use crate::secrets::config::{HostPattern, SecretEntry, SecretInjection, SecretsConfig};
882
883 fn make_plain_http_secret(placeholder: &str, value: &str, require_tls: bool) -> SecretsConfig {
884 SecretsConfig {
885 secrets: vec![SecretEntry {
886 env_var: "API_KEY".into(),
887 value: value.into(),
888 placeholder: placeholder.into(),
889 allowed_hosts: vec![HostPattern::Any],
890 injection: SecretInjection {
891 headers: true,
892 basic_auth: false,
893 query_params: false,
894 body: false,
895 },
896 on_violation: None,
897 require_tls_identity: require_tls,
898 }],
899 ..Default::default()
900 }
901 }
902
903 async fn spawn_sink() -> (SocketAddr, JoinHandle<Vec<u8>>) {
904 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
905 let addr = listener.local_addr().unwrap();
906 let handle = tokio::spawn(async move {
907 let (mut stream, _) = listener.accept().await.unwrap();
908 let mut received = Vec::new();
909 let mut buf = vec![0u8; 4096];
910 loop {
911 match stream.read(&mut buf).await {
912 Ok(0) | Err(_) => break,
913 Ok(n) => received.extend_from_slice(&buf[..n]),
914 }
915 }
916 received
917 });
918 (addr, handle)
919 }
920
921 async fn relay_through_proxy(
922 request: Vec<u8>,
923 secrets: SecretsConfig,
924 handle: JoinHandle<Vec<u8>>,
925 server_addr: SocketAddr,
926 ) -> Vec<u8> {
927 let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
928 let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
929 let shared = SharedState::new(4);
930 let policy = Arc::new(NetworkPolicy::default());
931 let secrets = Arc::new(secrets);
932 let proxy_connect = Arc::new(ProxyConnectState::new());
933
934 from_tx.send(Bytes::from(request)).await.unwrap();
935 drop(from_tx);
936
937 tcp_proxy_task(
938 server_addr,
939 server_addr,
940 from_rx,
941 to_tx,
942 Arc::new(shared),
943 policy,
944 secrets,
945 proxy_connect,
946 )
947 .await
948 .unwrap();
949
950 handle.await.unwrap()
951 }
952
953 #[tokio::test]
954 async fn plain_http_substitutes_placeholder_when_host_arrives_in_second_segment() {
955 let (addr, sink) = spawn_sink().await;
958 let secrets = make_plain_http_secret("$MSB_KEY", "real-secret-value", false);
959
960 let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
961 let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
962 let proxy_connect = Arc::new(ProxyConnectState::new());
963
964 from_tx
965 .send(Bytes::from_static(b"GET /api HTTP/1.1\r\n"))
966 .await
967 .unwrap();
968 from_tx
969 .send(Bytes::from_static(
970 b"Host: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n",
971 ))
972 .await
973 .unwrap();
974 drop(from_tx);
975
976 tcp_proxy_task(
977 addr,
978 addr,
979 from_rx,
980 to_tx,
981 Arc::new(SharedState::new(4)),
982 Arc::new(NetworkPolicy::default()),
983 Arc::new(secrets),
984 proxy_connect,
985 )
986 .await
987 .unwrap();
988
989 let wire = String::from_utf8(sink.await.unwrap()).unwrap();
990 assert!(wire.contains("real-secret-value"), "got: {wire:?}");
991 assert!(!wire.contains("$MSB_KEY"), "got: {wire:?}");
992 }
993
994 #[tokio::test]
995 async fn plain_http_forwards_placeholder_to_allowed_host_with_split_headers() {
996 let (addr, sink) = spawn_sink().await;
1001
1002 let shared = SharedState::new(4);
1003 shared.cache_resolved_hostname(
1004 "example.com",
1005 ResolvedHostnameFamily::Ipv4,
1006 ["127.0.0.1".parse::<IpAddr>().unwrap()],
1007 StdDuration::from_secs(60),
1008 );
1009
1010 let secrets = SecretsConfig {
1011 secrets: vec![SecretEntry {
1012 env_var: "API_KEY".into(),
1013 value: "real-secret-value".into(),
1014 placeholder: "$MSB_KEY".into(),
1015 allowed_hosts: vec![HostPattern::Exact("example.com".into())],
1016 injection: SecretInjection {
1017 headers: true,
1018 basic_auth: false,
1019 query_params: false,
1020 body: false,
1021 },
1022 on_violation: None,
1023 require_tls_identity: true,
1024 }],
1025 ..Default::default()
1026 };
1027
1028 let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
1029 let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
1030 let proxy_connect = Arc::new(ProxyConnectState::new());
1031
1032 from_tx
1033 .send(Bytes::from_static(b"GET /api HTTP/1.1\r\n"))
1034 .await
1035 .unwrap();
1036 from_tx
1037 .send(Bytes::from_static(
1038 b"Host: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n",
1039 ))
1040 .await
1041 .unwrap();
1042 drop(from_tx);
1043
1044 tcp_proxy_task(
1045 addr,
1046 addr,
1047 from_rx,
1048 to_tx,
1049 Arc::new(shared),
1050 Arc::new(NetworkPolicy::default()),
1051 Arc::new(secrets),
1052 proxy_connect,
1053 )
1054 .await
1055 .unwrap();
1056
1057 let wire = String::from_utf8(sink.await.unwrap()).unwrap();
1058 assert!(
1059 wire.contains("Host: example.com"),
1060 "request must reach the allowed host, got: {wire:?}"
1061 );
1062 assert!(
1063 wire.contains("$MSB_KEY"),
1064 "placeholder must be forwarded unchanged for a require_tls_identity secret, got: {wire:?}"
1065 );
1066 assert!(
1067 !wire.contains("real-secret-value"),
1068 "secret must never be substituted over plain HTTP, got: {wire:?}"
1069 );
1070 }
1071
1072 #[tokio::test]
1073 async fn plain_http_substitutes_placeholder_in_first_flight() {
1074 let (addr, sink) = spawn_sink().await;
1075
1076 let request =
1077 b"GET /api HTTP/1.1\r\nHost: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n"
1078 .to_vec();
1079 let secrets = make_plain_http_secret("$MSB_KEY", "real-secret-value", false);
1080
1081 let wire =
1082 String::from_utf8(relay_through_proxy(request, secrets, sink, addr).await).unwrap();
1083 assert!(
1084 wire.contains("real-secret-value"),
1085 "real value must reach server, got: {wire:?}"
1086 );
1087 assert!(
1088 !wire.contains("$MSB_KEY"),
1089 "placeholder must not reach server, got: {wire:?}"
1090 );
1091 }
1092
1093 #[tokio::test]
1094 async fn plain_http_no_substitution_when_require_tls_identity_true() {
1095 let (addr, sink) = spawn_sink().await;
1096
1097 let request =
1098 b"GET /api HTTP/1.1\r\nHost: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n"
1099 .to_vec();
1100 let secrets = make_plain_http_secret("$MSB_KEY", "real-secret-value", true);
1101
1102 let wire =
1103 String::from_utf8_lossy(&relay_through_proxy(request, secrets, sink, addr).await)
1104 .into_owned();
1105 assert!(
1106 wire.contains("$MSB_KEY"),
1107 "placeholder must be forwarded unchanged when require_tls_identity=true, got: {wire:?}"
1108 );
1109 assert!(
1110 !wire.contains("real-secret-value"),
1111 "real value must not leak when require_tls_identity=true, got: {wire:?}"
1112 );
1113 }
1114
1115 #[tokio::test]
1116 async fn plain_http_large_body_forwarded_verbatim_in_relay_loop() {
1117 let (addr, sink) = spawn_sink().await;
1121 let secrets = make_plain_http_secret("$MSB_KEY", "real-value", false);
1122
1123 let body = "x".repeat(32_000);
1124 let header = format!(
1125 "POST /upload HTTP/1.1\r\nHost: example.com\r\nAuthorization: Bearer $MSB_KEY\r\nContent-Length: {}\r\n\r\n",
1126 body.len()
1127 );
1128
1129 let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
1130 let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
1131 let proxy_connect = Arc::new(ProxyConnectState::new());
1132
1133 from_tx
1134 .send(Bytes::from(header.into_bytes()))
1135 .await
1136 .unwrap();
1137 from_tx
1138 .send(Bytes::from(body.clone().into_bytes()))
1139 .await
1140 .unwrap();
1141 drop(from_tx);
1142
1143 tcp_proxy_task(
1144 addr,
1145 addr,
1146 from_rx,
1147 to_tx,
1148 Arc::new(SharedState::new(4)),
1149 Arc::new(NetworkPolicy::default()),
1150 Arc::new(secrets),
1151 proxy_connect,
1152 )
1153 .await
1154 .unwrap();
1155
1156 let wire = String::from_utf8_lossy(&sink.await.unwrap()).into_owned();
1157 assert!(wire.contains(&body), "got {} bytes", wire.len());
1158 assert!(!wire.contains("$MSB_KEY"), "got: {wire:?}");
1159 }
1160}