Skip to main content

microsandbox_network/
proxy.rs

1//! Bidirectional TCP proxy: smoltcp socket ↔ channels ↔ tokio socket.
2//!
3//! Each outbound guest TCP connection gets a proxy task that opens a real
4//! TCP connection to the destination via tokio and relays data between the
5//! channel pair (connected to the smoltcp socket in the poll loop) and the
6//! real server.
7
8use std::borrow::Cow;
9use std::io;
10use std::net::SocketAddr;
11use std::sync::Arc;
12use std::time::Duration;
13
14use bytes::Bytes;
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16use tokio::net::TcpStream;
17use tokio::sync::mpsc;
18
19use crate::conn::ProxyConnectState;
20use crate::policy::{EgressEvaluation, HostnameSource, NetworkPolicy, Protocol};
21use crate::secrets::config::{SecretsConfig, ViolationAction};
22use crate::secrets::handler::{
23    SecretsHandler, first_line_is_not_http_request, looks_like_http_request_prefix,
24};
25use crate::shared::SharedState;
26use crate::tls::sni;
27
28//--------------------------------------------------------------------------------------------------
29// Constants
30//--------------------------------------------------------------------------------------------------
31
32/// Buffer size for reading from the real server.
33const SERVER_READ_BUF_SIZE: usize = 16384;
34
35/// Max bytes to buffer while peeking for the ClientHello's SNI.
36const PEEK_BUF_SIZE: usize = 16384;
37
38/// Upper bound on time spent buffering the first flight before
39/// falling back to a cache-only egress decision.
40const PEEK_BUDGET: Duration = Duration::from_secs(5);
41
42//--------------------------------------------------------------------------------------------------
43// Functions
44//--------------------------------------------------------------------------------------------------
45
46/// Spawn a TCP proxy task for a newly established connection.
47///
48/// `guest_dst` is what the guest dialed — the address policy rules
49/// match against. `connect_dst` is the host-side address tokio actually
50/// dials; for host-alias connections it's loopback (gateway rewritten).
51/// For everything else the two are identical.
52///
53/// `proxy_connect` is updated before the task exits so the connection
54/// tracker can decide between FIN (clean close) and RST (upstream
55/// connect failure).
56#[allow(clippy::too_many_arguments)]
57pub fn spawn_tcp_proxy(
58    handle: &tokio::runtime::Handle,
59    guest_dst: SocketAddr,
60    connect_dst: SocketAddr,
61    from_smoltcp: mpsc::Receiver<Bytes>,
62    to_smoltcp: mpsc::Sender<Bytes>,
63    shared: Arc<SharedState>,
64    network_policy: Arc<NetworkPolicy>,
65    secrets: Arc<SecretsConfig>,
66    proxy_connect: Arc<ProxyConnectState>,
67) {
68    handle.spawn(async move {
69        if let Err(e) = tcp_proxy_task(
70            guest_dst,
71            connect_dst,
72            from_smoltcp,
73            to_smoltcp,
74            shared,
75            network_policy,
76            secrets,
77            proxy_connect,
78        )
79        .await
80        {
81            tracing::debug!(dst = %connect_dst, error = %e, "TCP proxy task ended");
82        }
83    });
84}
85
86/// Core TCP proxy: peek for SNI, evaluate egress policy, then either
87/// connect and relay or drop the channels.
88#[allow(clippy::too_many_arguments)]
89async fn tcp_proxy_task(
90    guest_dst: SocketAddr,
91    connect_dst: SocketAddr,
92    mut from_smoltcp: mpsc::Receiver<Bytes>,
93    to_smoltcp: mpsc::Sender<Bytes>,
94    shared: Arc<SharedState>,
95    network_policy: Arc<NetworkPolicy>,
96    secrets: Arc<SecretsConfig>,
97    proxy_connect: Arc<ProxyConnectState>,
98) -> io::Result<()> {
99    // Pre-connect peek is only for domain policy: the hostname has to be known
100    // before we dial upstream so a Deny never opens a connection. Secrets do
101    // *not* gate the connect, so they no longer force a peek here — that work is
102    // deferred to `classify_first_flight` after the socket is open, where it can
103    // run without stalling server-first protocols (see below).
104    let (initial_buf, sni) = if network_policy.has_domain_rules() {
105        peek_for_sni(&mut from_smoltcp, PEEK_BUF_SIZE, PEEK_BUDGET).await
106    } else {
107        (Vec::new(), None)
108    };
109
110    // Re-evaluate egress against the *guest* dst — the address the
111    // guest dialed, not the post-rewrite host-side address. SNI
112    // refines over-allow when the cache matched a shared CDN IP;
113    // CacheOnly is the non-TLS fallback path so Domain rules still
114    // gate plain HTTP / SSH / etc.
115    if network_policy.has_domain_rules() {
116        let source = match sni.as_deref() {
117            Some(name) => HostnameSource::Sni(name),
118            None => HostnameSource::CacheOnly,
119        };
120        match network_policy.evaluate_egress_with_source(guest_dst, Protocol::Tcp, &shared, source)
121        {
122            EgressEvaluation::Allow => {}
123            EgressEvaluation::Deny => {
124                tracing::debug!(
125                    dst = %guest_dst,
126                    source = source.label(),
127                    "TCP egress denied by domain policy",
128                );
129                proxy_connect.mark_policy_denied();
130                shared.proxy_wake.wake();
131                return Ok(());
132            }
133            EgressEvaluation::DeferUntilHostname => {
134                debug_assert!(false, "DeferUntilHostname leaked into TCP proxy task");
135                proxy_connect.mark_policy_denied();
136                shared.proxy_wake.wake();
137                return Ok(());
138            }
139        }
140    }
141
142    // Connect upstream *before* finishing the secrets-side classification. A
143    // server-first protocol (SSH, SMTP, a database) sends nothing until it has
144    // seen the server's banner; with the socket already open we can relay that
145    // banner while we wait, instead of burning the peek budget pre-connect.
146    let stream = match TcpStream::connect(connect_dst).await {
147        Ok(stream) => {
148            proxy_connect.mark_connected();
149            stream
150        }
151        Err(e) => {
152            proxy_connect.mark_upstream_connect_failed();
153            shared.proxy_wake.wake();
154            return Err(e);
155        }
156    };
157    let (mut server_rx, mut server_tx) = stream.into_split();
158
159    // Finish classifying the first flight (TLS vs plain HTTP) and, for
160    // plain-HTTP candidates, gather a full header block — without blocking the
161    // server→guest direction. When domain rules already peeked, `initial_buf`
162    // is reused and this is cheap; with no secrets it is skipped entirely
163    // (`is_tls` only matters for deciding whether to build the handler).
164    let want_headers = secrets.has_plain_http_candidates() || secrets.has_host_scoped_secrets();
165    let (initial_buf, is_tls) = if !secrets.secrets.is_empty() {
166        classify_first_flight(
167            initial_buf,
168            &mut from_smoltcp,
169            &mut server_rx,
170            &to_smoltcp,
171            &shared,
172            want_headers,
173            PEEK_BUF_SIZE,
174            PEEK_BUDGET,
175        )
176        .await?
177    } else {
178        (initial_buf, false)
179    };
180
181    let mut secrets_handler: Option<SecretsHandler> = if !secrets.secrets.is_empty() && !is_tls {
182        Some(match extract_http_host(&initial_buf) {
183            Some(host) => SecretsHandler::new_plain_http(&secrets, &host, guest_dst.ip(), &shared),
184            None => SecretsHandler::new_plain_http_invalid_host(&secrets),
185        })
186    } else {
187        None
188    };
189
190    // Replay the buffered first flight — run through secrets handler first.
191    if !initial_buf.is_empty() {
192        let out: Cow<[u8]> = match secrets_handler.as_mut() {
193            Some(h) => match h.substitute(&initial_buf) {
194                // Borrow the input when nothing was substituted; only a chunk
195                // that actually carries a placeholder is reallocated.
196                Ok(cow) => cow,
197                Err(action) => {
198                    tracing::warn!(dst = %connect_dst, violation = ?action, "secret violation in first flight");
199                    if matches!(action, ViolationAction::BlockAndTerminate) {
200                        shared.trigger_termination();
201                    }
202                    return Ok(());
203                }
204            },
205            None => Cow::Borrowed(&initial_buf),
206        };
207        if !out.is_empty() {
208            if let Err(e) = server_tx.write_all(&out).await {
209                tracing::debug!(dst = %connect_dst, error = %e, "replay of buffered first flight failed");
210                return Ok(());
211            }
212            if let Err(e) = server_tx.flush().await {
213                tracing::debug!(dst = %connect_dst, error = %e, "flush after first flight failed");
214                return Ok(());
215            }
216        }
217    }
218
219    let mut server_buf = vec![0u8; SERVER_READ_BUF_SIZE];
220
221    // Bidirectional relay using tokio::select!.
222    //
223    // guest → server: receive from channel, write to server socket.
224    // server → guest: read from server socket, send via channel + wake poll.
225    loop {
226        tokio::select! {
227            // Guest → server: substitute placeholders before forwarding.
228            data = from_smoltcp.recv() => {
229                match data {
230                    Some(bytes) => {
231                        // No handler (no secrets / TLS) is the common path: forward
232                        // the chunk borrowed, with no per-chunk allocation or copy.
233                        let out: Cow<[u8]> = match secrets_handler.as_mut() {
234                            Some(h) => match h.substitute(&bytes) {
235                                Ok(cow) => cow,
236                                Err(action) => {
237                                    tracing::warn!(dst = %connect_dst, violation = ?action, "secret violation");
238                                    if matches!(action, ViolationAction::BlockAndTerminate) {
239                                        shared.trigger_termination();
240                                    }
241                                    break;
242                                }
243                            },
244                            None => Cow::Borrowed(&bytes),
245                        };
246                        if !out.is_empty() {
247                            if let Err(e) = server_tx.write_all(&out).await {
248                                tracing::debug!(dst = %connect_dst, error = %e, "write to server failed");
249                                break;
250                            }
251                            if let Err(e) = server_tx.flush().await {
252                                tracing::debug!(dst = %connect_dst, error = %e, "flush to server failed");
253                                break;
254                            }
255                        }
256                    }
257                    // Channel closed — smoltcp socket was closed by guest.
258                    None => break,
259                }
260            }
261
262            // Server → guest: no substitution — server never sends placeholders.
263            result = server_rx.read(&mut server_buf) => {
264                match result {
265                    Ok(0) => break, // Server closed connection.
266                    Ok(n) => {
267                        let data = Bytes::copy_from_slice(&server_buf[..n]);
268                        if to_smoltcp.send(data).await.is_err() {
269                            // Channel closed — poll loop dropped the receiver.
270                            break;
271                        }
272                        // Wake the poll thread so it writes data to the
273                        // smoltcp socket.
274                        shared.proxy_wake.wake();
275                    }
276                    Err(e) => {
277                        tracing::debug!(dst = %connect_dst, error = %e, "read from server failed");
278                        break;
279                    }
280                }
281            }
282        }
283    }
284
285    Ok(())
286}
287
288/// Extract the `Host:` header value from an already-buffered HTTP header block.
289///
290/// Returns `None` if:
291/// - The first byte is `0x16` (TLS — not HTTP)
292/// - The buffer does not yet contain `\r\n\r\n` (headers incomplete)
293/// - No `Host:` header is present
294///
295/// Strips port suffix, lowercases, and trims whitespace. Result is
296/// ready for byte-equal matching against `SecretEntry::allowed_hosts`.
297fn extract_http_host(buf: &[u8]) -> Option<String> {
298    if buf.first() == Some(&0x16) {
299        return None;
300    }
301    // Size the header pool to the buffer rather than a fixed array: a header
302    // line is at least four bytes (`a:\r\n`), so `len / 4` always covers the
303    // real header count, and `httparse` never reports `TooManyHeaders` (which
304    // would make a request with many headers look hostless). The first flight
305    // is capped at PEEK_BUF_SIZE, so this stays bounded.
306    let mut headers = vec![httparse::EMPTY_HEADER; (buf.len() / 4).max(16)];
307    let mut req = httparse::Request::new(&mut headers);
308    req.parse(buf).ok()?;
309    req.headers
310        .iter()
311        .find(|h| h.name.eq_ignore_ascii_case("host"))
312        .and_then(|h| std::str::from_utf8(h.value).ok())
313        .map(|v| {
314            let host = v.trim();
315            // Strip port suffix.
316            host.rsplit_once(':')
317                .map(|(h, _)| h)
318                .unwrap_or(host)
319                .to_ascii_lowercase()
320        })
321        .filter(|h| !h.is_empty())
322}
323
324/// Finish classifying the guest's first flight after the upstream socket is
325/// open, returning the (possibly extended) first-flight buffer and whether it
326/// is a TLS record.
327///
328/// `buf` carries whatever a pre-connect domain-rule peek already captured; when
329/// it is non-empty the TLS/plain decision is already settled and only header
330/// top-up runs. `want_headers` is set when at least one secret can be
331/// substituted over plain HTTP (`SecretsConfig::has_plain_http_candidates`); it
332/// makes the peek keep reading a non-TLS flight until `\r\n\r\n` so
333/// [`extract_http_host`] sees a complete header block.
334///
335/// Crucially, this relays server→guest while it waits. Server-first protocols
336/// (SSH, SMTP, databases) send nothing until they have seen the server's
337/// banner; draining the server side here lets the banner reach the guest
338/// immediately, so the guest's eventual first flight — not a 5s timeout — is
339/// what ends the peek.
340#[allow(clippy::too_many_arguments)]
341async fn classify_first_flight(
342    mut buf: Vec<u8>,
343    from_smoltcp: &mut mpsc::Receiver<Bytes>,
344    server_rx: &mut tokio::net::tcp::OwnedReadHalf,
345    to_smoltcp: &mpsc::Sender<Bytes>,
346    shared: &SharedState,
347    want_headers: bool,
348    max: usize,
349    budget: Duration,
350) -> io::Result<(Vec<u8>, bool)> {
351    let mut server_buf = vec![0u8; SERVER_READ_BUF_SIZE];
352    let timeout_fut = tokio::time::sleep(budget);
353    tokio::pin!(timeout_fut);
354
355    loop {
356        // Stop as soon as the protocol class is known and — for plain-HTTP
357        // candidates — a full header block has arrived. Bail the moment a
358        // non-TLS flight stops looking like an HTTP request so non-HTTP
359        // protocols (SSH, Postgres) aren't withheld from upstream for the
360        // whole budget while we wait for a `\r\n\r\n` that never comes.
361        if !buf.is_empty() {
362            let is_tls = buf.first() == Some(&0x16);
363            let not_http = !is_tls
364                && (!looks_like_http_request_prefix(&buf) || first_line_is_not_http_request(&buf));
365            let done = !want_headers
366                || is_tls
367                || not_http
368                || buf.len() >= max
369                || buf.windows(4).any(|w| w == b"\r\n\r\n");
370            if done {
371                return Ok((buf, is_tls));
372            }
373        }
374
375        tokio::select! {
376            biased;
377            _ = &mut timeout_fut => {
378                let is_tls = buf.first() == Some(&0x16);
379                return Ok((buf, is_tls));
380            }
381            // Guest → buffer (not forwarded here; the caller replays it once the
382            // handler is built, so substitution applies to the first flight too).
383            guest = from_smoltcp.recv() => match guest {
384                Some(bytes) => buf.extend_from_slice(&bytes),
385                None => {
386                    let is_tls = buf.first() == Some(&0x16);
387                    return Ok((buf, is_tls));
388                }
389            },
390            // Server → guest: relay immediately so a server-first banner is never
391            // held hostage by the peek.
392            server = server_rx.read(&mut server_buf) => match server {
393                Ok(0) => {
394                    let is_tls = buf.first() == Some(&0x16);
395                    return Ok((buf, is_tls));
396                }
397                Ok(n) => {
398                    let data = Bytes::copy_from_slice(&server_buf[..n]);
399                    if to_smoltcp.send(data).await.is_err() {
400                        let is_tls = buf.first() == Some(&0x16);
401                        return Ok((buf, is_tls));
402                    }
403                    shared.proxy_wake.wake();
404                }
405                Err(e) => return Err(e),
406            },
407        }
408    }
409}
410
411/// Buffer the first flight until SNI can be extracted, or until one
412/// of the bail-out conditions hits (channel close, buffer cap,
413/// timeout). Never errors; non-TLS / slow / malformed input all
414/// fall through to `None`.
415///
416/// On hit, the SNI is canonicalized (lowercase + trim trailing dot)
417/// for byte-equal matching against rule destinations. The returned
418/// buffer must be replayed verbatim to upstream before the caller
419/// starts its relay loop.
420async fn peek_for_sni(
421    rx: &mut mpsc::Receiver<Bytes>,
422    max: usize,
423    budget: Duration,
424) -> (Vec<u8>, Option<String>) {
425    let mut buf = Vec::with_capacity(PEEK_BUF_SIZE.min(8192));
426    let timeout_fut = tokio::time::sleep(budget);
427    tokio::pin!(timeout_fut);
428
429    let raw_sni = loop {
430        tokio::select! {
431            biased;
432            _ = &mut timeout_fut => break None,
433            data = rx.recv() => {
434                match data {
435                    Some(bytes) => {
436                        buf.extend_from_slice(&bytes);
437                        // First byte of a TLS record is the ContentType;
438                        // 0x16 is handshake. Anything else can't be a
439                        // ClientHello, so don't burn the full budget on
440                        // plain HTTP / SSH / etc.
441                        if buf.first() != Some(&0x16) {
442                            break None;
443                        }
444                        if let Some(name) = sni::extract_sni(&buf) {
445                            break Some(name);
446                        }
447                        if buf.len() >= max {
448                            break None;
449                        }
450                    }
451                    None => break None,
452                }
453            }
454        }
455    };
456
457    let canonical = raw_sni.map(|s| s.trim_end_matches('.').to_ascii_lowercase());
458    (buf, canonical)
459}
460
461//--------------------------------------------------------------------------------------------------
462// Tests
463//--------------------------------------------------------------------------------------------------
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468
469    /// Synthetic TLS ClientHello carrying SNI `example.com`. Bytes
470    /// borrowed from `tls::sni` test fixtures so the parser sees a
471    /// well-formed record.
472    fn synthetic_client_hello(sni: &str) -> Vec<u8> {
473        // Minimal but valid TLS 1.2 ClientHello with one SNI entry.
474        // Layout: record header (5) + handshake header (4) + body.
475        let host_bytes = sni.as_bytes();
476        let host_len = host_bytes.len() as u16;
477        let server_name_list_len = 3 + host_len; // type(1) + len(2) + host
478        let extension_data_len = 2 + server_name_list_len; // list-len(2) + list
479        let extensions_total = 4 + extension_data_len; // type(2) + len(2) + data
480
481        let mut body = Vec::new();
482        // Client version
483        body.extend_from_slice(&[0x03, 0x03]);
484        // Random (32 bytes)
485        body.extend_from_slice(&[0u8; 32]);
486        // Session id length + (empty)
487        body.push(0);
488        // Cipher suites length + one cipher
489        body.extend_from_slice(&[0x00, 0x02, 0x00, 0x2f]);
490        // Compression methods length + null
491        body.extend_from_slice(&[0x01, 0x00]);
492        // Extensions length
493        body.extend_from_slice(&extensions_total.to_be_bytes());
494        // SNI extension: type 0x0000
495        body.extend_from_slice(&[0x00, 0x00]);
496        body.extend_from_slice(&extension_data_len.to_be_bytes());
497        body.extend_from_slice(&server_name_list_len.to_be_bytes());
498        body.push(0x00); // host_name type
499        body.extend_from_slice(&host_len.to_be_bytes());
500        body.extend_from_slice(host_bytes);
501
502        let handshake_len = body.len() as u32;
503        let mut hs = Vec::new();
504        hs.push(0x01); // ClientHello
505        hs.extend_from_slice(&handshake_len.to_be_bytes()[1..]); // 24-bit length
506        hs.extend_from_slice(&body);
507
508        let record_len = hs.len() as u16;
509        let mut record = Vec::new();
510        record.extend_from_slice(&[0x16, 0x03, 0x01]); // Handshake, TLS 1.0
511        record.extend_from_slice(&record_len.to_be_bytes());
512        record.extend_from_slice(&hs);
513
514        record
515    }
516
517    #[tokio::test]
518    async fn peek_for_sni_extracts_and_canonicalizes() {
519        let (tx, mut rx) = mpsc::channel(4);
520        let hello = synthetic_client_hello("Example.COM");
521        tx.send(Bytes::from(hello.clone())).await.unwrap();
522        drop(tx); // close so peek returns even if SNI didn't satisfy
523
524        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
525        assert_eq!(sni.as_deref(), Some("example.com"));
526        assert_eq!(buf, hello);
527    }
528
529    #[tokio::test]
530    async fn peek_for_sni_returns_none_on_channel_close_without_data() {
531        let (tx, mut rx) = mpsc::channel::<Bytes>(1);
532        drop(tx);
533        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
534        assert!(buf.is_empty());
535        assert_eq!(sni, None);
536    }
537
538    #[tokio::test]
539    async fn peek_for_sni_returns_none_on_non_tls_data() {
540        let (tx, mut rx) = mpsc::channel(4);
541        // Plaintext HTTP request; not a TLS record so extract_sni returns None.
542        tx.send(Bytes::from_static(
543            b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n",
544        ))
545        .await
546        .unwrap();
547        drop(tx);
548        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
549        assert!(
550            !buf.is_empty(),
551            "buffered bytes must be returned for replay"
552        );
553        assert_eq!(sni, None);
554    }
555
556    #[tokio::test]
557    async fn peek_for_sni_falls_back_on_timeout() {
558        let (tx, mut rx) = mpsc::channel::<Bytes>(1);
559        // Hold the sender open but send nothing — peek must time out.
560        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, Duration::from_millis(50)).await;
561        drop(tx);
562        assert!(buf.is_empty());
563        assert_eq!(sni, None);
564    }
565
566    #[tokio::test]
567    async fn peek_for_sni_caps_at_max_bytes() {
568        let (tx, mut rx) = mpsc::channel(4);
569        // First byte 0x16 keeps the peek collecting past the early
570        // non-TLS bail. Padding bytes are zero so the SNI parser never
571        // matches and the loop drives to the size cap.
572        let mut first = vec![0u8; 8192];
573        first[0] = 0x16;
574        tx.send(Bytes::from(first)).await.unwrap();
575        tx.send(Bytes::from(vec![0u8; 8192])).await.unwrap();
576        tx.send(Bytes::from(vec![0u8; 8192])).await.unwrap();
577        drop(tx);
578
579        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
580        assert_eq!(sni, None, "no SNI in non-TLS data");
581        assert!(
582            buf.len() >= PEEK_BUF_SIZE,
583            "buffer must hit the cap before bail-out: got {}",
584            buf.len()
585        );
586    }
587
588    #[tokio::test]
589    async fn peek_for_sni_bails_immediately_on_non_tls_first_byte() {
590        let (tx, mut rx) = mpsc::channel(4);
591        // Plain HTTP request: first byte 'G' (0x47) — clearly not TLS.
592        tx.send(Bytes::from_static(b"GET / HTTP/1.1\r\nHost: x\r\n\r\n"))
593            .await
594            .unwrap();
595        drop(tx);
596
597        // 5-second nominal budget; assert we returned in well under
598        // that — the early-bail must not wait for the full window.
599        let started = std::time::Instant::now();
600        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
601        let elapsed = started.elapsed();
602        assert_eq!(sni, None);
603        assert!(buf.starts_with(b"GET"));
604        assert!(
605            elapsed < Duration::from_millis(500),
606            "non-TLS bail must be fast: took {elapsed:?}"
607        );
608    }
609
610    //----------------------------------------------------------------------------------------------
611    // peek_for_sni × evaluate_egress_with_source — combined integration tests
612    //----------------------------------------------------------------------------------------------
613
614    use std::net::IpAddr;
615    use std::time::Duration as StdDuration;
616
617    use crate::policy::{Action, Destination, NetworkPolicy, PortRange, Rule};
618    use crate::shared::{ResolvedHostnameFamily, SharedState};
619
620    const SHARED_FASTLY_IP: &str = "151.101.0.223";
621
622    fn shared_with(host: &str, ip: &str) -> SharedState {
623        let shared = SharedState::new(4);
624        shared.cache_resolved_hostname(
625            host,
626            ResolvedHostnameFamily::Ipv4,
627            [ip.parse::<IpAddr>().unwrap()],
628            StdDuration::from_secs(60),
629        );
630        shared
631    }
632
633    fn allow_https(domain: &str) -> Rule {
634        Rule {
635            direction: crate::policy::Direction::Egress,
636            destination: Destination::Domain(domain.parse().unwrap()),
637            protocols: vec![Protocol::Tcp],
638            ports: vec![PortRange::single(443)],
639            action: Action::Allow,
640        }
641    }
642
643    /// Over-allow case: cache says IP X is `pypi.org` (allowed); SNI
644    /// is `evil.com`. SNI must override the cache and deny.
645    #[tokio::test]
646    async fn integration_sni_overrides_cache_for_over_allow() {
647        let shared = shared_with("pypi.org", SHARED_FASTLY_IP);
648        let policy = NetworkPolicy {
649            default_egress: Action::Deny,
650            default_ingress: Action::Allow,
651            rules: vec![allow_https("pypi.org")],
652        };
653        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
654
655        let (tx, mut rx) = mpsc::channel(4);
656        tx.send(Bytes::from(synthetic_client_hello("evil.com")))
657            .await
658            .unwrap();
659        drop(tx);
660
661        let (initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
662        assert_eq!(sni.as_deref(), Some("evil.com"));
663        assert!(!initial_buf.is_empty());
664
665        let source = sni
666            .as_deref()
667            .map(HostnameSource::Sni)
668            .unwrap_or(HostnameSource::CacheOnly);
669        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
670        assert_eq!(
671            eval,
672            EgressEvaluation::Deny,
673            "SNI=evil.com must not piggy-back on the cached pypi.org match",
674        );
675    }
676
677    /// Over-block case: cache says IP X is `ads.example.com` (denied);
678    /// SNI is `api.example.com`. SNI must override the cache and allow.
679    #[tokio::test]
680    async fn integration_sni_overrides_cache_for_over_block() {
681        let shared = shared_with("ads.example.com", SHARED_FASTLY_IP);
682        let policy = NetworkPolicy {
683            default_egress: Action::Allow,
684            default_ingress: Action::Allow,
685            rules: vec![Rule::deny_egress(Destination::Domain(
686                "ads.example.com".parse().unwrap(),
687            ))],
688        };
689        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
690
691        let (tx, mut rx) = mpsc::channel(4);
692        tx.send(Bytes::from(synthetic_client_hello("api.example.com")))
693            .await
694            .unwrap();
695        drop(tx);
696
697        let (_initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
698        assert_eq!(sni.as_deref(), Some("api.example.com"));
699
700        let source = sni
701            .as_deref()
702            .map(HostnameSource::Sni)
703            .unwrap_or(HostnameSource::CacheOnly);
704        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
705        assert_eq!(
706            eval,
707            EgressEvaluation::Allow,
708            "SNI=api.example.com must not be caught by the deny on ads.example.com",
709        );
710    }
711
712    /// Non-TLS first-flight falls back to `CacheOnly`; the cache
713    /// match decides.
714    #[tokio::test]
715    async fn integration_non_tls_falls_back_to_cache() {
716        let shared = shared_with("pypi.org", SHARED_FASTLY_IP);
717        let policy = NetworkPolicy {
718            default_egress: Action::Deny,
719            default_ingress: Action::Allow,
720            rules: vec![allow_https("pypi.org")],
721        };
722        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
723
724        let (tx, mut rx) = mpsc::channel(4);
725        // Plain HTTP request; not a TLS record.
726        tx.send(Bytes::from_static(
727            b"GET / HTTP/1.1\r\nHost: pypi.org\r\n\r\n",
728        ))
729        .await
730        .unwrap();
731        drop(tx);
732
733        let (initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
734        assert_eq!(sni, None, "non-TLS data → no SNI");
735        assert!(
736            !initial_buf.is_empty(),
737            "buffered bytes must survive for replay"
738        );
739
740        let source = sni
741            .as_deref()
742            .map(HostnameSource::Sni)
743            .unwrap_or(HostnameSource::CacheOnly);
744        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
745        assert_eq!(
746            eval,
747            EgressEvaluation::Allow,
748            "cache-only fallback must still allow the cached hostname's IP",
749        );
750    }
751
752    /// SNI matches a `DomainSuffix` rule with a cache binding for the
753    /// claimed name. Genuine pre-resolved traffic passes.
754    #[tokio::test]
755    async fn integration_sni_matches_domain_suffix_with_cache_binding() {
756        let shared = shared_with("files.pythonhosted.org", SHARED_FASTLY_IP);
757        let policy = NetworkPolicy {
758            default_egress: Action::Deny,
759            default_ingress: Action::Allow,
760            rules: vec![Rule {
761                direction: crate::policy::Direction::Egress,
762                destination: Destination::DomainSuffix(".pythonhosted.org".parse().unwrap()),
763                protocols: vec![Protocol::Tcp],
764                ports: vec![PortRange::single(443)],
765                action: Action::Allow,
766            }],
767        };
768        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
769
770        let (tx, mut rx) = mpsc::channel(4);
771        tx.send(Bytes::from(synthetic_client_hello(
772            "files.pythonhosted.org",
773        )))
774        .await
775        .unwrap();
776        drop(tx);
777
778        let (_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
779        let source = sni
780            .as_deref()
781            .map(HostnameSource::Sni)
782            .unwrap_or(HostnameSource::CacheOnly);
783        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
784        assert_eq!(eval, EgressEvaluation::Allow);
785    }
786
787    /// Spoofed SNI on an IP with no cache binding for any matching
788    /// name: byte-equality with the suffix passes, but no DNS lookup
789    /// ever tied a `*.pythonhosted.org` name to the destination, so
790    /// the AND-check fails and the connection is denied.
791    #[tokio::test]
792    async fn integration_sni_denies_domain_suffix_without_cache_binding() {
793        let shared = SharedState::new(4); // empty cache
794        let policy = NetworkPolicy {
795            default_egress: Action::Deny,
796            default_ingress: Action::Allow,
797            rules: vec![Rule {
798                direction: crate::policy::Direction::Egress,
799                destination: Destination::DomainSuffix(".pythonhosted.org".parse().unwrap()),
800                protocols: vec![Protocol::Tcp],
801                ports: vec![PortRange::single(443)],
802                action: Action::Allow,
803            }],
804        };
805        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
806
807        let (tx, mut rx) = mpsc::channel(4);
808        tx.send(Bytes::from(synthetic_client_hello(
809            "files.pythonhosted.org",
810        )))
811        .await
812        .unwrap();
813        drop(tx);
814
815        let (_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
816        let source = sni
817            .as_deref()
818            .map(HostnameSource::Sni)
819            .unwrap_or(HostnameSource::CacheOnly);
820        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
821        assert_eq!(eval, EgressEvaluation::Deny);
822    }
823
824    // ── extract_http_host ──────────────────────────────────────────────────────
825
826    #[test]
827    fn extract_http_host_basic() {
828        let buf = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n";
829        assert_eq!(extract_http_host(buf), Some("example.com".into()));
830    }
831
832    #[test]
833    fn extract_http_host_strips_port() {
834        let buf = b"POST /api HTTP/1.1\r\nHost: api.company.com:8080\r\n\r\n";
835        assert_eq!(extract_http_host(buf), Some("api.company.com".into()));
836    }
837
838    #[test]
839    fn extract_http_host_case_insensitive_lowercased() {
840        let buf = b"GET / HTTP/1.1\r\nhost: Example.COM\r\n\r\n";
841        assert_eq!(extract_http_host(buf), Some("example.com".into()));
842    }
843
844    #[test]
845    fn extract_http_host_no_host_header() {
846        let buf = b"GET / HTTP/1.1\r\nX-Other: foo\r\n\r\n";
847        assert_eq!(extract_http_host(buf), None);
848    }
849
850    #[test]
851    fn extract_http_host_incomplete_headers() {
852        let buf = b"GET / HTTP/1.1\r\nHost: x";
853        assert_eq!(extract_http_host(buf), None);
854    }
855
856    #[test]
857    fn extract_http_host_tls_first_byte() {
858        let buf = [0x16u8, 0x03, 0x01, 0x00, 0x01];
859        assert_eq!(extract_http_host(&buf), None);
860    }
861
862    #[test]
863    fn extract_http_host_with_many_headers() {
864        // Far more headers than a small fixed parse array would hold: the Host
865        // must still be found rather than the request looking hostless.
866        let mut req = Vec::from(&b"GET / HTTP/1.1\r\n"[..]);
867        for i in 0..100 {
868            req.extend_from_slice(format!("X-Pad-{i}: v\r\n").as_bytes());
869        }
870        req.extend_from_slice(b"Host: example.com\r\n\r\n");
871        assert_eq!(extract_http_host(&req), Some("example.com".into()));
872    }
873
874    // ── plain-HTTP secret substitution ────────────────────────────────────────
875
876    use std::sync::Arc;
877    use tokio::io::AsyncReadExt;
878    use tokio::net::TcpListener;
879    use tokio::task::JoinHandle;
880
881    use crate::secrets::config::{HostPattern, SecretEntry, SecretInjection, SecretsConfig};
882
883    fn make_plain_http_secret(placeholder: &str, value: &str, require_tls: bool) -> SecretsConfig {
884        SecretsConfig {
885            secrets: vec![SecretEntry {
886                env_var: "API_KEY".into(),
887                value: value.into(),
888                placeholder: placeholder.into(),
889                allowed_hosts: vec![HostPattern::Any],
890                injection: SecretInjection {
891                    headers: true,
892                    basic_auth: false,
893                    query_params: false,
894                    body: false,
895                },
896                on_violation: None,
897                require_tls_identity: require_tls,
898            }],
899            ..Default::default()
900        }
901    }
902
903    async fn spawn_sink() -> (SocketAddr, JoinHandle<Vec<u8>>) {
904        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
905        let addr = listener.local_addr().unwrap();
906        let handle = tokio::spawn(async move {
907            let (mut stream, _) = listener.accept().await.unwrap();
908            let mut received = Vec::new();
909            let mut buf = vec![0u8; 4096];
910            loop {
911                match stream.read(&mut buf).await {
912                    Ok(0) | Err(_) => break,
913                    Ok(n) => received.extend_from_slice(&buf[..n]),
914                }
915            }
916            received
917        });
918        (addr, handle)
919    }
920
921    async fn relay_through_proxy(
922        request: Vec<u8>,
923        secrets: SecretsConfig,
924        handle: JoinHandle<Vec<u8>>,
925        server_addr: SocketAddr,
926    ) -> Vec<u8> {
927        let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
928        let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
929        let shared = SharedState::new(4);
930        let policy = Arc::new(NetworkPolicy::default());
931        let secrets = Arc::new(secrets);
932        let proxy_connect = Arc::new(ProxyConnectState::new());
933
934        from_tx.send(Bytes::from(request)).await.unwrap();
935        drop(from_tx);
936
937        tcp_proxy_task(
938            server_addr,
939            server_addr,
940            from_rx,
941            to_tx,
942            Arc::new(shared),
943            policy,
944            secrets,
945            proxy_connect,
946        )
947        .await
948        .unwrap();
949
950        handle.await.unwrap()
951    }
952
953    #[tokio::test]
954    async fn plain_http_substitutes_placeholder_when_host_arrives_in_second_segment() {
955        // Host header split across TCP segments — classify_first_flight must keep
956        // reading until \r\n\r\n before extract_http_host is called.
957        let (addr, sink) = spawn_sink().await;
958        let secrets = make_plain_http_secret("$MSB_KEY", "real-secret-value", false);
959
960        let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
961        let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
962        let proxy_connect = Arc::new(ProxyConnectState::new());
963
964        from_tx
965            .send(Bytes::from_static(b"GET /api HTTP/1.1\r\n"))
966            .await
967            .unwrap();
968        from_tx
969            .send(Bytes::from_static(
970                b"Host: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n",
971            ))
972            .await
973            .unwrap();
974        drop(from_tx);
975
976        tcp_proxy_task(
977            addr,
978            addr,
979            from_rx,
980            to_tx,
981            Arc::new(SharedState::new(4)),
982            Arc::new(NetworkPolicy::default()),
983            Arc::new(secrets),
984            proxy_connect,
985        )
986        .await
987        .unwrap();
988
989        let wire = String::from_utf8(sink.await.unwrap()).unwrap();
990        assert!(wire.contains("real-secret-value"), "got: {wire:?}");
991        assert!(!wire.contains("$MSB_KEY"), "got: {wire:?}");
992    }
993
994    #[tokio::test]
995    async fn plain_http_forwards_placeholder_to_allowed_host_with_split_headers() {
996        // A default (require_tls_identity = true) host-bound secret is never
997        // substituted over plain HTTP, but a request to its allowed host must
998        // have the placeholder forwarded unchanged — not blocked as a violation
999        // — even when the Host arrives in a later segment than the request line.
1000        let (addr, sink) = spawn_sink().await;
1001
1002        let shared = SharedState::new(4);
1003        shared.cache_resolved_hostname(
1004            "example.com",
1005            ResolvedHostnameFamily::Ipv4,
1006            ["127.0.0.1".parse::<IpAddr>().unwrap()],
1007            StdDuration::from_secs(60),
1008        );
1009
1010        let secrets = SecretsConfig {
1011            secrets: vec![SecretEntry {
1012                env_var: "API_KEY".into(),
1013                value: "real-secret-value".into(),
1014                placeholder: "$MSB_KEY".into(),
1015                allowed_hosts: vec![HostPattern::Exact("example.com".into())],
1016                injection: SecretInjection {
1017                    headers: true,
1018                    basic_auth: false,
1019                    query_params: false,
1020                    body: false,
1021                },
1022                on_violation: None,
1023                require_tls_identity: true,
1024            }],
1025            ..Default::default()
1026        };
1027
1028        let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
1029        let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
1030        let proxy_connect = Arc::new(ProxyConnectState::new());
1031
1032        from_tx
1033            .send(Bytes::from_static(b"GET /api HTTP/1.1\r\n"))
1034            .await
1035            .unwrap();
1036        from_tx
1037            .send(Bytes::from_static(
1038                b"Host: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n",
1039            ))
1040            .await
1041            .unwrap();
1042        drop(from_tx);
1043
1044        tcp_proxy_task(
1045            addr,
1046            addr,
1047            from_rx,
1048            to_tx,
1049            Arc::new(shared),
1050            Arc::new(NetworkPolicy::default()),
1051            Arc::new(secrets),
1052            proxy_connect,
1053        )
1054        .await
1055        .unwrap();
1056
1057        let wire = String::from_utf8(sink.await.unwrap()).unwrap();
1058        assert!(
1059            wire.contains("Host: example.com"),
1060            "request must reach the allowed host, got: {wire:?}"
1061        );
1062        assert!(
1063            wire.contains("$MSB_KEY"),
1064            "placeholder must be forwarded unchanged for a require_tls_identity secret, got: {wire:?}"
1065        );
1066        assert!(
1067            !wire.contains("real-secret-value"),
1068            "secret must never be substituted over plain HTTP, got: {wire:?}"
1069        );
1070    }
1071
1072    #[tokio::test]
1073    async fn plain_http_substitutes_placeholder_in_first_flight() {
1074        let (addr, sink) = spawn_sink().await;
1075
1076        let request =
1077            b"GET /api HTTP/1.1\r\nHost: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n"
1078                .to_vec();
1079        let secrets = make_plain_http_secret("$MSB_KEY", "real-secret-value", false);
1080
1081        let wire =
1082            String::from_utf8(relay_through_proxy(request, secrets, sink, addr).await).unwrap();
1083        assert!(
1084            wire.contains("real-secret-value"),
1085            "real value must reach server, got: {wire:?}"
1086        );
1087        assert!(
1088            !wire.contains("$MSB_KEY"),
1089            "placeholder must not reach server, got: {wire:?}"
1090        );
1091    }
1092
1093    #[tokio::test]
1094    async fn plain_http_no_substitution_when_require_tls_identity_true() {
1095        let (addr, sink) = spawn_sink().await;
1096
1097        let request =
1098            b"GET /api HTTP/1.1\r\nHost: example.com\r\nAuthorization: Bearer $MSB_KEY\r\n\r\n"
1099                .to_vec();
1100        let secrets = make_plain_http_secret("$MSB_KEY", "real-secret-value", true);
1101
1102        let wire =
1103            String::from_utf8_lossy(&relay_through_proxy(request, secrets, sink, addr).await)
1104                .into_owned();
1105        assert!(
1106            wire.contains("$MSB_KEY"),
1107            "placeholder must be forwarded unchanged when require_tls_identity=true, got: {wire:?}"
1108        );
1109        assert!(
1110            !wire.contains("real-secret-value"),
1111            "real value must not leak when require_tls_identity=true, got: {wire:?}"
1112        );
1113    }
1114
1115    #[tokio::test]
1116    async fn plain_http_large_body_forwarded_verbatim_in_relay_loop() {
1117        // Body arrives in a separate segment after headers — flows through the relay
1118        // loop, not the peek path. Ensures no bytes are dropped and header substitution
1119        // still happens.
1120        let (addr, sink) = spawn_sink().await;
1121        let secrets = make_plain_http_secret("$MSB_KEY", "real-value", false);
1122
1123        let body = "x".repeat(32_000);
1124        let header = format!(
1125            "POST /upload HTTP/1.1\r\nHost: example.com\r\nAuthorization: Bearer $MSB_KEY\r\nContent-Length: {}\r\n\r\n",
1126            body.len()
1127        );
1128
1129        let (from_tx, from_rx) = mpsc::channel::<Bytes>(8);
1130        let (to_tx, _to_rx) = mpsc::channel::<Bytes>(8);
1131        let proxy_connect = Arc::new(ProxyConnectState::new());
1132
1133        from_tx
1134            .send(Bytes::from(header.into_bytes()))
1135            .await
1136            .unwrap();
1137        from_tx
1138            .send(Bytes::from(body.clone().into_bytes()))
1139            .await
1140            .unwrap();
1141        drop(from_tx);
1142
1143        tcp_proxy_task(
1144            addr,
1145            addr,
1146            from_rx,
1147            to_tx,
1148            Arc::new(SharedState::new(4)),
1149            Arc::new(NetworkPolicy::default()),
1150            Arc::new(secrets),
1151            proxy_connect,
1152        )
1153        .await
1154        .unwrap();
1155
1156        let wire = String::from_utf8_lossy(&sink.await.unwrap()).into_owned();
1157        assert!(wire.contains(&body), "got {} bytes", wire.len());
1158        assert!(!wire.contains("$MSB_KEY"), "got: {wire:?}");
1159    }
1160}