Skip to main content

microsandbox_network/
proxy.rs

1//! Bidirectional TCP proxy: smoltcp socket ↔ channels ↔ tokio socket.
2//!
3//! Each outbound guest TCP connection gets a proxy task that opens a real
4//! TCP connection to the destination via tokio and relays data between the
5//! channel pair (connected to the smoltcp socket in the poll loop) and the
6//! real server.
7
8use std::borrow::Cow;
9use std::io;
10use std::net::{IpAddr, SocketAddr};
11use std::sync::Arc;
12use std::time::Duration;
13
14use bytes::Bytes;
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16use tokio::net::TcpStream;
17use tokio::sync::mpsc;
18
19use crate::conn::ProxyConnectState;
20use crate::policy::{EgressEvaluation, HostnameSource, NetworkPolicy, Protocol};
21use crate::secrets::config::{SecretsConfig, ViolationAction};
22use crate::secrets::handler::{
23    SecretsHandler, first_line_is_not_http_request, looks_like_http_request_prefix,
24};
25use crate::shared::SharedState;
26use crate::tls::proxy::{TlsProxyContext, tls_proxy_task};
27use crate::tls::sni;
28use crate::tls::state::TlsState;
29
30//--------------------------------------------------------------------------------------------------
31// Constants
32//--------------------------------------------------------------------------------------------------
33
34/// Buffer size for reading from the real server.
35const SERVER_READ_BUF_SIZE: usize = 16384;
36
37/// Max bytes buffered while reading the proxy's CONNECT response headers.
38const CONNECT_RESP_LIMIT: usize = 8192;
39
40/// Max bytes to buffer while peeking for the ClientHello's SNI.
41const PEEK_BUF_SIZE: usize = 16384;
42
43/// Upper bound on time spent buffering the first flight before
44/// falling back to a cache-only egress decision.
45const PEEK_BUDGET: Duration = Duration::from_secs(5);
46
47//--------------------------------------------------------------------------------------------------
48// Types
49//--------------------------------------------------------------------------------------------------
50
51#[derive(Debug)]
52struct ConnectRequest {
53    bytes: Vec<u8>,
54    header_end: usize,
55    target: ConnectTarget,
56}
57
58#[derive(Debug, Clone, PartialEq, Eq)]
59struct ConnectTarget {
60    host: String,
61    port: u16,
62    expected_sni: Option<String>,
63}
64
65//--------------------------------------------------------------------------------------------------
66// Methods
67//--------------------------------------------------------------------------------------------------
68
69impl ConnectRequest {
70    fn header_bytes(&self) -> &[u8] {
71        &self.bytes[..self.header_end]
72    }
73
74    fn post_header_bytes(&self) -> &[u8] {
75        &self.bytes[self.header_end..]
76    }
77}
78
79impl ConnectTarget {
80    fn is_intercepted(&self, tls_state: &TlsState) -> bool {
81        tls_state.config.intercepted_ports.contains(&self.port)
82    }
83
84    fn guest_dst(&self, fallback: SocketAddr, shared: &SharedState) -> SocketAddr {
85        if let Ok(ip) = self.host.parse::<IpAddr>() {
86            return SocketAddr::new(ip, self.port);
87        }
88
89        if self.host.eq_ignore_ascii_case(crate::HOST_ALIAS) {
90            match fallback.ip() {
91                IpAddr::V4(_) => {
92                    if let Some(ip) = shared.gateway_ipv4() {
93                        return SocketAddr::new(IpAddr::V4(ip), self.port);
94                    }
95                }
96                IpAddr::V6(_) => {
97                    if let Some(ip) = shared.gateway_ipv6() {
98                        return SocketAddr::new(IpAddr::V6(ip), self.port);
99                    }
100                }
101            }
102            if let Some(ip) = shared.gateway_ipv4() {
103                return SocketAddr::new(IpAddr::V4(ip), self.port);
104            }
105            if let Some(ip) = shared.gateway_ipv6() {
106                return SocketAddr::new(IpAddr::V6(ip), self.port);
107            }
108        }
109
110        SocketAddr::new(fallback.ip(), self.port)
111    }
112}
113
114//--------------------------------------------------------------------------------------------------
115// Functions
116//--------------------------------------------------------------------------------------------------
117
118/// Dial `dst` and update proxy state; wakes the poll thread on failure.
119pub(crate) async fn connect_upstream(
120    dst: SocketAddr,
121    proxy_connect: &ProxyConnectState,
122    shared: &SharedState,
123) -> io::Result<TcpStream> {
124    match TcpStream::connect(dst).await {
125        Ok(s) => {
126            proxy_connect.mark_connected();
127            Ok(s)
128        }
129        Err(e) => {
130            proxy_connect.mark_upstream_connect_failed();
131            shared.proxy_wake.wake();
132            Err(e)
133        }
134    }
135}
136
137/// Spawn a TCP proxy task for a newly established connection.
138///
139/// `guest_dst` is what the guest dialed — the address policy rules
140/// match against. `connect_dst` is the host-side address tokio actually
141/// dials; for host-alias connections it's loopback (gateway rewritten).
142/// For everything else the two are identical.
143///
144/// `proxy_connect` is updated before the task exits so the connection
145/// tracker can decide between FIN (clean close) and RST (upstream
146/// connect failure).
147#[allow(clippy::too_many_arguments)]
148pub fn spawn_tcp_proxy(
149    handle: &tokio::runtime::Handle,
150    guest_dst: SocketAddr,
151    connect_dst: SocketAddr,
152    from_smoltcp: mpsc::Receiver<Bytes>,
153    to_smoltcp: mpsc::Sender<Bytes>,
154    shared: Arc<SharedState>,
155    network_policy: Arc<NetworkPolicy>,
156    secrets: Arc<SecretsConfig>,
157    tls_state: Option<Arc<TlsState>>,
158    proxy_connect: Arc<ProxyConnectState>,
159) {
160    handle.spawn(async move {
161        if let Err(e) = tcp_proxy_task(
162            guest_dst,
163            connect_dst,
164            from_smoltcp,
165            to_smoltcp,
166            shared,
167            network_policy,
168            secrets,
169            tls_state,
170            proxy_connect,
171        )
172        .await
173        {
174            tracing::debug!(dst = %connect_dst, error = %e, "TCP proxy task ended");
175        }
176    });
177}
178
179/// Core TCP proxy: peek for SNI, evaluate egress policy, then either
180/// connect and relay or drop the channels.
181#[allow(clippy::too_many_arguments)]
182async fn tcp_proxy_task(
183    guest_dst: SocketAddr,
184    connect_dst: SocketAddr,
185    mut from_smoltcp: mpsc::Receiver<Bytes>,
186    to_smoltcp: mpsc::Sender<Bytes>,
187    shared: Arc<SharedState>,
188    network_policy: Arc<NetworkPolicy>,
189    secrets: Arc<SecretsConfig>,
190    tls_state: Option<Arc<TlsState>>,
191    proxy_connect: Arc<ProxyConnectState>,
192) -> io::Result<()> {
193    // Pre-connect peek is only for domain policy: the hostname has to be known
194    // before we dial upstream so a Deny never opens a connection. Secrets do
195    // *not* gate the connect, so they no longer force a peek here — that work is
196    // deferred to `classify_first_flight` after the socket is open, where it can
197    // run without stalling server-first protocols (see below).
198    let (mut initial_buf, sni) = if network_policy.has_domain_rules() {
199        peek_for_sni(&mut from_smoltcp, PEEK_BUF_SIZE, PEEK_BUDGET).await
200    } else {
201        (Vec::new(), None)
202    };
203
204    // Re-evaluate egress against the *guest* dst — the address the
205    // guest dialed, not the post-rewrite host-side address. SNI
206    // refines over-allow when the cache matched a shared CDN IP;
207    // CacheOnly is the non-TLS fallback path so Domain rules still
208    // gate plain HTTP / SSH / etc.
209    if network_policy.has_domain_rules() {
210        let source = match sni.as_deref() {
211            Some(name) => HostnameSource::Sni(name),
212            None => HostnameSource::CacheOnly,
213        };
214        match network_policy.evaluate_egress_with_source(guest_dst, Protocol::Tcp, &shared, source)
215        {
216            EgressEvaluation::Allow => {}
217            EgressEvaluation::Deny => {
218                tracing::debug!(
219                    dst = %guest_dst,
220                    source = source.label(),
221                    "TCP egress denied by domain policy",
222                );
223                proxy_connect.mark_policy_denied();
224                shared.proxy_wake.wake();
225                return Ok(());
226            }
227            EgressEvaluation::DeferUntilHostname => {
228                debug_assert!(false, "DeferUntilHostname leaked into TCP proxy task");
229                proxy_connect.mark_policy_denied();
230                shared.proxy_wake.wake();
231                return Ok(());
232            }
233        }
234    }
235
236    // Peek for HTTP CONNECT before dialing upstream; hand off if detected.
237    if let Some(tls_state) = tls_state.clone() {
238        if initial_buf.is_empty() {
239            let (peeked, _) = peek_for_sni(&mut from_smoltcp, PEEK_BUF_SIZE, PEEK_BUDGET).await;
240            initial_buf = peeked;
241        }
242        if could_be_connect_request(&initial_buf) {
243            return handle_connect_tunnel(
244                guest_dst,
245                connect_dst,
246                initial_buf,
247                from_smoltcp,
248                to_smoltcp,
249                shared,
250                network_policy,
251                tls_state,
252                proxy_connect,
253                None,
254            )
255            .await;
256        }
257    }
258
259    // Connect upstream *before* finishing the secrets-side classification. A
260    // server-first protocol (SSH, SMTP, a database) sends nothing until it has
261    // seen the server's banner; with the socket already open we can relay that
262    // banner while we wait, instead of burning the peek budget pre-connect.
263    let stream = connect_upstream(connect_dst, &proxy_connect, &shared).await?;
264    let (mut server_rx, mut server_tx) = stream.into_split();
265
266    // Finish classifying the first flight (TLS vs plain HTTP) and, for
267    // plain-HTTP candidates, gather a full header block — without blocking the
268    // server→guest direction. When domain rules already peeked, `initial_buf`
269    // is reused and this is cheap; with no secrets it is skipped entirely
270    // (`is_tls` only matters for deciding whether to build the handler).
271    let want_headers = secrets.has_plain_http_candidates() || secrets.has_host_scoped_secrets();
272    let (initial_buf, is_tls) = if !secrets.secrets.is_empty() {
273        classify_first_flight(
274            initial_buf,
275            &mut from_smoltcp,
276            &mut server_rx,
277            &to_smoltcp,
278            &shared,
279            want_headers,
280            PEEK_BUF_SIZE,
281            PEEK_BUDGET,
282        )
283        .await?
284    } else {
285        (initial_buf, false)
286    };
287
288    if let Some(tls_state) = tls_state.clone()
289        && could_be_connect_request(&initial_buf)
290    {
291        // The pre-connect CONNECT peek can miss a client whose first bytes arrive
292        // after we dial upstream. Once classify_first_flight has captured that
293        // request, rejoin the already-open proxy socket and use the CONNECT path
294        // so intercepted tunnels still get TLS substitution and policy checks.
295        let proxy_stream = server_rx
296            .reunite(server_tx)
297            .map_err(|_| io::Error::other("failed to reunite proxy stream halves"))?;
298        return handle_connect_tunnel(
299            guest_dst,
300            connect_dst,
301            initial_buf,
302            from_smoltcp,
303            to_smoltcp,
304            shared,
305            network_policy,
306            tls_state,
307            proxy_connect,
308            Some(proxy_stream),
309        )
310        .await;
311    }
312
313    let mut late_connect_state = tls_state;
314    let mut secrets_handler: Option<SecretsHandler> = if !secrets.secrets.is_empty() && !is_tls {
315        Some(match extract_http_host(&initial_buf) {
316            Some(host) => SecretsHandler::new_plain_http(&secrets, &host, guest_dst.ip(), &shared),
317            None => SecretsHandler::new_plain_http_invalid_host(&secrets),
318        })
319    } else {
320        None
321    };
322
323    // Replay the buffered first flight — run through secrets handler first.
324    if !initial_buf.is_empty() {
325        let out: Cow<[u8]> = match secrets_handler.as_mut() {
326            Some(h) => match h.substitute(&initial_buf) {
327                // Borrow the input when nothing was substituted; only a chunk
328                // that actually carries a placeholder is reallocated.
329                Ok(cow) => cow,
330                Err(action) => {
331                    tracing::warn!(dst = %connect_dst, violation = ?action, "secret violation in first flight");
332                    if matches!(action, ViolationAction::BlockAndTerminate) {
333                        shared.trigger_termination();
334                    }
335                    return Ok(());
336                }
337            },
338            None => Cow::Borrowed(&initial_buf),
339        };
340        if !out.is_empty() {
341            if let Err(e) = server_tx.write_all(&out).await {
342                tracing::debug!(dst = %connect_dst, error = %e, "replay of buffered first flight failed");
343                return Ok(());
344            }
345            if let Err(e) = server_tx.flush().await {
346                tracing::debug!(dst = %connect_dst, error = %e, "flush after first flight failed");
347                return Ok(());
348            }
349        }
350    }
351
352    let mut server_buf = vec![0u8; SERVER_READ_BUF_SIZE];
353
354    // Bidirectional relay using tokio::select!.
355    //
356    // guest → server: receive from channel, write to server socket.
357    // server → guest: read from server socket, send via channel + wake poll.
358    loop {
359        tokio::select! {
360            // Guest → server: substitute placeholders before forwarding.
361            data = from_smoltcp.recv() => {
362                match data {
363                    Some(bytes) => {
364                        if let Some(tls_state) = late_connect_state.take()
365                            && could_be_connect_request(&bytes)
366                        {
367                            // The first guest bytes can arrive after both peek
368                            // windows have completed. Nothing has been written
369                            // to the proxy socket yet, so this is still a valid
370                            // point to switch into CONNECT tunnel handling.
371                            let proxy_stream = server_rx
372                                .reunite(server_tx)
373                                .map_err(|_| io::Error::other("failed to reunite proxy stream halves"))?;
374                            return handle_connect_tunnel(
375                                guest_dst,
376                                connect_dst,
377                                bytes.to_vec(),
378                                from_smoltcp,
379                                to_smoltcp,
380                                shared,
381                                network_policy,
382                                tls_state,
383                                proxy_connect,
384                                Some(proxy_stream),
385                            )
386                            .await;
387                        }
388                        // No handler (no secrets / TLS) is the common path: forward
389                        // the chunk borrowed, with no per-chunk allocation or copy.
390                        let out: Cow<[u8]> = match secrets_handler.as_mut() {
391                            Some(h) => match h.substitute(&bytes) {
392                                Ok(cow) => cow,
393                                Err(action) => {
394                                    tracing::warn!(dst = %connect_dst, violation = ?action, "secret violation");
395                                    if matches!(action, ViolationAction::BlockAndTerminate) {
396                                        shared.trigger_termination();
397                                    }
398                                    break;
399                                }
400                            },
401                            None => Cow::Borrowed(&bytes),
402                        };
403                        if !out.is_empty() {
404                            if let Err(e) = server_tx.write_all(&out).await {
405                                tracing::debug!(dst = %connect_dst, error = %e, "write to server failed");
406                                break;
407                            }
408                            if let Err(e) = server_tx.flush().await {
409                                tracing::debug!(dst = %connect_dst, error = %e, "flush to server failed");
410                                break;
411                            }
412                        }
413                    }
414                    // Channel closed — smoltcp socket was closed by guest.
415                    None => break,
416                }
417            }
418
419            // Server → guest: no substitution — server never sends placeholders.
420            result = server_rx.read(&mut server_buf) => {
421                match result {
422                    Ok(0) => break, // Server closed connection.
423                    Ok(n) => {
424                        // A server-first byte means this is not an HTTP CONNECT
425                        // tunnel to a proxy. Keep relaying normally afterward.
426                        late_connect_state = None;
427                        let data = Bytes::copy_from_slice(&server_buf[..n]);
428                        if to_smoltcp.send(data).await.is_err() {
429                            // Channel closed — poll loop dropped the receiver.
430                            break;
431                        }
432                        // Wake the poll thread so it writes data to the
433                        // smoltcp socket.
434                        shared.proxy_wake.wake();
435                    }
436                    Err(e) => {
437                        tracing::debug!(dst = %connect_dst, error = %e, "read from server failed");
438                        break;
439                    }
440                }
441            }
442        }
443    }
444
445    Ok(())
446}
447
448/// Forward an HTTP CONNECT tunnel: dial the proxy, splice the handshake,
449/// then hand the established stream to `tls_proxy_task` for TLS MITM.
450///
451/// `guest_dst` is what the guest dialed; `proxy_dst` is the rewritten
452/// loopback address the gateway actually connects to.
453#[allow(clippy::too_many_arguments)]
454async fn handle_connect_tunnel(
455    guest_dst: SocketAddr,
456    proxy_dst: SocketAddr,
457    initial_buf: Vec<u8>,
458    mut from_smoltcp: mpsc::Receiver<Bytes>,
459    to_smoltcp: mpsc::Sender<Bytes>,
460    shared: Arc<SharedState>,
461    network_policy: Arc<NetworkPolicy>,
462    tls_state: Arc<TlsState>,
463    proxy_connect: Arc<ProxyConnectState>,
464    preconnected_proxy: Option<TcpStream>,
465) -> io::Result<()> {
466    let connect_req =
467        parse_connect_request(buffer_connect_request(initial_buf, &mut from_smoltcp).await?)?;
468
469    let connect_headers = match sanitize_connect_headers(
470        connect_req.header_bytes(),
471        &tls_state.secrets,
472    ) {
473        Ok(headers) => headers,
474        Err(action) => {
475            tracing::warn!(dst = %proxy_dst, violation = ?action, "secret violation in CONNECT headers");
476            if matches!(action, ViolationAction::BlockAndTerminate) {
477                shared.trigger_termination();
478            }
479            return Ok(());
480        }
481    };
482
483    // Dial the proxy and forward the CONNECT request so it opens the tunnel.
484    let mut proxy_stream = match preconnected_proxy {
485        Some(stream) => stream,
486        None => match TcpStream::connect(proxy_dst).await {
487            Ok(s) => s,
488            Err(e) => {
489                proxy_connect.mark_upstream_connect_failed();
490                shared.proxy_wake.wake();
491                return Err(e);
492            }
493        },
494    };
495
496    if !connect_req.target.is_intercepted(&tls_state) {
497        proxy_stream.write_all(&connect_headers).await?;
498        proxy_stream.flush().await?;
499        let (proxy_resp, header_end) = read_connect_response_headers(&mut proxy_stream).await?;
500        if to_smoltcp
501            .send(Bytes::copy_from_slice(&proxy_resp[..header_end]))
502            .await
503            .is_err()
504        {
505            return Ok(());
506        }
507        if !proxy_resp[header_end..].is_empty()
508            && to_smoltcp
509                .send(Bytes::copy_from_slice(&proxy_resp[header_end..]))
510                .await
511                .is_err()
512        {
513            return Ok(());
514        }
515        shared.proxy_wake.wake();
516        if !connect_response_is_success(&proxy_resp[..header_end]) {
517            proxy_connect.mark_connected();
518            return Ok(());
519        }
520        if !connect_req.post_header_bytes().is_empty() {
521            proxy_stream
522                .write_all(connect_req.post_header_bytes())
523                .await?;
524        }
525        proxy_stream.flush().await?;
526        proxy_connect.mark_connected();
527        return relay_connected_stream(proxy_stream, from_smoltcp, to_smoltcp, shared).await;
528    }
529
530    proxy_stream.write_all(&connect_headers).await?;
531    proxy_stream.flush().await?;
532
533    let (proxy_resp, header_end) = read_connect_response_headers(&mut proxy_stream).await?;
534    if !connect_response_is_success(&proxy_resp[..header_end]) {
535        return Err(io::Error::new(
536            io::ErrorKind::ConnectionRefused,
537            "proxy rejected CONNECT",
538        ));
539    }
540    if !proxy_resp[header_end..].is_empty() {
541        return Err(io::Error::new(
542            io::ErrorKind::InvalidData,
543            "proxy sent unexpected bytes after CONNECT response headers",
544        ));
545    }
546    proxy_connect.mark_connected();
547
548    if to_smoltcp
549        .send(Bytes::copy_from_slice(&proxy_resp[..header_end]))
550        .await
551        .is_err()
552    {
553        return Ok(());
554    }
555    shared.proxy_wake.wake();
556
557    let tls_seed = connect_req.post_header_bytes().to_vec();
558    let tls_guest_dst = connect_req.target.guest_dst(guest_dst, &shared);
559    let expected_sni = connect_req.target.expected_sni.clone();
560
561    tls_proxy_task(
562        TlsProxyContext {
563            guest_dst: tls_guest_dst,
564            connect_dst: proxy_dst,
565            shared,
566            tls_state,
567            network_policy,
568            proxy_connect,
569            upstream_stream: Some(proxy_stream),
570            via_connect: expected_sni.is_some(),
571            expected_sni,
572        },
573        from_smoltcp,
574        to_smoltcp,
575        tls_seed,
576    )
577    .await
578}
579
580/// Relay an established TCP stream without inspecting or substituting bytes.
581async fn relay_connected_stream(
582    stream: TcpStream,
583    mut from_smoltcp: mpsc::Receiver<Bytes>,
584    to_smoltcp: mpsc::Sender<Bytes>,
585    shared: Arc<SharedState>,
586) -> io::Result<()> {
587    let (mut server_rx, mut server_tx) = stream.into_split();
588    let mut server_buf = vec![0u8; SERVER_READ_BUF_SIZE];
589
590    loop {
591        tokio::select! {
592            data = from_smoltcp.recv() => {
593                match data {
594                    Some(bytes) => {
595                        server_tx.write_all(&bytes).await?;
596                        server_tx.flush().await?;
597                    }
598                    None => break,
599                }
600            }
601            result = server_rx.read(&mut server_buf) => {
602                match result {
603                    Ok(0) => break,
604                    Ok(n) => {
605                        if to_smoltcp
606                            .send(Bytes::copy_from_slice(&server_buf[..n]))
607                            .await
608                            .is_err()
609                        {
610                            break;
611                        }
612                        shared.proxy_wake.wake();
613                    }
614                    Err(e) => return Err(e),
615                }
616            }
617        }
618    }
619
620    Ok(())
621}
622
623async fn buffer_connect_request(
624    mut buf: Vec<u8>,
625    from_smoltcp: &mut mpsc::Receiver<Bytes>,
626) -> io::Result<Vec<u8>> {
627    let timeout_fut = tokio::time::sleep(PEEK_BUDGET);
628    tokio::pin!(timeout_fut);
629
630    loop {
631        if !could_be_connect_request(&buf) {
632            return Err(io::Error::new(
633                io::ErrorKind::InvalidData,
634                "malformed CONNECT request prefix",
635            ));
636        }
637        if headers_end(&buf).is_some() {
638            return Ok(buf);
639        }
640        if buf.len() >= PEEK_BUF_SIZE {
641            return Err(io::Error::new(
642                io::ErrorKind::InvalidData,
643                "CONNECT request headers too large",
644            ));
645        }
646
647        tokio::select! {
648            biased;
649            _ = &mut timeout_fut => {
650                return Err(io::Error::new(
651                    io::ErrorKind::TimedOut,
652                    "timed out waiting for complete CONNECT request headers",
653                ));
654            }
655            data = from_smoltcp.recv() => match data {
656                Some(bytes) => {
657                    buf.extend_from_slice(&bytes);
658                }
659                None => {
660                    return Err(io::Error::new(
661                        io::ErrorKind::UnexpectedEof,
662                        "channel closed before complete CONNECT request headers",
663                    ));
664                }
665            }
666        }
667    }
668}
669
670async fn read_connect_response_headers(stream: &mut TcpStream) -> io::Result<(Vec<u8>, usize)> {
671    tokio::time::timeout(PEEK_BUDGET, async {
672        let mut proxy_resp = Vec::with_capacity(256);
673        let mut buf = [0u8; 4096];
674        loop {
675            let n = stream.read(&mut buf).await?;
676            if n == 0 {
677                return Err(io::Error::new(
678                    io::ErrorKind::UnexpectedEof,
679                    "proxy closed before sending CONNECT response",
680                ));
681            }
682            proxy_resp.extend_from_slice(&buf[..n]);
683            if let Some(end) = headers_end(&proxy_resp) {
684                return Ok((proxy_resp, end));
685            }
686            if proxy_resp.len() > CONNECT_RESP_LIMIT {
687                return Err(io::Error::new(
688                    io::ErrorKind::InvalidData,
689                    "proxy CONNECT response too large",
690                ));
691            }
692        }
693    })
694    .await
695    .map_err(|_| {
696        io::Error::new(
697            io::ErrorKind::TimedOut,
698            "timed out waiting for proxy CONNECT response",
699        )
700    })?
701}
702
703fn sanitize_connect_headers<'a>(
704    header_bytes: &'a [u8],
705    secrets: &SecretsConfig,
706) -> Result<Cow<'a, [u8]>, ViolationAction> {
707    if secrets.secrets.is_empty() {
708        return Ok(Cow::Borrowed(header_bytes));
709    }
710
711    let mut handler = SecretsHandler::new_plain_http_untrusted_metadata(secrets);
712    handler.substitute(header_bytes)
713}
714
715/// Returns the byte offset just past the `\r\n\r\n` header terminator, or `None`.
716fn headers_end(buf: &[u8]) -> Option<usize> {
717    buf.windows(4).position(|w| w == b"\r\n\r\n").map(|p| p + 4)
718}
719
720fn could_be_connect_request(buf: &[u8]) -> bool {
721    const PREFIX: &[u8] = b"CONNECT ";
722    if buf.is_empty() {
723        return false;
724    }
725    let n = buf.len().min(PREFIX.len());
726    buf[..n].eq_ignore_ascii_case(&PREFIX[..n])
727}
728
729fn parse_connect_request(bytes: Vec<u8>) -> io::Result<ConnectRequest> {
730    let header_end = headers_end(&bytes).ok_or_else(|| {
731        io::Error::new(
732            io::ErrorKind::InvalidData,
733            "incomplete CONNECT request headers",
734        )
735    })?;
736    let target = {
737        let request_line = bytes[..header_end]
738            .split(|&b| b == b'\n')
739            .next()
740            .unwrap_or(&[]);
741        let request_line = std::str::from_utf8(request_line)
742            .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "CONNECT line is not UTF-8"))?
743            .trim_end_matches('\r');
744        let mut parts = request_line.split_ascii_whitespace();
745        let method = parts.next().unwrap_or_default();
746        let authority = parts.next().unwrap_or_default();
747        let version = parts.next().unwrap_or_default();
748        if !method.eq_ignore_ascii_case("CONNECT")
749            || authority.is_empty()
750            || !is_http_version(version)
751            || parts.next().is_some()
752        {
753            return Err(io::Error::new(
754                io::ErrorKind::InvalidData,
755                "malformed CONNECT request line",
756            ));
757        }
758        parse_connect_target(authority)?
759    };
760
761    Ok(ConnectRequest {
762        bytes,
763        header_end,
764        target,
765    })
766}
767
768fn parse_connect_target(authority: &str) -> io::Result<ConnectTarget> {
769    let authority = authority.trim();
770    let (host, port) = if let Some(rest) = authority.strip_prefix('[') {
771        let (host, rest) = rest.split_once(']').ok_or_else(|| {
772            io::Error::new(
773                io::ErrorKind::InvalidData,
774                "malformed CONNECT IPv6 authority",
775            )
776        })?;
777        let port = rest.strip_prefix(':').ok_or_else(|| {
778            io::Error::new(io::ErrorKind::InvalidData, "CONNECT authority missing port")
779        })?;
780        (host, port)
781    } else {
782        let (host, port) = authority.rsplit_once(':').ok_or_else(|| {
783            io::Error::new(io::ErrorKind::InvalidData, "CONNECT authority missing port")
784        })?;
785        if host.contains(':') {
786            return Err(io::Error::new(
787                io::ErrorKind::InvalidData,
788                "CONNECT IPv6 authority must be bracketed",
789            ));
790        }
791        (host, port)
792    };
793    let host = host.trim().trim_end_matches('.');
794    if host.is_empty() {
795        return Err(io::Error::new(
796            io::ErrorKind::InvalidData,
797            "CONNECT authority missing host",
798        ));
799    }
800    let port = port
801        .parse::<u16>()
802        .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid CONNECT port"))?;
803    let expected_sni = host
804        .parse::<IpAddr>()
805        .is_err()
806        .then(|| host.to_ascii_lowercase());
807
808    Ok(ConnectTarget {
809        host: host.to_ascii_lowercase(),
810        port,
811        expected_sni,
812    })
813}
814
815fn is_http_version(version: &str) -> bool {
816    let Some(version) = version.strip_prefix("HTTP/") else {
817        return false;
818    };
819    let Some((major, minor)) = version.split_once('.') else {
820        return false;
821    };
822    !major.is_empty()
823        && !minor.is_empty()
824        && major.bytes().all(|b| b.is_ascii_digit())
825        && minor.bytes().all(|b| b.is_ascii_digit())
826}
827
828fn connect_response_is_success(headers: &[u8]) -> bool {
829    let Some(status_line) = headers.split(|&b| b == b'\n').next() else {
830        return false;
831    };
832    let Ok(status_line) = std::str::from_utf8(status_line) else {
833        return false;
834    };
835    let mut parts = status_line.trim_end_matches('\r').split_ascii_whitespace();
836    let version = parts.next().unwrap_or_default();
837    let status = parts.next().unwrap_or_default();
838    is_http_version(version)
839        && status.len() == 3
840        && status
841            .parse::<u16>()
842            .is_ok_and(|code| (200..300).contains(&code))
843}
844
845/// Extract the `Host:` header value from an already-buffered HTTP header block.
846///
847/// Returns `None` if:
848/// - The first byte is `0x16` (TLS — not HTTP)
849/// - The buffer does not yet contain `\r\n\r\n` (headers incomplete)
850/// - No `Host:` header is present
851///
852/// Strips port suffix, lowercases, and trims whitespace. Result is
853/// ready for byte-equal matching against `SecretEntry::allowed_hosts`.
854fn extract_http_host(buf: &[u8]) -> Option<String> {
855    if buf.first() == Some(&0x16) {
856        return None;
857    }
858    // Size the header pool to the buffer rather than a fixed array: a header
859    // line is at least four bytes (`a:\r\n`), so `len / 4` always covers the
860    // real header count, and `httparse` never reports `TooManyHeaders` (which
861    // would make a request with many headers look hostless). The first flight
862    // is capped at PEEK_BUF_SIZE, so this stays bounded.
863    let mut headers = vec![httparse::EMPTY_HEADER; (buf.len() / 4).max(16)];
864    let mut req = httparse::Request::new(&mut headers);
865    req.parse(buf).ok()?;
866    req.headers
867        .iter()
868        .find(|h| h.name.eq_ignore_ascii_case("host"))
869        .and_then(|h| std::str::from_utf8(h.value).ok())
870        .map(|v| {
871            let host = v.trim();
872            // Strip port suffix.
873            host.rsplit_once(':')
874                .map(|(h, _)| h)
875                .unwrap_or(host)
876                .to_ascii_lowercase()
877        })
878        .filter(|h| !h.is_empty())
879}
880
881/// Finish classifying the guest's first flight after the upstream socket is
882/// open, returning the (possibly extended) first-flight buffer and whether it
883/// is a TLS record.
884///
885/// `buf` carries whatever a pre-connect domain-rule peek already captured; when
886/// it is non-empty the TLS/plain decision is already settled and only header
887/// top-up runs. `want_headers` is set when at least one secret can be
888/// substituted over plain HTTP (`SecretsConfig::has_plain_http_candidates`); it
889/// makes the peek keep reading a non-TLS flight until `\r\n\r\n` so
890/// [`extract_http_host`] sees a complete header block.
891///
892/// Crucially, this relays server→guest while it waits. Server-first protocols
893/// (SSH, SMTP, databases) send nothing until they have seen the server's
894/// banner; draining the server side here lets the banner reach the guest
895/// immediately, so the guest's eventual first flight — not a 5s timeout — is
896/// what ends the peek.
897#[allow(clippy::too_many_arguments)]
898async fn classify_first_flight(
899    mut buf: Vec<u8>,
900    from_smoltcp: &mut mpsc::Receiver<Bytes>,
901    server_rx: &mut tokio::net::tcp::OwnedReadHalf,
902    to_smoltcp: &mpsc::Sender<Bytes>,
903    shared: &SharedState,
904    want_headers: bool,
905    max: usize,
906    budget: Duration,
907) -> io::Result<(Vec<u8>, bool)> {
908    let mut server_buf = vec![0u8; SERVER_READ_BUF_SIZE];
909    let timeout_fut = tokio::time::sleep(budget);
910    tokio::pin!(timeout_fut);
911
912    loop {
913        // Stop as soon as the protocol class is known and — for plain-HTTP
914        // candidates — a full header block has arrived. Bail the moment a
915        // non-TLS flight stops looking like an HTTP request so non-HTTP
916        // protocols (SSH, Postgres) aren't withheld from upstream for the
917        // whole budget while we wait for a `\r\n\r\n` that never comes.
918        if !buf.is_empty() {
919            let is_tls = buf.first() == Some(&0x16);
920            let not_http = !is_tls
921                && (!looks_like_http_request_prefix(&buf) || first_line_is_not_http_request(&buf));
922            let done = !want_headers
923                || is_tls
924                || not_http
925                || buf.len() >= max
926                || buf.windows(4).any(|w| w == b"\r\n\r\n");
927            if done {
928                return Ok((buf, is_tls));
929            }
930        }
931
932        tokio::select! {
933            biased;
934            _ = &mut timeout_fut => {
935                let is_tls = buf.first() == Some(&0x16);
936                return Ok((buf, is_tls));
937            }
938            // Guest → buffer (not forwarded here; the caller replays it once the
939            // handler is built, so substitution applies to the first flight too).
940            guest = from_smoltcp.recv() => match guest {
941                Some(bytes) => buf.extend_from_slice(&bytes),
942                None => {
943                    let is_tls = buf.first() == Some(&0x16);
944                    return Ok((buf, is_tls));
945                }
946            },
947            // Server → guest: relay immediately so a server-first banner is never
948            // held hostage by the peek.
949            server = server_rx.read(&mut server_buf) => match server {
950                Ok(0) => {
951                    let is_tls = buf.first() == Some(&0x16);
952                    return Ok((buf, is_tls));
953                }
954                Ok(n) => {
955                    let data = Bytes::copy_from_slice(&server_buf[..n]);
956                    if to_smoltcp.send(data).await.is_err() {
957                        let is_tls = buf.first() == Some(&0x16);
958                        return Ok((buf, is_tls));
959                    }
960                    shared.proxy_wake.wake();
961                }
962                Err(e) => return Err(e),
963            },
964        }
965    }
966}
967
968/// Buffer the first flight until SNI can be extracted, or until one
969/// of the bail-out conditions hits (channel close, buffer cap,
970/// timeout). Never errors; non-TLS / slow / malformed input all
971/// fall through to `None`.
972///
973/// On hit, the SNI is canonicalized (lowercase + trim trailing dot)
974/// for byte-equal matching against rule destinations. The returned
975/// buffer must be replayed verbatim to upstream before the caller
976/// starts its relay loop.
977async fn peek_for_sni(
978    rx: &mut mpsc::Receiver<Bytes>,
979    max: usize,
980    budget: Duration,
981) -> (Vec<u8>, Option<String>) {
982    let mut buf = Vec::with_capacity(PEEK_BUF_SIZE.min(8192));
983    let timeout_fut = tokio::time::sleep(budget);
984    tokio::pin!(timeout_fut);
985
986    let raw_sni = loop {
987        tokio::select! {
988            biased;
989            _ = &mut timeout_fut => break None,
990            data = rx.recv() => {
991                match data {
992                    Some(bytes) => {
993                        buf.extend_from_slice(&bytes);
994                        // First byte of a TLS record is the ContentType;
995                        // 0x16 is handshake. Anything else can't be a
996                        // ClientHello, so don't burn the full budget on
997                        // plain HTTP / SSH / etc.
998                        if buf.first() != Some(&0x16) {
999                            break None;
1000                        }
1001                        if let Some(name) = sni::extract_sni(&buf) {
1002                            break Some(name);
1003                        }
1004                        if buf.len() >= max {
1005                            break None;
1006                        }
1007                    }
1008                    None => break None,
1009                }
1010            }
1011        }
1012    };
1013
1014    let canonical = raw_sni.map(|s| s.trim_end_matches('.').to_ascii_lowercase());
1015    (buf, canonical)
1016}
1017
1018//--------------------------------------------------------------------------------------------------
1019// Tests
1020//--------------------------------------------------------------------------------------------------
1021
1022#[cfg(test)]
1023mod tests {
1024    use super::*;
1025
1026    /// Synthetic TLS ClientHello carrying SNI `example.com`. Bytes
1027    /// borrowed from `tls::sni` test fixtures so the parser sees a
1028    /// well-formed record.
1029    fn synthetic_client_hello(sni: &str) -> Vec<u8> {
1030        // Minimal but valid TLS 1.2 ClientHello with one SNI entry.
1031        // Layout: record header (5) + handshake header (4) + body.
1032        let host_bytes = sni.as_bytes();
1033        let host_len = host_bytes.len() as u16;
1034        let server_name_list_len = 3 + host_len; // type(1) + len(2) + host
1035        let extension_data_len = 2 + server_name_list_len; // list-len(2) + list
1036        let extensions_total = 4 + extension_data_len; // type(2) + len(2) + data
1037
1038        let mut body = Vec::new();
1039        // Client version
1040        body.extend_from_slice(&[0x03, 0x03]);
1041        // Random (32 bytes)
1042        body.extend_from_slice(&[0u8; 32]);
1043        // Session id length + (empty)
1044        body.push(0);
1045        // Cipher suites length + one cipher
1046        body.extend_from_slice(&[0x00, 0x02, 0x00, 0x2f]);
1047        // Compression methods length + null
1048        body.extend_from_slice(&[0x01, 0x00]);
1049        // Extensions length
1050        body.extend_from_slice(&extensions_total.to_be_bytes());
1051        // SNI extension: type 0x0000
1052        body.extend_from_slice(&[0x00, 0x00]);
1053        body.extend_from_slice(&extension_data_len.to_be_bytes());
1054        body.extend_from_slice(&server_name_list_len.to_be_bytes());
1055        body.push(0x00); // host_name type
1056        body.extend_from_slice(&host_len.to_be_bytes());
1057        body.extend_from_slice(host_bytes);
1058
1059        let handshake_len = body.len() as u32;
1060        let mut hs = Vec::new();
1061        hs.push(0x01); // ClientHello
1062        hs.extend_from_slice(&handshake_len.to_be_bytes()[1..]); // 24-bit length
1063        hs.extend_from_slice(&body);
1064
1065        let record_len = hs.len() as u16;
1066        let mut record = Vec::new();
1067        record.extend_from_slice(&[0x16, 0x03, 0x01]); // Handshake, TLS 1.0
1068        record.extend_from_slice(&record_len.to_be_bytes());
1069        record.extend_from_slice(&hs);
1070
1071        record
1072    }
1073
1074    #[test]
1075    fn could_be_connect_request_matches_split_prefixes_only() {
1076        assert!(could_be_connect_request(b"C"));
1077        assert!(could_be_connect_request(b"connect "));
1078        assert!(could_be_connect_request(b"CONNECT example.com:443"));
1079        assert!(!could_be_connect_request(b"CLIENT"));
1080        assert!(!could_be_connect_request(b"GET / HTTP/1.1\r\n"));
1081    }
1082
1083    #[tokio::test]
1084    async fn buffer_connect_request_reads_split_headers() {
1085        let (tx, mut rx) = mpsc::channel(4);
1086        tx.send(Bytes::from_static(b"NECT example.com:443 HTTP/1.1\r\n"))
1087            .await
1088            .unwrap();
1089        tx.send(Bytes::from_static(b"Host: example.com\r\n\r\n"))
1090            .await
1091            .unwrap();
1092        drop(tx);
1093
1094        let buffered = buffer_connect_request(b"CON".to_vec(), &mut rx)
1095            .await
1096            .unwrap();
1097        let parsed = parse_connect_request(buffered).unwrap();
1098
1099        assert_eq!(parsed.target.host, "example.com");
1100        assert_eq!(parsed.target.port, 443);
1101        assert_eq!(parsed.target.expected_sni.as_deref(), Some("example.com"));
1102        assert!(parsed.post_header_bytes().is_empty());
1103    }
1104
1105    #[test]
1106    fn parse_connect_request_preserves_post_header_tls_seed() {
1107        let mut request = b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com\r\n\r\n".to_vec();
1108        request.extend_from_slice(b"\x16\x03\x01client-hello");
1109
1110        let parsed = parse_connect_request(request).unwrap();
1111
1112        assert_eq!(
1113            parsed.header_bytes(),
1114            b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com\r\n\r\n"
1115        );
1116        assert_eq!(parsed.post_header_bytes(), b"\x16\x03\x01client-hello");
1117    }
1118
1119    #[test]
1120    fn parse_connect_target_requires_authority_port() {
1121        assert!(parse_connect_target("example.com").is_err());
1122        assert!(parse_connect_target("2001:db8::1:443").is_err());
1123
1124        let target = parse_connect_target("[2001:db8::1]:8443").unwrap();
1125        assert_eq!(target.host, "2001:db8::1");
1126        assert_eq!(target.port, 8443);
1127        assert_eq!(target.expected_sni, None);
1128    }
1129
1130    #[test]
1131    fn connect_response_success_requires_exact_2xx_status_code() {
1132        assert!(connect_response_is_success(
1133            b"HTTP/1.1 200 Connection Established\r\n\r\n"
1134        ));
1135        assert!(connect_response_is_success(
1136            b"HTTP/1.1 204 Connection Established\r\n\r\n"
1137        ));
1138        assert!(!connect_response_is_success(b"HTTP/1.1 2000 Weird\r\n\r\n"));
1139        assert!(!connect_response_is_success(b"HTTP/1.1 199 Nope\r\n\r\n"));
1140        assert!(!connect_response_is_success(b"NOTHTTP 200 OK\r\n\r\n"));
1141    }
1142
1143    #[tokio::test]
1144    async fn peek_for_sni_extracts_and_canonicalizes() {
1145        let (tx, mut rx) = mpsc::channel(4);
1146        let hello = synthetic_client_hello("Example.COM");
1147        tx.send(Bytes::from(hello.clone())).await.unwrap();
1148        drop(tx); // close so peek returns even if SNI didn't satisfy
1149
1150        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1151        assert_eq!(sni.as_deref(), Some("example.com"));
1152        assert_eq!(buf, hello);
1153    }
1154
1155    #[tokio::test]
1156    async fn peek_for_sni_returns_none_on_channel_close_without_data() {
1157        let (tx, mut rx) = mpsc::channel::<Bytes>(1);
1158        drop(tx);
1159        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1160        assert!(buf.is_empty());
1161        assert_eq!(sni, None);
1162    }
1163
1164    #[tokio::test]
1165    async fn peek_for_sni_returns_none_on_non_tls_data() {
1166        let (tx, mut rx) = mpsc::channel(4);
1167        // Plaintext HTTP request; not a TLS record so extract_sni returns None.
1168        tx.send(Bytes::from_static(
1169            b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n",
1170        ))
1171        .await
1172        .unwrap();
1173        drop(tx);
1174        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1175        assert!(
1176            !buf.is_empty(),
1177            "buffered bytes must be returned for replay"
1178        );
1179        assert_eq!(sni, None);
1180    }
1181
1182    #[tokio::test]
1183    async fn peek_for_sni_falls_back_on_timeout() {
1184        let (tx, mut rx) = mpsc::channel::<Bytes>(1);
1185        // Hold the sender open but send nothing — peek must time out.
1186        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, Duration::from_millis(50)).await;
1187        drop(tx);
1188        assert!(buf.is_empty());
1189        assert_eq!(sni, None);
1190    }
1191
1192    #[tokio::test]
1193    async fn peek_for_sni_caps_at_max_bytes() {
1194        let (tx, mut rx) = mpsc::channel(4);
1195        // First byte 0x16 keeps the peek collecting past the early
1196        // non-TLS bail. Padding bytes are zero so the SNI parser never
1197        // matches and the loop drives to the size cap.
1198        let mut first = vec![0u8; 8192];
1199        first[0] = 0x16;
1200        tx.send(Bytes::from(first)).await.unwrap();
1201        tx.send(Bytes::from(vec![0u8; 8192])).await.unwrap();
1202        tx.send(Bytes::from(vec![0u8; 8192])).await.unwrap();
1203        drop(tx);
1204
1205        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1206        assert_eq!(sni, None, "no SNI in non-TLS data");
1207        assert!(
1208            buf.len() >= PEEK_BUF_SIZE,
1209            "buffer must hit the cap before bail-out: got {}",
1210            buf.len()
1211        );
1212    }
1213
1214    #[tokio::test]
1215    async fn peek_for_sni_bails_immediately_on_non_tls_first_byte() {
1216        let (tx, mut rx) = mpsc::channel(4);
1217        // Plain HTTP request: first byte 'G' (0x47) — clearly not TLS.
1218        tx.send(Bytes::from_static(b"GET / HTTP/1.1\r\nHost: x\r\n\r\n"))
1219            .await
1220            .unwrap();
1221        drop(tx);
1222
1223        // 5-second nominal budget; assert we returned in well under
1224        // that — the early-bail must not wait for the full window.
1225        let started = std::time::Instant::now();
1226        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1227        let elapsed = started.elapsed();
1228        assert_eq!(sni, None);
1229        assert!(buf.starts_with(b"GET"));
1230        assert!(
1231            elapsed < Duration::from_millis(500),
1232            "non-TLS bail must be fast: took {elapsed:?}"
1233        );
1234    }
1235
1236    //----------------------------------------------------------------------------------------------
1237    // peek_for_sni × evaluate_egress_with_source — combined integration tests
1238    //----------------------------------------------------------------------------------------------
1239
1240    use std::net::IpAddr;
1241    use std::time::Duration as StdDuration;
1242
1243    use crate::policy::{Action, Destination, NetworkPolicy, PortRange, Rule};
1244    use crate::shared::{ResolvedHostnameFamily, SharedState};
1245
1246    const SHARED_FASTLY_IP: &str = "151.101.0.223";
1247
1248    fn shared_with(host: &str, ip: &str) -> SharedState {
1249        let shared = SharedState::new(4);
1250        shared.cache_resolved_hostname(
1251            host,
1252            ResolvedHostnameFamily::Ipv4,
1253            [ip.parse::<IpAddr>().unwrap()],
1254            StdDuration::from_secs(60),
1255        );
1256        shared
1257    }
1258
1259    fn allow_https(domain: &str) -> Rule {
1260        Rule {
1261            direction: crate::policy::Direction::Egress,
1262            destination: Destination::Domain(domain.parse().unwrap()),
1263            protocols: vec![Protocol::Tcp],
1264            ports: vec![PortRange::single(443)],
1265            action: Action::Allow,
1266        }
1267    }
1268
1269    /// Over-allow case: cache says IP X is `pypi.org` (allowed); SNI
1270    /// is `evil.com`. SNI must override the cache and deny.
1271    #[tokio::test]
1272    async fn integration_sni_overrides_cache_for_over_allow() {
1273        let shared = shared_with("pypi.org", SHARED_FASTLY_IP);
1274        let policy = NetworkPolicy {
1275            default_egress: Action::Deny,
1276            default_ingress: Action::Allow,
1277            rules: vec![allow_https("pypi.org")],
1278        };
1279        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
1280
1281        let (tx, mut rx) = mpsc::channel(4);
1282        tx.send(Bytes::from(synthetic_client_hello("evil.com")))
1283            .await
1284            .unwrap();
1285        drop(tx);
1286
1287        let (initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1288        assert_eq!(sni.as_deref(), Some("evil.com"));
1289        assert!(!initial_buf.is_empty());
1290
1291        let source = sni
1292            .as_deref()
1293            .map(HostnameSource::Sni)
1294            .unwrap_or(HostnameSource::CacheOnly);
1295        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
1296        assert_eq!(
1297            eval,
1298            EgressEvaluation::Deny,
1299            "SNI=evil.com must not piggy-back on the cached pypi.org match",
1300        );
1301    }
1302
1303    /// Over-block case: cache says IP X is `ads.example.com` (denied);
1304    /// SNI is `api.example.com`. SNI must override the cache and allow.
1305    #[tokio::test]
1306    async fn integration_sni_overrides_cache_for_over_block() {
1307        let shared = shared_with("ads.example.com", SHARED_FASTLY_IP);
1308        let policy = NetworkPolicy {
1309            default_egress: Action::Allow,
1310            default_ingress: Action::Allow,
1311            rules: vec![Rule::deny_egress(Destination::Domain(
1312                "ads.example.com".parse().unwrap(),
1313            ))],
1314        };
1315        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
1316
1317        let (tx, mut rx) = mpsc::channel(4);
1318        tx.send(Bytes::from(synthetic_client_hello("api.example.com")))
1319            .await
1320            .unwrap();
1321        drop(tx);
1322
1323        let (_initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1324        assert_eq!(sni.as_deref(), Some("api.example.com"));
1325
1326        let source = sni
1327            .as_deref()
1328            .map(HostnameSource::Sni)
1329            .unwrap_or(HostnameSource::CacheOnly);
1330        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
1331        assert_eq!(
1332            eval,
1333            EgressEvaluation::Allow,
1334            "SNI=api.example.com must not be caught by the deny on ads.example.com",
1335        );
1336    }
1337
1338    /// Non-TLS first-flight falls back to `CacheOnly`; the cache
1339    /// match decides.
1340    #[tokio::test]
1341    async fn integration_non_tls_falls_back_to_cache() {
1342        let shared = shared_with("pypi.org", SHARED_FASTLY_IP);
1343        let policy = NetworkPolicy {
1344            default_egress: Action::Deny,
1345            default_ingress: Action::Allow,
1346            rules: vec![allow_https("pypi.org")],
1347        };
1348        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
1349
1350        let (tx, mut rx) = mpsc::channel(4);
1351        // Plain HTTP request; not a TLS record.
1352        tx.send(Bytes::from_static(
1353            b"GET / HTTP/1.1\r\nHost: pypi.org\r\n\r\n",
1354        ))
1355        .await
1356        .unwrap();
1357        drop(tx);
1358
1359        let (initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1360        assert_eq!(sni, None, "non-TLS data → no SNI");
1361        assert!(
1362            !initial_buf.is_empty(),
1363            "buffered bytes must survive for replay"
1364        );
1365
1366        let source = sni
1367            .as_deref()
1368            .map(HostnameSource::Sni)
1369            .unwrap_or(HostnameSource::CacheOnly);
1370        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
1371        assert_eq!(
1372            eval,
1373            EgressEvaluation::Allow,
1374            "cache-only fallback must still allow the cached hostname's IP",
1375        );
1376    }
1377
1378    /// SNI matches a `DomainSuffix` rule with a cache binding for the
1379    /// claimed name. Genuine pre-resolved traffic passes.
1380    #[tokio::test]
1381    async fn integration_sni_matches_domain_suffix_with_cache_binding() {
1382        let shared = shared_with("files.pythonhosted.org", SHARED_FASTLY_IP);
1383        let policy = NetworkPolicy {
1384            default_egress: Action::Deny,
1385            default_ingress: Action::Allow,
1386            rules: vec![Rule {
1387                direction: crate::policy::Direction::Egress,
1388                destination: Destination::DomainSuffix(".pythonhosted.org".parse().unwrap()),
1389                protocols: vec![Protocol::Tcp],
1390                ports: vec![PortRange::single(443)],
1391                action: Action::Allow,
1392            }],
1393        };
1394        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
1395
1396        let (tx, mut rx) = mpsc::channel(4);
1397        tx.send(Bytes::from(synthetic_client_hello(
1398            "files.pythonhosted.org",
1399        )))
1400        .await
1401        .unwrap();
1402        drop(tx);
1403
1404        let (_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1405        let source = sni
1406            .as_deref()
1407            .map(HostnameSource::Sni)
1408            .unwrap_or(HostnameSource::CacheOnly);
1409        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
1410        assert_eq!(eval, EgressEvaluation::Allow);
1411    }
1412
1413    /// Spoofed SNI on an IP with no cache binding for any matching
1414    /// name: byte-equality with the suffix passes, but no DNS lookup
1415    /// ever tied a `*.pythonhosted.org` name to the destination, so
1416    /// the AND-check fails and the connection is denied.
1417    #[tokio::test]
1418    async fn integration_sni_denies_domain_suffix_without_cache_binding() {
1419        let shared = SharedState::new(4); // empty cache
1420        let policy = NetworkPolicy {
1421            default_egress: Action::Deny,
1422            default_ingress: Action::Allow,
1423            rules: vec![Rule {
1424                direction: crate::policy::Direction::Egress,
1425                destination: Destination::DomainSuffix(".pythonhosted.org".parse().unwrap()),
1426                protocols: vec![Protocol::Tcp],
1427                ports: vec![PortRange::single(443)],
1428                action: Action::Allow,
1429            }],
1430        };
1431        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
1432
1433        let (tx, mut rx) = mpsc::channel(4);
1434        tx.send(Bytes::from(synthetic_client_hello(
1435            "files.pythonhosted.org",
1436        )))
1437        .await
1438        .unwrap();
1439        drop(tx);
1440
1441        let (_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
1442        let source = sni
1443            .as_deref()
1444            .map(HostnameSource::Sni)
1445            .unwrap_or(HostnameSource::CacheOnly);
1446        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
1447        assert_eq!(eval, EgressEvaluation::Deny);
1448    }
1449
1450    // ── extract_http_host ──────────────────────────────────────────────────────
1451
1452    #[test]
1453    fn extract_http_host_basic() {
1454        let buf = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n";
1455        assert_eq!(extract_http_host(buf), Some("example.com".into()));
1456    }
1457
1458    #[test]
1459    fn extract_http_host_strips_port() {
1460        let buf = b"POST /api HTTP/1.1\r\nHost: api.company.com:8080\r\n\r\n";
1461        assert_eq!(extract_http_host(buf), Some("api.company.com".into()));
1462    }
1463
1464    #[test]
1465    fn extract_http_host_case_insensitive_lowercased() {
1466        let buf = b"GET / HTTP/1.1\r\nhost: Example.COM\r\n\r\n";
1467        assert_eq!(extract_http_host(buf), Some("example.com".into()));
1468    }
1469
1470    #[test]
1471    fn extract_http_host_no_host_header() {
1472        let buf = b"GET / HTTP/1.1\r\nX-Other: foo\r\n\r\n";
1473        assert_eq!(extract_http_host(buf), None);
1474    }
1475
1476    #[test]
1477    fn extract_http_host_incomplete_headers() {
1478        let buf = b"GET / HTTP/1.1\r\nHost: x";
1479        assert_eq!(extract_http_host(buf), None);
1480    }
1481
1482    #[test]
1483    fn extract_http_host_tls_first_byte() {
1484        let buf = [0x16u8, 0x03, 0x01, 0x00, 0x01];
1485        assert_eq!(extract_http_host(&buf), None);
1486    }
1487
1488    #[test]
1489    fn extract_http_host_with_many_headers() {
1490        // Far more headers than a small fixed parse array would hold: the Host
1491        // must still be found rather than the request looking hostless.
1492        let mut req = Vec::from(&b"GET / HTTP/1.1\r\n"[..]);
1493        for i in 0..100 {
1494            req.extend_from_slice(format!("X-Pad-{i}: v\r\n").as_bytes());
1495        }
1496        req.extend_from_slice(b"Host: example.com\r\n\r\n");
1497        assert_eq!(extract_http_host(&req), Some("example.com".into()));
1498    }
1499
1500    // ── plain-HTTP secret substitution ────────────────────────────────────────
1501
1502    use std::sync::Arc;
1503    use tokio::io::AsyncReadExt;
1504    use tokio::net::TcpListener;
1505    use tokio::task::JoinHandle;
1506
1507    use crate::secrets::config::{HostPattern, SecretEntry, SecretInjection, SecretsConfig};
1508
1509    fn make_plain_http_secret(placeholder: &str, value: &str, require_tls: bool) -> SecretsConfig {
1510        SecretsConfig {
1511            secrets: vec![SecretEntry {
1512                env_var: "API_KEY".into(),
1513                value: value.into(),
1514                placeholder: placeholder.into(),
1515                allowed_hosts: vec![HostPattern::Any],
1516                injection: SecretInjection {
1517                    headers: true,
1518                    basic_auth: false,
1519                    query_params: false,
1520                    body: false,
1521                },
1522                on_violation: None,
1523                require_tls_identity: require_tls,
1524            }],
1525            ..Default::default()
1526        }
1527    }
1528
1529    fn make_host_bound_secret(placeholder: &str, value: &str, host: &str) -> SecretsConfig {
1530        SecretsConfig {
1531            secrets: vec![SecretEntry {
1532                env_var: "API_KEY".into(),
1533                value: value.into(),
1534                placeholder: placeholder.into(),
1535                allowed_hosts: vec![HostPattern::Exact(host.into())],
1536                injection: SecretInjection::default(),
1537                on_violation: None,
1538                require_tls_identity: true,
1539            }],
1540            ..Default::default()
1541        }
1542    }
1543
1544    #[test]
1545    fn sanitize_connect_headers_blocks_placeholder_metadata_header_by_default() {
1546        let secrets = make_host_bound_secret("$MSB_KEY", "real-secret-value", "example.com");
1547        let headers = b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\nProxy-Authorization: Bearer $MSB_KEY\r\nUser-Agent: curl\r\n\r\n";
1548
1549        assert_eq!(
1550            sanitize_connect_headers(headers, &secrets),
1551            Err(ViolationAction::BlockAndLog)
1552        );
1553    }
1554
1555    #[test]
1556    fn sanitize_connect_headers_respects_block_and_terminate() {
1557        let mut secrets = make_host_bound_secret("$MSB_KEY", "real-secret-value", "example.com");
1558        secrets.on_violation = ViolationAction::BlockAndTerminate;
1559        let headers = b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\nProxy-Authorization: Bearer $MSB_KEY\r\n\r\n";
1560
1561        assert_eq!(
1562            sanitize_connect_headers(headers, &secrets),
1563            Err(ViolationAction::BlockAndTerminate)
1564        );
1565    }
1566
1567    #[test]
1568    fn sanitize_connect_headers_respects_explicit_passthrough() {
1569        let mut secrets = make_host_bound_secret("$MSB_KEY", "real-secret-value", "example.com");
1570        secrets.on_violation = ViolationAction::Passthrough(vec![HostPattern::Any]);
1571        let headers = b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\nProxy-Authorization: Bearer $MSB_KEY\r\n\r\n";
1572
1573        let sanitized = sanitize_connect_headers(headers, &secrets).unwrap();
1574
1575        assert_eq!(sanitized.as_ref(), headers);
1576        assert!(
1577            !String::from_utf8_lossy(sanitized.as_ref()).contains("real-secret-value"),
1578            "passthrough must never substitute real secrets into CONNECT metadata"
1579        );
1580    }
1581
1582    #[test]
1583    fn sanitize_connect_headers_keeps_safe_metadata_headers() {
1584        let secrets = make_host_bound_secret("$MSB_KEY", "real-secret-value", "example.com");
1585        let headers =
1586            b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com:443\r\nUser-Agent: curl\r\n\r\n";
1587
1588        let sanitized = sanitize_connect_headers(headers, &secrets).unwrap();
1589
1590        assert_eq!(sanitized.as_ref(), headers);
1591    }
1592
1593    #[test]
1594    fn sanitize_connect_headers_blocks_placeholder_in_request_line() {
1595        let secrets = make_host_bound_secret("$MSB_KEY", "real-secret-value", "example.com");
1596        let headers = b"CONNECT $MSB_KEY:443 HTTP/1.1\r\nHost: example.com:443\r\n\r\n";
1597
1598        assert_eq!(
1599            sanitize_connect_headers(headers, &secrets),
1600            Err(ViolationAction::BlockAndLog)
1601        );
1602    }
1603
1604    async fn spawn_sink() -> (SocketAddr, JoinHandle<Vec<u8>>) {
1605        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1606        let addr = listener.local_addr().unwrap();
1607        let handle = tokio::spawn(async move {
1608            let (mut stream, _) = listener.accept().await.unwrap();
1609            let mut received = Vec::new();
1610            let mut buf = vec![0u8; 4096];
1611            loop {
1612                match stream.read(&mut buf).await {
1613                    Ok(0) | Err(_) => break,
1614                    Ok(n) => received.extend_from_slice(&buf[..n]),
1615                }
1616            }
1617            received
1618        });
1619        (addr, handle)
1620    }
1621
1622    async fn relay_through_proxy(
1623        request: Vec<u8>,
1624        secrets: SecretsConfig,
1625        handle: JoinHandle<Vec<u8>>,
1626        server_addr: SocketAddr,
1627    ) -> Vec<u8> {
1628        let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
1629        let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
1630        let shared = SharedState::new(4);
1631        let policy = Arc::new(NetworkPolicy::default());
1632        let secrets = Arc::new(secrets);
1633        let proxy_connect = Arc::new(ProxyConnectState::new());
1634
1635        from_tx.send(Bytes::from(request)).await.unwrap();
1636        drop(from_tx);
1637
1638        tcp_proxy_task(
1639            server_addr,
1640            server_addr,
1641            from_rx,
1642            to_tx,
1643            Arc::new(shared),
1644            policy,
1645            secrets,
1646            None,
1647            proxy_connect,
1648        )
1649        .await
1650        .unwrap();
1651
1652        handle.await.unwrap()
1653    }
1654
1655    #[tokio::test]
1656    async fn plain_http_substitutes_placeholder_when_host_arrives_in_second_segment() {
1657        // Host header split across TCP segments — classify_first_flight must keep
1658        // reading until \r\n\r\n before extract_http_host is called.
1659        let (addr, sink) = spawn_sink().await;
1660        let secrets = make_plain_http_secret("$MSB_KEY", "real-secret-value", false);
1661
1662        let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
1663        let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
1664        let proxy_connect = Arc::new(ProxyConnectState::new());
1665
1666        from_tx
1667            .send(Bytes::from_static(b"GET /api HTTP/1.1\r\n"))
1668            .await
1669            .unwrap();
1670        from_tx
1671            .send(Bytes::from_static(
1672                b"Host: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n",
1673            ))
1674            .await
1675            .unwrap();
1676        drop(from_tx);
1677
1678        tcp_proxy_task(
1679            addr,
1680            addr,
1681            from_rx,
1682            to_tx,
1683            Arc::new(SharedState::new(4)),
1684            Arc::new(NetworkPolicy::default()),
1685            Arc::new(secrets),
1686            None,
1687            proxy_connect,
1688        )
1689        .await
1690        .unwrap();
1691
1692        let wire = String::from_utf8(sink.await.unwrap()).unwrap();
1693        assert!(wire.contains("real-secret-value"), "got: {wire:?}");
1694        assert!(!wire.contains("$MSB_KEY"), "got: {wire:?}");
1695    }
1696
1697    #[tokio::test]
1698    async fn plain_http_forwards_placeholder_to_allowed_host_with_split_headers() {
1699        // A default (require_tls_identity = true) host-bound secret is never
1700        // substituted over plain HTTP, but a request to its allowed host must
1701        // have the placeholder forwarded unchanged — not blocked as a violation
1702        // — even when the Host arrives in a later segment than the request line.
1703        let (addr, sink) = spawn_sink().await;
1704
1705        let shared = SharedState::new(4);
1706        shared.cache_resolved_hostname(
1707            "example.com",
1708            ResolvedHostnameFamily::Ipv4,
1709            ["127.0.0.1".parse::<IpAddr>().unwrap()],
1710            StdDuration::from_secs(60),
1711        );
1712
1713        let secrets = SecretsConfig {
1714            secrets: vec![SecretEntry {
1715                env_var: "API_KEY".into(),
1716                value: "real-secret-value".into(),
1717                placeholder: "$MSB_KEY".into(),
1718                allowed_hosts: vec![HostPattern::Exact("example.com".into())],
1719                injection: SecretInjection {
1720                    headers: true,
1721                    basic_auth: false,
1722                    query_params: false,
1723                    body: false,
1724                },
1725                on_violation: None,
1726                require_tls_identity: true,
1727            }],
1728            ..Default::default()
1729        };
1730
1731        let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
1732        let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
1733        let proxy_connect = Arc::new(ProxyConnectState::new());
1734
1735        from_tx
1736            .send(Bytes::from_static(b"GET /api HTTP/1.1\r\n"))
1737            .await
1738            .unwrap();
1739        from_tx
1740            .send(Bytes::from_static(
1741                b"Host: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n",
1742            ))
1743            .await
1744            .unwrap();
1745        drop(from_tx);
1746
1747        tcp_proxy_task(
1748            addr,
1749            addr,
1750            from_rx,
1751            to_tx,
1752            Arc::new(shared),
1753            Arc::new(NetworkPolicy::default()),
1754            Arc::new(secrets),
1755            None,
1756            proxy_connect,
1757        )
1758        .await
1759        .unwrap();
1760
1761        let wire = String::from_utf8(sink.await.unwrap()).unwrap();
1762        assert!(
1763            wire.contains("Host: example.com"),
1764            "request must reach the allowed host, got: {wire:?}"
1765        );
1766        assert!(
1767            wire.contains("$MSB_KEY"),
1768            "placeholder must be forwarded unchanged for a require_tls_identity secret, got: {wire:?}"
1769        );
1770        assert!(
1771            !wire.contains("real-secret-value"),
1772            "secret must never be substituted over plain HTTP, got: {wire:?}"
1773        );
1774    }
1775
1776    #[tokio::test]
1777    async fn plain_http_substitutes_placeholder_in_first_flight() {
1778        let (addr, sink) = spawn_sink().await;
1779
1780        let request =
1781            b"GET /api HTTP/1.1\r\nHost: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n"
1782                .to_vec();
1783        let secrets = make_plain_http_secret("$MSB_KEY", "real-secret-value", false);
1784
1785        let wire =
1786            String::from_utf8(relay_through_proxy(request, secrets, sink, addr).await).unwrap();
1787        assert!(
1788            wire.contains("real-secret-value"),
1789            "real value must reach server, got: {wire:?}"
1790        );
1791        assert!(
1792            !wire.contains("$MSB_KEY"),
1793            "placeholder must not reach server, got: {wire:?}"
1794        );
1795    }
1796
1797    #[tokio::test]
1798    async fn plain_http_no_substitution_when_require_tls_identity_true() {
1799        let (addr, sink) = spawn_sink().await;
1800
1801        let request =
1802            b"GET /api HTTP/1.1\r\nHost: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n"
1803                .to_vec();
1804        let secrets = make_plain_http_secret("$MSB_KEY", "real-secret-value", true);
1805
1806        let wire =
1807            String::from_utf8_lossy(&relay_through_proxy(request, secrets, sink, addr).await)
1808                .into_owned();
1809        assert!(
1810            wire.contains("$MSB_KEY"),
1811            "placeholder must be forwarded unchanged when require_tls_identity=true, got: {wire:?}"
1812        );
1813        assert!(
1814            !wire.contains("real-secret-value"),
1815            "real value must not leak when require_tls_identity=true, got: {wire:?}"
1816        );
1817    }
1818
1819    #[tokio::test]
1820    async fn plain_http_large_body_forwarded_verbatim_in_relay_loop() {
1821        // Body arrives in a separate segment after headers — flows through the relay
1822        // loop, not the peek path. Ensures no bytes are dropped and header substitution
1823        // still happens.
1824        let (addr, sink) = spawn_sink().await;
1825        let secrets = make_plain_http_secret("$MSB_KEY", "real-value", false);
1826
1827        let body = "x".repeat(32_000);
1828        let header = format!(
1829            "POST /upload HTTP/1.1\r\nHost: example.com\r\nAuthorization: Bearer $MSB_KEY\r\nContent-Length: {}\r\n\r\n",
1830            body.len()
1831        );
1832
1833        let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
1834        let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
1835        let proxy_connect = Arc::new(ProxyConnectState::new());
1836
1837        from_tx
1838            .send(Bytes::from(header.into_bytes()))
1839            .await
1840            .unwrap();
1841        from_tx
1842            .send(Bytes::from(body.clone().into_bytes()))
1843            .await
1844            .unwrap();
1845        drop(from_tx);
1846
1847        tcp_proxy_task(
1848            addr,
1849            addr,
1850            from_rx,
1851            to_tx,
1852            Arc::new(SharedState::new(4)),
1853            Arc::new(NetworkPolicy::default()),
1854            Arc::new(secrets),
1855            None,
1856            proxy_connect,
1857        )
1858        .await
1859        .unwrap();
1860
1861        let wire = String::from_utf8_lossy(&sink.await.unwrap()).into_owned();
1862        assert!(wire.contains(&body), "got {} bytes", wire.len());
1863        assert!(!wire.contains("$MSB_KEY"), "got: {wire:?}");
1864    }
1865}