Skip to main content

microsandbox_network/
proxy.rs

1//! Bidirectional TCP proxy: smoltcp socket ↔ channels ↔ tokio socket.
2//!
3//! Each outbound guest TCP connection gets a proxy task that opens a real
4//! TCP connection to the destination via tokio and relays data between the
5//! channel pair (connected to the smoltcp socket in the poll loop) and the
6//! real server.
7
8use std::io;
9use std::net::SocketAddr;
10use std::sync::Arc;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::time::Duration;
13
14use bytes::Bytes;
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16use tokio::net::TcpStream;
17use tokio::sync::mpsc;
18
19use crate::policy::{EgressEvaluation, HostnameSource, NetworkPolicy, Protocol};
20use crate::shared::SharedState;
21use crate::tls::sni;
22
23//--------------------------------------------------------------------------------------------------
24// Constants
25//--------------------------------------------------------------------------------------------------
26
27/// Buffer size for reading from the real server.
28const SERVER_READ_BUF_SIZE: usize = 16384;
29
30/// Max bytes to buffer while peeking for the ClientHello's SNI.
31const PEEK_BUF_SIZE: usize = 16384;
32
33/// Upper bound on time spent buffering the first flight before
34/// falling back to a cache-only egress decision.
35const PEEK_BUDGET: Duration = Duration::from_secs(5);
36
37//--------------------------------------------------------------------------------------------------
38// Functions
39//--------------------------------------------------------------------------------------------------
40
41/// Spawn a TCP proxy task for a newly established connection.
42///
43/// `guest_dst` is what the guest dialed — the address policy rules
44/// match against. `connect_dst` is the host-side address tokio actually
45/// dials; for host-alias connections it's loopback (gateway rewritten).
46/// For everything else the two are identical.
47///
48/// `upstream_connected` is flipped to `true` after the upstream
49/// `TcpStream::connect` succeeds. The connection tracker reads this
50/// on proxy exit to decide between FIN (clean close) and RST
51/// (upstream never reached, e.g. connect failure or policy denial).
52#[allow(clippy::too_many_arguments)]
53pub fn spawn_tcp_proxy(
54    handle: &tokio::runtime::Handle,
55    guest_dst: SocketAddr,
56    connect_dst: SocketAddr,
57    from_smoltcp: mpsc::Receiver<Bytes>,
58    to_smoltcp: mpsc::Sender<Bytes>,
59    shared: Arc<SharedState>,
60    network_policy: Arc<NetworkPolicy>,
61    upstream_connected: Arc<AtomicBool>,
62) {
63    handle.spawn(async move {
64        if let Err(e) = tcp_proxy_task(
65            guest_dst,
66            connect_dst,
67            from_smoltcp,
68            to_smoltcp,
69            shared,
70            network_policy,
71            upstream_connected,
72        )
73        .await
74        {
75            tracing::debug!(dst = %connect_dst, error = %e, "TCP proxy task ended");
76        }
77    });
78}
79
80/// Core TCP proxy: peek for SNI, evaluate egress policy, then either
81/// connect and relay or drop the channels.
82async fn tcp_proxy_task(
83    guest_dst: SocketAddr,
84    connect_dst: SocketAddr,
85    mut from_smoltcp: mpsc::Receiver<Bytes>,
86    to_smoltcp: mpsc::Sender<Bytes>,
87    shared: Arc<SharedState>,
88    network_policy: Arc<NetworkPolicy>,
89    upstream_connected: Arc<AtomicBool>,
90) -> io::Result<()> {
91    // Peek only when there's a Domain/DomainSuffix rule that could
92    // need an SNI to refine. Otherwise the SYN handler's decision is
93    // authoritative.
94    let (initial_buf, sni) = if network_policy.has_domain_rules() {
95        peek_for_sni(&mut from_smoltcp, PEEK_BUF_SIZE, PEEK_BUDGET).await
96    } else {
97        (Vec::new(), None)
98    };
99
100    // Re-evaluate egress against the *guest* dst — the address the
101    // guest dialed, not the post-rewrite host-side address. SNI
102    // refines over-allow when the cache matched a shared CDN IP;
103    // CacheOnly is the non-TLS fallback path so Domain rules still
104    // gate plain HTTP / SSH / etc.
105    if network_policy.has_domain_rules() {
106        let source = match sni.as_deref() {
107            Some(name) => HostnameSource::Sni(name),
108            None => HostnameSource::CacheOnly,
109        };
110        match network_policy.evaluate_egress_with_source(guest_dst, Protocol::Tcp, &shared, source)
111        {
112            EgressEvaluation::Allow => {}
113            EgressEvaluation::Deny => {
114                tracing::debug!(
115                    dst = %guest_dst,
116                    source = source.label(),
117                    "TCP egress denied by domain policy",
118                );
119                return Ok(());
120            }
121            EgressEvaluation::DeferUntilHostname => {
122                debug_assert!(false, "DeferUntilHostname leaked into TCP proxy task");
123                return Ok(());
124            }
125        }
126    }
127
128    let stream = TcpStream::connect(connect_dst).await?;
129    upstream_connected.store(true, Ordering::Release);
130    let (mut server_rx, mut server_tx) = stream.into_split();
131
132    // Replay the buffered first flight before relay starts.
133    if !initial_buf.is_empty()
134        && let Err(e) = server_tx.write_all(&initial_buf).await
135    {
136        tracing::debug!(dst = %connect_dst, error = %e, "replay of buffered first flight failed");
137        return Ok(());
138    }
139
140    let mut server_buf = vec![0u8; SERVER_READ_BUF_SIZE];
141
142    // Bidirectional relay using tokio::select!.
143    //
144    // guest → server: receive from channel, write to server socket.
145    // server → guest: read from server socket, send via channel + wake poll.
146    loop {
147        tokio::select! {
148            // Guest → server.
149            data = from_smoltcp.recv() => {
150                match data {
151                    Some(bytes) => {
152                        if let Err(e) = server_tx.write_all(&bytes).await {
153                            tracing::debug!(dst = %connect_dst, error = %e, "write to server failed");
154                            break;
155                        }
156                    }
157                    // Channel closed — smoltcp socket was closed by guest.
158                    None => break,
159                }
160            }
161
162            // Server → guest.
163            result = server_rx.read(&mut server_buf) => {
164                match result {
165                    Ok(0) => break, // Server closed connection.
166                    Ok(n) => {
167                        let data = Bytes::copy_from_slice(&server_buf[..n]);
168                        if to_smoltcp.send(data).await.is_err() {
169                            // Channel closed — poll loop dropped the receiver.
170                            break;
171                        }
172                        // Wake the poll thread so it writes data to the
173                        // smoltcp socket.
174                        shared.proxy_wake.wake();
175                    }
176                    Err(e) => {
177                        tracing::debug!(dst = %connect_dst, error = %e, "read from server failed");
178                        break;
179                    }
180                }
181            }
182        }
183    }
184
185    Ok(())
186}
187
188/// Buffer the first flight until SNI can be extracted, or until one
189/// of the bail-out conditions hits (channel close, buffer cap,
190/// timeout). Never errors; non-TLS / slow / malformed input all
191/// fall through to `None`.
192///
193/// On hit, the SNI is canonicalized (lowercase + trim trailing dot)
194/// for byte-equal matching against rule destinations. The returned
195/// buffer must be replayed verbatim to upstream before the caller
196/// starts its relay loop.
197async fn peek_for_sni(
198    rx: &mut mpsc::Receiver<Bytes>,
199    max: usize,
200    budget: Duration,
201) -> (Vec<u8>, Option<String>) {
202    let mut buf = Vec::with_capacity(PEEK_BUF_SIZE.min(8192));
203    let timeout_fut = tokio::time::sleep(budget);
204    tokio::pin!(timeout_fut);
205
206    let raw_sni = loop {
207        tokio::select! {
208            biased;
209            _ = &mut timeout_fut => break None,
210            data = rx.recv() => {
211                match data {
212                    Some(bytes) => {
213                        buf.extend_from_slice(&bytes);
214                        // First byte of a TLS record is the ContentType;
215                        // 0x16 is handshake. Anything else can't be a
216                        // ClientHello, so don't burn the full budget on
217                        // plain HTTP / SSH / etc.
218                        if buf.first() != Some(&0x16) {
219                            break None;
220                        }
221                        if let Some(name) = sni::extract_sni(&buf) {
222                            break Some(name);
223                        }
224                        if buf.len() >= max {
225                            break None;
226                        }
227                    }
228                    None => break None,
229                }
230            }
231        }
232    };
233
234    let canonical = raw_sni.map(|s| s.trim_end_matches('.').to_ascii_lowercase());
235    (buf, canonical)
236}
237
238//--------------------------------------------------------------------------------------------------
239// Tests
240//--------------------------------------------------------------------------------------------------
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    /// Synthetic TLS ClientHello carrying SNI `example.com`. Bytes
247    /// borrowed from `tls::sni` test fixtures so the parser sees a
248    /// well-formed record.
249    fn synthetic_client_hello(sni: &str) -> Vec<u8> {
250        // Minimal but valid TLS 1.2 ClientHello with one SNI entry.
251        // Layout: record header (5) + handshake header (4) + body.
252        let host_bytes = sni.as_bytes();
253        let host_len = host_bytes.len() as u16;
254        let server_name_list_len = 3 + host_len; // type(1) + len(2) + host
255        let extension_data_len = 2 + server_name_list_len; // list-len(2) + list
256        let extensions_total = 4 + extension_data_len; // type(2) + len(2) + data
257
258        let mut body = Vec::new();
259        // Client version
260        body.extend_from_slice(&[0x03, 0x03]);
261        // Random (32 bytes)
262        body.extend_from_slice(&[0u8; 32]);
263        // Session id length + (empty)
264        body.push(0);
265        // Cipher suites length + one cipher
266        body.extend_from_slice(&[0x00, 0x02, 0x00, 0x2f]);
267        // Compression methods length + null
268        body.extend_from_slice(&[0x01, 0x00]);
269        // Extensions length
270        body.extend_from_slice(&extensions_total.to_be_bytes());
271        // SNI extension: type 0x0000
272        body.extend_from_slice(&[0x00, 0x00]);
273        body.extend_from_slice(&extension_data_len.to_be_bytes());
274        body.extend_from_slice(&server_name_list_len.to_be_bytes());
275        body.push(0x00); // host_name type
276        body.extend_from_slice(&host_len.to_be_bytes());
277        body.extend_from_slice(host_bytes);
278
279        let handshake_len = body.len() as u32;
280        let mut hs = Vec::new();
281        hs.push(0x01); // ClientHello
282        hs.extend_from_slice(&handshake_len.to_be_bytes()[1..]); // 24-bit length
283        hs.extend_from_slice(&body);
284
285        let record_len = hs.len() as u16;
286        let mut record = Vec::new();
287        record.extend_from_slice(&[0x16, 0x03, 0x01]); // Handshake, TLS 1.0
288        record.extend_from_slice(&record_len.to_be_bytes());
289        record.extend_from_slice(&hs);
290
291        record
292    }
293
294    #[tokio::test]
295    async fn peek_for_sni_extracts_and_canonicalizes() {
296        let (tx, mut rx) = mpsc::channel(4);
297        let hello = synthetic_client_hello("Example.COM");
298        tx.send(Bytes::from(hello.clone())).await.unwrap();
299        drop(tx); // close so peek returns even if SNI didn't satisfy
300
301        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
302        assert_eq!(sni.as_deref(), Some("example.com"));
303        assert_eq!(buf, hello);
304    }
305
306    #[tokio::test]
307    async fn peek_for_sni_returns_none_on_channel_close_without_data() {
308        let (tx, mut rx) = mpsc::channel::<Bytes>(1);
309        drop(tx);
310        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
311        assert!(buf.is_empty());
312        assert_eq!(sni, None);
313    }
314
315    #[tokio::test]
316    async fn peek_for_sni_returns_none_on_non_tls_data() {
317        let (tx, mut rx) = mpsc::channel(4);
318        // Plaintext HTTP request; not a TLS record so extract_sni returns None.
319        tx.send(Bytes::from_static(
320            b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n",
321        ))
322        .await
323        .unwrap();
324        drop(tx);
325        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
326        assert!(
327            !buf.is_empty(),
328            "buffered bytes must be returned for replay"
329        );
330        assert_eq!(sni, None);
331    }
332
333    #[tokio::test]
334    async fn peek_for_sni_falls_back_on_timeout() {
335        let (tx, mut rx) = mpsc::channel::<Bytes>(1);
336        // Hold the sender open but send nothing — peek must time out.
337        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, Duration::from_millis(50)).await;
338        drop(tx);
339        assert!(buf.is_empty());
340        assert_eq!(sni, None);
341    }
342
343    #[tokio::test]
344    async fn peek_for_sni_caps_at_max_bytes() {
345        let (tx, mut rx) = mpsc::channel(4);
346        // First byte 0x16 keeps the peek collecting past the early
347        // non-TLS bail. Padding bytes are zero so the SNI parser never
348        // matches and the loop drives to the size cap.
349        let mut first = vec![0u8; 8192];
350        first[0] = 0x16;
351        tx.send(Bytes::from(first)).await.unwrap();
352        tx.send(Bytes::from(vec![0u8; 8192])).await.unwrap();
353        tx.send(Bytes::from(vec![0u8; 8192])).await.unwrap();
354        drop(tx);
355
356        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
357        assert_eq!(sni, None, "no SNI in non-TLS data");
358        assert!(
359            buf.len() >= PEEK_BUF_SIZE,
360            "buffer must hit the cap before bail-out: got {}",
361            buf.len()
362        );
363    }
364
365    #[tokio::test]
366    async fn peek_for_sni_bails_immediately_on_non_tls_first_byte() {
367        let (tx, mut rx) = mpsc::channel(4);
368        // Plain HTTP request: first byte 'G' (0x47) — clearly not TLS.
369        tx.send(Bytes::from_static(b"GET / HTTP/1.1\r\nHost: x\r\n\r\n"))
370            .await
371            .unwrap();
372        drop(tx);
373
374        // 5-second nominal budget; assert we returned in well under
375        // that — the early-bail must not wait for the full window.
376        let started = std::time::Instant::now();
377        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
378        let elapsed = started.elapsed();
379        assert_eq!(sni, None);
380        assert!(buf.starts_with(b"GET"));
381        assert!(
382            elapsed < Duration::from_millis(500),
383            "non-TLS bail must be fast: took {elapsed:?}"
384        );
385    }
386
387    //----------------------------------------------------------------------------------------------
388    // peek_for_sni × evaluate_egress_with_source — combined integration tests
389    //----------------------------------------------------------------------------------------------
390
391    use std::net::IpAddr;
392    use std::time::Duration as StdDuration;
393
394    use crate::policy::{Action, Destination, NetworkPolicy, PortRange, Rule};
395    use crate::shared::{ResolvedHostnameFamily, SharedState};
396
397    const SHARED_FASTLY_IP: &str = "151.101.0.223";
398
399    fn shared_with(host: &str, ip: &str) -> SharedState {
400        let shared = SharedState::new(4);
401        shared.cache_resolved_hostname(
402            host,
403            ResolvedHostnameFamily::Ipv4,
404            [ip.parse::<IpAddr>().unwrap()],
405            StdDuration::from_secs(60),
406        );
407        shared
408    }
409
410    fn allow_https(domain: &str) -> Rule {
411        Rule {
412            direction: crate::policy::Direction::Egress,
413            destination: Destination::Domain(domain.parse().unwrap()),
414            protocols: vec![Protocol::Tcp],
415            ports: vec![PortRange::single(443)],
416            action: Action::Allow,
417        }
418    }
419
420    /// Over-allow case: cache says IP X is `pypi.org` (allowed); SNI
421    /// is `evil.com`. SNI must override the cache and deny.
422    #[tokio::test]
423    async fn integration_sni_overrides_cache_for_over_allow() {
424        let shared = shared_with("pypi.org", SHARED_FASTLY_IP);
425        let policy = NetworkPolicy {
426            default_egress: Action::Deny,
427            default_ingress: Action::Allow,
428            rules: vec![allow_https("pypi.org")],
429        };
430        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
431
432        let (tx, mut rx) = mpsc::channel(4);
433        tx.send(Bytes::from(synthetic_client_hello("evil.com")))
434            .await
435            .unwrap();
436        drop(tx);
437
438        let (initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
439        assert_eq!(sni.as_deref(), Some("evil.com"));
440        assert!(!initial_buf.is_empty());
441
442        let source = sni
443            .as_deref()
444            .map(HostnameSource::Sni)
445            .unwrap_or(HostnameSource::CacheOnly);
446        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
447        assert_eq!(
448            eval,
449            EgressEvaluation::Deny,
450            "SNI=evil.com must not piggy-back on the cached pypi.org match",
451        );
452    }
453
454    /// Over-block case: cache says IP X is `ads.example.com` (denied);
455    /// SNI is `api.example.com`. SNI must override the cache and allow.
456    #[tokio::test]
457    async fn integration_sni_overrides_cache_for_over_block() {
458        let shared = shared_with("ads.example.com", SHARED_FASTLY_IP);
459        let policy = NetworkPolicy {
460            default_egress: Action::Allow,
461            default_ingress: Action::Allow,
462            rules: vec![Rule::deny_egress(Destination::Domain(
463                "ads.example.com".parse().unwrap(),
464            ))],
465        };
466        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
467
468        let (tx, mut rx) = mpsc::channel(4);
469        tx.send(Bytes::from(synthetic_client_hello("api.example.com")))
470            .await
471            .unwrap();
472        drop(tx);
473
474        let (_initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
475        assert_eq!(sni.as_deref(), Some("api.example.com"));
476
477        let source = sni
478            .as_deref()
479            .map(HostnameSource::Sni)
480            .unwrap_or(HostnameSource::CacheOnly);
481        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
482        assert_eq!(
483            eval,
484            EgressEvaluation::Allow,
485            "SNI=api.example.com must not be caught by the deny on ads.example.com",
486        );
487    }
488
489    /// Non-TLS first-flight falls back to `CacheOnly`; the cache
490    /// match decides.
491    #[tokio::test]
492    async fn integration_non_tls_falls_back_to_cache() {
493        let shared = shared_with("pypi.org", SHARED_FASTLY_IP);
494        let policy = NetworkPolicy {
495            default_egress: Action::Deny,
496            default_ingress: Action::Allow,
497            rules: vec![allow_https("pypi.org")],
498        };
499        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
500
501        let (tx, mut rx) = mpsc::channel(4);
502        // Plain HTTP request; not a TLS record.
503        tx.send(Bytes::from_static(
504            b"GET / HTTP/1.1\r\nHost: pypi.org\r\n\r\n",
505        ))
506        .await
507        .unwrap();
508        drop(tx);
509
510        let (initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
511        assert_eq!(sni, None, "non-TLS data → no SNI");
512        assert!(
513            !initial_buf.is_empty(),
514            "buffered bytes must survive for replay"
515        );
516
517        let source = sni
518            .as_deref()
519            .map(HostnameSource::Sni)
520            .unwrap_or(HostnameSource::CacheOnly);
521        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
522        assert_eq!(
523            eval,
524            EgressEvaluation::Allow,
525            "cache-only fallback must still allow the cached hostname's IP",
526        );
527    }
528
529    /// SNI matches a `DomainSuffix` rule with a cache binding for the
530    /// claimed name. Genuine pre-resolved traffic passes.
531    #[tokio::test]
532    async fn integration_sni_matches_domain_suffix_with_cache_binding() {
533        let shared = shared_with("files.pythonhosted.org", SHARED_FASTLY_IP);
534        let policy = NetworkPolicy {
535            default_egress: Action::Deny,
536            default_ingress: Action::Allow,
537            rules: vec![Rule {
538                direction: crate::policy::Direction::Egress,
539                destination: Destination::DomainSuffix(".pythonhosted.org".parse().unwrap()),
540                protocols: vec![Protocol::Tcp],
541                ports: vec![PortRange::single(443)],
542                action: Action::Allow,
543            }],
544        };
545        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
546
547        let (tx, mut rx) = mpsc::channel(4);
548        tx.send(Bytes::from(synthetic_client_hello(
549            "files.pythonhosted.org",
550        )))
551        .await
552        .unwrap();
553        drop(tx);
554
555        let (_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
556        let source = sni
557            .as_deref()
558            .map(HostnameSource::Sni)
559            .unwrap_or(HostnameSource::CacheOnly);
560        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
561        assert_eq!(eval, EgressEvaluation::Allow);
562    }
563
564    /// Spoofed SNI on an IP with no cache binding for any matching
565    /// name: byte-equality with the suffix passes, but no DNS lookup
566    /// ever tied a `*.pythonhosted.org` name to the destination, so
567    /// the AND-check fails and the connection is denied.
568    #[tokio::test]
569    async fn integration_sni_denies_domain_suffix_without_cache_binding() {
570        let shared = SharedState::new(4); // empty cache
571        let policy = NetworkPolicy {
572            default_egress: Action::Deny,
573            default_ingress: Action::Allow,
574            rules: vec![Rule {
575                direction: crate::policy::Direction::Egress,
576                destination: Destination::DomainSuffix(".pythonhosted.org".parse().unwrap()),
577                protocols: vec![Protocol::Tcp],
578                ports: vec![PortRange::single(443)],
579                action: Action::Allow,
580            }],
581        };
582        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
583
584        let (tx, mut rx) = mpsc::channel(4);
585        tx.send(Bytes::from(synthetic_client_hello(
586            "files.pythonhosted.org",
587        )))
588        .await
589        .unwrap();
590        drop(tx);
591
592        let (_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
593        let source = sni
594            .as_deref()
595            .map(HostnameSource::Sni)
596            .unwrap_or(HostnameSource::CacheOnly);
597        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
598        assert_eq!(eval, EgressEvaluation::Deny);
599    }
600}