Skip to main content

microsandbox_network/
proxy.rs

1//! Bidirectional TCP proxy: smoltcp socket ↔ channels ↔ tokio socket.
2//!
3//! Each outbound guest TCP connection gets a proxy task that opens a real
4//! TCP connection to the destination via tokio and relays data between the
5//! channel pair (connected to the smoltcp socket in the poll loop) and the
6//! real server.
7
8use std::borrow::Cow;
9use std::io;
10use std::net::{IpAddr, SocketAddr};
11use std::sync::Arc;
12use std::time::Duration;
13
14use bytes::Bytes;
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16use tokio::net::TcpStream;
17use tokio::sync::mpsc;
18
19use crate::conn::ProxyConnectState;
20use crate::policy::{EgressEvaluation, HostnameSource, NetworkPolicy, Protocol};
21use crate::secrets::config::{SecretsConfig, ViolationAction};
22use crate::secrets::handler::{
23    SecretsHandler, first_line_is_not_http_request, looks_like_http_request_prefix,
24};
25use crate::shared::SharedState;
26use crate::tls::proxy::{TlsProxyContext, tls_proxy_task};
27use crate::tls::sni;
28use crate::tls::state::TlsState;
29
30//--------------------------------------------------------------------------------------------------
31// Constants
32//--------------------------------------------------------------------------------------------------
33
34/// Buffer size for reading from the real server.
35const SERVER_READ_BUF_SIZE: usize = 16384;
36
37/// Max bytes buffered while reading the proxy's CONNECT response headers.
38const CONNECT_RESP_LIMIT: usize = 8192;
39
40/// Max bytes to buffer while peeking for the ClientHello's SNI.
41const PEEK_BUF_SIZE: usize = 16384;
42
43/// Upper bound on time spent buffering the first flight before
44/// falling back to a cache-only egress decision.
45const PEEK_BUDGET: Duration = Duration::from_secs(5);
46
47//--------------------------------------------------------------------------------------------------
48// Types
49//--------------------------------------------------------------------------------------------------
50
51#[derive(Debug)]
52struct ConnectRequest {
53    bytes: Vec<u8>,
54    header_end: usize,
55    target: ConnectTarget,
56}
57
58#[derive(Debug, Clone, PartialEq, Eq)]
59struct ConnectTarget {
60    host: String,
61    port: u16,
62    expected_sni: Option<String>,
63}
64
65//--------------------------------------------------------------------------------------------------
66// Methods
67//--------------------------------------------------------------------------------------------------
68
69impl ConnectRequest {
70    fn header_bytes(&self) -> &[u8] {
71        &self.bytes[..self.header_end]
72    }
73
74    fn post_header_bytes(&self) -> &[u8] {
75        &self.bytes[self.header_end..]
76    }
77}
78
79impl ConnectTarget {
80    fn is_intercepted(&self, tls_state: &TlsState) -> bool {
81        tls_state.config.intercepted_ports.contains(&self.port)
82    }
83
84    fn guest_dst(&self, fallback: SocketAddr, shared: &SharedState) -> SocketAddr {
85        if let Ok(ip) = self.host.parse::<IpAddr>() {
86            return SocketAddr::new(ip, self.port);
87        }
88
89        if self.host.eq_ignore_ascii_case(crate::HOST_ALIAS) {
90            match fallback.ip() {
91                IpAddr::V4(_) => {
92                    if let Some(ip) = shared.gateway_ipv4() {
93                        return SocketAddr::new(IpAddr::V4(ip), self.port);
94                    }
95                }
96                IpAddr::V6(_) => {
97                    if let Some(ip) = shared.gateway_ipv6() {
98                        return SocketAddr::new(IpAddr::V6(ip), self.port);
99                    }
100                }
101            }
102            if let Some(ip) = shared.gateway_ipv4() {
103                return SocketAddr::new(IpAddr::V4(ip), self.port);
104            }
105            if let Some(ip) = shared.gateway_ipv6() {
106                return SocketAddr::new(IpAddr::V6(ip), self.port);
107            }
108        }
109
110        SocketAddr::new(fallback.ip(), self.port)
111    }
112}
113
114//--------------------------------------------------------------------------------------------------
115// Functions
116//--------------------------------------------------------------------------------------------------
117
118/// Dial `dst` and update proxy state; wakes the poll thread on failure.
119pub(crate) async fn connect_upstream(
120    dst: SocketAddr,
121    proxy_connect: &ProxyConnectState,
122    shared: &SharedState,
123) -> io::Result<TcpStream> {
124    match TcpStream::connect(dst).await {
125        Ok(s) => {
126            proxy_connect.mark_connected();
127            Ok(s)
128        }
129        Err(e) => {
130            proxy_connect.mark_upstream_connect_failed();
131            shared.proxy_wake.wake();
132            Err(e)
133        }
134    }
135}
136
137/// Spawn a TCP proxy task for a newly established connection.
138///
139/// `guest_dst` is what the guest dialed — the address policy rules
140/// match against. `connect_dst` is the host-side address tokio actually
141/// dials; for host-alias connections it's loopback (gateway rewritten).
142/// For everything else the two are identical.
143///
144/// `proxy_connect` is updated before the task exits so the connection
145/// tracker can decide between FIN (clean close) and RST (upstream
146/// connect failure).
147#[allow(clippy::too_many_arguments)]
148pub fn spawn_tcp_proxy(
149    handle: &tokio::runtime::Handle,
150    guest_dst: SocketAddr,
151    connect_dst: SocketAddr,
152    from_smoltcp: mpsc::Receiver<Bytes>,
153    to_smoltcp: mpsc::Sender<Bytes>,
154    shared: Arc<SharedState>,
155    network_policy: Arc<NetworkPolicy>,
156    secrets: Arc<SecretsConfig>,
157    tls_state: Option<Arc<TlsState>>,
158    proxy_connect: Arc<ProxyConnectState>,
159) {
160    handle.spawn(async move {
161        if let Err(e) = tcp_proxy_task(
162            guest_dst,
163            connect_dst,
164            from_smoltcp,
165            to_smoltcp,
166            shared,
167            network_policy,
168            secrets,
169            tls_state,
170            proxy_connect,
171        )
172        .await
173        {
174            tracing::debug!(dst = %connect_dst, error = %e, "TCP proxy task ended");
175        }
176    });
177}
178
179/// Core TCP proxy: peek for SNI, evaluate egress policy, then either
180/// connect and relay or drop the channels.
181#[allow(clippy::too_many_arguments)]
182async fn tcp_proxy_task(
183    guest_dst: SocketAddr,
184    connect_dst: SocketAddr,
185    mut from_smoltcp: mpsc::Receiver<Bytes>,
186    to_smoltcp: mpsc::Sender<Bytes>,
187    shared: Arc<SharedState>,
188    network_policy: Arc<NetworkPolicy>,
189    secrets: Arc<SecretsConfig>,
190    tls_state: Option<Arc<TlsState>>,
191    proxy_connect: Arc<ProxyConnectState>,
192) -> io::Result<()> {
193    // Pre-connect peek is only for domain policy: the hostname has to be known
194    // before we dial upstream so a Deny never opens a connection. Secrets do
195    // *not* gate the connect, so they no longer force a peek here — that work is
196    // deferred to `classify_first_flight` after the socket is open, where it can
197    // run without stalling server-first protocols (see below).
198    let (mut initial_buf, sni) = if network_policy.has_domain_rules() {
199        peek_for_sni(&mut from_smoltcp, PEEK_BUF_SIZE, PEEK_BUDGET).await
200    } else {
201        (Vec::new(), None)
202    };
203
204    // Re-evaluate egress against the *guest* dst — the address the
205    // guest dialed, not the post-rewrite host-side address. SNI
206    // refines over-allow when the cache matched a shared CDN IP;
207    // CacheOnly is the non-TLS fallback path so Domain rules still
208    // gate plain HTTP / SSH / etc.
209    if network_policy.has_domain_rules() {
210        let source = match sni.as_deref() {
211            Some(name) => HostnameSource::Sni(name),
212            None => HostnameSource::CacheOnly,
213        };
214        match network_policy.evaluate_egress_with_source(guest_dst, Protocol::Tcp, &shared, source)
215        {
216            EgressEvaluation::Allow => {}
217            EgressEvaluation::Deny => {
218                tracing::debug!(
219                    dst = %guest_dst,
220                    source = source.label(),
221                    "TCP egress denied by domain policy",
222                );
223                proxy_connect.mark_policy_denied();
224                shared.proxy_wake.wake();
225                return Ok(());
226            }
227            EgressEvaluation::DeferUntilHostname => {
228                debug_assert!(false, "DeferUntilHostname leaked into TCP proxy task");
229                proxy_connect.mark_policy_denied();
230                shared.proxy_wake.wake();
231                return Ok(());
232            }
233        }
234    }
235
236    // Peek for HTTP CONNECT before dialing upstream; hand off if detected.
237    if let Some(tls_state) = tls_state.clone() {
238        if initial_buf.is_empty() {
239            let (peeked, _) = peek_for_sni(&mut from_smoltcp, PEEK_BUF_SIZE, PEEK_BUDGET).await;
240            initial_buf = peeked;
241        }
242        if could_be_connect_request(&initial_buf) {
243            return handle_connect_tunnel(
244                guest_dst,
245                connect_dst,
246                initial_buf,
247                from_smoltcp,
248                to_smoltcp,
249                shared,
250                network_policy,
251                tls_state,
252                proxy_connect,
253                None,
254            )
255            .await;
256        }
257    }
258
259    // Connect upstream *before* finishing the secrets-side classification. A
260    // server-first protocol (SSH, SMTP, a database) sends nothing until it has
261    // seen the server's banner; with the socket already open we can relay that
262    // banner while we wait, instead of burning the peek budget pre-connect.
263    let stream = connect_upstream(connect_dst, &proxy_connect, &shared).await?;
264    let (mut server_rx, mut server_tx) = stream.into_split();
265
266    // Finish classifying the first flight (TLS vs plain HTTP) and, for
267    // plain-HTTP candidates, gather a full header block — without blocking the
268    // server→guest direction. When domain rules already peeked, `initial_buf`
269    // is reused and this is cheap; with no secrets it is skipped entirely
270    // (`is_tls` only matters for deciding whether to build the handler).
271    let want_headers = secrets.has_plain_http_candidates() || secrets.has_host_scoped_secrets();
272    let (initial_buf, is_tls) = if !secrets.secrets.is_empty() {
273        classify_first_flight(
274            initial_buf,
275            &mut from_smoltcp,
276            &mut server_rx,
277            &to_smoltcp,
278            &shared,
279            want_headers,
280            PEEK_BUF_SIZE,
281            PEEK_BUDGET,
282        )
283        .await?
284    } else {
285        (initial_buf, false)
286    };
287
288    if let Some(tls_state) = tls_state.clone()
289        && could_be_connect_request(&initial_buf)
290    {
291        // The pre-connect CONNECT peek can miss a client whose first bytes arrive
292        // after we dial upstream. Once classify_first_flight has captured that
293        // request, rejoin the already-open proxy socket and use the CONNECT path
294        // so intercepted tunnels still get TLS substitution and policy checks.
295        let proxy_stream = server_rx
296            .reunite(server_tx)
297            .map_err(|_| io::Error::other("failed to reunite proxy stream halves"))?;
298        return handle_connect_tunnel(
299            guest_dst,
300            connect_dst,
301            initial_buf,
302            from_smoltcp,
303            to_smoltcp,
304            shared,
305            network_policy,
306            tls_state,
307            proxy_connect,
308            Some(proxy_stream),
309        )
310        .await;
311    }
312
313    let mut late_connect_state = tls_state;
314    let mut secrets_handler: Option<SecretsHandler> = if !secrets.secrets.is_empty() && !is_tls {
315        Some(match extract_http_host(&initial_buf) {
316            Some(host) => SecretsHandler::new_plain_http(&secrets, &host, guest_dst.ip(), &shared),
317            None => SecretsHandler::new_plain_http_invalid_host(&secrets),
318        })
319    } else {
320        None
321    };
322
323    // Replay the buffered first flight — run through secrets handler first.
324    if !initial_buf.is_empty() {
325        let out: Cow<[u8]> = match secrets_handler.as_mut() {
326            Some(h) => match h.substitute(&initial_buf) {
327                // Borrow the input when nothing was substituted; only a chunk
328                // that actually carries a placeholder is reallocated.
329                Ok(cow) => cow,
330                Err(action) => {
331                    tracing::warn!(dst = %connect_dst, violation = ?action, "secret violation in first flight");
332                    if matches!(action, ViolationAction::BlockAndTerminate) {
333                        shared.trigger_termination();
334                    }
335                    return Ok(());
336                }
337            },
338            None => Cow::Borrowed(&initial_buf),
339        };
340        if !out.is_empty() {
341            if let Err(e) = server_tx.write_all(&out).await {
342                tracing::debug!(dst = %connect_dst, error = %e, "replay of buffered first flight failed");
343                return Ok(());
344            }
345            if let Err(e) = server_tx.flush().await {
346                tracing::debug!(dst = %connect_dst, error = %e, "flush after first flight failed");
347                return Ok(());
348            }
349        }
350    }
351
352    let mut server_buf = vec![0u8; SERVER_READ_BUF_SIZE];
353
354    // Bidirectional relay using tokio::select!.
355    //
356    // guest → server: receive from channel, write to server socket.
357    // server → guest: read from server socket, send via channel + wake poll.
358    loop {
359        tokio::select! {
360            // Guest → server: substitute placeholders before forwarding.
361            data = from_smoltcp.recv() => {
362                match data {
363                    Some(bytes) => {
364                        if let Some(tls_state) = late_connect_state.take()
365                            && could_be_connect_request(&bytes)
366                        {
367                            // The first guest bytes can arrive after both peek
368                            // windows have completed. Nothing has been written
369                            // to the proxy socket yet, so this is still a valid
370                            // point to switch into CONNECT tunnel handling.
371                            let proxy_stream = server_rx
372                                .reunite(server_tx)
373                                .map_err(|_| io::Error::other("failed to reunite proxy stream halves"))?;
374                            return handle_connect_tunnel(
375                                guest_dst,
376                                connect_dst,
377                                bytes.to_vec(),
378                                from_smoltcp,
379                                to_smoltcp,
380                                shared,
381                                network_policy,
382                                tls_state,
383                                proxy_connect,
384                                Some(proxy_stream),
385                            )
386                            .await;
387                        }
388                        // No handler (no secrets / TLS) is the common path: forward
389                        // the chunk borrowed, with no per-chunk allocation or copy.
390                        let out: Cow<[u8]> = match secrets_handler.as_mut() {
391                            Some(h) => match h.substitute(&bytes) {
392                                Ok(cow) => cow,
393                                Err(action) => {
394                                    tracing::warn!(dst = %connect_dst, violation = ?action, "secret violation");
395                                    if matches!(action, ViolationAction::BlockAndTerminate) {
396                                        shared.trigger_termination();
397                                    }
398                                    break;
399                                }
400                            },
401                            None => Cow::Borrowed(&bytes),
402                        };
403                        if !out.is_empty() {
404                            if let Err(e) = server_tx.write_all(&out).await {
405                                tracing::debug!(dst = %connect_dst, error = %e, "write to server failed");
406                                break;
407                            }
408                            if let Err(e) = server_tx.flush().await {
409                                tracing::debug!(dst = %connect_dst, error = %e, "flush to server failed");
410                                break;
411                            }
412                        }
413                    }
414                    // Channel closed — smoltcp socket was closed by guest.
415                    None => break,
416                }
417            }
418
419            // Server → guest: no substitution — server never sends placeholders.
420            result = server_rx.read(&mut server_buf) => {
421                match result {
422                    Ok(0) => break, // Server closed connection.
423                    Ok(n) => {
424                        // A server-first byte means this is not an HTTP CONNECT
425                        // tunnel to a proxy. Keep relaying normally afterward.
426                        late_connect_state = None;
427                        let data = Bytes::copy_from_slice(&server_buf[..n]);
428                        if to_smoltcp.send(data).await.is_err() {
429                            // Channel closed — poll loop dropped the receiver.
430                            break;
431                        }
432                        // Wake the poll thread so it writes data to the
433                        // smoltcp socket.
434                        shared.proxy_wake.wake();
435                    }
436                    Err(e) => {
437                        tracing::debug!(dst = %connect_dst, error = %e, "read from server failed");
438                        break;
439                    }
440                }
441            }
442        }
443    }
444
445    Ok(())
446}
447
448/// Forward an HTTP CONNECT tunnel: dial the proxy, splice the handshake,
449/// then hand the established stream to `tls_proxy_task` for TLS MITM.
450///
451/// `guest_dst` is what the guest dialed; `proxy_dst` is the rewritten
452/// loopback address the gateway actually connects to.
453#[allow(clippy::too_many_arguments)]
454async fn handle_connect_tunnel(
455    guest_dst: SocketAddr,
456    proxy_dst: SocketAddr,
457    initial_buf: Vec<u8>,
458    mut from_smoltcp: mpsc::Receiver<Bytes>,
459    to_smoltcp: mpsc::Sender<Bytes>,
460    shared: Arc<SharedState>,
461    network_policy: Arc<NetworkPolicy>,
462    tls_state: Arc<TlsState>,
463    proxy_connect: Arc<ProxyConnectState>,
464    preconnected_proxy: Option<TcpStream>,
465) -> io::Result<()> {
466    let connect_req =
467        parse_connect_request(buffer_connect_request(initial_buf, &mut from_smoltcp).await?)?;
468
469    let connect_headers = match sanitize_connect_headers(
470        connect_req.header_bytes(),
471        &tls_state.secrets,
472    ) {
473        Ok(headers) => headers,
474        Err(action) => {
475            tracing::warn!(dst = %proxy_dst, violation = ?action, "secret violation in CONNECT headers");
476            if matches!(action, ViolationAction::BlockAndTerminate) {
477                shared.trigger_termination();
478            }
479            return Ok(());
480        }
481    };
482
483    // Dial the proxy and forward the CONNECT request so it opens the tunnel.
484    let mut proxy_stream = match preconnected_proxy {
485        Some(stream) => stream,
486        None => match TcpStream::connect(proxy_dst).await {
487            Ok(s) => s,
488            Err(e) => {
489                proxy_connect.mark_upstream_connect_failed();
490                shared.proxy_wake.wake();
491                return Err(e);
492            }
493        },
494    };
495
496    if !connect_req.target.is_intercepted(&tls_state) {
497        proxy_stream.write_all(&connect_headers).await?;
498        proxy_stream.flush().await?;
499        let (proxy_resp, header_end) = read_connect_response_headers(&mut proxy_stream).await?;
500        if to_smoltcp
501            .send(Bytes::copy_from_slice(&proxy_resp[..header_end]))
502            .await
503            .is_err()
504        {
505            return Ok(());
506        }
507        if !proxy_resp[header_end..].is_empty()
508            && to_smoltcp
509                .send(Bytes::copy_from_slice(&proxy_resp[header_end..]))
510                .await
511                .is_err()
512        {
513            return Ok(());
514        }
515        shared.proxy_wake.wake();
516        if !connect_response_is_success(&proxy_resp[..header_end]) {
517            proxy_connect.mark_connected();
518            return Ok(());
519        }
520        if !connect_req.post_header_bytes().is_empty() {
521            proxy_stream
522                .write_all(connect_req.post_header_bytes())
523                .await?;
524        }
525        proxy_stream.flush().await?;
526        proxy_connect.mark_connected();
527        return relay_connected_stream(proxy_stream, from_smoltcp, to_smoltcp, shared).await;
528    }
529
530    proxy_stream.write_all(&connect_headers).await?;
531    proxy_stream.flush().await?;
532
533    let (proxy_resp, header_end) = read_connect_response_headers(&mut proxy_stream).await?;
534    if !connect_response_is_success(&proxy_resp[..header_end]) {
535        return Err(io::Error::new(
536            io::ErrorKind::ConnectionRefused,
537            "proxy rejected CONNECT",
538        ));
539    }
540    if !proxy_resp[header_end..].is_empty() {
541        return Err(io::Error::new(
542            io::ErrorKind::InvalidData,
543            "proxy sent unexpected bytes after CONNECT response headers",
544        ));
545    }
546    proxy_connect.mark_connected();
547
548    if to_smoltcp
549        .send(Bytes::copy_from_slice(&proxy_resp[..header_end]))
550        .await
551        .is_err()
552    {
553        return Ok(());
554    }
555    shared.proxy_wake.wake();
556
557    let tls_seed = connect_req.post_header_bytes().to_vec();
558    let tls_guest_dst = connect_req.target.guest_dst(guest_dst, &shared);
559    let expected_sni = connect_req.target.expected_sni.clone();
560
561    tls_proxy_task(
562        TlsProxyContext {
563            guest_dst: tls_guest_dst,
564            connect_dst: proxy_dst,
565            shared,
566            tls_state,
567            network_policy,
568            proxy_connect,
569            upstream_stream: Some(proxy_stream),
570            expected_sni,
571        },
572        from_smoltcp,
573        to_smoltcp,
574        tls_seed,
575    )
576    .await
577}
578
579/// Relay an established TCP stream without inspecting or substituting bytes.
580async fn relay_connected_stream(
581    stream: TcpStream,
582    mut from_smoltcp: mpsc::Receiver<Bytes>,
583    to_smoltcp: mpsc::Sender<Bytes>,
584    shared: Arc<SharedState>,
585) -> io::Result<()> {
586    let (mut server_rx, mut server_tx) = stream.into_split();
587    let mut server_buf = vec![0u8; SERVER_READ_BUF_SIZE];
588
589    loop {
590        tokio::select! {
591            data = from_smoltcp.recv() => {
592                match data {
593                    Some(bytes) => {
594                        server_tx.write_all(&bytes).await?;
595                        server_tx.flush().await?;
596                    }
597                    None => break,
598                }
599            }
600            result = server_rx.read(&mut server_buf) => {
601                match result {
602                    Ok(0) => break,
603                    Ok(n) => {
604                        if to_smoltcp
605                            .send(Bytes::copy_from_slice(&server_buf[..n]))
606                            .await
607                            .is_err()
608                        {
609                            break;
610                        }
611                        shared.proxy_wake.wake();
612                    }
613                    Err(e) => return Err(e),
614                }
615            }
616        }
617    }
618
619    Ok(())
620}
621
622async fn buffer_connect_request(
623    mut buf: Vec<u8>,
624    from_smoltcp: &mut mpsc::Receiver<Bytes>,
625) -> io::Result<Vec<u8>> {
626    let timeout_fut = tokio::time::sleep(PEEK_BUDGET);
627    tokio::pin!(timeout_fut);
628
629    loop {
630        if !could_be_connect_request(&buf) {
631            return Err(io::Error::new(
632                io::ErrorKind::InvalidData,
633                "malformed CONNECT request prefix",
634            ));
635        }
636        if headers_end(&buf).is_some() {
637            return Ok(buf);
638        }
639        if buf.len() >= PEEK_BUF_SIZE {
640            return Err(io::Error::new(
641                io::ErrorKind::InvalidData,
642                "CONNECT request headers too large",
643            ));
644        }
645
646        tokio::select! {
647            biased;
648            _ = &mut timeout_fut => {
649                return Err(io::Error::new(
650                    io::ErrorKind::TimedOut,
651                    "timed out waiting for complete CONNECT request headers",
652                ));
653            }
654            data = from_smoltcp.recv() => match data {
655                Some(bytes) => {
656                    buf.extend_from_slice(&bytes);
657                }
658                None => {
659                    return Err(io::Error::new(
660                        io::ErrorKind::UnexpectedEof,
661                        "channel closed before complete CONNECT request headers",
662                    ));
663                }
664            }
665        }
666    }
667}
668
669async fn read_connect_response_headers(stream: &mut TcpStream) -> io::Result<(Vec<u8>, usize)> {
670    tokio::time::timeout(PEEK_BUDGET, async {
671        let mut proxy_resp = Vec::with_capacity(256);
672        let mut buf = [0u8; 4096];
673        loop {
674            let n = stream.read(&mut buf).await?;
675            if n == 0 {
676                return Err(io::Error::new(
677                    io::ErrorKind::UnexpectedEof,
678                    "proxy closed before sending CONNECT response",
679                ));
680            }
681            proxy_resp.extend_from_slice(&buf[..n]);
682            if let Some(end) = headers_end(&proxy_resp) {
683                return Ok((proxy_resp, end));
684            }
685            if proxy_resp.len() > CONNECT_RESP_LIMIT {
686                return Err(io::Error::new(
687                    io::ErrorKind::InvalidData,
688                    "proxy CONNECT response too large",
689                ));
690            }
691        }
692    })
693    .await
694    .map_err(|_| {
695        io::Error::new(
696            io::ErrorKind::TimedOut,
697            "timed out waiting for proxy CONNECT response",
698        )
699    })?
700}
701
702fn sanitize_connect_headers<'a>(
703    header_bytes: &'a [u8],
704    secrets: &SecretsConfig,
705) -> Result<Cow<'a, [u8]>, ViolationAction> {
706    if secrets.secrets.is_empty() {
707        return Ok(Cow::Borrowed(header_bytes));
708    }
709
710    let mut handler = SecretsHandler::new_plain_http_untrusted_metadata(secrets);
711    handler.substitute(header_bytes)
712}
713
714/// Returns the byte offset just past the `\r\n\r\n` header terminator, or `None`.
715fn headers_end(buf: &[u8]) -> Option<usize> {
716    buf.windows(4).position(|w| w == b"\r\n\r\n").map(|p| p + 4)
717}
718
719fn could_be_connect_request(buf: &[u8]) -> bool {
720    const PREFIX: &[u8] = b"CONNECT ";
721    if buf.is_empty() {
722        return false;
723    }
724    let n = buf.len().min(PREFIX.len());
725    buf[..n].eq_ignore_ascii_case(&PREFIX[..n])
726}
727
728fn parse_connect_request(bytes: Vec<u8>) -> io::Result<ConnectRequest> {
729    let header_end = headers_end(&bytes).ok_or_else(|| {
730        io::Error::new(
731            io::ErrorKind::InvalidData,
732            "incomplete CONNECT request headers",
733        )
734    })?;
735    let target = {
736        let request_line = bytes[..header_end]
737            .split(|&b| b == b'\n')
738            .next()
739            .unwrap_or(&[]);
740        let request_line = std::str::from_utf8(request_line)
741            .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "CONNECT line is not UTF-8"))?
742            .trim_end_matches('\r');
743        let mut parts = request_line.split_ascii_whitespace();
744        let method = parts.next().unwrap_or_default();
745        let authority = parts.next().unwrap_or_default();
746        let version = parts.next().unwrap_or_default();
747        if !method.eq_ignore_ascii_case("CONNECT")
748            || authority.is_empty()
749            || !is_http_version(version)
750            || parts.next().is_some()
751        {
752            return Err(io::Error::new(
753                io::ErrorKind::InvalidData,
754                "malformed CONNECT request line",
755            ));
756        }
757        parse_connect_target(authority)?
758    };
759
760    Ok(ConnectRequest {
761        bytes,
762        header_end,
763        target,
764    })
765}
766
767fn parse_connect_target(authority: &str) -> io::Result<ConnectTarget> {
768    let authority = authority.trim();
769    let (host, port) = if let Some(rest) = authority.strip_prefix('[') {
770        let (host, rest) = rest.split_once(']').ok_or_else(|| {
771            io::Error::new(
772                io::ErrorKind::InvalidData,
773                "malformed CONNECT IPv6 authority",
774            )
775        })?;
776        let port = rest.strip_prefix(':').ok_or_else(|| {
777            io::Error::new(io::ErrorKind::InvalidData, "CONNECT authority missing port")
778        })?;
779        (host, port)
780    } else {
781        let (host, port) = authority.rsplit_once(':').ok_or_else(|| {
782            io::Error::new(io::ErrorKind::InvalidData, "CONNECT authority missing port")
783        })?;
784        if host.contains(':') {
785            return Err(io::Error::new(
786                io::ErrorKind::InvalidData,
787                "CONNECT IPv6 authority must be bracketed",
788            ));
789        }
790        (host, port)
791    };
792    let host = host.trim().trim_end_matches('.');
793    if host.is_empty() {
794        return Err(io::Error::new(
795            io::ErrorKind::InvalidData,
796            "CONNECT authority missing host",
797        ));
798    }
799    let port = port
800        .parse::<u16>()
801        .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid CONNECT port"))?;
802    let expected_sni = host
803        .parse::<IpAddr>()
804        .is_err()
805        .then(|| host.to_ascii_lowercase());
806
807    Ok(ConnectTarget {
808        host: host.to_ascii_lowercase(),
809        port,
810        expected_sni,
811    })
812}
813
814fn is_http_version(version: &str) -> bool {
815    let Some(version) = version.strip_prefix("HTTP/") else {
816        return false;
817    };
818    let Some((major, minor)) = version.split_once('.') else {
819        return false;
820    };
821    !major.is_empty()
822        && !minor.is_empty()
823        && major.bytes().all(|b| b.is_ascii_digit())
824        && minor.bytes().all(|b| b.is_ascii_digit())
825}
826
827fn connect_response_is_success(headers: &[u8]) -> bool {
828    let Some(status_line) = headers.split(|&b| b == b'\n').next() else {
829        return false;
830    };
831    let Ok(status_line) = std::str::from_utf8(status_line) else {
832        return false;
833    };
834    let mut parts = status_line.trim_end_matches('\r').split_ascii_whitespace();
835    let version = parts.next().unwrap_or_default();
836    let status = parts.next().unwrap_or_default();
837    is_http_version(version)
838        && status.len() == 3
839        && status
840            .parse::<u16>()
841            .is_ok_and(|code| (200..300).contains(&code))
842}
843
844/// Extract the `Host:` header value from an already-buffered HTTP header block.
845///
846/// Returns `None` if:
847/// - The first byte is `0x16` (TLS — not HTTP)
848/// - The buffer does not yet contain `\r\n\r\n` (headers incomplete)
849/// - No `Host:` header is present
850///
851/// Strips port suffix, lowercases, and trims whitespace. Result is
852/// ready for byte-equal matching against `SecretEntry::allowed_hosts`.
853fn extract_http_host(buf: &[u8]) -> Option<String> {
854    if buf.first() == Some(&0x16) {
855        return None;
856    }
857    // Size the header pool to the buffer rather than a fixed array: a header
858    // line is at least four bytes (`a:\r\n`), so `len / 4` always covers the
859    // real header count, and `httparse` never reports `TooManyHeaders` (which
860    // would make a request with many headers look hostless). The first flight
861    // is capped at PEEK_BUF_SIZE, so this stays bounded.
862    let mut headers = vec![httparse::EMPTY_HEADER; (buf.len() / 4).max(16)];
863    let mut req = httparse::Request::new(&mut headers);
864    req.parse(buf).ok()?;
865    req.headers
866        .iter()
867        .find(|h| h.name.eq_ignore_ascii_case("host"))
868        .and_then(|h| std::str::from_utf8(h.value).ok())
869        .map(|v| {
870            let host = v.trim();
871            // Strip port suffix.
872            host.rsplit_once(':')
873                .map(|(h, _)| h)
874                .unwrap_or(host)
875                .to_ascii_lowercase()
876        })
877        .filter(|h| !h.is_empty())
878}
879
880/// Finish classifying the guest's first flight after the upstream socket is
881/// open, returning the (possibly extended) first-flight buffer and whether it
882/// is a TLS record.
883///
884/// `buf` carries whatever a pre-connect domain-rule peek already captured; when
885/// it is non-empty the TLS/plain decision is already settled and only header
886/// top-up runs. `want_headers` is set when at least one secret can be
887/// substituted over plain HTTP (`SecretsConfig::has_plain_http_candidates`); it
888/// makes the peek keep reading a non-TLS flight until `\r\n\r\n` so
889/// [`extract_http_host`] sees a complete header block.
890///
891/// Crucially, this relays server→guest while it waits. Server-first protocols
892/// (SSH, SMTP, databases) send nothing until they have seen the server's
893/// banner; draining the server side here lets the banner reach the guest
894/// immediately, so the guest's eventual first flight — not a 5s timeout — is
895/// what ends the peek.
896#[allow(clippy::too_many_arguments)]
897async fn classify_first_flight(
898    mut buf: Vec<u8>,
899    from_smoltcp: &mut mpsc::Receiver<Bytes>,
900    server_rx: &mut tokio::net::tcp::OwnedReadHalf,
901    to_smoltcp: &mpsc::Sender<Bytes>,
902    shared: &SharedState,
903    want_headers: bool,
904    max: usize,
905    budget: Duration,
906) -> io::Result<(Vec<u8>, bool)> {
907    let mut server_buf = vec![0u8; SERVER_READ_BUF_SIZE];
908    let timeout_fut = tokio::time::sleep(budget);
909    tokio::pin!(timeout_fut);
910
911    loop {
912        // Stop as soon as the protocol class is known and — for plain-HTTP
913        // candidates — a full header block has arrived. Bail the moment a
914        // non-TLS flight stops looking like an HTTP request so non-HTTP
915        // protocols (SSH, Postgres) aren't withheld from upstream for the
916        // whole budget while we wait for a `\r\n\r\n` that never comes.
917        if !buf.is_empty() {
918            let is_tls = buf.first() == Some(&0x16);
919            let not_http = !is_tls
920                && (!looks_like_http_request_prefix(&buf) || first_line_is_not_http_request(&buf));
921            let done = !want_headers
922                || is_tls
923                || not_http
924                || buf.len() >= max
925                || buf.windows(4).any(|w| w == b"\r\n\r\n");
926            if done {
927                return Ok((buf, is_tls));
928            }
929        }
930
931        tokio::select! {
932            biased;
933            _ = &mut timeout_fut => {
934                let is_tls = buf.first() == Some(&0x16);
935                return Ok((buf, is_tls));
936            }
937            // Guest → buffer (not forwarded here; the caller replays it once the
938            // handler is built, so substitution applies to the first flight too).
939            guest = from_smoltcp.recv() => match guest {
940                Some(bytes) => buf.extend_from_slice(&bytes),
941                None => {
942                    let is_tls = buf.first() == Some(&0x16);
943                    return Ok((buf, is_tls));
944                }
945            },
946            // Server → guest: relay immediately so a server-first banner is never
947            // held hostage by the peek.
948            server = server_rx.read(&mut server_buf) => match server {
949                Ok(0) => {
950                    let is_tls = buf.first() == Some(&0x16);
951                    return Ok((buf, is_tls));
952                }
953                Ok(n) => {
954                    let data = Bytes::copy_from_slice(&server_buf[..n]);
955                    if to_smoltcp.send(data).await.is_err() {
956                        let is_tls = buf.first() == Some(&0x16);
957                        return Ok((buf, is_tls));
958                    }
959                    shared.proxy_wake.wake();
960                }
961                Err(e) => return Err(e),
962            },
963        }
964    }
965}
966
967/// Buffer the first flight until SNI can be extracted, or until one
968/// of the bail-out conditions hits (channel close, buffer cap,
969/// timeout). Never errors; non-TLS / slow / malformed input all
970/// fall through to `None`.
971///
972/// On hit, the SNI is canonicalized (lowercase + trim trailing dot)
973/// for byte-equal matching against rule destinations. The returned
974/// buffer must be replayed verbatim to upstream before the caller
975/// starts its relay loop.
976async fn peek_for_sni(
977    rx: &mut mpsc::Receiver<Bytes>,
978    max: usize,
979    budget: Duration,
980) -> (Vec<u8>, Option<String>) {
981    let mut buf = Vec::with_capacity(PEEK_BUF_SIZE.min(8192));
982    let timeout_fut = tokio::time::sleep(budget);
983    tokio::pin!(timeout_fut);
984
985    let raw_sni = loop {
986        tokio::select! {
987            biased;
988            _ = &mut timeout_fut => break None,
989            data = rx.recv() => {
990                match data {
991                    Some(bytes) => {
992                        buf.extend_from_slice(&bytes);
993                        // First byte of a TLS record is the ContentType;
994                        // 0x16 is handshake. Anything else can't be a
995                        // ClientHello, so don't burn the full budget on
996                        // plain HTTP / SSH / etc.
997                        if buf.first() != Some(&0x16) {
998                            break None;
999                        }
1000                        if let Some(name) = sni::extract_sni(&buf) {
1001                            break Some(name);
1002                        }
1003                        if buf.len() >= max {
1004                            break None;
1005                        }
1006                    }
1007                    None => break None,
1008                }
1009            }
1010        }
1011    };
1012
1013    let canonical = raw_sni.map(|s| s.trim_end_matches('.').to_ascii_lowercase());
1014    (buf, canonical)
1015}
1016
1017//--------------------------------------------------------------------------------------------------
1018// Tests
1019//--------------------------------------------------------------------------------------------------
1020
1021#[cfg(test)]
1022mod tests {
1023    use super::*;
1024
1025    /// Synthetic TLS ClientHello carrying SNI `example.com`. Bytes
1026    /// borrowed from `tls::sni` test fixtures so the parser sees a
1027    /// well-formed record.
1028    fn synthetic_client_hello(sni: &str) -> Vec<u8> {
1029        // Minimal but valid TLS 1.2 ClientHello with one SNI entry.
1030        // Layout: record header (5) + handshake header (4) + body.
1031        let host_bytes = sni.as_bytes();
1032        let host_len = host_bytes.len() as u16;
1033        let server_name_list_len = 3 + host_len; // type(1) + len(2) + host
1034        let extension_data_len = 2 + server_name_list_len; // list-len(2) + list
1035        let extensions_total = 4 + extension_data_len; // type(2) + len(2) + data
1036
1037        let mut body = Vec::new();
1038        // Client version
1039        body.extend_from_slice(&[0x03, 0x03]);
1040        // Random (32 bytes)
1041        body.extend_from_slice(&[0u8; 32]);
1042        // Session id length + (empty)
1043        body.push(0);
1044        // Cipher suites length + one cipher
1045        body.extend_from_slice(&[0x00, 0x02, 0x00, 0x2f]);
1046        // Compression methods length + null
1047        body.extend_from_slice(&[0x01, 0x00]);
1048        // Extensions length
1049        body.extend_from_slice(&extensions_total.to_be_bytes());
1050        // SNI extension: type 0x0000
1051        body.extend_from_slice(&[0x00, 0x00]);
1052        body.extend_from_slice(&extension_data_len.to_be_bytes());
1053        body.extend_from_slice(&server_name_list_len.to_be_bytes());
1054        body.push(0x00); // host_name type
1055        body.extend_from_slice(&host_len.to_be_bytes());
1056        body.extend_from_slice(host_bytes);
1057
1058        let handshake_len = body.len() as u32;
1059        let mut hs = Vec::new();
1060        hs.push(0x01); // ClientHello
1061        hs.extend_from_slice(&handshake_len.to_be_bytes()[1..]); // 24-bit length
1062        hs.extend_from_slice(&body);
1063
1064        let record_len = hs.len() as u16;
1065        let mut record = Vec::new();
1066        record.extend_from_slice(&[0x16, 0x03, 0x01]); // Handshake, TLS 1.0
1067        record.extend_from_slice(&record_len.to_be_bytes());
1068        record.extend_from_slice(&hs);
1069
1070        record
1071    }
1072
1073    #[test]
1074    fn could_be_connect_request_matches_split_prefixes_only() {
1075        assert!(could_be_connect_request(b"C"));
1076        assert!(could_be_connect_request(b"connect "));
1077        assert!(could_be_connect_request(b"CONNECT example.com:443"));
1078        assert!(!could_be_connect_request(b"CLIENT"));
1079        assert!(!could_be_connect_request(b"GET / HTTP/1.1\r\n"));
1080    }
1081
1082    #[tokio::test]
1083    async fn buffer_connect_request_reads_split_headers() {
1084        let (tx, mut rx) = mpsc::channel(4);
1085        tx.send(Bytes::from_static(b"NECT example.com:443 HTTP/1.1\r\n"))
1086            .await
1087            .unwrap();
1088        tx.send(Bytes::from_static(b"Host: example.com\r\n\r\n"))
1089            .await
1090            .unwrap();
1091        drop(tx);
1092
1093        let buffered = buffer_connect_request(b"CON".to_vec(), &mut rx)
1094            .await
1095            .unwrap();
1096        let parsed = parse_connect_request(buffered).unwrap();
1097
1098        assert_eq!(parsed.target.host, "example.com");
1099        assert_eq!(parsed.target.port, 443);
1100        assert_eq!(parsed.target.expected_sni.as_deref(), Some("example.com"));
1101        assert!(parsed.post_header_bytes().is_empty());
1102    }
1103
1104    #[test]
1105    fn parse_connect_request_preserves_post_header_tls_seed() {
1106        let mut request = b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com\r\n\r\n".to_vec();
1107        request.extend_from_slice(b"\x16\x03\x01client-hello");
1108
1109        let parsed = parse_connect_request(request).unwrap();
1110
1111        assert_eq!(
1112            parsed.header_bytes(),
1113            b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com\r\n\r\n"
1114        );
1115        assert_eq!(parsed.post_header_bytes(), b"\x16\x03\x01client-hello");
1116    }
1117
1118    #[test]
1119    fn parse_connect_target_requires_authority_port() {
1120        assert!(parse_connect_target("example.com").is_err());
1121        assert!(parse_connect_target("2001:db8::1:443").is_err());
1122
1123        let target = parse_connect_target("[2001:db8::1]:8443").unwrap();
1124        assert_eq!(target.host, "2001:db8::1");
1125        assert_eq!(target.port, 8443);
1126        assert_eq!(target.expected_sni, None);
1127    }
1128
1129    #[test]
1130    fn connect_response_success_requires_exact_2xx_status_code() {
1131        assert!(connect_response_is_success(
1132            b"HTTP/1.1 200 Connection Established\r\n\r\n"
1133        ));
1134        assert!(connect_response_is_success(
1135            b"HTTP/1.1 204 Connection Established\r\n\r\n"
1136        ));
1137        assert!(!connect_response_is_success(b"HTTP/1.1 2000 Weird\r\n\r\n"));
1138        assert!(!connect_response_is_success(b"HTTP/1.1 199 Nope\r\n\r\n"));
1139        assert!(!connect_response_is_success(b"NOTHTTP 200 OK\r\n\r\n"));
1140    }
1141
1142    #[tokio::test]
1143    async fn peek_for_sni_extracts_and_canonicalizes() {
1144        let (tx, mut rx) = mpsc::channel(4);
1145        let hello = synthetic_client_hello("Example.COM");
1146        tx.send(Bytes::from(hello.clone())).await.unwrap();
1147        drop(tx); // close so peek returns even if SNI didn't satisfy
1148
1149        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1150        assert_eq!(sni.as_deref(), Some("example.com"));
1151        assert_eq!(buf, hello);
1152    }
1153
1154    #[tokio::test]
1155    async fn peek_for_sni_returns_none_on_channel_close_without_data() {
1156        let (tx, mut rx) = mpsc::channel::<Bytes>(1);
1157        drop(tx);
1158        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1159        assert!(buf.is_empty());
1160        assert_eq!(sni, None);
1161    }
1162
1163    #[tokio::test]
1164    async fn peek_for_sni_returns_none_on_non_tls_data() {
1165        let (tx, mut rx) = mpsc::channel(4);
1166        // Plaintext HTTP request; not a TLS record so extract_sni returns None.
1167        tx.send(Bytes::from_static(
1168            b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n",
1169        ))
1170        .await
1171        .unwrap();
1172        drop(tx);
1173        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1174        assert!(
1175            !buf.is_empty(),
1176            "buffered bytes must be returned for replay"
1177        );
1178        assert_eq!(sni, None);
1179    }
1180
1181    #[tokio::test]
1182    async fn peek_for_sni_falls_back_on_timeout() {
1183        let (tx, mut rx) = mpsc::channel::<Bytes>(1);
1184        // Hold the sender open but send nothing — peek must time out.
1185        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, Duration::from_millis(50)).await;
1186        drop(tx);
1187        assert!(buf.is_empty());
1188        assert_eq!(sni, None);
1189    }
1190
1191    #[tokio::test]
1192    async fn peek_for_sni_caps_at_max_bytes() {
1193        let (tx, mut rx) = mpsc::channel(4);
1194        // First byte 0x16 keeps the peek collecting past the early
1195        // non-TLS bail. Padding bytes are zero so the SNI parser never
1196        // matches and the loop drives to the size cap.
1197        let mut first = vec![0u8; 8192];
1198        first[0] = 0x16;
1199        tx.send(Bytes::from(first)).await.unwrap();
1200        tx.send(Bytes::from(vec![0u8; 8192])).await.unwrap();
1201        tx.send(Bytes::from(vec![0u8; 8192])).await.unwrap();
1202        drop(tx);
1203
1204        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1205        assert_eq!(sni, None, "no SNI in non-TLS data");
1206        assert!(
1207            buf.len() >= PEEK_BUF_SIZE,
1208            "buffer must hit the cap before bail-out: got {}",
1209            buf.len()
1210        );
1211    }
1212
1213    #[tokio::test]
1214    async fn peek_for_sni_bails_immediately_on_non_tls_first_byte() {
1215        let (tx, mut rx) = mpsc::channel(4);
1216        // Plain HTTP request: first byte 'G' (0x47) — clearly not TLS.
1217        tx.send(Bytes::from_static(b"GET / HTTP/1.1\r\nHost: x\r\n\r\n"))
1218            .await
1219            .unwrap();
1220        drop(tx);
1221
1222        // 5-second nominal budget; assert we returned in well under
1223        // that — the early-bail must not wait for the full window.
1224        let started = std::time::Instant::now();
1225        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1226        let elapsed = started.elapsed();
1227        assert_eq!(sni, None);
1228        assert!(buf.starts_with(b"GET"));
1229        assert!(
1230            elapsed < Duration::from_millis(500),
1231            "non-TLS bail must be fast: took {elapsed:?}"
1232        );
1233    }
1234
1235    //----------------------------------------------------------------------------------------------
1236    // peek_for_sni × evaluate_egress_with_source — combined integration tests
1237    //----------------------------------------------------------------------------------------------
1238
1239    use std::net::IpAddr;
1240    use std::time::Duration as StdDuration;
1241
1242    use crate::policy::{Action, Destination, NetworkPolicy, PortRange, Rule};
1243    use crate::shared::{ResolvedHostnameFamily, SharedState};
1244
1245    const SHARED_FASTLY_IP: &str = "151.101.0.223";
1246
1247    fn shared_with(host: &str, ip: &str) -> SharedState {
1248        let shared = SharedState::new(4);
1249        shared.cache_resolved_hostname(
1250            host,
1251            ResolvedHostnameFamily::Ipv4,
1252            [ip.parse::<IpAddr>().unwrap()],
1253            StdDuration::from_secs(60),
1254        );
1255        shared
1256    }
1257
1258    fn allow_https(domain: &str) -> Rule {
1259        Rule {
1260            direction: crate::policy::Direction::Egress,
1261            destination: Destination::Domain(domain.parse().unwrap()),
1262            protocols: vec![Protocol::Tcp],
1263            ports: vec![PortRange::single(443)],
1264            action: Action::Allow,
1265        }
1266    }
1267
1268    /// Over-allow case: cache says IP X is `pypi.org` (allowed); SNI
1269    /// is `evil.com`. SNI must override the cache and deny.
1270    #[tokio::test]
1271    async fn integration_sni_overrides_cache_for_over_allow() {
1272        let shared = shared_with("pypi.org", SHARED_FASTLY_IP);
1273        let policy = NetworkPolicy {
1274            default_egress: Action::Deny,
1275            default_ingress: Action::Allow,
1276            rules: vec![allow_https("pypi.org")],
1277        };
1278        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
1279
1280        let (tx, mut rx) = mpsc::channel(4);
1281        tx.send(Bytes::from(synthetic_client_hello("evil.com")))
1282            .await
1283            .unwrap();
1284        drop(tx);
1285
1286        let (initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1287        assert_eq!(sni.as_deref(), Some("evil.com"));
1288        assert!(!initial_buf.is_empty());
1289
1290        let source = sni
1291            .as_deref()
1292            .map(HostnameSource::Sni)
1293            .unwrap_or(HostnameSource::CacheOnly);
1294        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
1295        assert_eq!(
1296            eval,
1297            EgressEvaluation::Deny,
1298            "SNI=evil.com must not piggy-back on the cached pypi.org match",
1299        );
1300    }
1301
1302    /// Over-block case: cache says IP X is `ads.example.com` (denied);
1303    /// SNI is `api.example.com`. SNI must override the cache and allow.
1304    #[tokio::test]
1305    async fn integration_sni_overrides_cache_for_over_block() {
1306        let shared = shared_with("ads.example.com", SHARED_FASTLY_IP);
1307        let policy = NetworkPolicy {
1308            default_egress: Action::Allow,
1309            default_ingress: Action::Allow,
1310            rules: vec![Rule::deny_egress(Destination::Domain(
1311                "ads.example.com".parse().unwrap(),
1312            ))],
1313        };
1314        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
1315
1316        let (tx, mut rx) = mpsc::channel(4);
1317        tx.send(Bytes::from(synthetic_client_hello("api.example.com")))
1318            .await
1319            .unwrap();
1320        drop(tx);
1321
1322        let (_initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1323        assert_eq!(sni.as_deref(), Some("api.example.com"));
1324
1325        let source = sni
1326            .as_deref()
1327            .map(HostnameSource::Sni)
1328            .unwrap_or(HostnameSource::CacheOnly);
1329        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
1330        assert_eq!(
1331            eval,
1332            EgressEvaluation::Allow,
1333            "SNI=api.example.com must not be caught by the deny on ads.example.com",
1334        );
1335    }
1336
1337    /// Non-TLS first-flight falls back to `CacheOnly`; the cache
1338    /// match decides.
1339    #[tokio::test]
1340    async fn integration_non_tls_falls_back_to_cache() {
1341        let shared = shared_with("pypi.org", SHARED_FASTLY_IP);
1342        let policy = NetworkPolicy {
1343            default_egress: Action::Deny,
1344            default_ingress: Action::Allow,
1345            rules: vec![allow_https("pypi.org")],
1346        };
1347        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
1348
1349        let (tx, mut rx) = mpsc::channel(4);
1350        // Plain HTTP request; not a TLS record.
1351        tx.send(Bytes::from_static(
1352            b"GET / HTTP/1.1\r\nHost: pypi.org\r\n\r\n",
1353        ))
1354        .await
1355        .unwrap();
1356        drop(tx);
1357
1358        let (initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1359        assert_eq!(sni, None, "non-TLS data → no SNI");
1360        assert!(
1361            !initial_buf.is_empty(),
1362            "buffered bytes must survive for replay"
1363        );
1364
1365        let source = sni
1366            .as_deref()
1367            .map(HostnameSource::Sni)
1368            .unwrap_or(HostnameSource::CacheOnly);
1369        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
1370        assert_eq!(
1371            eval,
1372            EgressEvaluation::Allow,
1373            "cache-only fallback must still allow the cached hostname's IP",
1374        );
1375    }
1376
1377    /// SNI matches a `DomainSuffix` rule with a cache binding for the
1378    /// claimed name. Genuine pre-resolved traffic passes.
1379    #[tokio::test]
1380    async fn integration_sni_matches_domain_suffix_with_cache_binding() {
1381        let shared = shared_with("files.pythonhosted.org", SHARED_FASTLY_IP);
1382        let policy = NetworkPolicy {
1383            default_egress: Action::Deny,
1384            default_ingress: Action::Allow,
1385            rules: vec![Rule {
1386                direction: crate::policy::Direction::Egress,
1387                destination: Destination::DomainSuffix(".pythonhosted.org".parse().unwrap()),
1388                protocols: vec![Protocol::Tcp],
1389                ports: vec![PortRange::single(443)],
1390                action: Action::Allow,
1391            }],
1392        };
1393        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
1394
1395        let (tx, mut rx) = mpsc::channel(4);
1396        tx.send(Bytes::from(synthetic_client_hello(
1397            "files.pythonhosted.org",
1398        )))
1399        .await
1400        .unwrap();
1401        drop(tx);
1402
1403        let (_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1404        let source = sni
1405            .as_deref()
1406            .map(HostnameSource::Sni)
1407            .unwrap_or(HostnameSource::CacheOnly);
1408        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
1409        assert_eq!(eval, EgressEvaluation::Allow);
1410    }
1411
1412    /// Spoofed SNI on an IP with no cache binding for any matching
1413    /// name: byte-equality with the suffix passes, but no DNS lookup
1414    /// ever tied a `*.pythonhosted.org` name to the destination, so
1415    /// the AND-check fails and the connection is denied.
1416    #[tokio::test]
1417    async fn integration_sni_denies_domain_suffix_without_cache_binding() {
1418        let shared = SharedState::new(4); // empty cache
1419        let policy = NetworkPolicy {
1420            default_egress: Action::Deny,
1421            default_ingress: Action::Allow,
1422            rules: vec![Rule {
1423                direction: crate::policy::Direction::Egress,
1424                destination: Destination::DomainSuffix(".pythonhosted.org".parse().unwrap()),
1425                protocols: vec![Protocol::Tcp],
1426                ports: vec![PortRange::single(443)],
1427                action: Action::Allow,
1428            }],
1429        };
1430        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
1431
1432        let (tx, mut rx) = mpsc::channel(4);
1433        tx.send(Bytes::from(synthetic_client_hello(
1434            "files.pythonhosted.org",
1435        )))
1436        .await
1437        .unwrap();
1438        drop(tx);
1439
1440        let (_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1441        let source = sni
1442            .as_deref()
1443            .map(HostnameSource::Sni)
1444            .unwrap_or(HostnameSource::CacheOnly);
1445        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
1446        assert_eq!(eval, EgressEvaluation::Deny);
1447    }
1448
1449    // ── extract_http_host ──────────────────────────────────────────────────────
1450
1451    #[test]
1452    fn extract_http_host_basic() {
1453        let buf = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n";
1454        assert_eq!(extract_http_host(buf), Some("example.com".into()));
1455    }
1456
1457    #[test]
1458    fn extract_http_host_strips_port() {
1459        let buf = b"POST /api HTTP/1.1\r\nHost: api.company.com:8080\r\n\r\n";
1460        assert_eq!(extract_http_host(buf), Some("api.company.com".into()));
1461    }
1462
1463    #[test]
1464    fn extract_http_host_case_insensitive_lowercased() {
1465        let buf = b"GET / HTTP/1.1\r\nhost: Example.COM\r\n\r\n";
1466        assert_eq!(extract_http_host(buf), Some("example.com".into()));
1467    }
1468
1469    #[test]
1470    fn extract_http_host_no_host_header() {
1471        let buf = b"GET / HTTP/1.1\r\nX-Other: foo\r\n\r\n";
1472        assert_eq!(extract_http_host(buf), None);
1473    }
1474
1475    #[test]
1476    fn extract_http_host_incomplete_headers() {
1477        let buf = b"GET / HTTP/1.1\r\nHost: x";
1478        assert_eq!(extract_http_host(buf), None);
1479    }
1480
1481    #[test]
1482    fn extract_http_host_tls_first_byte() {
1483        let buf = [0x16u8, 0x03, 0x01, 0x00, 0x01];
1484        assert_eq!(extract_http_host(&buf), None);
1485    }
1486
1487    #[test]
1488    fn extract_http_host_with_many_headers() {
1489        // Far more headers than a small fixed parse array would hold: the Host
1490        // must still be found rather than the request looking hostless.
1491        let mut req = Vec::from(&b"GET / HTTP/1.1\r\n"[..]);
1492        for i in 0..100 {
1493            req.extend_from_slice(format!("X-Pad-{i}: v\r\n").as_bytes());
1494        }
1495        req.extend_from_slice(b"Host: example.com\r\n\r\n");
1496        assert_eq!(extract_http_host(&req), Some("example.com".into()));
1497    }
1498
1499    // ── plain-HTTP secret substitution ────────────────────────────────────────
1500
1501    use std::sync::Arc;
1502    use tokio::io::AsyncReadExt;
1503    use tokio::net::TcpListener;
1504    use tokio::task::JoinHandle;
1505
1506    use crate::secrets::config::{HostPattern, SecretEntry, SecretInjection, SecretsConfig};
1507
1508    fn make_plain_http_secret(placeholder: &str, value: &str, require_tls: bool) -> SecretsConfig {
1509        SecretsConfig {
1510            secrets: vec![SecretEntry {
1511                env_var: "API_KEY".into(),
1512                value: value.into(),
1513                placeholder: placeholder.into(),
1514                allowed_hosts: vec![HostPattern::Any],
1515                injection: SecretInjection {
1516                    headers: true,
1517                    basic_auth: false,
1518                    query_params: false,
1519                    body: false,
1520                },
1521                on_violation: None,
1522                require_tls_identity: require_tls,
1523            }],
1524            ..Default::default()
1525        }
1526    }
1527
1528    fn make_host_bound_secret(placeholder: &str, value: &str, host: &str) -> SecretsConfig {
1529        SecretsConfig {
1530            secrets: vec![SecretEntry {
1531                env_var: "API_KEY".into(),
1532                value: value.into(),
1533                placeholder: placeholder.into(),
1534                allowed_hosts: vec![HostPattern::Exact(host.into())],
1535                injection: SecretInjection::default(),
1536                on_violation: None,
1537                require_tls_identity: true,
1538            }],
1539            ..Default::default()
1540        }
1541    }
1542
1543    #[test]
1544    fn sanitize_connect_headers_blocks_placeholder_metadata_header_by_default() {
1545        let secrets = make_host_bound_secret("$MSB_KEY", "real-secret-value", "example.com");
1546        let headers = b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\nProxy-Authorization: Bearer $MSB_KEY\r\nUser-Agent: curl\r\n\r\n";
1547
1548        assert_eq!(
1549            sanitize_connect_headers(headers, &secrets),
1550            Err(ViolationAction::BlockAndLog)
1551        );
1552    }
1553
1554    #[test]
1555    fn sanitize_connect_headers_respects_block_and_terminate() {
1556        let mut secrets = make_host_bound_secret("$MSB_KEY", "real-secret-value", "example.com");
1557        secrets.on_violation = ViolationAction::BlockAndTerminate;
1558        let headers = b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\nProxy-Authorization: Bearer $MSB_KEY\r\n\r\n";
1559
1560        assert_eq!(
1561            sanitize_connect_headers(headers, &secrets),
1562            Err(ViolationAction::BlockAndTerminate)
1563        );
1564    }
1565
1566    #[test]
1567    fn sanitize_connect_headers_respects_explicit_passthrough() {
1568        let mut secrets = make_host_bound_secret("$MSB_KEY", "real-secret-value", "example.com");
1569        secrets.on_violation = ViolationAction::Passthrough(vec![HostPattern::Any]);
1570        let headers = b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\nProxy-Authorization: Bearer $MSB_KEY\r\n\r\n";
1571
1572        let sanitized = sanitize_connect_headers(headers, &secrets).unwrap();
1573
1574        assert_eq!(sanitized.as_ref(), headers);
1575        assert!(
1576            !String::from_utf8_lossy(sanitized.as_ref()).contains("real-secret-value"),
1577            "passthrough must never substitute real secrets into CONNECT metadata"
1578        );
1579    }
1580
1581    #[test]
1582    fn sanitize_connect_headers_keeps_safe_metadata_headers() {
1583        let secrets = make_host_bound_secret("$MSB_KEY", "real-secret-value", "example.com");
1584        let headers =
1585            b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\nUser-Agent: curl\r\n\r\n";
1586
1587        let sanitized = sanitize_connect_headers(headers, &secrets).unwrap();
1588
1589        assert_eq!(sanitized.as_ref(), headers);
1590    }
1591
1592    #[test]
1593    fn sanitize_connect_headers_blocks_placeholder_in_request_line() {
1594        let secrets = make_host_bound_secret("$MSB_KEY", "real-secret-value", "example.com");
1595        let headers = b"CONNECT $MSB_KEY:443 HTTP/1.1\r\nHost: example.com:443\r\n\r\n";
1596
1597        assert_eq!(
1598            sanitize_connect_headers(headers, &secrets),
1599            Err(ViolationAction::BlockAndLog)
1600        );
1601    }
1602
1603    async fn spawn_sink() -> (SocketAddr, JoinHandle<Vec<u8>>) {
1604        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1605        let addr = listener.local_addr().unwrap();
1606        let handle = tokio::spawn(async move {
1607            let (mut stream, _) = listener.accept().await.unwrap();
1608            let mut received = Vec::new();
1609            let mut buf = vec![0u8; 4096];
1610            loop {
1611                match stream.read(&mut buf).await {
1612                    Ok(0) | Err(_) => break,
1613                    Ok(n) => received.extend_from_slice(&buf[..n]),
1614                }
1615            }
1616            received
1617        });
1618        (addr, handle)
1619    }
1620
1621    async fn relay_through_proxy(
1622        request: Vec<u8>,
1623        secrets: SecretsConfig,
1624        handle: JoinHandle<Vec<u8>>,
1625        server_addr: SocketAddr,
1626    ) -> Vec<u8> {
1627        let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
1628        let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
1629        let shared = SharedState::new(4);
1630        let policy = Arc::new(NetworkPolicy::default());
1631        let secrets = Arc::new(secrets);
1632        let proxy_connect = Arc::new(ProxyConnectState::new());
1633
1634        from_tx.send(Bytes::from(request)).await.unwrap();
1635        drop(from_tx);
1636
1637        tcp_proxy_task(
1638            server_addr,
1639            server_addr,
1640            from_rx,
1641            to_tx,
1642            Arc::new(shared),
1643            policy,
1644            secrets,
1645            None,
1646            proxy_connect,
1647        )
1648        .await
1649        .unwrap();
1650
1651        handle.await.unwrap()
1652    }
1653
1654    #[tokio::test]
1655    async fn plain_http_substitutes_placeholder_when_host_arrives_in_second_segment() {
1656        // Host header split across TCP segments — classify_first_flight must keep
1657        // reading until \r\n\r\n before extract_http_host is called.
1658        let (addr, sink) = spawn_sink().await;
1659        let secrets = make_plain_http_secret("$MSB_KEY", "real-secret-value", false);
1660
1661        let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
1662        let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
1663        let proxy_connect = Arc::new(ProxyConnectState::new());
1664
1665        from_tx
1666            .send(Bytes::from_static(b"GET /api HTTP/1.1\r\n"))
1667            .await
1668            .unwrap();
1669        from_tx
1670            .send(Bytes::from_static(
1671                b"Host: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n",
1672            ))
1673            .await
1674            .unwrap();
1675        drop(from_tx);
1676
1677        tcp_proxy_task(
1678            addr,
1679            addr,
1680            from_rx,
1681            to_tx,
1682            Arc::new(SharedState::new(4)),
1683            Arc::new(NetworkPolicy::default()),
1684            Arc::new(secrets),
1685            None,
1686            proxy_connect,
1687        )
1688        .await
1689        .unwrap();
1690
1691        let wire = String::from_utf8(sink.await.unwrap()).unwrap();
1692        assert!(wire.contains("real-secret-value"), "got: {wire:?}");
1693        assert!(!wire.contains("$MSB_KEY"), "got: {wire:?}");
1694    }
1695
1696    #[tokio::test]
1697    async fn plain_http_forwards_placeholder_to_allowed_host_with_split_headers() {
1698        // A default (require_tls_identity = true) host-bound secret is never
1699        // substituted over plain HTTP, but a request to its allowed host must
1700        // have the placeholder forwarded unchanged — not blocked as a violation
1701        // — even when the Host arrives in a later segment than the request line.
1702        let (addr, sink) = spawn_sink().await;
1703
1704        let shared = SharedState::new(4);
1705        shared.cache_resolved_hostname(
1706            "example.com",
1707            ResolvedHostnameFamily::Ipv4,
1708            ["127.0.0.1".parse::<IpAddr>().unwrap()],
1709            StdDuration::from_secs(60),
1710        );
1711
1712        let secrets = SecretsConfig {
1713            secrets: vec![SecretEntry {
1714                env_var: "API_KEY".into(),
1715                value: "real-secret-value".into(),
1716                placeholder: "$MSB_KEY".into(),
1717                allowed_hosts: vec![HostPattern::Exact("example.com".into())],
1718                injection: SecretInjection {
1719                    headers: true,
1720                    basic_auth: false,
1721                    query_params: false,
1722                    body: false,
1723                },
1724                on_violation: None,
1725                require_tls_identity: true,
1726            }],
1727            ..Default::default()
1728        };
1729
1730        let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
1731        let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
1732        let proxy_connect = Arc::new(ProxyConnectState::new());
1733
1734        from_tx
1735            .send(Bytes::from_static(b"GET /api HTTP/1.1\r\n"))
1736            .await
1737            .unwrap();
1738        from_tx
1739            .send(Bytes::from_static(
1740                b"Host: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n",
1741            ))
1742            .await
1743            .unwrap();
1744        drop(from_tx);
1745
1746        tcp_proxy_task(
1747            addr,
1748            addr,
1749            from_rx,
1750            to_tx,
1751            Arc::new(shared),
1752            Arc::new(NetworkPolicy::default()),
1753            Arc::new(secrets),
1754            None,
1755            proxy_connect,
1756        )
1757        .await
1758        .unwrap();
1759
1760        let wire = String::from_utf8(sink.await.unwrap()).unwrap();
1761        assert!(
1762            wire.contains("Host: example.com"),
1763            "request must reach the allowed host, got: {wire:?}"
1764        );
1765        assert!(
1766            wire.contains("$MSB_KEY"),
1767            "placeholder must be forwarded unchanged for a require_tls_identity secret, got: {wire:?}"
1768        );
1769        assert!(
1770            !wire.contains("real-secret-value"),
1771            "secret must never be substituted over plain HTTP, got: {wire:?}"
1772        );
1773    }
1774
1775    #[tokio::test]
1776    async fn plain_http_substitutes_placeholder_in_first_flight() {
1777        let (addr, sink) = spawn_sink().await;
1778
1779        let request =
1780            b"GET /api HTTP/1.1\r\nHost: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n"
1781                .to_vec();
1782        let secrets = make_plain_http_secret("$MSB_KEY", "real-secret-value", false);
1783
1784        let wire =
1785            String::from_utf8(relay_through_proxy(request, secrets, sink, addr).await).unwrap();
1786        assert!(
1787            wire.contains("real-secret-value"),
1788            "real value must reach server, got: {wire:?}"
1789        );
1790        assert!(
1791            !wire.contains("$MSB_KEY"),
1792            "placeholder must not reach server, got: {wire:?}"
1793        );
1794    }
1795
1796    #[tokio::test]
1797    async fn plain_http_no_substitution_when_require_tls_identity_true() {
1798        let (addr, sink) = spawn_sink().await;
1799
1800        let request =
1801            b"GET /api HTTP/1.1\r\nHost: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n"
1802                .to_vec();
1803        let secrets = make_plain_http_secret("$MSB_KEY", "real-secret-value", true);
1804
1805        let wire =
1806            String::from_utf8_lossy(&relay_through_proxy(request, secrets, sink, addr).await)
1807                .into_owned();
1808        assert!(
1809            wire.contains("$MSB_KEY"),
1810            "placeholder must be forwarded unchanged when require_tls_identity=true, got: {wire:?}"
1811        );
1812        assert!(
1813            !wire.contains("real-secret-value"),
1814            "real value must not leak when require_tls_identity=true, got: {wire:?}"
1815        );
1816    }
1817
1818    #[tokio::test]
1819    async fn plain_http_large_body_forwarded_verbatim_in_relay_loop() {
1820        // Body arrives in a separate segment after headers — flows through the relay
1821        // loop, not the peek path. Ensures no bytes are dropped and header substitution
1822        // still happens.
1823        let (addr, sink) = spawn_sink().await;
1824        let secrets = make_plain_http_secret("$MSB_KEY", "real-value", false);
1825
1826        let body = "x".repeat(32_000);
1827        let header = format!(
1828            "POST /upload HTTP/1.1\r\nHost: example.com\r\nAuthorization: Bearer $MSB_KEY\r\nContent-Length: {}\r\n\r\n",
1829            body.len()
1830        );
1831
1832        let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
1833        let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
1834        let proxy_connect = Arc::new(ProxyConnectState::new());
1835
1836        from_tx
1837            .send(Bytes::from(header.into_bytes()))
1838            .await
1839            .unwrap();
1840        from_tx
1841            .send(Bytes::from(body.clone().into_bytes()))
1842            .await
1843            .unwrap();
1844        drop(from_tx);
1845
1846        tcp_proxy_task(
1847            addr,
1848            addr,
1849            from_rx,
1850            to_tx,
1851            Arc::new(SharedState::new(4)),
1852            Arc::new(NetworkPolicy::default()),
1853            Arc::new(secrets),
1854            None,
1855            proxy_connect,
1856        )
1857        .await
1858        .unwrap();
1859
1860        let wire = String::from_utf8_lossy(&sink.await.unwrap()).into_owned();
1861        assert!(wire.contains(&body), "got {} bytes", wire.len());
1862        assert!(!wire.contains("$MSB_KEY"), "got: {wire:?}");
1863    }
1864}