Skip to main content

microsandbox_network/
proxy.rs

1//! Bidirectional TCP proxy: smoltcp socket ↔ channels ↔ tokio socket.
2//!
3//! Each outbound guest TCP connection gets a proxy task that opens a real
4//! TCP connection to the destination via tokio and relays data between the
5//! channel pair (connected to the smoltcp socket in the poll loop) and the
6//! real server.
7
8use std::io;
9use std::net::SocketAddr;
10use std::sync::Arc;
11use std::time::Duration;
12
13use bytes::Bytes;
14use tokio::io::{AsyncReadExt, AsyncWriteExt};
15use tokio::net::TcpStream;
16use tokio::sync::mpsc;
17
18use crate::policy::{EgressEvaluation, HostnameSource, NetworkPolicy, Protocol};
19use crate::shared::SharedState;
20use crate::tls::sni;
21
22//--------------------------------------------------------------------------------------------------
23// Constants
24//--------------------------------------------------------------------------------------------------
25
26/// Buffer size for reading from the real server.
27const SERVER_READ_BUF_SIZE: usize = 16384;
28
29/// Max bytes to buffer while peeking for the ClientHello's SNI.
30const PEEK_BUF_SIZE: usize = 16384;
31
32/// Upper bound on time spent buffering the first flight before
33/// falling back to a cache-only egress decision.
34const PEEK_BUDGET: Duration = Duration::from_secs(5);
35
36//--------------------------------------------------------------------------------------------------
37// Functions
38//--------------------------------------------------------------------------------------------------
39
40/// Spawn a TCP proxy task for a newly established connection.
41///
42/// `guest_dst` is what the guest dialed — the address policy rules
43/// match against. `connect_dst` is the host-side address tokio actually
44/// dials; for host-alias connections it's loopback (gateway rewritten).
45/// For everything else the two are identical.
46pub fn spawn_tcp_proxy(
47    handle: &tokio::runtime::Handle,
48    guest_dst: SocketAddr,
49    connect_dst: SocketAddr,
50    from_smoltcp: mpsc::Receiver<Bytes>,
51    to_smoltcp: mpsc::Sender<Bytes>,
52    shared: Arc<SharedState>,
53    network_policy: Arc<NetworkPolicy>,
54) {
55    handle.spawn(async move {
56        if let Err(e) = tcp_proxy_task(
57            guest_dst,
58            connect_dst,
59            from_smoltcp,
60            to_smoltcp,
61            shared,
62            network_policy,
63        )
64        .await
65        {
66            tracing::debug!(dst = %connect_dst, error = %e, "TCP proxy task ended");
67        }
68    });
69}
70
71/// Core TCP proxy: peek for SNI, evaluate egress policy, then either
72/// connect and relay or drop the channels.
73async fn tcp_proxy_task(
74    guest_dst: SocketAddr,
75    connect_dst: SocketAddr,
76    mut from_smoltcp: mpsc::Receiver<Bytes>,
77    to_smoltcp: mpsc::Sender<Bytes>,
78    shared: Arc<SharedState>,
79    network_policy: Arc<NetworkPolicy>,
80) -> io::Result<()> {
81    // Peek only when there's a Domain/DomainSuffix rule that could
82    // need an SNI to refine. Otherwise the SYN handler's decision is
83    // authoritative.
84    let (initial_buf, sni) = if network_policy.has_domain_rules() {
85        peek_for_sni(&mut from_smoltcp, PEEK_BUF_SIZE, PEEK_BUDGET).await
86    } else {
87        (Vec::new(), None)
88    };
89
90    // Re-evaluate egress against the *guest* dst — the address the
91    // guest dialed, not the post-rewrite host-side address. SNI
92    // refines over-allow when the cache matched a shared CDN IP;
93    // CacheOnly is the non-TLS fallback path so Domain rules still
94    // gate plain HTTP / SSH / etc.
95    if network_policy.has_domain_rules() {
96        let source = match sni.as_deref() {
97            Some(name) => HostnameSource::Sni(name),
98            None => HostnameSource::CacheOnly,
99        };
100        match network_policy.evaluate_egress_with_source(guest_dst, Protocol::Tcp, &shared, source)
101        {
102            EgressEvaluation::Allow => {}
103            EgressEvaluation::Deny => {
104                tracing::debug!(
105                    dst = %guest_dst,
106                    source = source.label(),
107                    "TCP egress denied by domain policy",
108                );
109                return Ok(());
110            }
111            EgressEvaluation::DeferUntilHostname => {
112                debug_assert!(false, "DeferUntilHostname leaked into TCP proxy task");
113                return Ok(());
114            }
115        }
116    }
117
118    let stream = TcpStream::connect(connect_dst).await?;
119    let (mut server_rx, mut server_tx) = stream.into_split();
120
121    // Replay the buffered first flight before relay starts.
122    if !initial_buf.is_empty()
123        && let Err(e) = server_tx.write_all(&initial_buf).await
124    {
125        tracing::debug!(dst = %connect_dst, error = %e, "replay of buffered first flight failed");
126        return Ok(());
127    }
128
129    let mut server_buf = vec![0u8; SERVER_READ_BUF_SIZE];
130
131    // Bidirectional relay using tokio::select!.
132    //
133    // guest → server: receive from channel, write to server socket.
134    // server → guest: read from server socket, send via channel + wake poll.
135    loop {
136        tokio::select! {
137            // Guest → server.
138            data = from_smoltcp.recv() => {
139                match data {
140                    Some(bytes) => {
141                        if let Err(e) = server_tx.write_all(&bytes).await {
142                            tracing::debug!(dst = %connect_dst, error = %e, "write to server failed");
143                            break;
144                        }
145                    }
146                    // Channel closed — smoltcp socket was closed by guest.
147                    None => break,
148                }
149            }
150
151            // Server → guest.
152            result = server_rx.read(&mut server_buf) => {
153                match result {
154                    Ok(0) => break, // Server closed connection.
155                    Ok(n) => {
156                        let data = Bytes::copy_from_slice(&server_buf[..n]);
157                        if to_smoltcp.send(data).await.is_err() {
158                            // Channel closed — poll loop dropped the receiver.
159                            break;
160                        }
161                        // Wake the poll thread so it writes data to the
162                        // smoltcp socket.
163                        shared.proxy_wake.wake();
164                    }
165                    Err(e) => {
166                        tracing::debug!(dst = %connect_dst, error = %e, "read from server failed");
167                        break;
168                    }
169                }
170            }
171        }
172    }
173
174    Ok(())
175}
176
177/// Buffer the first flight until SNI can be extracted, or until one
178/// of the bail-out conditions hits (channel close, buffer cap,
179/// timeout). Never errors; non-TLS / slow / malformed input all
180/// fall through to `None`.
181///
182/// On hit, the SNI is canonicalized (lowercase + trim trailing dot)
183/// for byte-equal matching against rule destinations. The returned
184/// buffer must be replayed verbatim to upstream before the caller
185/// starts its relay loop.
186async fn peek_for_sni(
187    rx: &mut mpsc::Receiver<Bytes>,
188    max: usize,
189    budget: Duration,
190) -> (Vec<u8>, Option<String>) {
191    let mut buf = Vec::with_capacity(PEEK_BUF_SIZE.min(8192));
192    let timeout_fut = tokio::time::sleep(budget);
193    tokio::pin!(timeout_fut);
194
195    let raw_sni = loop {
196        tokio::select! {
197            biased;
198            _ = &mut timeout_fut => break None,
199            data = rx.recv() => {
200                match data {
201                    Some(bytes) => {
202                        buf.extend_from_slice(&bytes);
203                        // First byte of a TLS record is the ContentType;
204                        // 0x16 is handshake. Anything else can't be a
205                        // ClientHello, so don't burn the full budget on
206                        // plain HTTP / SSH / etc.
207                        if buf.first() != Some(&0x16) {
208                            break None;
209                        }
210                        if let Some(name) = sni::extract_sni(&buf) {
211                            break Some(name);
212                        }
213                        if buf.len() >= max {
214                            break None;
215                        }
216                    }
217                    None => break None,
218                }
219            }
220        }
221    };
222
223    let canonical = raw_sni.map(|s| s.trim_end_matches('.').to_ascii_lowercase());
224    (buf, canonical)
225}
226
227//--------------------------------------------------------------------------------------------------
228// Tests
229//--------------------------------------------------------------------------------------------------
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    /// Synthetic TLS ClientHello carrying SNI `example.com`. Bytes
236    /// borrowed from `tls::sni` test fixtures so the parser sees a
237    /// well-formed record.
238    fn synthetic_client_hello(sni: &str) -> Vec<u8> {
239        // Minimal but valid TLS 1.2 ClientHello with one SNI entry.
240        // Layout: record header (5) + handshake header (4) + body.
241        let host_bytes = sni.as_bytes();
242        let host_len = host_bytes.len() as u16;
243        let server_name_list_len = 3 + host_len; // type(1) + len(2) + host
244        let extension_data_len = 2 + server_name_list_len; // list-len(2) + list
245        let extensions_total = 4 + extension_data_len; // type(2) + len(2) + data
246
247        let mut body = Vec::new();
248        // Client version
249        body.extend_from_slice(&[0x03, 0x03]);
250        // Random (32 bytes)
251        body.extend_from_slice(&[0u8; 32]);
252        // Session id length + (empty)
253        body.push(0);
254        // Cipher suites length + one cipher
255        body.extend_from_slice(&[0x00, 0x02, 0x00, 0x2f]);
256        // Compression methods length + null
257        body.extend_from_slice(&[0x01, 0x00]);
258        // Extensions length
259        body.extend_from_slice(&extensions_total.to_be_bytes());
260        // SNI extension: type 0x0000
261        body.extend_from_slice(&[0x00, 0x00]);
262        body.extend_from_slice(&extension_data_len.to_be_bytes());
263        body.extend_from_slice(&server_name_list_len.to_be_bytes());
264        body.push(0x00); // host_name type
265        body.extend_from_slice(&host_len.to_be_bytes());
266        body.extend_from_slice(host_bytes);
267
268        let handshake_len = body.len() as u32;
269        let mut hs = Vec::new();
270        hs.push(0x01); // ClientHello
271        hs.extend_from_slice(&handshake_len.to_be_bytes()[1..]); // 24-bit length
272        hs.extend_from_slice(&body);
273
274        let record_len = hs.len() as u16;
275        let mut record = Vec::new();
276        record.extend_from_slice(&[0x16, 0x03, 0x01]); // Handshake, TLS 1.0
277        record.extend_from_slice(&record_len.to_be_bytes());
278        record.extend_from_slice(&hs);
279
280        record
281    }
282
283    #[tokio::test]
284    async fn peek_for_sni_extracts_and_canonicalizes() {
285        let (tx, mut rx) = mpsc::channel(4);
286        let hello = synthetic_client_hello("Example.COM");
287        tx.send(Bytes::from(hello.clone())).await.unwrap();
288        drop(tx); // close so peek returns even if SNI didn't satisfy
289
290        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
291        assert_eq!(sni.as_deref(), Some("example.com"));
292        assert_eq!(buf, hello);
293    }
294
295    #[tokio::test]
296    async fn peek_for_sni_returns_none_on_channel_close_without_data() {
297        let (tx, mut rx) = mpsc::channel::<Bytes>(1);
298        drop(tx);
299        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
300        assert!(buf.is_empty());
301        assert_eq!(sni, None);
302    }
303
304    #[tokio::test]
305    async fn peek_for_sni_returns_none_on_non_tls_data() {
306        let (tx, mut rx) = mpsc::channel(4);
307        // Plaintext HTTP request; not a TLS record so extract_sni returns None.
308        tx.send(Bytes::from_static(
309            b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n",
310        ))
311        .await
312        .unwrap();
313        drop(tx);
314        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
315        assert!(
316            !buf.is_empty(),
317            "buffered bytes must be returned for replay"
318        );
319        assert_eq!(sni, None);
320    }
321
322    #[tokio::test]
323    async fn peek_for_sni_falls_back_on_timeout() {
324        let (tx, mut rx) = mpsc::channel::<Bytes>(1);
325        // Hold the sender open but send nothing — peek must time out.
326        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, Duration::from_millis(50)).await;
327        drop(tx);
328        assert!(buf.is_empty());
329        assert_eq!(sni, None);
330    }
331
332    #[tokio::test]
333    async fn peek_for_sni_caps_at_max_bytes() {
334        let (tx, mut rx) = mpsc::channel(4);
335        // First byte 0x16 keeps the peek collecting past the early
336        // non-TLS bail. Padding bytes are zero so the SNI parser never
337        // matches and the loop drives to the size cap.
338        let mut first = vec![0u8; 8192];
339        first[0] = 0x16;
340        tx.send(Bytes::from(first)).await.unwrap();
341        tx.send(Bytes::from(vec![0u8; 8192])).await.unwrap();
342        tx.send(Bytes::from(vec![0u8; 8192])).await.unwrap();
343        drop(tx);
344
345        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
346        assert_eq!(sni, None, "no SNI in non-TLS data");
347        assert!(
348            buf.len() >= PEEK_BUF_SIZE,
349            "buffer must hit the cap before bail-out: got {}",
350            buf.len()
351        );
352    }
353
354    #[tokio::test]
355    async fn peek_for_sni_bails_immediately_on_non_tls_first_byte() {
356        let (tx, mut rx) = mpsc::channel(4);
357        // Plain HTTP request: first byte 'G' (0x47) — clearly not TLS.
358        tx.send(Bytes::from_static(b"GET / HTTP/1.1\r\nHost: x\r\n\r\n"))
359            .await
360            .unwrap();
361        drop(tx);
362
363        // 5-second nominal budget; assert we returned in well under
364        // that — the early-bail must not wait for the full window.
365        let started = std::time::Instant::now();
366        let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
367        let elapsed = started.elapsed();
368        assert_eq!(sni, None);
369        assert!(buf.starts_with(b"GET"));
370        assert!(
371            elapsed < Duration::from_millis(500),
372            "non-TLS bail must be fast: took {elapsed:?}"
373        );
374    }
375
376    //----------------------------------------------------------------------------------------------
377    // peek_for_sni × evaluate_egress_with_source — combined integration tests
378    //----------------------------------------------------------------------------------------------
379
380    use std::net::IpAddr;
381    use std::time::Duration as StdDuration;
382
383    use crate::policy::{Action, Destination, NetworkPolicy, PortRange, Rule};
384    use crate::shared::{ResolvedHostnameFamily, SharedState};
385
386    const SHARED_FASTLY_IP: &str = "151.101.0.223";
387
388    fn shared_with(host: &str, ip: &str) -> SharedState {
389        let shared = SharedState::new(4);
390        shared.cache_resolved_hostname(
391            host,
392            ResolvedHostnameFamily::Ipv4,
393            [ip.parse::<IpAddr>().unwrap()],
394            StdDuration::from_secs(60),
395        );
396        shared
397    }
398
399    fn allow_https(domain: &str) -> Rule {
400        Rule {
401            direction: crate::policy::Direction::Egress,
402            destination: Destination::Domain(domain.parse().unwrap()),
403            protocols: vec![Protocol::Tcp],
404            ports: vec![PortRange::single(443)],
405            action: Action::Allow,
406        }
407    }
408
409    /// Over-allow case: cache says IP X is `pypi.org` (allowed); SNI
410    /// is `evil.com`. SNI must override the cache and deny.
411    #[tokio::test]
412    async fn integration_sni_overrides_cache_for_over_allow() {
413        let shared = shared_with("pypi.org", SHARED_FASTLY_IP);
414        let policy = NetworkPolicy {
415            default_egress: Action::Deny,
416            default_ingress: Action::Allow,
417            rules: vec![allow_https("pypi.org")],
418        };
419        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
420
421        let (tx, mut rx) = mpsc::channel(4);
422        tx.send(Bytes::from(synthetic_client_hello("evil.com")))
423            .await
424            .unwrap();
425        drop(tx);
426
427        let (initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
428        assert_eq!(sni.as_deref(), Some("evil.com"));
429        assert!(!initial_buf.is_empty());
430
431        let source = sni
432            .as_deref()
433            .map(HostnameSource::Sni)
434            .unwrap_or(HostnameSource::CacheOnly);
435        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
436        assert_eq!(
437            eval,
438            EgressEvaluation::Deny,
439            "SNI=evil.com must not piggy-back on the cached pypi.org match",
440        );
441    }
442
443    /// Over-block case: cache says IP X is `ads.example.com` (denied);
444    /// SNI is `api.example.com`. SNI must override the cache and allow.
445    #[tokio::test]
446    async fn integration_sni_overrides_cache_for_over_block() {
447        let shared = shared_with("ads.example.com", SHARED_FASTLY_IP);
448        let policy = NetworkPolicy {
449            default_egress: Action::Allow,
450            default_ingress: Action::Allow,
451            rules: vec![Rule::deny_egress(Destination::Domain(
452                "ads.example.com".parse().unwrap(),
453            ))],
454        };
455        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
456
457        let (tx, mut rx) = mpsc::channel(4);
458        tx.send(Bytes::from(synthetic_client_hello("api.example.com")))
459            .await
460            .unwrap();
461        drop(tx);
462
463        let (_initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
464        assert_eq!(sni.as_deref(), Some("api.example.com"));
465
466        let source = sni
467            .as_deref()
468            .map(HostnameSource::Sni)
469            .unwrap_or(HostnameSource::CacheOnly);
470        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
471        assert_eq!(
472            eval,
473            EgressEvaluation::Allow,
474            "SNI=api.example.com must not be caught by the deny on ads.example.com",
475        );
476    }
477
478    /// Non-TLS first-flight falls back to `CacheOnly`; the cache
479    /// match decides.
480    #[tokio::test]
481    async fn integration_non_tls_falls_back_to_cache() {
482        let shared = shared_with("pypi.org", SHARED_FASTLY_IP);
483        let policy = NetworkPolicy {
484            default_egress: Action::Deny,
485            default_ingress: Action::Allow,
486            rules: vec![allow_https("pypi.org")],
487        };
488        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
489
490        let (tx, mut rx) = mpsc::channel(4);
491        // Plain HTTP request; not a TLS record.
492        tx.send(Bytes::from_static(
493            b"GET / HTTP/1.1\r\nHost: pypi.org\r\n\r\n",
494        ))
495        .await
496        .unwrap();
497        drop(tx);
498
499        let (initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
500        assert_eq!(sni, None, "non-TLS data → no SNI");
501        assert!(
502            !initial_buf.is_empty(),
503            "buffered bytes must survive for replay"
504        );
505
506        let source = sni
507            .as_deref()
508            .map(HostnameSource::Sni)
509            .unwrap_or(HostnameSource::CacheOnly);
510        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
511        assert_eq!(
512            eval,
513            EgressEvaluation::Allow,
514            "cache-only fallback must still allow the cached hostname's IP",
515        );
516    }
517
518    /// SNI matches a `DomainSuffix` rule with a cache binding for the
519    /// claimed name. Genuine pre-resolved traffic passes.
520    #[tokio::test]
521    async fn integration_sni_matches_domain_suffix_with_cache_binding() {
522        let shared = shared_with("files.pythonhosted.org", SHARED_FASTLY_IP);
523        let policy = NetworkPolicy {
524            default_egress: Action::Deny,
525            default_ingress: Action::Allow,
526            rules: vec![Rule {
527                direction: crate::policy::Direction::Egress,
528                destination: Destination::DomainSuffix(".pythonhosted.org".parse().unwrap()),
529                protocols: vec![Protocol::Tcp],
530                ports: vec![PortRange::single(443)],
531                action: Action::Allow,
532            }],
533        };
534        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
535
536        let (tx, mut rx) = mpsc::channel(4);
537        tx.send(Bytes::from(synthetic_client_hello(
538            "files.pythonhosted.org",
539        )))
540        .await
541        .unwrap();
542        drop(tx);
543
544        let (_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
545        let source = sni
546            .as_deref()
547            .map(HostnameSource::Sni)
548            .unwrap_or(HostnameSource::CacheOnly);
549        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
550        assert_eq!(eval, EgressEvaluation::Allow);
551    }
552
553    /// Spoofed SNI on an IP with no cache binding for any matching
554    /// name: byte-equality with the suffix passes, but no DNS lookup
555    /// ever tied a `*.pythonhosted.org` name to the destination, so
556    /// the AND-check fails and the connection is denied.
557    #[tokio::test]
558    async fn integration_sni_denies_domain_suffix_without_cache_binding() {
559        let shared = SharedState::new(4); // empty cache
560        let policy = NetworkPolicy {
561            default_egress: Action::Deny,
562            default_ingress: Action::Allow,
563            rules: vec![Rule {
564                direction: crate::policy::Direction::Egress,
565                destination: Destination::DomainSuffix(".pythonhosted.org".parse().unwrap()),
566                protocols: vec![Protocol::Tcp],
567                ports: vec![PortRange::single(443)],
568                action: Action::Allow,
569            }],
570        };
571        let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
572
573        let (tx, mut rx) = mpsc::channel(4);
574        tx.send(Bytes::from(synthetic_client_hello(
575            "files.pythonhosted.org",
576        )))
577        .await
578        .unwrap();
579        drop(tx);
580
581        let (_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
582        let source = sni
583            .as_deref()
584            .map(HostnameSource::Sni)
585            .unwrap_or(HostnameSource::CacheOnly);
586        let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
587        assert_eq!(eval, EgressEvaluation::Deny);
588    }
589}