1use std::io;
9use std::net::SocketAddr;
10use std::sync::Arc;
11use std::time::Duration;
12
13use bytes::Bytes;
14use tokio::io::{AsyncReadExt, AsyncWriteExt};
15use tokio::net::TcpStream;
16use tokio::sync::mpsc;
17
18use crate::policy::{EgressEvaluation, HostnameSource, NetworkPolicy, Protocol};
19use crate::shared::SharedState;
20use crate::tls::sni;
21
22const SERVER_READ_BUF_SIZE: usize = 16384;
28
29const PEEK_BUF_SIZE: usize = 16384;
31
32const PEEK_BUDGET: Duration = Duration::from_secs(5);
35
36pub fn spawn_tcp_proxy(
47 handle: &tokio::runtime::Handle,
48 guest_dst: SocketAddr,
49 connect_dst: SocketAddr,
50 from_smoltcp: mpsc::Receiver<Bytes>,
51 to_smoltcp: mpsc::Sender<Bytes>,
52 shared: Arc<SharedState>,
53 network_policy: Arc<NetworkPolicy>,
54) {
55 handle.spawn(async move {
56 if let Err(e) = tcp_proxy_task(
57 guest_dst,
58 connect_dst,
59 from_smoltcp,
60 to_smoltcp,
61 shared,
62 network_policy,
63 )
64 .await
65 {
66 tracing::debug!(dst = %connect_dst, error = %e, "TCP proxy task ended");
67 }
68 });
69}
70
71async fn tcp_proxy_task(
74 guest_dst: SocketAddr,
75 connect_dst: SocketAddr,
76 mut from_smoltcp: mpsc::Receiver<Bytes>,
77 to_smoltcp: mpsc::Sender<Bytes>,
78 shared: Arc<SharedState>,
79 network_policy: Arc<NetworkPolicy>,
80) -> io::Result<()> {
81 let (initial_buf, sni) = if network_policy.has_domain_rules() {
85 peek_for_sni(&mut from_smoltcp, PEEK_BUF_SIZE, PEEK_BUDGET).await
86 } else {
87 (Vec::new(), None)
88 };
89
90 if network_policy.has_domain_rules() {
96 let source = match sni.as_deref() {
97 Some(name) => HostnameSource::Sni(name),
98 None => HostnameSource::CacheOnly,
99 };
100 match network_policy.evaluate_egress_with_source(guest_dst, Protocol::Tcp, &shared, source)
101 {
102 EgressEvaluation::Allow => {}
103 EgressEvaluation::Deny => {
104 tracing::debug!(
105 dst = %guest_dst,
106 source = source.label(),
107 "TCP egress denied by domain policy",
108 );
109 return Ok(());
110 }
111 EgressEvaluation::DeferUntilHostname => {
112 debug_assert!(false, "DeferUntilHostname leaked into TCP proxy task");
113 return Ok(());
114 }
115 }
116 }
117
118 let stream = TcpStream::connect(connect_dst).await?;
119 let (mut server_rx, mut server_tx) = stream.into_split();
120
121 if !initial_buf.is_empty()
123 && let Err(e) = server_tx.write_all(&initial_buf).await
124 {
125 tracing::debug!(dst = %connect_dst, error = %e, "replay of buffered first flight failed");
126 return Ok(());
127 }
128
129 let mut server_buf = vec![0u8; SERVER_READ_BUF_SIZE];
130
131 loop {
136 tokio::select! {
137 data = from_smoltcp.recv() => {
139 match data {
140 Some(bytes) => {
141 if let Err(e) = server_tx.write_all(&bytes).await {
142 tracing::debug!(dst = %connect_dst, error = %e, "write to server failed");
143 break;
144 }
145 }
146 None => break,
148 }
149 }
150
151 result = server_rx.read(&mut server_buf) => {
153 match result {
154 Ok(0) => break, Ok(n) => {
156 let data = Bytes::copy_from_slice(&server_buf[..n]);
157 if to_smoltcp.send(data).await.is_err() {
158 break;
160 }
161 shared.proxy_wake.wake();
164 }
165 Err(e) => {
166 tracing::debug!(dst = %connect_dst, error = %e, "read from server failed");
167 break;
168 }
169 }
170 }
171 }
172 }
173
174 Ok(())
175}
176
177async fn peek_for_sni(
187 rx: &mut mpsc::Receiver<Bytes>,
188 max: usize,
189 budget: Duration,
190) -> (Vec<u8>, Option<String>) {
191 let mut buf = Vec::with_capacity(PEEK_BUF_SIZE.min(8192));
192 let timeout_fut = tokio::time::sleep(budget);
193 tokio::pin!(timeout_fut);
194
195 let raw_sni = loop {
196 tokio::select! {
197 biased;
198 _ = &mut timeout_fut => break None,
199 data = rx.recv() => {
200 match data {
201 Some(bytes) => {
202 buf.extend_from_slice(&bytes);
203 if buf.first() != Some(&0x16) {
208 break None;
209 }
210 if let Some(name) = sni::extract_sni(&buf) {
211 break Some(name);
212 }
213 if buf.len() >= max {
214 break None;
215 }
216 }
217 None => break None,
218 }
219 }
220 }
221 };
222
223 let canonical = raw_sni.map(|s| s.trim_end_matches('.').to_ascii_lowercase());
224 (buf, canonical)
225}
226
227#[cfg(test)]
232mod tests {
233 use super::*;
234
235 fn synthetic_client_hello(sni: &str) -> Vec<u8> {
239 let host_bytes = sni.as_bytes();
242 let host_len = host_bytes.len() as u16;
243 let server_name_list_len = 3 + host_len; let extension_data_len = 2 + server_name_list_len; let extensions_total = 4 + extension_data_len; let mut body = Vec::new();
248 body.extend_from_slice(&[0x03, 0x03]);
250 body.extend_from_slice(&[0u8; 32]);
252 body.push(0);
254 body.extend_from_slice(&[0x00, 0x02, 0x00, 0x2f]);
256 body.extend_from_slice(&[0x01, 0x00]);
258 body.extend_from_slice(&extensions_total.to_be_bytes());
260 body.extend_from_slice(&[0x00, 0x00]);
262 body.extend_from_slice(&extension_data_len.to_be_bytes());
263 body.extend_from_slice(&server_name_list_len.to_be_bytes());
264 body.push(0x00); body.extend_from_slice(&host_len.to_be_bytes());
266 body.extend_from_slice(host_bytes);
267
268 let handshake_len = body.len() as u32;
269 let mut hs = Vec::new();
270 hs.push(0x01); hs.extend_from_slice(&handshake_len.to_be_bytes()[1..]); hs.extend_from_slice(&body);
273
274 let record_len = hs.len() as u16;
275 let mut record = Vec::new();
276 record.extend_from_slice(&[0x16, 0x03, 0x01]); record.extend_from_slice(&record_len.to_be_bytes());
278 record.extend_from_slice(&hs);
279
280 record
281 }
282
283 #[tokio::test]
284 async fn peek_for_sni_extracts_and_canonicalizes() {
285 let (tx, mut rx) = mpsc::channel(4);
286 let hello = synthetic_client_hello("Example.COM");
287 tx.send(Bytes::from(hello.clone())).await.unwrap();
288 drop(tx); let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
291 assert_eq!(sni.as_deref(), Some("example.com"));
292 assert_eq!(buf, hello);
293 }
294
295 #[tokio::test]
296 async fn peek_for_sni_returns_none_on_channel_close_without_data() {
297 let (tx, mut rx) = mpsc::channel::<Bytes>(1);
298 drop(tx);
299 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
300 assert!(buf.is_empty());
301 assert_eq!(sni, None);
302 }
303
304 #[tokio::test]
305 async fn peek_for_sni_returns_none_on_non_tls_data() {
306 let (tx, mut rx) = mpsc::channel(4);
307 tx.send(Bytes::from_static(
309 b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n",
310 ))
311 .await
312 .unwrap();
313 drop(tx);
314 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
315 assert!(
316 !buf.is_empty(),
317 "buffered bytes must be returned for replay"
318 );
319 assert_eq!(sni, None);
320 }
321
322 #[tokio::test]
323 async fn peek_for_sni_falls_back_on_timeout() {
324 let (tx, mut rx) = mpsc::channel::<Bytes>(1);
325 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, Duration::from_millis(50)).await;
327 drop(tx);
328 assert!(buf.is_empty());
329 assert_eq!(sni, None);
330 }
331
332 #[tokio::test]
333 async fn peek_for_sni_caps_at_max_bytes() {
334 let (tx, mut rx) = mpsc::channel(4);
335 let mut first = vec![0u8; 8192];
339 first[0] = 0x16;
340 tx.send(Bytes::from(first)).await.unwrap();
341 tx.send(Bytes::from(vec![0u8; 8192])).await.unwrap();
342 tx.send(Bytes::from(vec![0u8; 8192])).await.unwrap();
343 drop(tx);
344
345 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
346 assert_eq!(sni, None, "no SNI in non-TLS data");
347 assert!(
348 buf.len() >= PEEK_BUF_SIZE,
349 "buffer must hit the cap before bail-out: got {}",
350 buf.len()
351 );
352 }
353
354 #[tokio::test]
355 async fn peek_for_sni_bails_immediately_on_non_tls_first_byte() {
356 let (tx, mut rx) = mpsc::channel(4);
357 tx.send(Bytes::from_static(b"GET / HTTP/1.1\r\nHost: x\r\n\r\n"))
359 .await
360 .unwrap();
361 drop(tx);
362
363 let started = std::time::Instant::now();
366 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
367 let elapsed = started.elapsed();
368 assert_eq!(sni, None);
369 assert!(buf.starts_with(b"GET"));
370 assert!(
371 elapsed < Duration::from_millis(500),
372 "non-TLS bail must be fast: took {elapsed:?}"
373 );
374 }
375
376 use std::net::IpAddr;
381 use std::time::Duration as StdDuration;
382
383 use crate::policy::{Action, Destination, NetworkPolicy, PortRange, Rule};
384 use crate::shared::{ResolvedHostnameFamily, SharedState};
385
386 const SHARED_FASTLY_IP: &str = "151.101.0.223";
387
388 fn shared_with(host: &str, ip: &str) -> SharedState {
389 let shared = SharedState::new(4);
390 shared.cache_resolved_hostname(
391 host,
392 ResolvedHostnameFamily::Ipv4,
393 [ip.parse::<IpAddr>().unwrap()],
394 StdDuration::from_secs(60),
395 );
396 shared
397 }
398
399 fn allow_https(domain: &str) -> Rule {
400 Rule {
401 direction: crate::policy::Direction::Egress,
402 destination: Destination::Domain(domain.parse().unwrap()),
403 protocols: vec![Protocol::Tcp],
404 ports: vec![PortRange::single(443)],
405 action: Action::Allow,
406 }
407 }
408
409 #[tokio::test]
412 async fn integration_sni_overrides_cache_for_over_allow() {
413 let shared = shared_with("pypi.org", SHARED_FASTLY_IP);
414 let policy = NetworkPolicy {
415 default_egress: Action::Deny,
416 default_ingress: Action::Allow,
417 rules: vec![allow_https("pypi.org")],
418 };
419 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
420
421 let (tx, mut rx) = mpsc::channel(4);
422 tx.send(Bytes::from(synthetic_client_hello("evil.com")))
423 .await
424 .unwrap();
425 drop(tx);
426
427 let (initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
428 assert_eq!(sni.as_deref(), Some("evil.com"));
429 assert!(!initial_buf.is_empty());
430
431 let source = sni
432 .as_deref()
433 .map(HostnameSource::Sni)
434 .unwrap_or(HostnameSource::CacheOnly);
435 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
436 assert_eq!(
437 eval,
438 EgressEvaluation::Deny,
439 "SNI=evil.com must not piggy-back on the cached pypi.org match",
440 );
441 }
442
443 #[tokio::test]
446 async fn integration_sni_overrides_cache_for_over_block() {
447 let shared = shared_with("ads.example.com", SHARED_FASTLY_IP);
448 let policy = NetworkPolicy {
449 default_egress: Action::Allow,
450 default_ingress: Action::Allow,
451 rules: vec![Rule::deny_egress(Destination::Domain(
452 "ads.example.com".parse().unwrap(),
453 ))],
454 };
455 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
456
457 let (tx, mut rx) = mpsc::channel(4);
458 tx.send(Bytes::from(synthetic_client_hello("api.example.com")))
459 .await
460 .unwrap();
461 drop(tx);
462
463 let (_initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
464 assert_eq!(sni.as_deref(), Some("api.example.com"));
465
466 let source = sni
467 .as_deref()
468 .map(HostnameSource::Sni)
469 .unwrap_or(HostnameSource::CacheOnly);
470 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
471 assert_eq!(
472 eval,
473 EgressEvaluation::Allow,
474 "SNI=api.example.com must not be caught by the deny on ads.example.com",
475 );
476 }
477
478 #[tokio::test]
481 async fn integration_non_tls_falls_back_to_cache() {
482 let shared = shared_with("pypi.org", SHARED_FASTLY_IP);
483 let policy = NetworkPolicy {
484 default_egress: Action::Deny,
485 default_ingress: Action::Allow,
486 rules: vec![allow_https("pypi.org")],
487 };
488 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
489
490 let (tx, mut rx) = mpsc::channel(4);
491 tx.send(Bytes::from_static(
493 b"GET / HTTP/1.1\r\nHost: pypi.org\r\n\r\n",
494 ))
495 .await
496 .unwrap();
497 drop(tx);
498
499 let (initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
500 assert_eq!(sni, None, "non-TLS data → no SNI");
501 assert!(
502 !initial_buf.is_empty(),
503 "buffered bytes must survive for replay"
504 );
505
506 let source = sni
507 .as_deref()
508 .map(HostnameSource::Sni)
509 .unwrap_or(HostnameSource::CacheOnly);
510 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
511 assert_eq!(
512 eval,
513 EgressEvaluation::Allow,
514 "cache-only fallback must still allow the cached hostname's IP",
515 );
516 }
517
518 #[tokio::test]
521 async fn integration_sni_matches_domain_suffix_with_cache_binding() {
522 let shared = shared_with("files.pythonhosted.org", SHARED_FASTLY_IP);
523 let policy = NetworkPolicy {
524 default_egress: Action::Deny,
525 default_ingress: Action::Allow,
526 rules: vec![Rule {
527 direction: crate::policy::Direction::Egress,
528 destination: Destination::DomainSuffix(".pythonhosted.org".parse().unwrap()),
529 protocols: vec![Protocol::Tcp],
530 ports: vec![PortRange::single(443)],
531 action: Action::Allow,
532 }],
533 };
534 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
535
536 let (tx, mut rx) = mpsc::channel(4);
537 tx.send(Bytes::from(synthetic_client_hello(
538 "files.pythonhosted.org",
539 )))
540 .await
541 .unwrap();
542 drop(tx);
543
544 let (_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
545 let source = sni
546 .as_deref()
547 .map(HostnameSource::Sni)
548 .unwrap_or(HostnameSource::CacheOnly);
549 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
550 assert_eq!(eval, EgressEvaluation::Allow);
551 }
552
553 #[tokio::test]
558 async fn integration_sni_denies_domain_suffix_without_cache_binding() {
559 let shared = SharedState::new(4); let policy = NetworkPolicy {
561 default_egress: Action::Deny,
562 default_ingress: Action::Allow,
563 rules: vec![Rule {
564 direction: crate::policy::Direction::Egress,
565 destination: Destination::DomainSuffix(".pythonhosted.org".parse().unwrap()),
566 protocols: vec![Protocol::Tcp],
567 ports: vec![PortRange::single(443)],
568 action: Action::Allow,
569 }],
570 };
571 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
572
573 let (tx, mut rx) = mpsc::channel(4);
574 tx.send(Bytes::from(synthetic_client_hello(
575 "files.pythonhosted.org",
576 )))
577 .await
578 .unwrap();
579 drop(tx);
580
581 let (_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
582 let source = sni
583 .as_deref()
584 .map(HostnameSource::Sni)
585 .unwrap_or(HostnameSource::CacheOnly);
586 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
587 assert_eq!(eval, EgressEvaluation::Deny);
588 }
589}